Allow loading the NS3 library manually (nef needs it to avoid threading issues)
[nepi.git] / src / nepi / testbeds / ns3 / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from nepi.core import testbed_impl
5 from nepi.core.attributes import Attribute
6 from constants import TESTBED_ID
7 from nepi.util.constants import TIME_NOW, \
8     TESTBED_STATUS_STARTED
9 import os
10 import sys
11 import threading
12 import random
13 import socket
14 import weakref
15
16 def init():
17         if 'ns3' in sys.modules:
18             return
19
20         import ctypes
21         import imp
22
23         bindings = os.environ["NEPI_NS3BINDINGS"] \
24                 if "NEPI_NS3BINDINGS" in os.environ else None
25         libfile = os.environ["NEPI_NS3LIBRARY"] \
26                 if "NEPI_NS3LIBRARY" in os.environ else None
27
28         if libfile:
29             ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
30
31         path = [ os.path.dirname(__file__) ] + sys.path
32         if bindings:
33             path = [ bindings ] + path
34
35         try:
36             module = imp.find_module ('ns3', path)
37             mod = imp.load_module ('ns3', *module)
38         except ImportError:
39             # In some environments, ns3 per-se does not exist,
40             # only the low-level _ns3
41             module = imp.find_module ('_ns3', path)
42             mod = imp.load_module ('_ns3', *module)
43             sys.modules["ns3"] = mod # install it as ns3 too
44             
45             # When using _ns3, we have to make sure we destroy
46             # the simulator when the process finishes
47             import atexit
48             atexit.register(mod.Simulator.Destroy)
49
50  
51
52 class TestbedController(testbed_impl.TestbedController):
53     from nepi.util.tunchannel_impl import TunChannel
54     
55     LOCAL_FACTORIES = {
56         'ns3::Nepi::TunChannel' : TunChannel,
57     }
58     
59     LOCAL_TYPES = tuple(LOCAL_FACTORIES.values())
60
61     def __init__(self, testbed_version):
62         super(TestbedController, self).__init__(TESTBED_ID, testbed_version)
63         self._ns3 = None
64         self._home_directory = None
65         self._traces = dict()
66         self._simulator_thread = None
67         self._condition = None
68
69     @property
70     def home_directory(self):
71         return self._home_directory
72
73     @property
74     def ns3(self):
75         return self._ns3
76
77     def do_setup(self):
78         self._home_directory = self._attributes.\
79             get_attribute_value("homeDirectory")
80         self._ns3 = self._load_ns3_module()
81         
82         # create home...
83         home = os.path.normpath(self.home_directory)
84         if not os.path.exists(home):
85             os.makedirs(home, 0755)
86         
87         super(TestbedController, self).do_setup()
88
89     def start(self):
90         super(TestbedController, self).start()
91         self._condition = threading.Condition()
92         self._simulator_thread = threading.Thread(target = self._simulator_run,
93                 args = [self._condition])
94         self._simulator_thread.setDaemon(True)
95         self._simulator_thread.start()
96
97     def stop(self, time = TIME_NOW):
98         super(TestbedController, self).stop(time)
99         #self.ns3.Simulator.Stop()
100         self._stop_simulation(time)
101
102     def set(self, guid, name, value, time = TIME_NOW):
103         super(TestbedController, self).set(guid, name, value, time)
104         # TODO: take on account schedule time for the task
105         factory_id = self._create[guid]
106         factory = self._factories[factory_id]
107         if factory.box_attributes.is_attribute_design_only(name):
108             return
109         element = self._elements[guid]
110         if factory_id in self.LOCAL_FACTORIES:
111             setattr(element, name, value)
112         elif factory.box_attributes.is_attribute_invisible(name):
113             return
114         else:
115             ns3_value = self._to_ns3_value(guid, name, value)
116             self._set_attribute(name, ns3_value, element)
117
118     def get(self, guid, name, time = TIME_NOW):
119         value = super(TestbedController, self).get(guid, name, time)
120         # TODO: take on account schedule time for the task
121         factory_id = self._create[guid]
122         factory = self._factories[factory_id]
123         element = self._elements[guid]
124         if factory_id in self.LOCAL_FACTORIES:
125             if hasattr(element, name):
126                 return getattr(element, name)
127             else:
128                 return value
129         if factory.box_attributes.is_attribute_design_only(name) or \
130                 factory.box_attributes.is_attribute_invisible(name):
131             return value
132         TypeId = self.ns3.TypeId()
133         typeid = TypeId.LookupByName(factory_id)
134         info = TypeId.AttributeInfo()
135         if not typeid or not typeid.LookupAttributeByName(name, info):
136             raise AttributeError("Invalid attribute %s for element type %d" % \
137                 (name, guid))
138         checker = info.checker
139         ns3_value = checker.Create() 
140         self._get_attribute(name, ns3_value, element)
141         value = ns3_value.SerializeToString(checker)
142         attr_type = factory.box_attributes.get_attribute_type(name)
143         if attr_type == Attribute.INTEGER:
144             return int(value)
145         if attr_type == Attribute.DOUBLE:
146             return float(value)
147         if attr_type == Attribute.BOOL:
148             return value == "true"
149         return value
150
151     def action(self, time, guid, action):
152         raise NotImplementedError
153
154     def trace_filepath(self, guid, trace_id):
155         filename = self._traces[guid][trace_id]
156         return os.path.join(self.home_directory, filename)
157
158     def follow_trace(self, guid, trace_id, filename):
159         if not guid in self._traces:
160             self._traces[guid] = dict()
161         self._traces[guid][trace_id] = filename
162
163     def shutdown(self):
164         for element in self._elements.itervalues():
165             if isinstance(element, self.LOCAL_TYPES):
166                 # graceful shutdown of locally-implemented objects
167                 element.Cleanup()
168         if self.ns3:
169             self.ns3.Simulator.Stop()
170             
171             # Wait for it to stop, with a 30s timeout
172             for i in xrange(300):
173                 if self.ns3.Simulator.IsFinished():
174                     break
175                 time.sleep(0.1)
176             #self._stop_simulation("0s")
177         
178         self._elements.clear()
179         
180         if self.ns3:
181             # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
182             #   if self._simulator_thread:
183             #       self._simulator_thread.join()
184             self.ns3.Simulator.Destroy()
185         
186         self._ns3 = None
187         sys.stdout.flush()
188         sys.stderr.flush()
189
190     def _simulator_run(self, condition):
191         # Run simulation
192         self.ns3.Simulator.Run()
193         # Signal condition on simulation end to notify waiting threads
194         condition.acquire()
195         condition.notifyAll()
196         condition.release()
197
198     def _schedule_event(self, condition, func, *args):
199         """Schedules event on running experiment"""
200         def execute_event(condition, has_event_occurred, func, *args):
201             # exec func
202             try:
203                 func(*args)
204             finally:
205                 # flag event occured
206                 has_event_occurred[0] = True
207                 # notify condition indicating attribute was set
208                 condition.acquire()
209                 condition.notifyAll()
210                 condition.release()
211
212         # contextId is defined as general context
213         contextId = long(0xffffffff)
214         # delay 0 means that the event is expected to execute inmediately
215         delay = self.ns3.Seconds(0)
216         # flag to indicate that the event occured
217         # because bool is an inmutable object in python, in order to create a
218         # bool flag, a list is used as wrapper
219         has_event_occurred = [False]
220         condition.acquire()
221         if not self.ns3.Simulator.IsFinished():
222             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event,
223                  condition, has_event_occurred, func, *args)
224             while not has_event_occurred[0] and not self.ns3.Simulator.IsFinished():
225                 condition.wait()
226                 condition.release()
227
228     def _set_attribute(self, name, ns3_value, element):
229         if self.status() == TESTBED_STATUS_STARTED:
230             # schedule the event in the Simulator
231             self._schedule_event(self._condition, self._set_ns3_attribute, 
232                     name, ns3_value, element)
233         else:
234             self._set_ns3_attribute(name, ns3_value, element)
235
236     def _get_attribute(self, name, ns3_value, element):
237         if self.status() == TESTBED_STATUS_STARTED:
238             # schedule the event in the Simulator
239             self._schedule_event(self._condition, self._get_ns3_attribute, 
240                     name, ns3_value, element)
241         else:
242             self._get_ns3_attribute(name, ns3_value, element)
243
244     def _set_ns3_attribute(self, name, ns3_value, element):
245         element.SetAttribute(name, ns3_value)
246
247     def _get_ns3_attribute(self, name, ns3_value, element):
248         element.GetAttribute(name, ns3_value)
249
250     def _stop_simulation(self, time):
251         if self.status() == TESTBED_STATUS_STARTED:
252             # schedule the event in the Simulator
253             self._schedule_event(self._condition, self._stop_ns3_simulation, 
254                     time)
255         else:
256             self._stop_ns3_simulation(time)
257
258     def _stop_simulation(self, time = TIME_NOW):
259         if not self.ns3:
260             return
261         if time == TIME_NOW:
262             self.ns3.Simulator.Stop()
263         else:
264             self.ns3.Simulator.Stop(self.ns3.Time(time))
265
266     def _to_ns3_value(self, guid, name, value):
267         factory_id = self._create[guid]
268         TypeId = self.ns3.TypeId()
269         typeid = TypeId.LookupByName(factory_id)
270         info = TypeId.AttributeInfo()
271         if not typeid.LookupAttributeByName(name, info):
272             raise RuntimeError("Attribute %s doesn't belong to element %s" \
273                    % (name, factory_id))
274         str_value = str(value)
275         if isinstance(value, bool):
276             str_value = str_value.lower()
277         checker = info.checker
278         ns3_value = checker.Create()
279         ns3_value.DeserializeFromString(str_value, checker)
280         return ns3_value
281
282     def _load_ns3_module(self):
283         simu_impl_type = self._attributes.get_attribute_value(
284                 "SimulatorImplementationType")
285         checksum = self._attributes.get_attribute_value("ChecksumEnabled")
286         stop_time = self._attributes.get_attribute_value("StopTime")
287
288         init()
289
290         import ns3 as mod
291  
292         if simu_impl_type:
293             value = mod.StringValue(simu_impl_type)
294             mod.GlobalValue.Bind ("SimulatorImplementationType", value)
295         if checksum:
296             value = mod.BooleanValue(checksum)
297             mod.GlobalValue.Bind ("ChecksumEnabled", value)
298         if stop_time:
299             value = mod.Time(stop_time)
300             mod.Simulator.Stop (value)
301         return mod
302
303     def _get_construct_parameters(self, guid):
304         params = self._get_parameters(guid)
305         construct_params = dict()
306         factory_id = self._create[guid]
307         TypeId = self.ns3.TypeId()
308         typeid = TypeId.LookupByName(factory_id)
309         for name, value in params.iteritems():
310             info = self.ns3.TypeId.AttributeInfo()
311             found = typeid.LookupAttributeByName(name, info)
312             if found and \
313                 (info.flags & TypeId.ATTR_CONSTRUCT == TypeId.ATTR_CONSTRUCT):
314                 construct_params[name] = value
315         return construct_params
316
317
318