7c26ac21fac4175bd88327b0fd13b7bd3e86187b
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28 SIMULATOR_UUID = "singleton::Simulator"
29 CONFIG_UUID = "singleton::Config"
30 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
31 IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
32
33 def load_ns3_libraries():
34     import ctypes
35     import re
36
37     libdir = os.environ.get("NS3LIBRARIES")
38
39     # Load the ns-3 modules shared libraries
40     if libdir:
41         files = os.listdir(libdir)
42         regex = re.compile("(.*\.so)$")
43         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
44
45         initial_size = len(libs)
46         # Try to load the libraries in the right order by trial and error.
47         # Loop until all libraries are loaded.
48         while len(libs) > 0:
49             for lib in libs:
50                 libfile = os.path.join(libdir, lib)
51                 try:
52                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
53                     libs.remove(lib)
54                 except:
55                     #import traceback
56                     #err = traceback.format_exc()
57                     #print err
58                     pass
59
60             # if did not load any libraries in the last iteration break
61             # to prevent infinit loop
62             if initial_size == len(libs):
63                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
64             initial_size = len(libs)
65
66 def load_ns3_module():
67     load_ns3_libraries()
68
69     # import the python bindings for the ns-3 modules
70     bindings = os.environ.get("NS3BINDINGS")
71     if bindings:
72         sys.path.append(bindings)
73
74     import pkgutil
75     import imp
76     import ns
77
78     # create a Python module to add all ns3 classes
79     ns3mod = imp.new_module("ns3")
80     sys.modules["ns3"] = ns3mod
81
82     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
83         if modname in [ "visualizer" ]:
84             continue
85
86         fullmodname = "ns.%s" % modname
87         module = __import__(fullmodname, globals(), locals(), ['*'])
88
89         for sattr in dir(module):
90             if sattr.startswith("_"):
91                 continue
92
93             attr = getattr(module, sattr)
94
95             # netanim.Config and lte.Config singleton overrides ns3::Config
96             if sattr == "Config" and modname in ['netanim', 'lte']:
97                 sattr = "%s.%s" % (modname, sattr)
98
99             setattr(ns3mod, sattr, attr)
100
101     return ns3mod
102
103 class NS3Wrapper(object):
104     def __init__(self, loglevel = logging.INFO, enable_dump = False):
105         super(NS3Wrapper, self).__init__()
106         # Thread used to run the simulation
107         self._simulation_thread = None
108         self._condition = None
109
110         # True if Simulator::Run was invoked
111         self._started = False
112
113         # holds reference to all C++ objects and variables in the simulation
114         self._objects = dict()
115
116         # Logging
117         self._logger = logging.getLogger("ns3wrapper")
118         self._logger.setLevel(loglevel)
119
120         ## NOTE that the reason to create a handler to the ns3 module,
121         # that is re-loaded each time a ns-3 wrapper is instantiated,
122         # is that else each unit test for the ns3wrapper class would need
123         # a separate file. Several ns3wrappers would be created in the 
124         # same unit test (single process), leading to inchorences in the 
125         # state of ns-3 global objects
126         #
127         # Handler to ns3 classes
128         self._ns3 = None
129
130         # Collection of allowed ns3 classes
131         self._allowed_types = None
132
133         # Object to dump instructions to reproduce and debug experiment
134         from ns3wrapper_debug import NS3WrapperDebuger
135         self._debuger = NS3WrapperDebuger(enabled = enable_dump)
136
137     @property
138     def debuger(self):
139         return self._debuger
140
141     @property
142     def ns3(self):
143         if not self._ns3:
144             # load ns-3 libraries and bindings
145             self._ns3 = load_ns3_module()
146
147         return self._ns3
148
149     @property
150     def allowed_types(self):
151         if not self._allowed_types:
152             self._allowed_types = set()
153             type_id = self.ns3.TypeId()
154             
155             tid_count = type_id.GetRegisteredN()
156             base = type_id.LookupByName("ns3::Object")
157
158             for i in xrange(tid_count):
159                 tid = type_id.GetRegistered(i)
160                 
161                 if tid.MustHideFromDocumentation() or \
162                         not tid.HasConstructor() or \
163                         not tid.IsChildOf(base): 
164                     continue
165
166                 type_name = tid.GetName()
167                 self._allowed_types.add(type_name)
168         
169         return self._allowed_types
170
171     @property
172     def logger(self):
173         return self._logger
174
175     @property
176     def is_running(self):
177         return self.is_started and not self.ns3.Simulator.IsFinished()
178
179     @property
180     def is_started(self):
181         if not self._started:
182             now = self.ns3.Simulator.Now()
183             if not now.IsZero():
184                 self._started = True
185
186         return self._started
187
188     @property
189     def is_finished(self):
190         return self.ns3.Simulator.IsFinished()
191
192     def make_uuid(self):
193         return "uuid%s" % uuid.uuid4()
194
195     def get_object(self, uuid):
196         return self._objects.get(uuid)
197
198     def factory(self, type_name, **kwargs):
199         """ This method should be used to construct ns-3 objects
200         that have a TypeId and related introspection information """
201
202         if type_name not in self.allowed_types:
203             msg = "Type %s not supported" % (type_name) 
204             self.logger.error(msg)
205
206         uuid = self.make_uuid()
207         
208         ### DEBUG
209         self.logger.debug("FACTORY %s( %s )" % (type_name, str(kwargs)))
210         
211         ### DUMP
212         self.debuger.dump_factory(uuid, type_name, kwargs)
213
214         factory = self.ns3.ObjectFactory()
215         factory.SetTypeId(type_name)
216
217         for name, value in kwargs.iteritems():
218             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
219             factory.Set(name, ns3_value)
220
221         obj = factory.Create()
222
223         self._objects[uuid] = obj
224
225         ### DEBUG
226         self.logger.debug("RET FACTORY ( uuid %s ) %s = %s( %s )" % (
227             str(uuid), str(obj), type_name, str(kwargs)))
228  
229         return uuid
230
231     def create(self, clazzname, *args):
232         """ This method should be used to construct ns-3 objects that
233         do not have a TypeId (e.g. Values) """
234
235         if not hasattr(self.ns3, clazzname):
236             msg = "Type %s not supported" % (clazzname) 
237             self.logger.error(msg)
238
239         uuid = self.make_uuid()
240         
241         ### DEBUG
242         self.logger.debug("CREATE %s( %s )" % (clazzname, str(args)))
243     
244         ### DUMP
245         self.debuger.dump_create(uuid, clazzname, args)
246
247         clazz = getattr(self.ns3, clazzname)
248  
249         # arguments starting with 'uuid' identify ns-3 C++
250         # objects and must be replaced by the actual object
251         realargs = self.replace_args(args)
252        
253         obj = clazz(*realargs)
254         
255         self._objects[uuid] = obj
256
257         ### DEBUG
258         self.logger.debug("RET CREATE ( uuid %s ) %s = %s( %s )" % (str(uuid), 
259             str(obj), clazzname, str(args)))
260
261         return uuid
262
263     def invoke(self, uuid, operation, *args, **kwargs):
264         ### DEBUG
265         self.logger.debug("INVOKE %s -> %s( %s, %s ) " % (
266             uuid, operation, str(args), str(kwargs)))
267         ########
268
269         result = None
270         newuuid = None
271
272         if operation == "isRunning":
273             result = self.is_running
274
275         elif operation == "isStarted":
276             result = self.is_started
277
278         elif operation == "isFinished":
279             result = self.is_finished
280
281         elif operation == "isAppRunning":
282             result = self._is_app_running(uuid)
283
284         elif operation == "recvFD":
285             ### passFD operation binds to a different random socket 
286             ### en every execution, so the socket name that could be
287             ### dumped to the debug script using dump_invoke is
288             ### not be valid accross debug executions.
289             result = self._recv_fd(uuid, *args, **kwargs)
290
291         elif operation == "addStaticRoute":
292             result = self._add_static_route(uuid, *args)
293             
294             ### DUMP - result is static, so will be dumped as plain text
295             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
296
297         elif operation == "retrieveObject":
298             result = self._retrieve_object(uuid, *args, **kwargs)
299        
300             ### DUMP - result is static, so will be dumped as plain text
301             self.debuger.dump_invoke(result, uuid, operation, args, kwargs)
302        
303         else:
304             newuuid = self.make_uuid()
305
306             ### DUMP - result is a uuid that encoded an dynamically generated 
307             ### object
308             self.debuger.dump_invoke(newuuid, uuid, operation, args, kwargs)
309
310             if uuid.startswith(SINGLETON):
311                 obj = self._singleton(uuid)
312             else:
313                 obj = self.get_object(uuid)
314             
315             method = getattr(obj, operation)
316
317             # arguments starting with 'uuid' identify ns-3 C++
318             # objects and must be replaced by the actual object
319             realargs = self.replace_args(args)
320             realkwargs = self.replace_kwargs(kwargs)
321
322             result = method(*realargs, **realkwargs)
323
324             # If the result is an object (not a base value),
325             # then keep track of the object a return the object
326             # reference (newuuid)
327             if not (result is None or type(result) in [
328                     bool, float, long, str, int]):
329                 self._objects[newuuid] = result
330                 result = newuuid
331
332         ### DEBUG
333         self.logger.debug("RET INVOKE %s%s = %s -> %s(%s, %s) " % (
334             "(uuid %s) " % str(newuuid) if newuuid else "", str(result), uuid, 
335             operation, str(args), str(kwargs)))
336         ########
337
338         return result
339
340     def _set_attr(self, obj, name, ns3_value):
341         obj.SetAttribute(name, ns3_value)
342
343     def set(self, uuid, name, value):
344         ### DEBUG
345         self.logger.debug("SET %s %s %s" % (uuid, name, str(value)))
346     
347         ### DUMP
348         self.debuger.dump_set(uuid, name, value)
349
350         obj = self.get_object(uuid)
351         type_name = obj.GetInstanceTypeId().GetName()
352         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
353
354         # If the Simulation thread is not running,
355         # then there will be no thread-safety problems
356         # in changing the value of an attribute directly.
357         # However, if the simulation is running we need
358         # to set the value by scheduling an event, else
359         # we risk to corrupt the state of the
360         # simulation.
361         
362         event_executed = [False]
363
364         if self.is_running:
365             # schedule the event in the Simulator
366             self._schedule_event(self._condition, event_executed, 
367                     self._set_attr, obj, name, ns3_value)
368
369         if not event_executed[0]:
370             self._set_attr(obj, name, ns3_value)
371
372         ### DEBUG
373         self.logger.debug("RET SET %s = %s -> set(%s, %s)" % (str(value), uuid, name, 
374             str(value)))
375
376         return value
377
378     def _get_attr(self, obj, name, ns3_value):
379         obj.GetAttribute(name, ns3_value)
380
381     def get(self, uuid, name):
382         ### DEBUG
383         self.logger.debug("GET %s %s" % (uuid, name))
384         
385         ### DUMP
386         self.debuger.dump_get(uuid, name)
387
388         obj = self.get_object(uuid)
389         type_name = obj.GetInstanceTypeId().GetName()
390         ns3_value = self._create_attr_ns3_value(type_name, name)
391
392         event_executed = [False]
393
394         if self.is_running:
395             # schedule the event in the Simulator
396             self._schedule_event(self._condition, event_executed,
397                     self._get_attr, obj, name, ns3_value)
398
399         if not event_executed[0]:
400             self._get_attr(obj, name, ns3_value)
401
402         result = self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
403
404         ### DEBUG
405         self.logger.debug("RET GET %s = %s -> get(%s)" % (str(result), uuid, name))
406
407         return result
408
409     def start(self):
410         ### DUMP
411         self.debuger.dump_start()
412
413         # Launch the simulator thread and Start the
414         # simulator in that thread
415         self._condition = threading.Condition()
416         self._simulator_thread = threading.Thread(
417                 target = self._simulator_run,
418                 args = [self._condition])
419         self._simulator_thread.setDaemon(True)
420         self._simulator_thread.start()
421         
422         ### DEBUG
423         self.logger.debug("START")
424
425     def stop(self, time = None):
426         ### DUMP
427         self.debuger.dump_stop(time=time)
428         
429         if time is None:
430             self.ns3.Simulator.Stop()
431         else:
432             self.ns3.Simulator.Stop(self.ns3.Time(time))
433
434         ### DEBUG
435         self.logger.debug("STOP time=%s" % str(time))
436
437     def shutdown(self):
438         ### DUMP
439         self.debuger.dump_shutdown()
440
441         while not self.ns3.Simulator.IsFinished():
442             #self.logger.debug("Waiting for simulation to finish")
443             time.sleep(0.5)
444         
445         if self._simulator_thread:
446             self._simulator_thread.join()
447        
448         self.ns3.Simulator.Destroy()
449         
450         # Remove all references to ns-3 objects
451         self._objects.clear()
452         
453         sys.stdout.flush()
454         sys.stderr.flush()
455
456         ### DEBUG
457         self.logger.debug("SHUTDOWN")
458
459     def _simulator_run(self, condition):
460         # Run simulation
461         self.ns3.Simulator.Run()
462         # Signal condition to indicate simulation ended and
463         # notify waiting threads
464         condition.acquire()
465         condition.notifyAll()
466         condition.release()
467
468     def _schedule_event(self, condition, event_executed, func, *args):
469         """ Schedules event on running simulation, and wait until
470             event is executed"""
471
472         def execute_event(contextId, condition, event_executed, func, *args):
473             try:
474                 func(*args)
475                 event_executed[0] = True
476             finally:
477                 # notify condition indicating event was executed
478                 condition.acquire()
479                 condition.notifyAll()
480                 condition.release()
481
482         # contextId is defined as general context
483         contextId = long(0xffffffff)
484
485         # delay 0 means that the event is expected to execute inmediately
486         delay = self.ns3.Seconds(0)
487     
488         # Mark event as not executed
489         event_executed[0] = False
490
491         condition.acquire()
492         try:
493             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
494                     condition, event_executed, func, *args)
495             if not self.ns3.Simulator.IsFinished():
496                 condition.wait()
497         finally:
498             condition.release()
499
500     def _create_attr_ns3_value(self, type_name, name):
501         TypeId = self.ns3.TypeId()
502         tid = TypeId.LookupByName(type_name)
503         info = TypeId.AttributeInformation()
504         if not tid.LookupAttributeByName(name, info):
505             msg = "TypeId %s has no attribute %s" % (type_name, name) 
506             self.logger.error(msg)
507
508         checker = info.checker
509         ns3_value = checker.Create() 
510         return ns3_value
511
512     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
513         TypeId = self.ns3.TypeId()
514         tid = TypeId.LookupByName(type_name)
515         info = TypeId.AttributeInformation()
516         if not tid.LookupAttributeByName(name, info):
517             msg = "TypeId %s has no attribute %s" % (type_name, name) 
518             self.logger.error(msg)
519
520         checker = info.checker
521         value = ns3_value.SerializeToString(checker)
522
523         type_name = checker.GetValueTypeName()
524         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
525             return int(value)
526         if type_name == "ns3::DoubleValue":
527             return float(value)
528         if type_name == "ns3::BooleanValue":
529             return value == "true"
530
531         return value
532
533     def _attr_from_string_to_ns3_value(self, type_name, name, value):
534         TypeId = self.ns3.TypeId()
535         tid = TypeId.LookupByName(type_name)
536         info = TypeId.AttributeInformation()
537         if not tid.LookupAttributeByName(name, info):
538             msg = "TypeId %s has no attribute %s" % (type_name, name) 
539             self.logger.error(msg)
540
541         str_value = str(value)
542         if isinstance(value, bool):
543             str_value = str_value.lower()
544
545         checker = info.checker
546         ns3_value = checker.Create()
547         ns3_value.DeserializeFromString(str_value, checker)
548         return ns3_value
549
550     # singletons are identified as "ns3::ClassName"
551     def _singleton(self, ident):
552         if not ident.startswith(SINGLETON):
553             return None
554
555         clazzname = ident[ident.find("::")+2:]
556         if not hasattr(self.ns3, clazzname):
557             msg = "Type %s not supported" % (clazzname)
558             self.logger.error(msg)
559
560         return getattr(self.ns3, clazzname)
561
562     # replace uuids and singleton references for the real objects
563     def replace_args(self, args):
564         realargs = [self.get_object(arg) if \
565                 str(arg).startswith("uuid") else arg for arg in args]
566  
567         realargs = [self._singleton(arg) if \
568                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
569
570         return realargs
571
572     # replace uuids and singleton references for the real objects
573     def replace_kwargs(self, kwargs):
574         realkwargs = dict([(k, self.get_object(v) \
575                 if str(v).startswith("uuid") else v) \
576                 for k,v in kwargs.iteritems()])
577  
578         realkwargs = dict([(k, self._singleton(v) \
579                 if str(v).startswith(SINGLETON) else v )\
580                 for k, v in realkwargs.iteritems()])
581
582         return realkwargs
583
584     def _is_app_running(self, uuid):
585         now = self.ns3.Simulator.Now()
586         if now.IsZero():
587             return False
588
589         if self.ns3.Simulator.IsFinished():
590             return False
591
592         app = self.get_object(uuid)
593         stop_time_value = self.ns3.TimeValue()
594         app.GetAttribute("StopTime", stop_time_value)
595         stop_time = stop_time_value.Get()
596
597         start_time_value = self.ns3.TimeValue()
598         app.GetAttribute("StartTime", start_time_value)
599         start_time = start_time_value.Get()
600         
601         if now.Compare(start_time) >= 0:
602             if stop_time.IsZero() or now.Compare(stop_time) < 0:
603                 return True
604
605         return False
606
607     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
608         ipv4 = self.get_object(ipv4_uuid)
609
610         list_routing = ipv4.GetRoutingProtocol()
611         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
612
613         ifindex = self._find_ifindex(ipv4, nexthop)
614         if ifindex == -1:
615             return False
616         
617         nexthop = self.ns3.Ipv4Address(nexthop)
618
619         if network in ["0.0.0.0", "0", None]:
620             # Default route: 0.0.0.0/0
621             static_routing.SetDefaultRoute(nexthop, ifindex)
622         else:
623             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
624             network = self.ns3.Ipv4Address(network)
625
626             if prefix == 32:
627                 # Host route: x.y.z.w/32
628                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
629             else:
630                 # Network route: x.y.z.w/n
631                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
632                         ifindex) 
633         return True
634
635     def _find_ifindex(self, ipv4, nexthop):
636         ifindex = -1
637
638         nexthop = self.ns3.Ipv4Address(nexthop)
639
640         # For all the interfaces registered with the ipv4 object, find
641         # the one that matches the network of the nexthop
642         nifaces = ipv4.GetNInterfaces()
643         for ifidx in xrange(nifaces):
644             iface = ipv4.GetInterface(ifidx)
645             naddress = iface.GetNAddresses()
646             for addridx in xrange(naddress):
647                 ifaddr = iface.GetAddress(addridx)
648                 ifmask = ifaddr.GetMask()
649                 
650                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
651
652                 if ifindex == ifidx:
653                     return ifindex
654         return ifindex
655
656     def _retrieve_object(self, uuid, typeid, search = False):
657         obj = self.get_object(uuid)
658
659         type_id = self.ns3.TypeId()
660         tid = type_id.LookupByName(typeid)
661         nobj = obj.GetObject(tid)
662
663         newuuid = None
664         if search:
665             # search object
666             for ouuid, oobj in self._objects.iteritems():
667                 if nobj == oobj:
668                     newuuid = ouuid
669                     break
670         else: 
671             newuuid = self.make_uuid()
672             self._objects[newuuid] = nobj
673
674         return newuuid
675
676     def _recv_fd(self, uuid):
677         """ Waits on a local address to receive a file descriptor
678         from a local process. The file descriptor is associated
679         to a FdNetDevice to stablish communication between the
680         simulation and what ever process writes on that file descriptor
681         """
682
683         def recvfd(sock, fdnd):
684             (fd, msg) = passfd.recvfd(sock)
685             # Store a reference to the endpoint to keep the socket alive
686             fdnd.SetFileDescriptor(fd)
687         
688         import passfd
689         import socket
690         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
691         sock.bind("")
692         address = sock.getsockname()
693         
694         fdnd = self.get_object(uuid)
695         t = threading.Thread(target=recvfd, args=(sock,fdnd))
696         t.start()
697
698         return address
699
700