enable dce in simulation is now automatic
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28 SIMULATOR_UUID = "singleton::Simulator"
29 CONFIG_UUID = "singleton::Config"
30 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
31 IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
32
33 def load_ns3_libraries():
34     import ctypes
35     import re
36
37     libdir = os.environ.get("NS3LIBRARIES")
38
39     # Load the ns-3 modules shared libraries
40     if libdir:
41         files = os.listdir(libdir)
42         regex = re.compile("(.*\.so)$")
43         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
44
45         initial_size = len(libs)
46         # Try to load the libraries in the right order by trial and error.
47         # Loop until all libraries are loaded.
48         while len(libs) > 0:
49             for lib in libs:
50                 libfile = os.path.join(libdir, lib)
51                 try:
52                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
53                     libs.remove(lib)
54                 except:
55                     #import traceback
56                     #err = traceback.format_exc()
57                     #print err
58                     pass
59
60             # if did not load any libraries in the last iteration break
61             # to prevent infinit loop
62             if initial_size == len(libs):
63                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
64             initial_size = len(libs)
65
66 def load_ns3_module():
67     load_ns3_libraries()
68
69     # import the python bindings for the ns-3 modules
70     bindings = os.environ.get("NS3BINDINGS")
71     if bindings:
72         sys.path.append(bindings)
73
74     import pkgutil
75     import imp
76     import ns
77
78     # create a Python module to add all ns3 classes
79     ns3mod = imp.new_module("ns3")
80     sys.modules["ns3"] = ns3mod
81
82     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
83         if modname in [ "visualizer" ]:
84             continue
85
86         fullmodname = "ns.%s" % modname
87         module = __import__(fullmodname, globals(), locals(), ['*'])
88
89         for sattr in dir(module):
90             if sattr.startswith("_"):
91                 continue
92
93             attr = getattr(module, sattr)
94
95             # netanim.Config and lte.Config singleton overrides ns3::Config
96             if sattr == "Config" and modname in ['netanim', 'lte']:
97                 sattr = "%s.%s" % (modname, sattr)
98
99             setattr(ns3mod, sattr, attr)
100
101     return ns3mod
102
103 class NS3Wrapper(object):
104     def __init__(self, loglevel = logging.INFO):
105         super(NS3Wrapper, self).__init__()
106         # Thread used to run the simulation
107         self._simulation_thread = None
108         self._condition = None
109
110         # True if Simulator::Run was invoked
111         self._started = False
112
113         # holds reference to all C++ objects and variables in the simulation
114         self._objects = dict()
115
116         # Logging
117         self._logger = logging.getLogger("ns3wrapper")
118         self._logger.setLevel(loglevel)
119
120         ## NOTE that the reason to create a handler to the ns3 module,
121         # that is re-loaded each time a ns-3 wrapper is instantiated,
122         # is that else each unit test for the ns3wrapper class would need
123         # a separate file. Several ns3wrappers would be created in the 
124         # same unit test (single process), leading to inchorences in the 
125         # state of ns-3 global objects
126         #
127         # Handler to ns3 classes
128         self._ns3 = None
129
130         # Collection of allowed ns3 classes
131         self._allowed_types = None
132
133     @property
134     def ns3(self):
135         if not self._ns3:
136             # load ns-3 libraries and bindings
137             self._ns3 = load_ns3_module()
138
139         return self._ns3
140
141     @property
142     def allowed_types(self):
143         if not self._allowed_types:
144             self._allowed_types = set()
145             type_id = self.ns3.TypeId()
146             
147             tid_count = type_id.GetRegisteredN()
148             base = type_id.LookupByName("ns3::Object")
149
150             for i in xrange(tid_count):
151                 tid = type_id.GetRegistered(i)
152                 
153                 if tid.MustHideFromDocumentation() or \
154                         not tid.HasConstructor() or \
155                         not tid.IsChildOf(base): 
156                     continue
157
158                 type_name = tid.GetName()
159                 self._allowed_types.add(type_name)
160         
161         return self._allowed_types
162
163     @property
164     def logger(self):
165         return self._logger
166
167     @property
168     def is_running(self):
169         return self._started and self.ns3.Simulator.IsFinished()
170
171     def make_uuid(self):
172         return "uuid%s" % uuid.uuid4()
173
174     def get_object(self, uuid):
175         return self._objects.get(uuid)
176
177     def factory(self, type_name, **kwargs):
178         if type_name not in self.allowed_types:
179             msg = "Type %s not supported" % (type_name) 
180             self.logger.error(msg)
181  
182         factory = self.ns3.ObjectFactory()
183         factory.SetTypeId(type_name)
184
185         for name, value in kwargs.iteritems():
186             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
187             factory.Set(name, ns3_value)
188
189         obj = factory.Create()
190
191         uuid = self.make_uuid()
192         self._objects[uuid] = obj
193
194         return uuid
195
196     def create(self, clazzname, *args):
197         if not hasattr(self.ns3, clazzname):
198             msg = "Type %s not supported" % (clazzname) 
199             self.logger.error(msg)
200      
201         clazz = getattr(self.ns3, clazzname)
202  
203         # arguments starting with 'uuid' identify ns-3 C++
204         # objects and must be replaced by the actual object
205         realargs = self.replace_args(args)
206        
207         obj = clazz(*realargs)
208         
209         uuid = self.make_uuid()
210         self._objects[uuid] = obj
211
212         return uuid
213
214     def invoke(self, uuid, operation, *args, **kwargs):
215         if operation == "isRunning":
216             return self.is_running
217         if operation == "isAppRunning":
218             return self._is_app_running(uuid)
219         if operation == "addStaticRoute":
220             return self._add_static_route(uuid, *args)
221
222         if uuid.startswith(SINGLETON):
223             obj = self._singleton(uuid)
224         else:
225             obj = self.get_object(uuid)
226         
227         method = getattr(obj, operation)
228
229         # arguments starting with 'uuid' identify ns-3 C++
230         # objects and must be replaced by the actual object
231         realargs = self.replace_args(args)
232         realkwargs = self.replace_kwargs(kwargs)
233
234         result = method(*realargs, **realkwargs)
235
236         # If the result is not an object, no need to 
237         # keep a reference. Directly return value.
238         if result is None or type(result) in [bool, float, long, str, int]:
239             return result
240       
241         newuuid = self.make_uuid()
242         self._objects[newuuid] = result
243
244         return newuuid
245
246     def _set_attr(self, obj, name, ns3_value):
247         obj.SetAttribute(name, ns3_value)
248
249     def set(self, uuid, name, value):
250         obj = self.get_object(uuid)
251         type_name = obj.GetInstanceTypeId().GetName()
252         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
253
254         # If the Simulation thread is not running,
255         # then there will be no thread-safety problems
256         # in changing the value of an attribute directly.
257         # However, if the simulation is running we need
258         # to set the value by scheduling an event, else
259         # we risk to corrupt the state of the
260         # simulation.
261         
262         event_executed = [False]
263
264         if self.is_running:
265             # schedule the event in the Simulator
266             self._schedule_event(self._condition, event_executed, 
267                     self._set_attr, obj, name, ns3_value)
268
269         if not event_executed[0]:
270             self._set_attr(obj, name, ns3_value)
271
272         return value
273
274     def _get_attr(self, obj, name, ns3_value):
275         obj.GetAttribute(name, ns3_value)
276
277     def get(self, uuid, name):
278         obj = self.get_object(uuid)
279         type_name = obj.GetInstanceTypeId().GetName()
280         ns3_value = self._create_attr_ns3_value(type_name, name)
281
282         event_executed = [False]
283
284         if self.is_running:
285             # schedule the event in the Simulator
286             self._schedule_event(self._condition, event_executed,
287                     self._get_attr, obj, name, ns3_value)
288
289         if not event_executed[0]:
290             self._get_attr(obj, name, ns3_value)
291
292         return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
293
294     def start(self):
295         # Launch the simulator thread and Start the
296         # simulator in that thread
297         self._condition = threading.Condition()
298         self._simulator_thread = threading.Thread(
299                 target = self._simulator_run,
300                 args = [self._condition])
301         self._simulator_thread.setDaemon(True)
302         self._simulator_thread.start()
303         self._started = True
304
305     def stop(self, time = None):
306         if time is None:
307             self.ns3.Simulator.Stop()
308         else:
309             self.ns3.Simulator.Stop(self.ns3.Time(time))
310
311     def shutdown(self):
312         while not self.ns3.Simulator.IsFinished():
313             #self.logger.debug("Waiting for simulation to finish")
314             time.sleep(0.5)
315         
316         if self._simulator_thread:
317             self._simulator_thread.join()
318        
319         self.ns3.Simulator.Destroy()
320         
321         # Remove all references to ns-3 objects
322         self._objects.clear()
323         
324         sys.stdout.flush()
325         sys.stderr.flush()
326
327     def _simulator_run(self, condition):
328         # Run simulation
329         self.ns3.Simulator.Run()
330         # Signal condition to indicate simulation ended and
331         # notify waiting threads
332         condition.acquire()
333         condition.notifyAll()
334         condition.release()
335
336     def _schedule_event(self, condition, event_executed, func, *args):
337         """ Schedules event on running simulation, and wait until
338             event is executed"""
339
340         def execute_event(contextId, condition, event_executed, func, *args):
341             try:
342                 func(*args)
343                 event_executed[0] = True
344             finally:
345                 # notify condition indicating event was executed
346                 condition.acquire()
347                 condition.notifyAll()
348                 condition.release()
349
350         # contextId is defined as general context
351         contextId = long(0xffffffff)
352
353         # delay 0 means that the event is expected to execute inmediately
354         delay = self.ns3.Seconds(0)
355     
356         # Mark event as not executed
357         event_executed[0] = False
358
359         condition.acquire()
360         try:
361             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
362                     condition, event_executed, func, *args)
363             if not self.ns3.Simulator.IsFinished():
364                 condition.wait()
365         finally:
366             condition.release()
367
368     def _create_attr_ns3_value(self, type_name, name):
369         TypeId = self.ns3.TypeId()
370         tid = TypeId.LookupByName(type_name)
371         info = TypeId.AttributeInformation()
372         if not tid.LookupAttributeByName(name, info):
373             msg = "TypeId %s has no attribute %s" % (type_name, name) 
374             self.logger.error(msg)
375
376         checker = info.checker
377         ns3_value = checker.Create() 
378         return ns3_value
379
380     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
381         TypeId = self.ns3.TypeId()
382         tid = TypeId.LookupByName(type_name)
383         info = TypeId.AttributeInformation()
384         if not tid.LookupAttributeByName(name, info):
385             msg = "TypeId %s has no attribute %s" % (type_name, name) 
386             self.logger.error(msg)
387
388         checker = info.checker
389         value = ns3_value.SerializeToString(checker)
390
391         type_name = checker.GetValueTypeName()
392         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
393             return int(value)
394         if type_name == "ns3::DoubleValue":
395             return float(value)
396         if type_name == "ns3::BooleanValue":
397             return value == "true"
398
399         return value
400
401     def _attr_from_string_to_ns3_value(self, type_name, name, value):
402         TypeId = self.ns3.TypeId()
403         tid = TypeId.LookupByName(type_name)
404         info = TypeId.AttributeInformation()
405         if not tid.LookupAttributeByName(name, info):
406             msg = "TypeId %s has no attribute %s" % (type_name, name) 
407             self.logger.error(msg)
408
409         str_value = str(value)
410         if isinstance(value, bool):
411             str_value = str_value.lower()
412
413         checker = info.checker
414         ns3_value = checker.Create()
415         ns3_value.DeserializeFromString(str_value, checker)
416         return ns3_value
417
418     # singletons are identified as "ns3::ClassName"
419     def _singleton(self, ident):
420         if not ident.startswith(SINGLETON):
421             return None
422
423         clazzname = ident[ident.find("::")+2:]
424         if not hasattr(self.ns3, clazzname):
425             msg = "Type %s not supported" % (clazzname)
426             self.logger.error(msg)
427
428         return getattr(self.ns3, clazzname)
429
430     # replace uuids and singleton references for the real objects
431     def replace_args(self, args):
432         realargs = [self.get_object(arg) if \
433                 str(arg).startswith("uuid") else arg for arg in args]
434  
435         realargs = [self._singleton(arg) if \
436                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
437
438         return realargs
439
440     # replace uuids and singleton references for the real objects
441     def replace_kwargs(self, kwargs):
442         realkwargs = dict([(k, self.get_object(v) \
443                 if str(v).startswith("uuid") else v) \
444                 for k,v in kwargs.iteritems()])
445  
446         realkwargs = dict([(k, self._singleton(v) \
447                 if str(v).startswith(SINGLETON) else v )\
448                 for k, v in realkwargs.iteritems()])
449
450         return realkwargs
451
452     def _is_app_running(self, uuid): 
453         now = self.ns3.Simulator.Now()
454         if now.IsZero():
455             return False
456
457         app = self.get_object(uuid)
458         stop_time_value = self.ns3.TimeValue()
459         app.GetAttribute("StopTime", stop_time_value)
460         stop_time = stop_time_value.Get()
461
462         start_time_value = self.ns3.TimeValue()
463         app.GetAttribute("StartTime", start_time_value)
464         start_time = start_time_value.Get()
465         
466         if now.Compare(start_time) >= 0 and now.Compare(stop_time) < 0:
467             return True
468
469         return False
470
471     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
472         ipv4 = self.get_object(ipv4_uuid)
473
474         list_routing = ipv4.GetRoutingProtocol()
475         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
476
477         ifindex = self._find_ifindex(ipv4, nexthop)
478         if ifindex == -1:
479             return False
480         
481         nexthop = self.ns3.Ipv4Address(nexthop)
482
483         if network in ["0.0.0.0", "0", None]:
484             # Default route: 0.0.0.0/0
485             static_routing.SetDefaultRoute(nexthop, ifindex)
486         else:
487             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
488             network = self.ns3.Ipv4Address(network)
489
490             if prefix == 32:
491                 # Host route: x.y.z.w/32
492                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
493             else:
494                 # Network route: x.y.z.w/n
495                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
496                         ifindex) 
497         return True
498
499     def _find_ifindex(self, ipv4, nexthop):
500         ifindex = -1
501
502         nexthop = self.ns3.Ipv4Address(nexthop)
503
504         # For all the interfaces registered with the ipv4 object, find
505         # the one that matches the network of the nexthop
506         nifaces = ipv4.GetNInterfaces()
507         for ifidx in xrange(nifaces):
508             iface = ipv4.GetInterface(ifidx)
509             naddress = iface.GetNAddresses()
510             for addridx in xrange(naddress):
511                 ifaddr = iface.GetAddress(addridx)
512                 ifmask = ifaddr.GetMask()
513                 
514                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
515
516                 if ifindex == ifidx:
517                     return ifindex
518         return ifindex
519