9251006668f112b0121bb4c81d46fb50dbfa50c9
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License version 2 as
7 #    published by the Free Software Foundation;
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
18
19 import logging
20 import os
21 import sys
22 import threading
23 import time
24 import uuid
25
26 SINGLETON = "singleton::"
27 SIMULATOR_UUID = "singleton::Simulator"
28 CONFIG_UUID = "singleton::Config"
29 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
30 IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
31
32 def load_ns3_libraries():
33     import ctypes
34     import re
35
36     libdir = os.environ.get("NS3LIBRARIES")
37
38     # Load the ns-3 modules shared libraries
39     if libdir:
40         files = os.listdir(libdir)
41         regex = re.compile("(.*\.so)$")
42         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
43
44         initial_size = len(libs)
45         # Try to load the libraries in the right order by trial and error.
46         # Loop until all libraries are loaded.
47         while len(libs) > 0:
48             for lib in libs:
49                 libfile = os.path.join(libdir, lib)
50                 try:
51                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
52                     libs.remove(lib)
53                 except:
54                     #import traceback
55                     #err = traceback.format_exc()
56                     #print err
57                     pass
58
59             # if did not load any libraries in the last iteration break
60             # to prevent infinit loop
61             if initial_size == len(libs):
62                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
63             initial_size = len(libs)
64
65 def load_ns3_module():
66     load_ns3_libraries()
67
68     # import the python bindings for the ns-3 modules
69     bindings = os.environ.get("NS3BINDINGS")
70     if bindings:
71         sys.path.append(bindings)
72
73     import pkgutil
74     import imp
75     import ns
76
77     # create a Python module to add all ns3 classes
78     ns3mod = imp.new_module("ns3")
79     sys.modules["ns3"] = ns3mod
80
81     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
82         if modname in [ "visualizer" ]:
83             continue
84
85         fullmodname = "ns.%s" % modname
86         module = __import__(fullmodname, globals(), locals(), ['*'])
87
88         for sattr in dir(module):
89             if sattr.startswith("_"):
90                 continue
91
92             attr = getattr(module, sattr)
93
94             # netanim.Config and lte.Config singleton overrides ns3::Config
95             if sattr == "Config" and modname in ['netanim', 'lte']:
96                 sattr = "%s.%s" % (modname, sattr)
97
98             setattr(ns3mod, sattr, attr)
99
100     return ns3mod
101
102 class NS3Wrapper(object):
103     def __init__(self, loglevel = logging.INFO, enable_dump = False):
104         super(NS3Wrapper, self).__init__()
105         # Thread used to run the simulation
106         self._simulation_thread = None
107         self._condition = None
108
109         # True if Simulator::Run was invoked
110         self._started = False
111
112         # holds reference to all C++ objects and variables in the simulation
113         self._objects = dict()
114
115         # Logging
116         self._logger = logging.getLogger("ns3wrapper")
117         self._logger.setLevel(loglevel)
118
119         ## NOTE that the reason to create a handler to the ns3 module,
120         # that is re-loaded each time a ns-3 wrapper is instantiated,
121         # is that else each unit test for the ns3wrapper class would need
122         # a separate file. Several ns3wrappers would be created in the 
123         # same unit test (single process), leading to inchorences in the 
124         # state of ns-3 global objects
125         #
126         # Handler to ns3 classes
127         self._ns3 = None
128
129         # Collection of allowed ns3 classes
130         self._allowed_types = None
131
132         # Object to dump instructions to reproduce and debug experiment
133         from ns3wrapper_debug import NS3WrapperDebuger
134         self._debuger = NS3WrapperDebuger(enabled = enable_dump)
135
136     @property
137     def debuger(self):
138         return self._debuger
139
140     @property
141     def ns3(self):
142         if not self._ns3:
143             # load ns-3 libraries and bindings
144             self._ns3 = load_ns3_module()
145
146         return self._ns3
147
148     @property
149     def allowed_types(self):
150         if not self._allowed_types:
151             self._allowed_types = set()
152             type_id = self.ns3.TypeId()
153             
154             tid_count = type_id.GetRegisteredN()
155             base = type_id.LookupByName("ns3::Object")
156
157             for i in range(tid_count):
158                 tid = type_id.GetRegistered(i)
159                 
160                 if tid.MustHideFromDocumentation() or \
161                         not tid.HasConstructor() or \
162                         not tid.IsChildOf(base): 
163                     continue
164
165                 type_name = tid.GetName()
166                 self._allowed_types.add(type_name)
167         
168         return self._allowed_types
169
170     @property
171     def logger(self):
172         return self._logger
173
174     @property
175     def is_running(self):
176         return self.is_started and not self.ns3.Simulator.IsFinished()
177
178     @property
179     def is_started(self):
180         if not self._started:
181             now = self.ns3.Simulator.Now()
182             if not now.IsZero():
183                 self._started = True
184
185         return self._started
186
187     @property
188     def is_finished(self):
189         return self.ns3.Simulator.IsFinished()
190
191     def make_uuid(self):
192         return "uuid%s" % uuid.uuid4()
193
194     def get_object(self, uuid):
195         return self._objects.get(uuid)
196
197     def factory(self, type_name, **kwargs):
198         """ This method should be used to construct ns-3 objects
199         that have a TypeId and related introspection information """
200
201         if type_name not in self.allowed_types:
202             msg = "Type %s not supported" % (type_name) 
203             self.logger.error(msg)
204
205         uuid = self.make_uuid()
206         
207         ### DEBUG
208         self.logger.debug("FACTORY %s( %s )" % (type_name, str(kwargs)))
209         
210         ### DUMP
211         self.debuger.dump_factory(uuid, type_name, kwargs)
212
213         factory = self.ns3.ObjectFactory()
214         factory.SetTypeId(type_name)
215
216         for name, value in kwargs.items():
217             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
218             factory.Set(name, ns3_value)
219
220         obj = factory.Create()
221
222         self._objects[uuid] = obj
223
224         ### DEBUG
225         self.logger.debug("RET FACTORY ( uuid %s ) %s = %s( %s )" % (
226             str(uuid), str(obj), type_name, str(kwargs)))
227  
228         return uuid
229
230     def create(self, clazzname, *args):
231         """ This method should be used to construct ns-3 objects that
232         do not have a TypeId (e.g. Values) """
233
234         if not hasattr(self.ns3, clazzname):
235             msg = "Type %s not supported" % (clazzname) 
236             self.logger.error(msg)
237
238         uuid = self.make_uuid()
239         
240         ### DEBUG
241         self.logger.debug("CREATE %s( %s )" % (clazzname, str(args)))
242     
243         ### DUMP
244         self.debuger.dump_create(uuid, clazzname, args)
245
246         clazz = getattr(self.ns3, clazzname)
247  
248         # arguments starting with 'uuid' identify ns-3 C++
249         # objects and must be replaced by the actual object
250         realargs = self.replace_args(args)
251        
252         obj = clazz(*realargs)
253         
254         self._objects[uuid] = obj
255
256         ### DEBUG
257         self.logger.debug("RET CREATE ( uuid %s ) %s = %s( %s )" % (str(uuid), 
258             str(obj), clazzname, str(args)))
259
260         return uuid
261
262     def invoke(self, uuid, operation, *args, **kwargs):
263         ### DEBUG
264         self.logger.debug("INVOKE %s -> %s( %s, %s ) " % (
265             uuid, operation, str(args), str(kwargs)))
266         ########
267
268         result = None
269         newuuid = None
270
271         if operation == "isRunning":
272             result = self.is_running
273
274         elif operation == "isStarted":
275             result = self.is_started
276
277         elif operation == "isFinished":
278             result = self.is_finished
279
280         elif operation == "isAppRunning":
281             result = self._is_app_running(uuid)
282
283         elif operation == "isAppStarted":
284             result = self._is_app_started(uuid)
285
286         elif operation == "recvFD":
287             ### passFD operation binds to a different random socket 
288             ### en every execution, so the socket name that could be
289             ### dumped to the debug script using dump_invoke is
290             ### not be valid accross debug executions.
291             result = self._recv_fd(uuid, *args, **kwargs)
292
293         elif operation == "addStaticRoute":
294             result = self._add_static_route(uuid, *args)
295             
296             ### DUMP - result is static, so will be dumped as plain text
297             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
298
299         elif operation == "retrieveObject":
300             result = self._retrieve_object(uuid, *args, **kwargs)
301        
302             ### DUMP - result is static, so will be dumped as plain text
303             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
304        
305         else:
306             newuuid = self.make_uuid()
307
308             ### DUMP - result is a uuid that encoded an dynamically generated 
309             ### object
310             self.debuger.dump_invoke(newuuid, uuid, operation, args, kwargs)
311
312             if uuid.startswith(SINGLETON):
313                 obj = self._singleton(uuid)
314             else:
315                 obj = self.get_object(uuid)
316             
317             method = getattr(obj, operation)
318
319             # arguments starting with 'uuid' identify ns-3 C++
320             # objects and must be replaced by the actual object
321             realargs = self.replace_args(args)
322             realkwargs = self.replace_kwargs(kwargs)
323
324             result = method(*realargs, **realkwargs)
325
326             # If the result is an object (not a base value),
327             # then keep track of the object a return the object
328             # reference (newuuid)
329             if not (result is None or type(result) in [
330                     bool, float, long, str, int]):
331                 self._objects[newuuid] = result
332                 result = newuuid
333
334         ### DEBUG
335         self.logger.debug("RET INVOKE %s%s = %s -> %s(%s, %s) " % (
336             "(uuid %s) " % str(newuuid) if newuuid else "", str(result), uuid, 
337             operation, str(args), str(kwargs)))
338         ########
339
340         return result
341
342     def _set_attr(self, obj, name, ns3_value):
343         obj.SetAttribute(name, ns3_value)
344
345     def set(self, uuid, name, value):
346         ### DEBUG
347         self.logger.debug("SET %s %s %s" % (uuid, name, str(value)))
348     
349         ### DUMP
350         self.debuger.dump_set(uuid, name, value)
351
352         obj = self.get_object(uuid)
353         type_name = obj.GetInstanceTypeId().GetName()
354         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
355
356         # If the Simulation thread is not running,
357         # then there will be no thread-safety problems
358         # in changing the value of an attribute directly.
359         # However, if the simulation is running we need
360         # to set the value by scheduling an event, else
361         # we risk to corrupt the state of the
362         # simulation.
363         
364         event_executed = [False]
365
366         if self.is_running:
367             # schedule the event in the Simulator
368             self._schedule_event(self._condition, event_executed, 
369                     self._set_attr, obj, name, ns3_value)
370
371         if not event_executed[0]:
372             self._set_attr(obj, name, ns3_value)
373
374         ### DEBUG
375         self.logger.debug("RET SET %s = %s -> set(%s, %s)" % (str(value), uuid, name, 
376             str(value)))
377
378         return value
379
380     def _get_attr(self, obj, name, ns3_value):
381         obj.GetAttribute(name, ns3_value)
382
383     def get(self, uuid, name):
384         ### DEBUG
385         self.logger.debug("GET %s %s" % (uuid, name))
386         
387         ### DUMP
388         self.debuger.dump_get(uuid, name)
389
390         obj = self.get_object(uuid)
391         type_name = obj.GetInstanceTypeId().GetName()
392         ns3_value = self._create_attr_ns3_value(type_name, name)
393
394         event_executed = [False]
395
396         if self.is_running:
397             # schedule the event in the Simulator
398             self._schedule_event(self._condition, event_executed,
399                     self._get_attr, obj, name, ns3_value)
400
401         if not event_executed[0]:
402             self._get_attr(obj, name, ns3_value)
403
404         result = self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
405
406         ### DEBUG
407         self.logger.debug("RET GET %s = %s -> get(%s)" % (str(result), uuid, name))
408
409         return result
410
411     def start(self):
412         ### DUMP
413         self.debuger.dump_start()
414
415         # Launch the simulator thread and Start the
416         # simulator in that thread
417         self._condition = threading.Condition()
418         self._simulator_thread = threading.Thread(
419                 target = self._simulator_run,
420                 args = [self._condition])
421         self._simulator_thread.setDaemon(True)
422         self._simulator_thread.start()
423         
424         ### DEBUG
425         self.logger.debug("START")
426
427     def stop(self, time = None):
428         ### DUMP
429         self.debuger.dump_stop(time=time)
430         
431         if time is None:
432             self.ns3.Simulator.Stop()
433         else:
434             self.ns3.Simulator.Stop(self.ns3.Time(time))
435
436         ### DEBUG
437         self.logger.debug("STOP time=%s" % str(time))
438
439     def shutdown(self):
440         ### DUMP
441         self.debuger.dump_shutdown()
442
443         while not self.ns3.Simulator.IsFinished():
444             #self.logger.debug("Waiting for simulation to finish")
445             time.sleep(0.5)
446         
447         if self._simulator_thread:
448             self._simulator_thread.join()
449        
450         self.ns3.Simulator.Destroy()
451         
452         # Remove all references to ns-3 objects
453         self._objects.clear()
454         
455         sys.stdout.flush()
456         sys.stderr.flush()
457
458         ### DEBUG
459         self.logger.debug("SHUTDOWN")
460
461     def _simulator_run(self, condition):
462         # Run simulation
463         self.ns3.Simulator.Run()
464         # Signal condition to indicate simulation ended and
465         # notify waiting threads
466         condition.acquire()
467         condition.notifyAll()
468         condition.release()
469
470     def _schedule_event(self, condition, event_executed, func, *args):
471         """ Schedules event on running simulation, and wait until
472             event is executed"""
473
474         def execute_event(contextId, condition, event_executed, func, *args):
475             try:
476                 func(*args)
477                 event_executed[0] = True
478             finally:
479                 # notify condition indicating event was executed
480                 condition.acquire()
481                 condition.notifyAll()
482                 condition.release()
483
484         # contextId is defined as general context
485         contextId = long(0xffffffff)
486
487         # delay 0 means that the event is expected to execute inmediately
488         delay = self.ns3.Seconds(0)
489     
490         # Mark event as not executed
491         event_executed[0] = False
492
493         condition.acquire()
494         try:
495             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
496                     condition, event_executed, func, *args)
497             if not self.ns3.Simulator.IsFinished():
498                 condition.wait()
499         finally:
500             condition.release()
501
502     def _create_attr_ns3_value(self, type_name, name):
503         TypeId = self.ns3.TypeId()
504         tid = TypeId.LookupByName(type_name)
505         info = TypeId.AttributeInformation()
506         if not tid.LookupAttributeByName(name, info):
507             msg = "TypeId %s has no attribute %s" % (type_name, name) 
508             self.logger.error(msg)
509
510         checker = info.checker
511         ns3_value = checker.Create() 
512         return ns3_value
513
514     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
515         TypeId = self.ns3.TypeId()
516         tid = TypeId.LookupByName(type_name)
517         info = TypeId.AttributeInformation()
518         if not tid.LookupAttributeByName(name, info):
519             msg = "TypeId %s has no attribute %s" % (type_name, name) 
520             self.logger.error(msg)
521
522         checker = info.checker
523         value = ns3_value.SerializeToString(checker)
524
525         type_name = checker.GetValueTypeName()
526         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
527             return int(value)
528         if type_name == "ns3::DoubleValue":
529             return float(value)
530         if type_name == "ns3::BooleanValue":
531             return value == "true"
532
533         return value
534
535     def _attr_from_string_to_ns3_value(self, type_name, name, value):
536         TypeId = self.ns3.TypeId()
537         tid = TypeId.LookupByName(type_name)
538         info = TypeId.AttributeInformation()
539         if not tid.LookupAttributeByName(name, info):
540             msg = "TypeId %s has no attribute %s" % (type_name, name) 
541             self.logger.error(msg)
542
543         str_value = str(value)
544         if isinstance(value, bool):
545             str_value = str_value.lower()
546
547         checker = info.checker
548         ns3_value = checker.Create()
549         ns3_value.DeserializeFromString(str_value, checker)
550         return ns3_value
551
552     # singletons are identified as "ns3::ClassName"
553     def _singleton(self, ident):
554         if not ident.startswith(SINGLETON):
555             return None
556
557         clazzname = ident[ident.find("::")+2:]
558         if not hasattr(self.ns3, clazzname):
559             msg = "Type %s not supported" % (clazzname)
560             self.logger.error(msg)
561
562         return getattr(self.ns3, clazzname)
563
564     # replace uuids and singleton references for the real objects
565     def replace_args(self, args):
566         realargs = [self.get_object(arg) if \
567                 str(arg).startswith("uuid") else arg for arg in args]
568  
569         realargs = [self._singleton(arg) if \
570                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
571
572         return realargs
573
574     # replace uuids and singleton references for the real objects
575     def replace_kwargs(self, kwargs):
576         realkwargs = dict([(k, self.get_object(v) \
577                 if str(v).startswith("uuid") else v) \
578                 for k,v in kwargs.items()])
579  
580         realkwargs = dict([(k, self._singleton(v) \
581                 if str(v).startswith(SINGLETON) else v )\
582                 for k, v in realkwargs.items()])
583
584         return realkwargs
585
586     def _is_app_running(self, uuid):
587         now = self.ns3.Simulator.Now()
588         if now.IsZero():
589             return False
590
591         if self.ns3.Simulator.IsFinished():
592             return False
593
594         app = self.get_object(uuid)
595         stop_time_value = self.ns3.TimeValue()
596         app.GetAttribute("StopTime", stop_time_value)
597         stop_time = stop_time_value.Get()
598
599         start_time_value = self.ns3.TimeValue()
600         app.GetAttribute("StartTime", start_time_value)
601         start_time = start_time_value.Get()
602         
603         if now.Compare(start_time) >= 0:
604             if stop_time.IsZero() or now.Compare(stop_time) < 0:
605                 return True
606
607         return False
608     
609     def _is_app_started(self, uuid):
610         return self._is_app_running(uuid) or self.is_finished
611
612     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
613         ipv4 = self.get_object(ipv4_uuid)
614
615         list_routing = ipv4.GetRoutingProtocol()
616         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
617
618         ifindex = self._find_ifindex(ipv4, nexthop)
619         if ifindex == -1:
620             return False
621         
622         nexthop = self.ns3.Ipv4Address(nexthop)
623
624         if network in ["0.0.0.0", "0", None]:
625             # Default route: 0.0.0.0/0
626             static_routing.SetDefaultRoute(nexthop, ifindex)
627         else:
628             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
629             network = self.ns3.Ipv4Address(network)
630
631             if prefix == 32:
632                 # Host route: x.y.z.w/32
633                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
634             else:
635                 # Network route: x.y.z.w/n
636                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
637                         ifindex) 
638         return True
639
640     def _find_ifindex(self, ipv4, nexthop):
641         ifindex = -1
642
643         nexthop = self.ns3.Ipv4Address(nexthop)
644
645         # For all the interfaces registered with the ipv4 object, find
646         # the one that matches the network of the nexthop
647         nifaces = ipv4.GetNInterfaces()
648         for ifidx in range(nifaces):
649             iface = ipv4.GetInterface(ifidx)
650             naddress = iface.GetNAddresses()
651             for addridx in range(naddress):
652                 ifaddr = iface.GetAddress(addridx)
653                 ifmask = ifaddr.GetMask()
654                 
655                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
656
657                 if ifindex == ifidx:
658                     return ifindex
659         return ifindex
660
661     def _retrieve_object(self, uuid, typeid, search = False):
662         obj = self.get_object(uuid)
663
664         type_id = self.ns3.TypeId()
665         tid = type_id.LookupByName(typeid)
666         nobj = obj.GetObject(tid)
667
668         newuuid = None
669         if search:
670             # search object
671             for ouuid, oobj in self._objects.items():
672                 if nobj == oobj:
673                     newuuid = ouuid
674                     break
675         else: 
676             newuuid = self.make_uuid()
677             self._objects[newuuid] = nobj
678
679         return newuuid
680
681     def _recv_fd(self, uuid):
682         """ Waits on a local address to receive a file descriptor
683         from a local process. The file descriptor is associated
684         to a FdNetDevice to stablish communication between the
685         simulation and what ever process writes on that file descriptor
686         """
687
688         def recvfd(sock, fdnd):
689             (fd, msg) = passfd.recvfd(sock)
690             # Store a reference to the endpoint to keep the socket alive
691             fdnd.SetFileDescriptor(fd)
692         
693         import passfd
694         import socket
695         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
696         sock.bind("")
697         address = sock.getsockname()
698         
699         fdnd = self.get_object(uuid)
700         t = threading.Thread(target=recvfd, args=(sock,fdnd))
701         t.start()
702
703         return address
704
705