Initial untested implementation of ns3 tun-compliant connections (Tunchannel)
[nepi.git] / src / nepi / testbeds / ns3 / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 from nepi.core import testbed_impl
6 from nepi.core.attributes import Attribute
7 from nepi.util.constants import TIME_NOW
8 import os
9 import sys
10 import threading
11 import random
12 import socket
13 import weakref
14
15 class TunChannel(object):
16     def __init__(self):
17         # These get initialized when the channel is configured
18         self.external_addr = None
19         
20         # These get initialized when the channel is configured
21         # They're part of the TUN standard attribute set
22         self.tun_port = None
23         self.tun_addr = None
24         
25         # These get initialized when the channel is connected to its peer
26         self.peer_proto = None
27         self.peer_addr = None
28         self.peer_port = None
29         
30         # These get initialized when the channel is connected to its iface
31         self.tun_socket = None
32
33         # same as peer proto, but for execute-time standard attribute lookups
34         self.tun_proto = None 
35         
36         # some state
37         self.prepared = False
38         self.listen = False
39         self._terminate = [] # terminate signaller
40         self._connected = threading.Event()
41         self._forwarder_thread = None
42         
43         # Generate an initial random cryptographic key to use for tunnelling
44         # Upon connection, both endpoints will agree on a common one based on
45         # this one.
46         self.tun_key = ( ''.join(map(chr, [ 
47                     r.getrandbits(8) 
48                     for i in xrange(32) 
49                     for r in (random.SystemRandom(),) ])
50                 ).encode("base64").strip() )        
51         
52
53     def __str__(self):
54         return "%s<ip:%s/%s %s%s>" % (
55             self.__class__.__name__,
56             self.address, self.netprefix,
57             " up" if self.up else " down",
58             " snat" if self.snat else "",
59         )
60
61     def Prepare(self):
62         if not self.udp and self.listen and not self._forwarder_thread:
63             if self.listen or (self.peer_addr and self.peer_port and self.peer_proto):
64                 self._launch()
65     
66     def Setup(self):
67         if not self._forwarder_thread:
68             self._launch()
69     
70     def Cleanup(self):
71         if self._forwarder_thread:
72             self.Kill()
73
74     def Wait(self):
75         if self._forwarder_thread:
76             self._connected.wait()
77
78     def Kill(self):    
79         if self._forwarder_thread:
80             if not self._terminate:
81                 self._terminate.append(None)
82             self._forwarder_thread.join()
83
84     def _launch(self):
85         # Launch forwarder thread with a weak reference
86         # to self, so that we don't create any strong cycles
87         # and automatic refcounting works as expected
88         self._forwarder_thread = threading.Thread(
89             self._forwarder,
90             args = (weakref.ref(self),) )
91         self._forwarder_thread.start()
92     
93     @staticmethod
94     def _forwarder(weak_self):
95         import tunchannel
96         
97         # grab strong reference
98         self = weak_self()
99         if not self:
100             return
101         
102         peer_port = self.peer_port
103         peer_addr = self.peer_addr
104         peer_proto= self.peer_proto
105
106         local_port = self.tun_port
107         local_addr = self.tun_addr
108         local_proto = self.tun_proto
109         
110         if local_proto != peer_proto:
111             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (local_proto, peer_proto)
112         
113         udp = local_proto == 'udp'
114         listen = self.listen
115
116         if (udp or not listen) and (not peer_port or not peer_addr):
117             raise RuntimeError, "Misconfigured peer for: %s" % (self,)
118
119         if (udp or listen) and (not local_port or not local_addr):
120             raise RuntimeError, "Misconfigured TUN: %s" % (self,)
121         
122         TERMINATE = self._terminate
123         cipher_key = self.tun_key
124         tun = self.tun_socket
125         
126         if not tun:
127             raise RuntimeError, "Unconnected TUN channel %s" % (self,)
128         
129         if udp:
130             # listen on udp port
131             if remaining_args and not remaining_args[0].startswith('-'):
132                 rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
133                 rsock.bind((local_addr,local_port))
134                 rsock.connect((peer_addr,peer_port))
135             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
136         elif listen:
137             # accept tcp connections
138             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
139             lsock.bind((local_addr,local_port))
140             lsock.listen(1)
141             rsock,raddr = lsock.accept()
142             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
143         else:
144             # connect to tcp server
145             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
146             for i in xrange(30):
147                 try:
148                     rsock.connect((peer_addr,peer_port))
149                     break
150                 except socket.error:
151                     # wait a while, retry
152                     time.sleep(1)
153             else:
154                 rsock.connect((peer_addr,peer_port))
155             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
156         
157         # notify that we're ready
158         self._connected.set()
159         
160         # drop strong reference
161         del self
162         
163         tunchannel.tun_fwd(tun, remote,
164             with_pi = False, 
165             ether_mode = True, 
166             cipher_key = cipher_key, 
167             udp = udp, 
168             TERMINATE = TERMINATE,
169             stderr = open("/dev/null","w") # silence logging
170         )
171         
172         tun.close()
173         remote.close()
174
175
176 class TestbedController(testbed_impl.TestbedController):
177     LOCAL_FACTORIES = {
178         'ns3::Nepi::TunChannel' : TunChannel,
179     }
180
181     def __init__(self, testbed_version):
182         super(TestbedController, self).__init__(TESTBED_ID, testbed_version)
183         self._ns3 = None
184         self._home_directory = None
185         self._traces = dict()
186         self._simulator_thread = None
187         self._condition = None
188         
189         # local factories
190         self.TunChannel = TunChannel
191
192     @property
193     def home_directory(self):
194         return self._home_directory
195
196     @property
197     def ns3(self):
198         return self._ns3
199
200     def do_setup(self):
201         self._home_directory = self._attributes.\
202             get_attribute_value("homeDirectory")
203         self._ns3 = self._load_ns3_module()
204         
205         # create home...
206         home = os.path.normpath(self.home_directory)
207         if not os.path.exists(home):
208             os.makedirs(home, 0755)
209         
210         super(TestbedController, self).do_setup()
211
212     def start(self):
213         super(TestbedController, self).start()
214         self._condition = threading.Condition()
215         self._simulator_thread = threading.Thread(target = self._simulator_run,
216                 args = [self._condition])
217         self._simulator_thread.start()
218
219     def set(self, guid, name, value, time = TIME_NOW):
220         super(TestbedController, self).set(guid, name, value, time)
221         # TODO: take on account schedule time for the task
222         factory_id = self._create[guid]
223         factory = self._factories[factory_id]
224         if factory.box_attributes.is_attribute_design_only(name) or \
225                 factory.box_attributes.is_attribute_invisible(name):
226             return
227         element = self._elements[guid]
228         if factory_id in self.LOCAL_FACTORIES:
229             setattr(element, name, value)
230         else:
231             ns3_value = self._to_ns3_value(guid, name, value) 
232             element.SetAttribute(name, ns3_value)
233
234     def get(self, guid, name, time = TIME_NOW):
235         value = super(TestbedController, self).get(guid, name, time)
236         # TODO: take on account schedule time for the task
237         factory_id = self._create[guid]
238         factory = self._factories[factory_id]
239         if factory.box_attributes.is_attribute_design_only(name) or \
240                 factory.box_attributes.is_attribute_invisible(name):
241             return value
242         if factory_id in self.LOCAL_FACTORIES:
243             return getattr(element, name)
244         TypeId = self.ns3.TypeId()
245         typeid = TypeId.LookupByName(factory_id)
246         info = TypeId.AttributeInfo()
247         if not typeid or not typeid.LookupAttributeByName(name, info):
248             raise AttributeError("Invalid attribute %s for element type %d" % \
249                 (name, guid))
250         checker = info.checker
251         ns3_value = checker.Create() 
252         element = self._elements[guid]
253         element.GetAttribute(name, ns3_value)
254         value = ns3_value.SerializeToString(checker)
255         attr_type = factory.box_attributes.get_attribute_type(name)
256         if attr_type == Attribute.INTEGER:
257             return int(value)
258         if attr_type == Attribute.DOUBLE:
259             return float(value)
260         if attr_type == Attribute.BOOL:
261             return value == "true"
262         return value
263
264     def action(self, time, guid, action):
265         raise NotImplementedError
266
267     def trace_filename(self, guid, trace_id):
268         # TODO: Need to be defined inside a home!!!! with and experiment id_code
269         filename = self._traces[guid][trace_id]
270         return os.path.join(self.home_directory, filename)
271
272     def follow_trace(self, guid, trace_id, filename):
273         if guid not in self._traces:
274             self._traces[guid] = dict()
275         self._traces[guid][trace_id] = filename
276
277     def shutdown(self):
278         for element in self._elements.values():
279             element = None
280
281     def _simulator_run(self, condition):
282         # Run simulation
283         self.ns3.Simulator.Run()
284         # Signal condition on simulation end to notify waiting threads
285         condition.acquire()
286         condition.notifyAll()
287         condition.release()
288
289     def _schedule_event(self, condition, func, *args):
290         """Schedules event on running experiment"""
291         def execute_event(condition, has_event_occurred, func, *args):
292             # exec func
293             func(*args)
294             # flag event occured
295             has_event_occurred[0] = True
296             # notify condition indicating attribute was set
297             condition.acquire()
298             condition.notifyAll()
299             condition.release()
300
301         # contextId is defined as general context
302         contextId = long(0xffffffff)
303         # delay 0 means that the event is expected to execute inmediately
304         delay = self.ns3.Seconds(0)
305         # flag to indicate that the event occured
306         # because bool is an inmutable object in python, in order to create a
307         # bool flag, a list is used as wrapper
308         has_event_occurred = [False]
309         condition.acquire()
310         if not self.ns3.Simulator.IsFinished():
311             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event,
312                  condition, has_event_occurred, func, *args)
313             while not has_event_occurred[0] and not self.ns3.Simulator.IsFinished():
314                 condition.wait()
315                 condition.release()
316                 if not has_event_occurred[0]:
317                     raise RuntimeError('Event could not be scheduled : %s %s ' \
318                     % (repr(func), repr(args)))
319
320     def _to_ns3_value(self, guid, name, value):
321         factory_id = self._create[guid]
322         TypeId = self.ns3.TypeId()
323         typeid = TypeId.LookupByName(factory_id)
324         info = TypeId.AttributeInfo()
325         if not typeid.LookupAttributeByName(name, info):
326             raise RuntimeError("Attribute %s doesn't belong to element %s" \
327                    % (name, factory_id))
328         str_value = str(value)
329         if isinstance(value, bool):
330             str_value = str_value.lower()
331         checker = info.checker
332         ns3_value = checker.Create()
333         ns3_value.DeserializeFromString(str_value, checker)
334         return ns3_value
335
336     def _load_ns3_module(self):
337         import ctypes
338         import imp
339
340         simu_impl_type = self._attributes.get_attribute_value(
341                 "SimulatorImplementationType")
342         checksum = self._attributes.get_attribute_value("ChecksumEnabled")
343
344         bindings = os.environ["NEPI_NS3BINDINGS"] \
345                 if "NEPI_NS3BINDINGS" in os.environ else None
346         libfile = os.environ["NEPI_NS3LIBRARY"] \
347                 if "NEPI_NS3LIBRARY" in os.environ else None
348
349         if libfile:
350             ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
351
352         path = [ os.path.dirname(__file__) ] + sys.path
353         if bindings:
354             path = [ bindings ] + path
355
356         try:
357             module = imp.find_module ('ns3', path)
358             mod = imp.load_module ('ns3', *module)
359         except ImportError:
360             # In some environments, ns3 per-se does not exist,
361             # only the low-level _ns3
362             module = imp.find_module ('_ns3', path)
363             mod = imp.load_module ('_ns3', *module)
364             sys.modules["ns3"] = mod # install it as ns3 too
365             
366             # When using _ns3, we have to make sure we destroy
367             # the simulator when the process finishes
368             import atexit
369             atexit.register(mod.Simulator.Destroy)
370     
371         if simu_impl_type:
372             value = mod.StringValue(simu_impl_type)
373             mod.GlobalValue.Bind ("SimulatorImplementationType", value)
374         if checksum:
375             value = mod.BooleanValue(checksum)
376             mod.GlobalValue.Bind ("ChecksumEnabled", value)
377         return mod
378
379     def _get_construct_parameters(self, guid):
380         params = self._get_parameters(guid)
381         construct_params = dict()
382         factory_id = self._create[guid]
383         TypeId = self.ns3.TypeId()
384         typeid = TypeId.LookupByName(factory_id)
385         for name, value in params.iteritems():
386             info = self.ns3.TypeId.AttributeInfo()
387             found = typeid.LookupAttributeByName(name, info)
388             if found and \
389                 (info.flags & TypeId.ATTR_CONSTRUCT == TypeId.ATTR_CONSTRUCT):
390                 construct_params[name] = value
391         return construct_params
392
393
394