Adding base RMs for ns-3
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28
29 def load_ns3_module():
30     import ctypes
31     import re
32
33     bindings = os.environ.get("NS3BINDINGS")
34     libdir = os.environ.get("NS3LIBRARIES")
35
36     # Load the ns-3 modules shared libraries
37     if libdir:
38         files = os.listdir(libdir)
39         regex = re.compile("(.*\.so)$")
40         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
41
42         libscp = list(libs)
43         while len(libs) > 0:
44             for lib in libs:
45                 libfile = os.path.join(libdir, lib)
46                 try:
47                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
48                     libs.remove(lib)
49                 except:
50                     pass
51
52             # if did not load any libraries in the last iteration break
53             # to prevent infinit loop
54             if len(libscp) == len(libs):
55                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
56             libscp = list(libs)
57
58     # import the python bindings for the ns-3 modules
59     if bindings:
60         sys.path.append(bindings)
61
62     import pkgutil
63     import imp
64     import ns
65
66     # create a module to add all ns3 classes
67     ns3mod = imp.new_module("ns3")
68     sys.modules["ns3"] = ns3mod
69
70     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
71         fullmodname = "ns.%s" % modname
72         module = __import__(fullmodname, globals(), locals(), ['*'])
73
74         for sattr in dir(module):
75             if sattr.startswith("_"):
76                 continue
77
78             attr = getattr(module, sattr)
79
80             # netanim.Config and lte.Config singleton overrides ns3::Config
81             if sattr == "Config" and modname in ['netanim', 'lte']:
82                 sattr = "%s.%s" % (modname, sattr)
83
84             setattr(ns3mod, sattr, attr)
85
86     return ns3mod
87
88 class NS3Wrapper(object):
89     def __init__(self, homedir = None):
90         super(NS3Wrapper, self).__init__()
91         # Thread used to run the simulation
92         self._simulation_thread = None
93         self._condition = None
94
95         # XXX: Started should be global. There is no support for more than
96         # one simulator per process
97         self._started = False
98
99         # holds reference to all C++ objects and variables in the simulation
100         self._objects = dict()
101
102         # create home dir (where all simulation related files will end up)
103         self._homedir = homedir or os.path.join("/", "tmp", "ns3_wrapper" )
104         
105         home = os.path.normpath(self.homedir)
106         if not os.path.exists(home):
107             os.makedirs(home, 0755)
108
109         # Logging
110         loglevel = os.environ.get("NS3LOGLEVEL", "debug")
111         self._logger = logging.getLogger("ns3wrapper")
112         self._logger.setLevel(getattr(logging, loglevel.upper()))
113         
114         hdlr = logging.FileHandler(os.path.join(self.homedir, "ns3wrapper.log"))
115         formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
116         hdlr.setFormatter(formatter)
117         
118         self._logger.addHandler(hdlr)
119
120         ## NOTE that the reason to create a handler to the ns3 module,
121         # that is re-loaded each time a ns-3 wrapper is instantiated,
122         # is that else each unit test for the ns3wrapper class would need
123         # a separate file. Several ns3wrappers would be created in the 
124         # same unit test (single process), leading to inchorences in the 
125         # state of ns-3 global objects
126         #
127         # Handler to ns3 classes
128         self._ns3 = None
129
130         # Collection of allowed ns3 classes
131         self._allowed_types = None
132
133     @property
134     def ns3(self):
135         if not self._ns3:
136             # load ns-3 libraries and bindings
137             self._ns3 = load_ns3_module()
138
139         return self._ns3
140
141     @property
142     def allowed_types(self):
143         if not self._allowed_types:
144             self._allowed_types = set()
145             type_id = self.ns3.TypeId()
146             
147             tid_count = type_id.GetRegisteredN()
148             base = type_id.LookupByName("ns3::Object")
149
150             # Create a .py file using the ns-3 RM template for each ns-3 TypeId
151             for i in xrange(tid_count):
152                 tid = type_id.GetRegistered(i)
153                 
154                 if tid.MustHideFromDocumentation() or \
155                         not tid.HasConstructor() or \
156                         not tid.IsChildOf(base): 
157                     continue
158
159                 type_name = tid.GetName()
160                 self._allowed_types.add(type_name)
161         
162         return self._allowed_types
163
164     @property
165     def homedir(self):
166         return self._homedir
167
168     @property
169     def logger(self):
170         return self._logger
171
172     @property
173     def is_running(self):
174         return self._started and self.ns3.Simulator.IsFinished()
175
176     def make_uuid(self):
177         return "uuid%s" % uuid.uuid4()
178
179     def get_object(self, uuid):
180         return self._objects.get(uuid)
181
182     def factory(self, type_name, **kwargs):
183         if type_name not in allowed_types:
184             msg = "Type %s not supported" % (type_name) 
185             self.logger.error(msg)
186  
187         factory = self.ns3.ObjectFactory()
188         factory.SetTypeId(type_name)
189
190         for name, value in kwargs.iteritems():
191             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
192             factory.Set(name, ns3_value)
193
194         obj = factory.Create()
195
196         uuid = self.make_uuid()
197         self._objects[uuid] = obj
198
199         return uuid
200
201     def create(self, clazzname, *args):
202         if not hasattr(self.ns3, clazzname):
203             msg = "Type %s not supported" % (clazzname) 
204             self.logger.error(msg)
205      
206         clazz = getattr(self.ns3, clazzname)
207  
208         # arguments starting with 'uuid' identify ns-3 C++
209         # objects and must be replaced by the actual object
210         realargs = self.replace_args(args)
211        
212         obj = clazz(*realargs)
213         
214         uuid = self.make_uuid()
215         self._objects[uuid] = obj
216
217         return uuid
218
219     def invoke(self, uuid, operation, *args):
220         if uuid.startswith(SINGLETON):
221             obj = self._singleton(uuid)
222         else:
223             obj = self.get_object(uuid)
224     
225         method = getattr(obj, operation)
226
227         # arguments starting with 'uuid' identify ns-3 C++
228         # objects and must be replaced by the actual object
229         realargs = self.replace_args(args)
230
231         result = method(*realargs)
232
233         if not result:
234             return None
235         
236         newuuid = self.make_uuid()
237         self._objects[newuuid] = result
238
239         return newuuid
240
241     def _set_attr(self, obj, name, ns3_value):
242         obj.SetAttribute(name, ns3_value)
243
244     def set(self, uuid, name, value):
245         obj = self.get_object(uuid)
246         type_name = obj.GetInstanceTypeId().GetName()
247         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
248
249         # If the Simulation thread is not running,
250         # then there will be no thread-safety problems
251         # in changing the value of an attribute directly.
252         # However, if the simulation is running we need
253         # to set the value by scheduling an event, else
254         # we risk to corrupt the state of the
255         # simulation.
256         if self.is_running:
257             # schedule the event in the Simulator
258             self._schedule_event(self._condition, self._set_attr, 
259                     obj, name, ns3_value)
260         else:
261             self._set_attr(obj, name, ns3_value)
262
263         return value
264
265     def _get_attr(self, obj, name, ns3_value):
266         obj.GetAttribute(name, ns3_value)
267
268     def get(self, uuid, name):
269         obj = self.get_object(uuid)
270         type_name = obj.GetInstanceTypeId().GetName()
271         ns3_value = self._create_attr_ns3_value(type_name, name)
272
273         if self.is_running:
274             # schedule the event in the Simulator
275             self._schedule_event(self._condition, self._get_attr, obj,
276                     name, ns3_value)
277         else:
278             get_attr(obj, name, ns3_value)
279
280         return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
281
282     def start(self):
283         # Launch the simulator thread and Start the
284         # simulator in that thread
285         self._condition = threading.Condition()
286         self._simulator_thread = threading.Thread(
287                 target = self._simulator_run,
288                 args = [self._condition])
289         self._simulator_thread.setDaemon(True)
290         self._simulator_thread.start()
291         self._started = True
292
293     def stop(self, time = None):
294         if time is None:
295             self.ns3.Simulator.Stop()
296         else:
297             self.ns3.Simulator.Stop(self.ns3.Time(time))
298
299     def shutdown(self):
300         while not self.ns3.Simulator.IsFinished():
301             #self.logger.debug("Waiting for simulation to finish")
302             time.sleep(0.5)
303         
304         # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
305         if self._simulator_thread:
306             self._simulator_thread.join()
307         
308         self.ns3.Simulator.Destroy()
309         
310         # Remove all references to ns-3 objects
311         self._objects.clear()
312         
313         sys.stdout.flush()
314         sys.stderr.flush()
315
316     def _simulator_run(self, condition):
317         # Run simulation
318         self.ns3.Simulator.Run()
319         # Signal condition to indicate simulation ended and
320         # notify waiting threads
321         condition.acquire()
322         condition.notifyAll()
323         condition.release()
324
325     def _schedule_event(self, condition, func, *args):
326         """ Schedules event on running simulation, and wait until
327             event is executed"""
328
329         def execute_event(contextId, condition, has_event_occurred, func, *args):
330             try:
331                 func(*args)
332             finally:
333                 # flag event occured
334                 has_event_occurred[0] = True
335                 # notify condition indicating attribute was set
336                 condition.acquire()
337                 condition.notifyAll()
338                 condition.release()
339
340         # contextId is defined as general context
341         contextId = long(0xffffffff)
342
343         # delay 0 means that the event is expected to execute inmediately
344         delay = self.ns3.Seconds(0)
345
346         # flag to indicate that the event occured
347         # because bool is an inmutable object in python, in order to create a
348         # bool flag, a list is used as wrapper
349         has_event_occurred = [False]
350         condition.acquire()
351
352         simu = self.ns3.Simulator
353
354         try:
355             if not simu.IsFinished():
356                 simu.ScheduleWithContext(contextId, delay, execute_event,
357                      condition, has_event_occurred, func, *args)
358                 while not has_event_occurred[0] and not simu.IsFinished():
359                     condition.wait()
360         finally:
361             condition.release()
362
363     def _create_attr_ns3_value(self, type_name, name):
364         TypeId = self.ns3.TypeId()
365         tid = TypeId.LookupByName(type_name)
366         info = TypeId.AttributeInformation()
367         if not tid.LookupAttributeByName(name, info):
368             msg = "TypeId %s has no attribute %s" % (type_name, name) 
369             self.logger.error(msg)
370
371         checker = info.checker
372         ns3_value = checker.Create() 
373         return ns3_value
374
375     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
376         TypeId = self.ns3.TypeId()
377         tid = TypeId.LookupByName(type_name)
378         info = TypeId.AttributeInformation()
379         if not tid.LookupAttributeByName(name, info):
380             msg = "TypeId %s has no attribute %s" % (type_name, name) 
381             self.logger.error(msg)
382
383         checker = info.checker
384         value = ns3_value.SerializeToString(checker)
385
386         type_name = checker.GetValueTypeName()
387         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
388             return int(value)
389         if type_name == "ns3::DoubleValue":
390             return float(value)
391         if type_name == "ns3::BooleanValue":
392             return value == "true"
393
394         return value
395
396     def _attr_from_string_to_ns3_value(self, type_name, name, value):
397         TypeId = self.ns3.TypeId()
398         tid = TypeId.LookupByName(type_name)
399         info = TypeId.AttributeInformation()
400         if not tid.LookupAttributeByName(name, info):
401             msg = "TypeId %s has no attribute %s" % (type_name, name) 
402             self.logger.error(msg)
403
404         str_value = str(value)
405         if isinstance(value, bool):
406             str_value = str_value.lower()
407
408         checker = info.checker
409         ns3_value = checker.Create()
410         ns3_value.DeserializeFromString(str_value, checker)
411         return ns3_value
412
413     # singletons are identified as "ns3::ClassName"
414     def _singleton(self, ident):
415         if not ident.startswith(SINGLETON):
416             return None
417
418         clazzname = ident[ident.find("::")+2:]
419         if not hasattr(self.ns3, clazzname):
420             msg = "Type %s not supported" % (clazzname)
421             self.logger.error(msg)
422
423         return getattr(self.ns3, clazzname)
424
425     # replace uuids and singleton references for the real objects
426     def replace_args(self, args):
427         realargs = [self.get_object(arg) if \
428                 str(arg).startswith("uuid") else arg for arg in args]
429  
430         realargs = [self._singleton(arg) if \
431                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
432
433         return realargs
434