Fix #126 Routes configuration
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import time
25 import uuid
26
27 SINGLETON = "singleton::"
28 SIMULATOR_UUID = "singleton::Simulator"
29 CONFIG_UUID = "singleton::Config"
30 GLOBAL_VALUE_UUID = "singleton::GlobalValue"
31
32 def load_ns3_module():
33     import ctypes
34     import re
35
36     bindings = os.environ.get("NS3BINDINGS")
37     libdir = os.environ.get("NS3LIBRARIES")
38
39     # Load the ns-3 modules shared libraries
40     if libdir:
41         files = os.listdir(libdir)
42         regex = re.compile("(.*\.so)$")
43         libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
44
45         initial_size = len(libs)
46         # Try to load the libraries in the right order by trial and error.
47         # Loop until all libraries are loaded.
48         while len(libs) > 0:
49             for lib in libs:
50                 libfile = os.path.join(libdir, lib)
51                 try:
52                     ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
53                     libs.remove(lib)
54                 except:
55                     #import traceback
56                     #err = traceback.format_exc()
57                     #print err
58                     pass
59
60             # if did not load any libraries in the last iteration break
61             # to prevent infinit loop
62             if initial_size == len(libs):
63                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
64             initial_size = list(libs)
65
66     # import the python bindings for the ns-3 modules
67     if bindings:
68         sys.path.append(bindings)
69
70     import pkgutil
71     import imp
72     import ns
73
74     # create a Python module to add all ns3 classes
75     ns3mod = imp.new_module("ns3")
76     sys.modules["ns3"] = ns3mod
77
78     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
79         if modname in [ "visualizer" ]:
80             continue
81
82         fullmodname = "ns.%s" % modname
83         module = __import__(fullmodname, globals(), locals(), ['*'])
84
85         for sattr in dir(module):
86             if sattr.startswith("_"):
87                 continue
88
89             attr = getattr(module, sattr)
90
91             # netanim.Config and lte.Config singleton overrides ns3::Config
92             if sattr == "Config" and modname in ['netanim', 'lte']:
93                 sattr = "%s.%s" % (modname, sattr)
94
95             setattr(ns3mod, sattr, attr)
96
97     return ns3mod
98
99 class NS3Wrapper(object):
100     def __init__(self, loglevel = logging.INFO):
101         super(NS3Wrapper, self).__init__()
102         # Thread used to run the simulation
103         self._simulation_thread = None
104         self._condition = None
105
106         # True if Simulator::Run was invoked
107         self._started = False
108
109         # holds reference to all C++ objects and variables in the simulation
110         self._objects = dict()
111
112         # Logging
113         self._logger = logging.getLogger("ns3wrapper")
114         self._logger.setLevel(loglevel)
115
116         ## NOTE that the reason to create a handler to the ns3 module,
117         # that is re-loaded each time a ns-3 wrapper is instantiated,
118         # is that else each unit test for the ns3wrapper class would need
119         # a separate file. Several ns3wrappers would be created in the 
120         # same unit test (single process), leading to inchorences in the 
121         # state of ns-3 global objects
122         #
123         # Handler to ns3 classes
124         self._ns3 = None
125
126         # Collection of allowed ns3 classes
127         self._allowed_types = None
128
129     @property
130     def ns3(self):
131         if not self._ns3:
132             # load ns-3 libraries and bindings
133             self._ns3 = load_ns3_module()
134
135         return self._ns3
136
137     @property
138     def allowed_types(self):
139         if not self._allowed_types:
140             self._allowed_types = set()
141             type_id = self.ns3.TypeId()
142             
143             tid_count = type_id.GetRegisteredN()
144             base = type_id.LookupByName("ns3::Object")
145
146             for i in xrange(tid_count):
147                 tid = type_id.GetRegistered(i)
148                 
149                 if tid.MustHideFromDocumentation() or \
150                         not tid.HasConstructor() or \
151                         not tid.IsChildOf(base): 
152                     continue
153
154                 type_name = tid.GetName()
155                 self._allowed_types.add(type_name)
156         
157         return self._allowed_types
158
159     @property
160     def logger(self):
161         return self._logger
162
163     @property
164     def is_running(self):
165         return self._started and self.ns3.Simulator.IsFinished()
166
167     def make_uuid(self):
168         return "uuid%s" % uuid.uuid4()
169
170     def get_object(self, uuid):
171         return self._objects.get(uuid)
172
173     def factory(self, type_name, **kwargs):
174         if type_name not in self.allowed_types:
175             msg = "Type %s not supported" % (type_name) 
176             self.logger.error(msg)
177  
178         factory = self.ns3.ObjectFactory()
179         factory.SetTypeId(type_name)
180
181         for name, value in kwargs.iteritems():
182             ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
183             factory.Set(name, ns3_value)
184
185         obj = factory.Create()
186
187         uuid = self.make_uuid()
188         self._objects[uuid] = obj
189
190         return uuid
191
192     def create(self, clazzname, *args):
193         if not hasattr(self.ns3, clazzname):
194             msg = "Type %s not supported" % (clazzname) 
195             self.logger.error(msg)
196      
197         clazz = getattr(self.ns3, clazzname)
198  
199         # arguments starting with 'uuid' identify ns-3 C++
200         # objects and must be replaced by the actual object
201         realargs = self.replace_args(args)
202        
203         obj = clazz(*realargs)
204         
205         uuid = self.make_uuid()
206         self._objects[uuid] = obj
207
208         return uuid
209
210     def invoke(self, uuid, operation, *args, **kwargs):
211         if operation == "isAppRunning":
212             return self._is_app_running(uuid)
213         if operation == "addStaticRoute":
214             return self._add_static_route(uuid, *args)
215
216         if uuid.startswith(SINGLETON):
217             obj = self._singleton(uuid)
218         else:
219             obj = self.get_object(uuid)
220         
221         method = getattr(obj, operation)
222
223         # arguments starting with 'uuid' identify ns-3 C++
224         # objects and must be replaced by the actual object
225         realargs = self.replace_args(args)
226         realkwargs = self.replace_kwargs(kwargs)
227
228         result = method(*realargs, **realkwargs)
229
230         if result is None or \
231                 isinstance(result, bool):
232             return result
233       
234         newuuid = self.make_uuid()
235         self._objects[newuuid] = result
236
237         return newuuid
238
239     def _set_attr(self, obj, name, ns3_value):
240         obj.SetAttribute(name, ns3_value)
241
242     def set(self, uuid, name, value):
243         obj = self.get_object(uuid)
244         type_name = obj.GetInstanceTypeId().GetName()
245         ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
246
247         # If the Simulation thread is not running,
248         # then there will be no thread-safety problems
249         # in changing the value of an attribute directly.
250         # However, if the simulation is running we need
251         # to set the value by scheduling an event, else
252         # we risk to corrupt the state of the
253         # simulation.
254         
255         event_executed = [False]
256
257         if self.is_running:
258             # schedule the event in the Simulator
259             self._schedule_event(self._condition, event_executed, 
260                     self._set_attr, obj, name, ns3_value)
261
262         if not event_executed[0]:
263             self._set_attr(obj, name, ns3_value)
264
265         return value
266
267     def _get_attr(self, obj, name, ns3_value):
268         obj.GetAttribute(name, ns3_value)
269
270     def get(self, uuid, name):
271         obj = self.get_object(uuid)
272         type_name = obj.GetInstanceTypeId().GetName()
273         ns3_value = self._create_attr_ns3_value(type_name, name)
274
275         event_executed = [False]
276
277         if self.is_running:
278             # schedule the event in the Simulator
279             self._schedule_event(self._condition, event_executed,
280                     self._get_attr, obj, name, ns3_value)
281
282         if not event_executed[0]:
283             self._get_attr(obj, name, ns3_value)
284
285         return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
286
287     def start(self):
288         # Launch the simulator thread and Start the
289         # simulator in that thread
290         self._condition = threading.Condition()
291         self._simulator_thread = threading.Thread(
292                 target = self._simulator_run,
293                 args = [self._condition])
294         self._simulator_thread.setDaemon(True)
295         self._simulator_thread.start()
296         self._started = True
297
298     def stop(self, time = None):
299         if time is None:
300             self.ns3.Simulator.Stop()
301         else:
302             self.ns3.Simulator.Stop(self.ns3.Time(time))
303
304     def shutdown(self):
305         while not self.ns3.Simulator.IsFinished():
306             #self.logger.debug("Waiting for simulation to finish")
307             time.sleep(0.5)
308         
309         if self._simulator_thread:
310             self._simulator_thread.join()
311        
312         self.ns3.Simulator.Destroy()
313         
314         # Remove all references to ns-3 objects
315         self._objects.clear()
316         
317         sys.stdout.flush()
318         sys.stderr.flush()
319
320     def _simulator_run(self, condition):
321         # Run simulation
322         self.ns3.Simulator.Run()
323         # Signal condition to indicate simulation ended and
324         # notify waiting threads
325         condition.acquire()
326         condition.notifyAll()
327         condition.release()
328
329     def _schedule_event(self, condition, event_executed, func, *args):
330         """ Schedules event on running simulation, and wait until
331             event is executed"""
332
333         def execute_event(contextId, condition, event_executed, func, *args):
334             try:
335                 func(*args)
336                 event_executed[0] = True
337             finally:
338                 # notify condition indicating event was executed
339                 condition.acquire()
340                 condition.notifyAll()
341                 condition.release()
342
343         # contextId is defined as general context
344         contextId = long(0xffffffff)
345
346         # delay 0 means that the event is expected to execute inmediately
347         delay = self.ns3.Seconds(0)
348     
349         # Mark event as not executed
350         event_executed[0] = False
351
352         condition.acquire()
353         try:
354             self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event, 
355                     condition, event_executed, func, *args)
356             if not self.ns3.Simulator.IsFinished():
357                 condition.wait()
358         finally:
359             condition.release()
360
361     def _create_attr_ns3_value(self, type_name, name):
362         TypeId = self.ns3.TypeId()
363         tid = TypeId.LookupByName(type_name)
364         info = TypeId.AttributeInformation()
365         if not tid.LookupAttributeByName(name, info):
366             msg = "TypeId %s has no attribute %s" % (type_name, name) 
367             self.logger.error(msg)
368
369         checker = info.checker
370         ns3_value = checker.Create() 
371         return ns3_value
372
373     def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
374         TypeId = self.ns3.TypeId()
375         tid = TypeId.LookupByName(type_name)
376         info = TypeId.AttributeInformation()
377         if not tid.LookupAttributeByName(name, info):
378             msg = "TypeId %s has no attribute %s" % (type_name, name) 
379             self.logger.error(msg)
380
381         checker = info.checker
382         value = ns3_value.SerializeToString(checker)
383
384         type_name = checker.GetValueTypeName()
385         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
386             return int(value)
387         if type_name == "ns3::DoubleValue":
388             return float(value)
389         if type_name == "ns3::BooleanValue":
390             return value == "true"
391
392         return value
393
394     def _attr_from_string_to_ns3_value(self, type_name, name, value):
395         TypeId = self.ns3.TypeId()
396         tid = TypeId.LookupByName(type_name)
397         info = TypeId.AttributeInformation()
398         if not tid.LookupAttributeByName(name, info):
399             msg = "TypeId %s has no attribute %s" % (type_name, name) 
400             self.logger.error(msg)
401
402         str_value = str(value)
403         if isinstance(value, bool):
404             str_value = str_value.lower()
405
406         checker = info.checker
407         ns3_value = checker.Create()
408         ns3_value.DeserializeFromString(str_value, checker)
409         return ns3_value
410
411     # singletons are identified as "ns3::ClassName"
412     def _singleton(self, ident):
413         if not ident.startswith(SINGLETON):
414             return None
415
416         clazzname = ident[ident.find("::")+2:]
417         if not hasattr(self.ns3, clazzname):
418             msg = "Type %s not supported" % (clazzname)
419             self.logger.error(msg)
420
421         return getattr(self.ns3, clazzname)
422
423     # replace uuids and singleton references for the real objects
424     def replace_args(self, args):
425         realargs = [self.get_object(arg) if \
426                 str(arg).startswith("uuid") else arg for arg in args]
427  
428         realargs = [self._singleton(arg) if \
429                 str(arg).startswith(SINGLETON) else arg for arg in realargs]
430
431         return realargs
432
433     # replace uuids and singleton references for the real objects
434     def replace_kwargs(self, kwargs):
435         realkwargs = dict([(k, self.get_object(v) \
436                 if str(v).startswith("uuid") else v) \
437                 for k,v in kwargs.iteritems()])
438  
439         realkwargs = dict([(k, self._singleton(v) \
440                 if str(v).startswith(SINGLETON) else v )\
441                 for k, v in realkwargs.iteritems()])
442
443         return realkwargs
444
445     def _is_app_running(self, uuid): 
446         now = self.ns3.Simulator.Now()
447         if now.IsZero():
448             return False
449
450         app = self.get_object(uuid)
451         stop_time_value = self.ns3.TimeValue()
452         app.GetAttribute("StopTime", stop_time_value)
453         stop_time = stop_time_value.Get()
454
455         start_time_value = self.ns3.TimeValue()
456         app.GetAttribute("StartTime", start_time_value)
457         start_time = start_time_value.Get()
458         
459         if now.Compare(start_time) >= 0 and now.Compare(stop_time) < 0:
460             return True
461
462         return False
463
464     def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
465         ipv4 = self.get_object(ipv4_uuid)
466
467         list_routing = ipv4.GetRoutingProtocol()
468         (static_routing, priority) = list_routing.GetRoutingProtocol(0)
469
470         ifindex = self._find_ifindex(ipv4, nexthop)
471         if ifindex == -1:
472             return False
473         
474         nexthop = self.ns3.Ipv4Address(nexthop)
475
476         if network in ["0.0.0.0", "0", None]:
477             # Default route: 0.0.0.0/0
478             static_routing.SetDefaultRoute(nexthop, ifindex)
479         else:
480             mask = self.ns3.Ipv4Mask("/%s" % prefix) 
481             network = self.ns3.Ipv4Address(network)
482
483             if prefix == 32:
484                 # Host route: x.y.z.w/32
485                 static_routing.AddHostRouteTo(network, nexthop, ifindex)
486             else:
487                 # Network route: x.y.z.w/n
488                 static_routing.AddNetworkRouteTo(network, mask, nexthop, 
489                         ifindex) 
490         return True
491
492     def _find_ifindex(self, ipv4, nexthop):
493         ifindex = -1
494
495         nexthop = self.ns3.Ipv4Address(nexthop)
496
497         # For all the interfaces registered with the ipv4 object, find
498         # the one that matches the network of the nexthop
499         nifaces = ipv4.GetNInterfaces()
500         for ifidx in xrange(nifaces):
501             iface = ipv4.GetInterface(ifidx)
502             naddress = iface.GetNAddresses()
503             for addridx in xrange(naddress):
504                 ifaddr = iface.GetAddress(addridx)
505                 ifmask = ifaddr.GetMask()
506                 
507                 ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
508
509                 if ifindex == ifidx:
510                     return ifindex
511         return ifindex
512