README moves to markdown
[nepi.git] / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License version 2 as
7 #    published by the Free Software Foundation;
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
18
19 import logging
20 import os
21 import sys
22 import threading
23 import time
24 import uuid
25
26 from six import integer_types, string_types
27
28 SINGLETON = "singleton::"
29 SIMULATOR_UUID = "singleton::Simulator"
30 CONFIG_UUID = "singleton::Config"
31 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
32 IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
33
34 def load_ns3_libraries():
35     import ctypes
36     import re
37
38     libdir = os.environ.get("NS3LIBRARIES")
39
40     # Load the ns-3 modules shared libraries
41     if libdir:
42         files = os.listdir(libdir)
43         regex = re.compile("(.*\.so)$")
44         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
45
46         initial_size = len(libs)
47         # Try to load the libraries in the right order by trial and error.
48         # Loop until all libraries are loaded.
49         while len(libs) > 0:
50             for lib in libs:
51                 libfile = os.path.join(libdir, lib)
52                 try:
53                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
54                     libs.remove(lib)
55                 except:
56                     #import traceback
57                     #err = traceback.format_exc()
58                     #print err
59                     pass
60
61             # if did not load any libraries in the last iteration break
62             # to prevent infinit loop
63             if initial_size == len(libs):
64                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
65             initial_size = len(libs)
66
67 def load_ns3_module():
68     load_ns3_libraries()
69
70     # import the python bindings for the ns-3 modules
71     bindings = os.environ.get("NS3BINDINGS")
72     if bindings:
73         sys.path.append(bindings)
74
75     import pkgutil
76     import imp
77     import ns
78
79     # create a Python module to add all ns3 classes
80     ns3mod = imp.new_module("ns3")
81     sys.modules["ns3"] = ns3mod
82
83     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
84         if modname in [ "visualizer" ]:
85             continue
86
87         fullmodname = "ns.%s" % modname
88         module = __import__(fullmodname, globals(), locals(), ['*'])
89
90         for sattr in dir(module):
91             if sattr.startswith("_"):
92                 continue
93
94             attr = getattr(module, sattr)
95
96             # netanim.Config and lte.Config singleton overrides ns3::Config
97             if sattr == "Config" and modname in ['netanim', 'lte']:
98                 sattr = "%s.%s" % (modname, sattr)
99
100             setattr(ns3mod, sattr, attr)
101
102     return ns3mod
103
104 class NS3Wrapper(object):
105     def __init__(self, loglevel = logging.INFO, enable_dump = False):
106         super(NS3Wrapper, self).__init__()
107         # Thread used to run the simulation
108         self._simulation_thread = None
109         self._condition = None
110
111         # True if Simulator::Run was invoked
112         self._started = False
113
114         # holds reference to all C++ objects and variables in the simulation
115         self._objects = dict()
116
117         # Logging
118         self._logger = logging.getLogger("ns3wrapper")
119         self._logger.setLevel(loglevel)
120
121         ## NOTE that the reason to create a handler to the ns3 module,
122         # that is re-loaded each time a ns-3 wrapper is instantiated,
123         # is that else each unit test for the ns3wrapper class would need
124         # a separate file. Several ns3wrappers would be created in the 
125         # same unit test (single process), leading to inchorences in the 
126         # state of ns-3 global objects
127         #
128         # Handler to ns3 classes
129         self._ns3 = None
130
131         # Collection of allowed ns3 classes
132         self._allowed_types = None
133
134         # Object to dump instructions to reproduce and debug experiment
135         from nepi.resources.ns3.ns3wrapper_debug import NS3WrapperDebuger
136         self._debuger = NS3WrapperDebuger(enabled = enable_dump)
137
138     @property
139     def debuger(self):
140         return self._debuger
141
142     @property
143     def ns3(self):
144         if not self._ns3:
145             # load ns-3 libraries and bindings
146             self._ns3 = load_ns3_module()
147
148         return self._ns3
149
150     @property
151     def allowed_types(self):
152         if not self._allowed_types:
153             self._allowed_types = set()
154             type_id = self.ns3.TypeId()
155             
156             tid_count = type_id.GetRegisteredN()
157             base = type_id.LookupByName("ns3::Object")
158
159             for i in range(tid_count):
160                 tid = type_id.GetRegistered(i)
161                 
162                 if tid.MustHideFromDocumentation() or \
163                         not tid.HasConstructor() or \
164                         not tid.IsChildOf(base): 
165                     continue
166
167                 type_name = tid.GetName()
168                 self._allowed_types.add(type_name)
169         
170         return self._allowed_types
171
172     @property
173     def logger(self):
174         return self._logger
175
176     @property
177     def is_running(self):
178         return self.is_started and not self.ns3.Simulator.IsFinished()
179
180     @property
181     def is_started(self):
182         if not self._started:
183             now = self.ns3.Simulator.Now()
184             if not now.IsZero():
185                 self._started = True
186
187         return self._started
188
189     @property
190     def is_finished(self):
191         return self.ns3.Simulator.IsFinished()
192
193     def make_uuid(self):
194         return "uuid%s" % uuid.uuid4()
195
196     def get_object(self, uuid):
197         return self._objects.get(uuid)
198
199     def factory(self, type_name, **kwargs):
200         """ This method should be used to construct ns-3 objects
201         that have a TypeId and related introspection information """
202
203         if type_name not in self.allowed_types:
204             msg = "Type %s not supported" % (type_name) 
205             self.logger.error(msg)
206
207         uuid = self.make_uuid()
208         
209         ### DEBUG
210         self.logger.debug("FACTORY %s( %s )" % (type_name, str(kwargs)))
211         
212         ### DUMP
213         self.debuger.dump_factory(uuid, type_name, kwargs)
214
215         factory = self.ns3.ObjectFactory()
216         factory.SetTypeId(type_name)
217
218         for name, value in kwargs.items():
219             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
220             factory.Set(name, ns3_value)
221
222         obj = factory.Create()
223
224         self._objects[uuid] = obj
225
226         ### DEBUG
227         self.logger.debug("RET FACTORY ( uuid %s ) %s = %s( %s )" % (
228             str(uuid), str(obj), type_name, str(kwargs)))
229  
230         return uuid
231
232     def create(self, clazzname, *args):
233         """ This method should be used to construct ns-3 objects that
234         do not have a TypeId (e.g. Values) """
235
236         if not hasattr(self.ns3, clazzname):
237             msg = "Type %s not supported" % (clazzname) 
238             self.logger.error(msg)
239
240         uuid = self.make_uuid()
241         
242         ### DEBUG
243         self.logger.debug("CREATE %s( %s )" % (clazzname, str(args)))
244     
245         ### DUMP
246         self.debuger.dump_create(uuid, clazzname, args)
247
248         clazz = getattr(self.ns3, clazzname)
249  
250         # arguments starting with 'uuid' identify ns-3 C++
251         # objects and must be replaced by the actual object
252         realargs = self.replace_args(args)
253        
254         obj = clazz(*realargs)
255         
256         self._objects[uuid] = obj
257
258         ### DEBUG
259         self.logger.debug("RET CREATE ( uuid %s ) %s = %s( %s )" % (str(uuid), 
260             str(obj), clazzname, str(args)))
261
262         return uuid
263
264     def invoke(self, uuid, operation, *args, **kwargs):
265         ### DEBUG
266         self.logger.debug("INVOKE %s -> %s( %s, %s ) " % (
267             uuid, operation, str(args), str(kwargs)))
268         ########
269
270         result = None
271         newuuid = None
272
273         if operation == "isRunning":
274             result = self.is_running
275
276         elif operation == "isStarted":
277             result = self.is_started
278
279         elif operation == "isFinished":
280             result = self.is_finished
281
282         elif operation == "isAppRunning":
283             result = self._is_app_running(uuid)
284
285         elif operation == "isAppStarted":
286             result = self._is_app_started(uuid)
287
288         elif operation == "recvFD":
289             ### passFD operation binds to a different random socket 
290             ### en every execution, so the socket name that could be
291             ### dumped to the debug script using dump_invoke is
292             ### not be valid accross debug executions.
293             result = self._recv_fd(uuid, *args, **kwargs)
294
295         elif operation == "addStaticRoute":
296             result = self._add_static_route(uuid, *args)
297             
298             ### DUMP - result is static, so will be dumped as plain text
299             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
300
301         elif operation == "retrieveObject":
302             result = self._retrieve_object(uuid, *args, **kwargs)
303        
304             ### DUMP - result is static, so will be dumped as plain text
305             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
306        
307         else:
308             newuuid = self.make_uuid()
309
310             ### DUMP - result is a uuid that encoded an dynamically generated 
311             ### object
312             self.debuger.dump_invoke(newuuid, uuid, operation, args, kwargs)
313
314             if uuid.startswith(SINGLETON):
315                 obj = self._singleton(uuid)
316             else:
317                 obj = self.get_object(uuid)
318             
319             method = getattr(obj, operation)
320
321             # arguments starting with 'uuid' identify ns-3 C++
322             # objects and must be replaced by the actual object
323             realargs = self.replace_args(args)
324             realkwargs = self.replace_kwargs(kwargs)
325
326             result = method(*realargs, **realkwargs)
327
328             # If the result is an object (not a base value),
329             # then keep track of the object a return the object
330             # reference (newuuid)
331             if result is not None \
332               and not isinstance(result, (bool, float) + integer_types + string_types):
333                 self._objects[newuuid] = result
334                 result = newuuid
335
336         ### DEBUG
337         self.logger.debug("RET INVOKE %s%s = %s -> %s(%s, %s) " % (
338             "(uuid %s) " % str(newuuid) if newuuid else "", str(result), uuid, 
339             operation, str(args), str(kwargs)))
340         ########
341
342         return result
343
344     def _set_attr(self, obj, name, ns3_value):
345         obj.SetAttribute(name, ns3_value)
346
347     def set(self, uuid, name, value):
348         ### DEBUG
349         self.logger.debug("SET %s %s %s" % (uuid, name, str(value)))
350     
351         ### DUMP
352         self.debuger.dump_set(uuid, name, value)
353
354         obj = self.get_object(uuid)
355         type_name = obj.GetInstanceTypeId().GetName()
356         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
357
358         # If the Simulation thread is not running,
359         # then there will be no thread-safety problems
360         # in changing the value of an attribute directly.
361         # However, if the simulation is running we need
362         # to set the value by scheduling an event, else
363         # we risk to corrupt the state of the
364         # simulation.
365         
366         event_executed = [False]
367
368         if self.is_running:
369             # schedule the event in the Simulator
370             self._schedule_event(self._condition, event_executed, 
371                     self._set_attr, obj, name, ns3_value)
372
373         if not event_executed[0]:
374             self._set_attr(obj, name, ns3_value)
375
376         ### DEBUG
377         self.logger.debug("RET SET %s = %s -> set(%s, %s)" % (str(value), uuid, name, 
378             str(value)))
379
380         return value
381
382     def _get_attr(self, obj, name, ns3_value):
383         obj.GetAttribute(name, ns3_value)
384
385     def get(self, uuid, name):
386         ### DEBUG
387         self.logger.debug("GET %s %s" % (uuid, name))
388         
389         ### DUMP
390         self.debuger.dump_get(uuid, name)
391
392         obj = self.get_object(uuid)
393         type_name = obj.GetInstanceTypeId().GetName()
394         ns3_value = self._create_attr_ns3_value(type_name, name)
395
396         event_executed = [False]
397
398         if self.is_running:
399             # schedule the event in the Simulator
400             self._schedule_event(self._condition, event_executed,
401                     self._get_attr, obj, name, ns3_value)
402
403         if not event_executed[0]:
404             self._get_attr(obj, name, ns3_value)
405
406         result = self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
407
408         ### DEBUG
409         self.logger.debug("RET GET %s = %s -> get(%s)" % (str(result), uuid, name))
410
411         return result
412
413     def start(self):
414         ### DUMP
415         self.debuger.dump_start()
416
417         # Launch the simulator thread and Start the
418         # simulator in that thread
419         self._condition = threading.Condition()
420         self._simulator_thread = threading.Thread(
421                 target = self._simulator_run,
422                 args = [self._condition])
423         self._simulator_thread.setDaemon(True)
424         self._simulator_thread.start()
425         
426         ### DEBUG
427         self.logger.debug("START")
428
429     def stop(self, time = None):
430         ### DUMP
431         self.debuger.dump_stop(time=time)
432         
433         if time is None:
434             self.ns3.Simulator.Stop()
435         else:
436             self.ns3.Simulator.Stop(self.ns3.Time(time))
437
438         ### DEBUG
439         self.logger.debug("STOP time=%s" % str(time))
440
441     def shutdown(self):
442         ### DUMP
443         self.debuger.dump_shutdown()
444
445         while not self.ns3.Simulator.IsFinished():
446             #self.logger.debug("Waiting for simulation to finish")
447             time.sleep(0.5)
448         
449         if self._simulator_thread:
450             self._simulator_thread.join()
451        
452         self.ns3.Simulator.Destroy()
453         
454         # Remove all references to ns-3 objects
455         self._objects.clear()
456         
457         sys.stdout.flush()
458         sys.stderr.flush()
459
460         ### DEBUG
461         self.logger.debug("SHUTDOWN")
462
463     def _simulator_run(self, condition):
464         # Run simulation
465         self.ns3.Simulator.Run()
466         # Signal condition to indicate simulation ended and
467         # notify waiting threads
468         condition.acquire()
469         condition.notifyAll()
470         condition.release()
471
472     def _schedule_event(self, condition, event_executed, func, *args):
473         """ Schedules event on running simulation, and wait until
474             event is executed"""
475
476         def execute_event(contextId, condition, event_executed, func, *args):
477             try:
478                 func(*args)
479                 event_executed[0] = True
480             finally:
481                 # notify condition indicating event was executed
482                 condition.acquire()
483                 condition.notifyAll()
484                 condition.release()
485
486         # contextId is defined as general context
487         # xxx possible distortion when upgrading to python3
488         # really not sure what's the point in this long() business..
489         #contextId = long(0xffffffff)
490         contextId = 0xffffffff
491
492         # delay 0 means that the event is expected to execute inmediately
493         delay = self.ns3.Seconds(0)
494     
495         # Mark event as not executed
496         event_executed[0] = False
497
498         condition.acquire()
499         try:
500             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
501                     condition, event_executed, func, *args)
502             if not self.ns3.Simulator.IsFinished():
503                 condition.wait()
504         finally:
505             condition.release()
506
507     def _create_attr_ns3_value(self, type_name, name):
508         TypeId = self.ns3.TypeId()
509         tid = TypeId.LookupByName(type_name)
510         info = TypeId.AttributeInformation()
511         if not tid.LookupAttributeByName(name, info):
512             msg = "TypeId %s has no attribute %s" % (type_name, name) 
513             self.logger.error(msg)
514
515         checker = info.checker
516         ns3_value = checker.Create() 
517         return ns3_value
518
519     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
520         TypeId = self.ns3.TypeId()
521         tid = TypeId.LookupByName(type_name)
522         info = TypeId.AttributeInformation()
523         if not tid.LookupAttributeByName(name, info):
524             msg = "TypeId %s has no attribute %s" % (type_name, name) 
525             self.logger.error(msg)
526
527         checker = info.checker
528         value = ns3_value.SerializeToString(checker)
529
530         type_name = checker.GetValueTypeName()
531         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
532             return int(value)
533         if type_name == "ns3::DoubleValue":
534             return float(value)
535         if type_name == "ns3::BooleanValue":
536             return value == "true"
537
538         return value
539
540     def _attr_from_string_to_ns3_value(self, type_name, name, value):
541         TypeId = self.ns3.TypeId()
542         tid = TypeId.LookupByName(type_name)
543         info = TypeId.AttributeInformation()
544         if not tid.LookupAttributeByName(name, info):
545             msg = "TypeId %s has no attribute %s" % (type_name, name) 
546             self.logger.error(msg)
547
548         str_value = str(value)
549         if isinstance(value, bool):
550             str_value = str_value.lower()
551
552         checker = info.checker
553         ns3_value = checker.Create()
554         ns3_value.DeserializeFromString(str_value, checker)
555         return ns3_value
556
557     # singletons are identified as "ns3::ClassName"
558     def _singleton(self, ident):
559         if not ident.startswith(SINGLETON):
560             return None
561
562         clazzname = ident[ident.find("::")+2:]
563         if not hasattr(self.ns3, clazzname):
564             msg = "Type %s not supported" % (clazzname)
565             self.logger.error(msg)
566
567         return getattr(self.ns3, clazzname)
568
569     # replace uuids and singleton references for the real objects
570     def replace_args(self, args):
571         realargs = [self.get_object(arg) if \
572                 str(arg).startswith("uuid") else arg for arg in args]
573  
574         realargs = [self._singleton(arg) if \
575                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
576
577         return realargs
578
579     # replace uuids and singleton references for the real objects
580     def replace_kwargs(self, kwargs):
581         realkwargs = dict([(k, self.get_object(v) \
582                 if str(v).startswith("uuid") else v) \
583                 for k,v in kwargs.items()])
584  
585         realkwargs = dict([(k, self._singleton(v) \
586                 if str(v).startswith(SINGLETON) else v )\
587                 for k, v in realkwargs.items()])
588
589         return realkwargs
590
591     def _is_app_running(self, uuid):
592         now = self.ns3.Simulator.Now()
593         if now.IsZero():
594             return False
595
596         if self.ns3.Simulator.IsFinished():
597             return False
598
599         app = self.get_object(uuid)
600         stop_time_value = self.ns3.TimeValue()
601         app.GetAttribute("StopTime", stop_time_value)
602         stop_time = stop_time_value.Get()
603
604         start_time_value = self.ns3.TimeValue()
605         app.GetAttribute("StartTime", start_time_value)
606         start_time = start_time_value.Get()
607         
608         if now.Compare(start_time) >= 0:
609             if stop_time.IsZero() or now.Compare(stop_time) < 0:
610                 return True
611
612         return False
613     
614     def _is_app_started(self, uuid):
615         return self._is_app_running(uuid) or self.is_finished
616
617     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
618         ipv4 = self.get_object(ipv4_uuid)
619
620         list_routing = ipv4.GetRoutingProtocol()
621         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
622
623         ifindex = self._find_ifindex(ipv4, nexthop)
624         if ifindex == -1:
625             return False
626         
627         nexthop = self.ns3.Ipv4Address(nexthop)
628
629         if network in ["0.0.0.0", "0", None]:
630             # Default route: 0.0.0.0/0
631             static_routing.SetDefaultRoute(nexthop, ifindex)
632         else:
633             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
634             network = self.ns3.Ipv4Address(network)
635
636             if prefix == 32:
637                 # Host route: x.y.z.w/32
638                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
639             else:
640                 # Network route: x.y.z.w/n
641                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
642                         ifindex) 
643         return True
644
645     def _find_ifindex(self, ipv4, nexthop):
646         ifindex = -1
647
648         nexthop = self.ns3.Ipv4Address(nexthop)
649
650         # For all the interfaces registered with the ipv4 object, find
651         # the one that matches the network of the nexthop
652         nifaces = ipv4.GetNInterfaces()
653         for ifidx in range(nifaces):
654             iface = ipv4.GetInterface(ifidx)
655             naddress = iface.GetNAddresses()
656             for addridx in range(naddress):
657                 ifaddr = iface.GetAddress(addridx)
658                 ifmask = ifaddr.GetMask()
659                 
660                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
661
662                 if ifindex == ifidx:
663                     return ifindex
664         return ifindex
665
666     def _retrieve_object(self, uuid, typeid, search = False):
667         obj = self.get_object(uuid)
668
669         type_id = self.ns3.TypeId()
670         tid = type_id.LookupByName(typeid)
671         nobj = obj.GetObject(tid)
672
673         newuuid = None
674         if search:
675             # search object
676             for ouuid, oobj in self._objects.items():
677                 if nobj == oobj:
678                     newuuid = ouuid
679                     break
680         else: 
681             newuuid = self.make_uuid()
682             self._objects[newuuid] = nobj
683
684         return newuuid
685
686     def _recv_fd(self, uuid):
687         """ Waits on a local address to receive a file descriptor
688         from a local process. The file descriptor is associated
689         to a FdNetDevice to stablish communication between the
690         simulation and what ever process writes on that file descriptor
691         """
692
693         def recvfd(sock, fdnd):
694             (fd, msg) = passfd.recvfd(sock)
695             # Store a reference to the endpoint to keep the socket alive
696             fdnd.SetFileDescriptor(fd)
697         
698         import passfd
699         import socket
700         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
701         sock.bind("")
702         address = sock.getsockname()
703         
704         fdnd = self.get_object(uuid)
705         t = threading.Thread(target=recvfd, args=(sock,fdnd))
706         t.start()
707
708         return address
709
710