Merging ns-3 into nepi-3-dev
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28 SIMULATOR_UUID = "singleton::Simulator"
29 CONFIG_UUID = "singleton::Config"
30 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
31 IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
32
33 def load_ns3_module():
34     import ctypes
35     import re
36
37     bindings = os.environ.get("NS3BINDINGS")
38     libdir = os.environ.get("NS3LIBRARIES")
39
40     # Load the ns-3 modules shared libraries
41     if libdir:
42         files = os.listdir(libdir)
43         regex = re.compile("(.*\.so)$")
44         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
45
46         initial_size = len(libs)
47         # Try to load the libraries in the right order by trial and error.
48         # Loop until all libraries are loaded.
49         while len(libs) > 0:
50             for lib in libs:
51                 libfile = os.path.join(libdir, lib)
52                 try:
53                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
54                     libs.remove(lib)
55                 except:
56                     #import traceback
57                     #err = traceback.format_exc()
58                     #print err
59                     pass
60
61             # if did not load any libraries in the last iteration break
62             # to prevent infinit loop
63             if initial_size == len(libs):
64                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
65             initial_size = list(libs)
66
67     # import the python bindings for the ns-3 modules
68     if bindings:
69         sys.path.append(bindings)
70
71     import pkgutil
72     import imp
73     import ns
74
75     # create a Python module to add all ns3 classes
76     ns3mod = imp.new_module("ns3")
77     sys.modules["ns3"] = ns3mod
78
79     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
80         if modname in [ "visualizer" ]:
81             continue
82
83         fullmodname = "ns.%s" % modname
84         module = __import__(fullmodname, globals(), locals(), ['*'])
85
86         for sattr in dir(module):
87             if sattr.startswith("_"):
88                 continue
89
90             attr = getattr(module, sattr)
91
92             # netanim.Config and lte.Config singleton overrides ns3::Config
93             if sattr == "Config" and modname in ['netanim', 'lte']:
94                 sattr = "%s.%s" % (modname, sattr)
95
96             setattr(ns3mod, sattr, attr)
97
98     return ns3mod
99
100 class NS3Wrapper(object):
101     def __init__(self, loglevel = logging.INFO):
102         super(NS3Wrapper, self).__init__()
103         # Thread used to run the simulation
104         self._simulation_thread = None
105         self._condition = None
106
107         # True if Simulator::Run was invoked
108         self._started = False
109
110         # holds reference to all C++ objects and variables in the simulation
111         self._objects = dict()
112
113         # Logging
114         self._logger = logging.getLogger("ns3wrapper")
115         self._logger.setLevel(loglevel)
116
117         ## NOTE that the reason to create a handler to the ns3 module,
118         # that is re-loaded each time a ns-3 wrapper is instantiated,
119         # is that else each unit test for the ns3wrapper class would need
120         # a separate file. Several ns3wrappers would be created in the 
121         # same unit test (single process), leading to inchorences in the 
122         # state of ns-3 global objects
123         #
124         # Handler to ns3 classes
125         self._ns3 = None
126
127         # Collection of allowed ns3 classes
128         self._allowed_types = None
129
130     @property
131     def ns3(self):
132         if not self._ns3:
133             # load ns-3 libraries and bindings
134             self._ns3 = load_ns3_module()
135
136         return self._ns3
137
138     @property
139     def allowed_types(self):
140         if not self._allowed_types:
141             self._allowed_types = set()
142             type_id = self.ns3.TypeId()
143             
144             tid_count = type_id.GetRegisteredN()
145             base = type_id.LookupByName("ns3::Object")
146
147             for i in xrange(tid_count):
148                 tid = type_id.GetRegistered(i)
149                 
150                 if tid.MustHideFromDocumentation() or \
151                         not tid.HasConstructor() or \
152                         not tid.IsChildOf(base): 
153                     continue
154
155                 type_name = tid.GetName()
156                 self._allowed_types.add(type_name)
157         
158         return self._allowed_types
159
160     @property
161     def logger(self):
162         return self._logger
163
164     @property
165     def is_running(self):
166         return self._started and self.ns3.Simulator.IsFinished()
167
168     def make_uuid(self):
169         return "uuid%s" % uuid.uuid4()
170
171     def get_object(self, uuid):
172         return self._objects.get(uuid)
173
174     def factory(self, type_name, **kwargs):
175         if type_name not in self.allowed_types:
176             msg = "Type %s not supported" % (type_name) 
177             self.logger.error(msg)
178  
179         factory = self.ns3.ObjectFactory()
180         factory.SetTypeId(type_name)
181
182         for name, value in kwargs.iteritems():
183             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
184             factory.Set(name, ns3_value)
185
186         obj = factory.Create()
187
188         uuid = self.make_uuid()
189         self._objects[uuid] = obj
190
191         return uuid
192
193     def create(self, clazzname, *args):
194         if not hasattr(self.ns3, clazzname):
195             msg = "Type %s not supported" % (clazzname) 
196             self.logger.error(msg)
197      
198         clazz = getattr(self.ns3, clazzname)
199  
200         # arguments starting with 'uuid' identify ns-3 C++
201         # objects and must be replaced by the actual object
202         realargs = self.replace_args(args)
203        
204         obj = clazz(*realargs)
205         
206         uuid = self.make_uuid()
207         self._objects[uuid] = obj
208
209         return uuid
210
211     def invoke(self, uuid, operation, *args, **kwargs):
212         if operation == "isAppRunning":
213             return self._is_app_running(uuid)
214         if operation == "addStaticRoute":
215             return self._add_static_route(uuid, *args)
216
217         if uuid.startswith(SINGLETON):
218             obj = self._singleton(uuid)
219         else:
220             obj = self.get_object(uuid)
221         
222         method = getattr(obj, operation)
223
224         # arguments starting with 'uuid' identify ns-3 C++
225         # objects and must be replaced by the actual object
226         realargs = self.replace_args(args)
227         realkwargs = self.replace_kwargs(kwargs)
228
229         result = method(*realargs, **realkwargs)
230
231         if result is None or \
232                 isinstance(result, bool):
233             return result
234       
235         newuuid = self.make_uuid()
236         self._objects[newuuid] = result
237
238         return newuuid
239
240     def _set_attr(self, obj, name, ns3_value):
241         obj.SetAttribute(name, ns3_value)
242
243     def set(self, uuid, name, value):
244         obj = self.get_object(uuid)
245         type_name = obj.GetInstanceTypeId().GetName()
246         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
247
248         # If the Simulation thread is not running,
249         # then there will be no thread-safety problems
250         # in changing the value of an attribute directly.
251         # However, if the simulation is running we need
252         # to set the value by scheduling an event, else
253         # we risk to corrupt the state of the
254         # simulation.
255         
256         event_executed = [False]
257
258         if self.is_running:
259             # schedule the event in the Simulator
260             self._schedule_event(self._condition, event_executed, 
261                     self._set_attr, obj, name, ns3_value)
262
263         if not event_executed[0]:
264             self._set_attr(obj, name, ns3_value)
265
266         return value
267
268     def _get_attr(self, obj, name, ns3_value):
269         obj.GetAttribute(name, ns3_value)
270
271     def get(self, uuid, name):
272         obj = self.get_object(uuid)
273         type_name = obj.GetInstanceTypeId().GetName()
274         ns3_value = self._create_attr_ns3_value(type_name, name)
275
276         event_executed = [False]
277
278         if self.is_running:
279             # schedule the event in the Simulator
280             self._schedule_event(self._condition, event_executed,
281                     self._get_attr, obj, name, ns3_value)
282
283         if not event_executed[0]:
284             self._get_attr(obj, name, ns3_value)
285
286         return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
287
288     def start(self):
289         # Launch the simulator thread and Start the
290         # simulator in that thread
291         self._condition = threading.Condition()
292         self._simulator_thread = threading.Thread(
293                 target = self._simulator_run,
294                 args = [self._condition])
295         self._simulator_thread.setDaemon(True)
296         self._simulator_thread.start()
297         self._started = True
298
299     def stop(self, time = None):
300         if time is None:
301             self.ns3.Simulator.Stop()
302         else:
303             self.ns3.Simulator.Stop(self.ns3.Time(time))
304
305     def shutdown(self):
306         while not self.ns3.Simulator.IsFinished():
307             #self.logger.debug("Waiting for simulation to finish")
308             time.sleep(0.5)
309         
310         if self._simulator_thread:
311             self._simulator_thread.join()
312        
313         self.ns3.Simulator.Destroy()
314         
315         # Remove all references to ns-3 objects
316         self._objects.clear()
317         
318         sys.stdout.flush()
319         sys.stderr.flush()
320
321     def _simulator_run(self, condition):
322         # Run simulation
323         self.ns3.Simulator.Run()
324         # Signal condition to indicate simulation ended and
325         # notify waiting threads
326         condition.acquire()
327         condition.notifyAll()
328         condition.release()
329
330     def _schedule_event(self, condition, event_executed, func, *args):
331         """ Schedules event on running simulation, and wait until
332             event is executed"""
333
334         def execute_event(contextId, condition, event_executed, func, *args):
335             try:
336                 func(*args)
337                 event_executed[0] = True
338             finally:
339                 # notify condition indicating event was executed
340                 condition.acquire()
341                 condition.notifyAll()
342                 condition.release()
343
344         # contextId is defined as general context
345         contextId = long(0xffffffff)
346
347         # delay 0 means that the event is expected to execute inmediately
348         delay = self.ns3.Seconds(0)
349     
350         # Mark event as not executed
351         event_executed[0] = False
352
353         condition.acquire()
354         try:
355             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
356                     condition, event_executed, func, *args)
357             if not self.ns3.Simulator.IsFinished():
358                 condition.wait()
359         finally:
360             condition.release()
361
362     def _create_attr_ns3_value(self, type_name, name):
363         TypeId = self.ns3.TypeId()
364         tid = TypeId.LookupByName(type_name)
365         info = TypeId.AttributeInformation()
366         if not tid.LookupAttributeByName(name, info):
367             msg = "TypeId %s has no attribute %s" % (type_name, name) 
368             self.logger.error(msg)
369
370         checker = info.checker
371         ns3_value = checker.Create() 
372         return ns3_value
373
374     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
375         TypeId = self.ns3.TypeId()
376         tid = TypeId.LookupByName(type_name)
377         info = TypeId.AttributeInformation()
378         if not tid.LookupAttributeByName(name, info):
379             msg = "TypeId %s has no attribute %s" % (type_name, name) 
380             self.logger.error(msg)
381
382         checker = info.checker
383         value = ns3_value.SerializeToString(checker)
384
385         type_name = checker.GetValueTypeName()
386         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
387             return int(value)
388         if type_name == "ns3::DoubleValue":
389             return float(value)
390         if type_name == "ns3::BooleanValue":
391             return value == "true"
392
393         return value
394
395     def _attr_from_string_to_ns3_value(self, type_name, name, value):
396         TypeId = self.ns3.TypeId()
397         tid = TypeId.LookupByName(type_name)
398         info = TypeId.AttributeInformation()
399         if not tid.LookupAttributeByName(name, info):
400             msg = "TypeId %s has no attribute %s" % (type_name, name) 
401             self.logger.error(msg)
402
403         str_value = str(value)
404         if isinstance(value, bool):
405             str_value = str_value.lower()
406
407         checker = info.checker
408         ns3_value = checker.Create()
409         ns3_value.DeserializeFromString(str_value, checker)
410         return ns3_value
411
412     # singletons are identified as "ns3::ClassName"
413     def _singleton(self, ident):
414         if not ident.startswith(SINGLETON):
415             return None
416
417         clazzname = ident[ident.find("::")+2:]
418         if not hasattr(self.ns3, clazzname):
419             msg = "Type %s not supported" % (clazzname)
420             self.logger.error(msg)
421
422         return getattr(self.ns3, clazzname)
423
424     # replace uuids and singleton references for the real objects
425     def replace_args(self, args):
426         realargs = [self.get_object(arg) if \
427                 str(arg).startswith("uuid") else arg for arg in args]
428  
429         realargs = [self._singleton(arg) if \
430                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
431
432         return realargs
433
434     # replace uuids and singleton references for the real objects
435     def replace_kwargs(self, kwargs):
436         realkwargs = dict([(k, self.get_object(v) \
437                 if str(v).startswith("uuid") else v) \
438                 for k,v in kwargs.iteritems()])
439  
440         realkwargs = dict([(k, self._singleton(v) \
441                 if str(v).startswith(SINGLETON) else v )\
442                 for k, v in realkwargs.iteritems()])
443
444         return realkwargs
445
446     def _is_app_running(self, uuid): 
447         now = self.ns3.Simulator.Now()
448         if now.IsZero():
449             return False
450
451         app = self.get_object(uuid)
452         stop_time_value = self.ns3.TimeValue()
453         app.GetAttribute("StopTime", stop_time_value)
454         stop_time = stop_time_value.Get()
455
456         start_time_value = self.ns3.TimeValue()
457         app.GetAttribute("StartTime", start_time_value)
458         start_time = start_time_value.Get()
459         
460         if now.Compare(start_time) >= 0 and now.Compare(stop_time) < 0:
461             return True
462
463         return False
464
465     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
466         ipv4 = self.get_object(ipv4_uuid)
467
468         list_routing = ipv4.GetRoutingProtocol()
469         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
470
471         ifindex = self._find_ifindex(ipv4, nexthop)
472         if ifindex == -1:
473             return False
474         
475         nexthop = self.ns3.Ipv4Address(nexthop)
476
477         if network in ["0.0.0.0", "0", None]:
478             # Default route: 0.0.0.0/0
479             static_routing.SetDefaultRoute(nexthop, ifindex)
480         else:
481             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
482             network = self.ns3.Ipv4Address(network)
483
484             if prefix == 32:
485                 # Host route: x.y.z.w/32
486                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
487             else:
488                 # Network route: x.y.z.w/n
489                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
490                         ifindex) 
491         return True
492
493     def _find_ifindex(self, ipv4, nexthop):
494         ifindex = -1
495
496         nexthop = self.ns3.Ipv4Address(nexthop)
497
498         # For all the interfaces registered with the ipv4 object, find
499         # the one that matches the network of the nexthop
500         nifaces = ipv4.GetNInterfaces()
501         for ifidx in xrange(nifaces):
502             iface = ipv4.GetInterface(ifidx)
503             naddress = iface.GetNAddresses()
504             for addridx in xrange(naddress):
505                 ifaddr = iface.GetAddress(addridx)
506                 ifmask = ifaddr.GetMask()
507                 
508                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
509
510                 if ifindex == ifidx:
511                     return ifindex
512         return ifindex
513