Adding linux ns3 server unit test
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28
29 def load_ns3_module():
30     import ctypes
31     import re
32
33     bindings = os.environ.get("NS3BINDINGS")
34     libdir = os.environ.get("NS3LIBRARIES")
35
36     # Load the ns-3 modules shared libraries
37     if libdir:
38         files = os.listdir(libdir)
39         regex = re.compile("(.*\.so)$")
40         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
41
42         libscp = list(libs)
43         while len(libs) > 0:
44             for lib in libs:
45                 libfile = os.path.join(libdir, lib)
46                 try:
47                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
48                     libs.remove(lib)
49                 except:
50                     pass
51
52             # if did not load any libraries in the last iteration break
53             # to prevent infinit loop
54             if len(libscp) == len(libs):
55                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
56             libscp = list(libs)
57
58     # import the python bindings for the ns-3 modules
59     if bindings:
60         sys.path.append(bindings)
61
62     import pkgutil
63     import imp
64     import ns
65
66     # create a module to add all ns3 classes
67     ns3mod = imp.new_module("ns3")
68     sys.modules["ns3"] = ns3mod
69
70     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
71         fullmodname = "ns.%s" % modname
72         module = __import__(fullmodname, globals(), locals(), ['*'])
73
74         for sattr in dir(module):
75             if sattr.startswith("_"):
76                 continue
77
78             attr = getattr(module, sattr)
79
80             # netanim.Config and lte.Config singleton overrides ns3::Config
81             if sattr == "Config" and modname in ['netanim', 'lte']:
82                 sattr = "%s.%s" % (modname, sattr)
83
84             setattr(ns3mod, sattr, attr)
85
86     return ns3mod
87
88 class NS3Wrapper(object):
89     def __init__(self, homedir = None, loglevel = logging.INFO):
90         super(NS3Wrapper, self).__init__()
91         # Thread used to run the simulation
92         self._simulation_thread = None
93         self._condition = None
94
95         # XXX: Started should be global. There is no support for more than
96         # one simulator per process
97         self._started = False
98
99         # holds reference to all C++ objects and variables in the simulation
100         self._objects = dict()
101
102         # create home dir (where all simulation related files will end up)
103         self._homedir = homedir or os.path.join("/", "tmp", "ns3_wrapper" )
104         
105         home = os.path.normpath(self.homedir)
106         if not os.path.exists(home):
107             os.makedirs(home, 0755)
108
109         # Logging
110         self._logger = logging.getLogger("ns3wrapper")
111         self._logger.setLevel(loglevel)
112         
113         hdlr = logging.FileHandler(os.path.join(self.homedir, "ns3wrapper.log"))
114         formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
115         hdlr.setFormatter(formatter)
116         
117         self._logger.addHandler(hdlr)
118
119         ## NOTE that the reason to create a handler to the ns3 module,
120         # that is re-loaded each time a ns-3 wrapper is instantiated,
121         # is that else each unit test for the ns3wrapper class would need
122         # a separate file. Several ns3wrappers would be created in the 
123         # same unit test (single process), leading to inchorences in the 
124         # state of ns-3 global objects
125         #
126         # Handler to ns3 classes
127         self._ns3 = None
128
129         # Collection of allowed ns3 classes
130         self._allowed_types = None
131
132     @property
133     def ns3(self):
134         if not self._ns3:
135             # load ns-3 libraries and bindings
136             self._ns3 = load_ns3_module()
137
138         return self._ns3
139
140     @property
141     def allowed_types(self):
142         if not self._allowed_types:
143             self._allowed_types = set()
144             type_id = self.ns3.TypeId()
145             
146             tid_count = type_id.GetRegisteredN()
147             base = type_id.LookupByName("ns3::Object")
148
149             # Create a .py file using the ns-3 RM template for each ns-3 TypeId
150             for i in xrange(tid_count):
151                 tid = type_id.GetRegistered(i)
152                 
153                 if tid.MustHideFromDocumentation() or \
154                         not tid.HasConstructor() or \
155                         not tid.IsChildOf(base): 
156                     continue
157
158                 type_name = tid.GetName()
159                 self._allowed_types.add(type_name)
160         
161         return self._allowed_types
162
163     @property
164     def homedir(self):
165         return self._homedir
166
167     @property
168     def logger(self):
169         return self._logger
170
171     @property
172     def is_running(self):
173         return self._started and self.ns3.Simulator.IsFinished()
174
175     def make_uuid(self):
176         return "uuid%s" % uuid.uuid4()
177
178     def get_object(self, uuid):
179         return self._objects.get(uuid)
180
181     def factory(self, type_name, **kwargs):
182         if type_name not in self.allowed_types:
183             msg = "Type %s not supported" % (type_name) 
184             self.logger.error(msg)
185  
186         factory = self.ns3.ObjectFactory()
187         factory.SetTypeId(type_name)
188
189         for name, value in kwargs.iteritems():
190             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
191             factory.Set(name, ns3_value)
192
193         obj = factory.Create()
194
195         uuid = self.make_uuid()
196         self._objects[uuid] = obj
197
198         return uuid
199
200     def create(self, clazzname, *args):
201         if not hasattr(self.ns3, clazzname):
202             msg = "Type %s not supported" % (clazzname) 
203             self.logger.error(msg)
204      
205         clazz = getattr(self.ns3, clazzname)
206  
207         # arguments starting with 'uuid' identify ns-3 C++
208         # objects and must be replaced by the actual object
209         realargs = self.replace_args(args)
210        
211         obj = clazz(*realargs)
212         
213         uuid = self.make_uuid()
214         self._objects[uuid] = obj
215
216         return uuid
217
218     def invoke(self, uuid, operation, *args):
219         if uuid.startswith(SINGLETON):
220             obj = self._singleton(uuid)
221         else:
222             obj = self.get_object(uuid)
223     
224         method = getattr(obj, operation)
225
226         # arguments starting with 'uuid' identify ns-3 C++
227         # objects and must be replaced by the actual object
228         realargs = self.replace_args(args)
229
230         result = method(*realargs)
231
232         if not result:
233             return None
234         
235         newuuid = self.make_uuid()
236         self._objects[newuuid] = result
237
238         return newuuid
239
240     def _set_attr(self, obj, name, ns3_value):
241         obj.SetAttribute(name, ns3_value)
242
243     def set(self, uuid, name, value):
244         obj = self.get_object(uuid)
245         type_name = obj.GetInstanceTypeId().GetName()
246         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
247
248         # If the Simulation thread is not running,
249         # then there will be no thread-safety problems
250         # in changing the value of an attribute directly.
251         # However, if the simulation is running we need
252         # to set the value by scheduling an event, else
253         # we risk to corrupt the state of the
254         # simulation.
255         if self.is_running:
256             # schedule the event in the Simulator
257             self._schedule_event(self._condition, self._set_attr, 
258                     obj, name, ns3_value)
259         else:
260             self._set_attr(obj, name, ns3_value)
261
262         return value
263
264     def _get_attr(self, obj, name, ns3_value):
265         obj.GetAttribute(name, ns3_value)
266
267     def get(self, uuid, name):
268         obj = self.get_object(uuid)
269         type_name = obj.GetInstanceTypeId().GetName()
270         ns3_value = self._create_attr_ns3_value(type_name, name)
271
272         if self.is_running:
273             # schedule the event in the Simulator
274             self._schedule_event(self._condition, self._get_attr, obj,
275                     name, ns3_value)
276         else:
277             get_attr(obj, name, ns3_value)
278
279         return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
280
281     def start(self):
282         # Launch the simulator thread and Start the
283         # simulator in that thread
284         self._condition = threading.Condition()
285         self._simulator_thread = threading.Thread(
286                 target = self._simulator_run,
287                 args = [self._condition])
288         self._simulator_thread.setDaemon(True)
289         self._simulator_thread.start()
290         self._started = True
291
292     def stop(self, time = None):
293         if time is None:
294             self.ns3.Simulator.Stop()
295         else:
296             self.ns3.Simulator.Stop(self.ns3.Time(time))
297
298     def shutdown(self):
299         while not self.ns3.Simulator.IsFinished():
300             #self.logger.debug("Waiting for simulation to finish")
301             time.sleep(0.5)
302         
303         # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
304         if self._simulator_thread:
305             self._simulator_thread.join()
306         
307         self.ns3.Simulator.Destroy()
308         
309         # Remove all references to ns-3 objects
310         self._objects.clear()
311         
312         sys.stdout.flush()
313         sys.stderr.flush()
314
315     def _simulator_run(self, condition):
316         # Run simulation
317         self.ns3.Simulator.Run()
318         # Signal condition to indicate simulation ended and
319         # notify waiting threads
320         condition.acquire()
321         condition.notifyAll()
322         condition.release()
323
324     def _schedule_event(self, condition, func, *args):
325         """ Schedules event on running simulation, and wait until
326             event is executed"""
327
328         def execute_event(contextId, condition, has_event_occurred, func, *args):
329             try:
330                 func(*args)
331             finally:
332                 # flag event occured
333                 has_event_occurred[0] = True
334                 # notify condition indicating attribute was set
335                 condition.acquire()
336                 condition.notifyAll()
337                 condition.release()
338
339         # contextId is defined as general context
340         contextId = long(0xffffffff)
341
342         # delay 0 means that the event is expected to execute inmediately
343         delay = self.ns3.Seconds(0)
344
345         # flag to indicate that the event occured
346         # because bool is an inmutable object in python, in order to create a
347         # bool flag, a list is used as wrapper
348         has_event_occurred = [False]
349         condition.acquire()
350
351         simu = self.ns3.Simulator
352
353         try:
354             if not simu.IsFinished():
355                 simu.ScheduleWithContext(contextId, delay, execute_event,
356                      condition, has_event_occurred, func, *args)
357                 while not has_event_occurred[0] and not simu.IsFinished():
358                     condition.wait()
359         finally:
360             condition.release()
361
362     def _create_attr_ns3_value(self, type_name, name):
363         TypeId = self.ns3.TypeId()
364         tid = TypeId.LookupByName(type_name)
365         info = TypeId.AttributeInformation()
366         if not tid.LookupAttributeByName(name, info):
367             msg = "TypeId %s has no attribute %s" % (type_name, name) 
368             self.logger.error(msg)
369
370         checker = info.checker
371         ns3_value = checker.Create() 
372         return ns3_value
373
374     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
375         TypeId = self.ns3.TypeId()
376         tid = TypeId.LookupByName(type_name)
377         info = TypeId.AttributeInformation()
378         if not tid.LookupAttributeByName(name, info):
379             msg = "TypeId %s has no attribute %s" % (type_name, name) 
380             self.logger.error(msg)
381
382         checker = info.checker
383         value = ns3_value.SerializeToString(checker)
384
385         type_name = checker.GetValueTypeName()
386         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
387             return int(value)
388         if type_name == "ns3::DoubleValue":
389             return float(value)
390         if type_name == "ns3::BooleanValue":
391             return value == "true"
392
393         return value
394
395     def _attr_from_string_to_ns3_value(self, type_name, name, value):
396         TypeId = self.ns3.TypeId()
397         tid = TypeId.LookupByName(type_name)
398         info = TypeId.AttributeInformation()
399         if not tid.LookupAttributeByName(name, info):
400             msg = "TypeId %s has no attribute %s" % (type_name, name) 
401             self.logger.error(msg)
402
403         str_value = str(value)
404         if isinstance(value, bool):
405             str_value = str_value.lower()
406
407         checker = info.checker
408         ns3_value = checker.Create()
409         ns3_value.DeserializeFromString(str_value, checker)
410         return ns3_value
411
412     # singletons are identified as "ns3::ClassName"
413     def _singleton(self, ident):
414         if not ident.startswith(SINGLETON):
415             return None
416
417         clazzname = ident[ident.find("::")+2:]
418         if not hasattr(self.ns3, clazzname):
419             msg = "Type %s not supported" % (clazzname)
420             self.logger.error(msg)
421
422         return getattr(self.ns3, clazzname)
423
424     # replace uuids and singleton references for the real objects
425     def replace_args(self, args):
426         realargs = [self.get_object(arg) if \
427                 str(arg).startswith("uuid") else arg for arg in args]
428  
429         realargs = [self._singleton(arg) if \
430                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
431
432         return realargs
433