NS3 Wrapper test
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22 import sys
23 import threading
24 import uuid
25
26 class NS3Wrapper(object):
27     def __init__(self, homedir = None):
28         super(NS3Wrapper, self).__init__()
29         # Thread used to run the simulation
30         self._simulation_thread = None
31         self._condition = None
32
33         self._started = False
34         self._stopped = False
35
36         # holds reference to all C++ objects and variables in the simulation
37         self._objects = dict()
38
39         # holds the class identifiers of uuid to be able to retrieve
40         # the corresponding ns3 TypeId to set/get attributes.
41         # This is necessary because the method GetInstanceTypeId is not
42         # exposed through the Python bindings
43         self._tids = dict()
44
45         # Generate unique identifier for the simulation wrapper 
46         self._uuid = self.make_uuid()
47
48         # create home dir (where all simulation related files will end up)
49         self._homedir = homedir or os.path.join("/tmp", self.uuid)
50         
51         home = os.path.normpath(self.homedir)
52         if not os.path.exists(home):
53             os.makedirs(home, 0755)
54
55         # Logging
56         loglevel = os.environ.get("NS3LOGLEVEL", "debug")
57         self._logger = logging.getLogger("ns3wrapper.%s" % self.uuid)
58         self._logger.setLevel(getattr(logging, loglevel.upper()))
59         
60         hdlr = logging.FileHandler(os.path.join(self.homedir, "ns3wrapper.log"))
61         formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
62         hdlr.setFormatter(formatter)
63         
64         self._logger.addHandler(hdlr) 
65
66         # Python module to refernce all ns-3 classes and types
67         self._ns3 = None
68
69         # Load ns-3 shared libraries and import modules
70         self._load_ns3_module()
71         
72         # Add module as anoter object, so we can reference it later
73         self._objects[self.uuid] = self.ns3
74         
75     @property
76     def ns3(self):
77         return self._ns3
78
79     @property
80     def homedir(self):
81         return self._homedir
82
83     @property
84     def uuid(self):
85         return self._uuid
86
87     @property
88     def logger(self):
89         return self._logger
90
91     def make_uuid(self):
92         return "uuid%s" % uuid.uuid4()
93
94     def is_running(self):
95         return self._started and not self._stopped
96
97     def get_object(self, uuid):
98         return self._objects.get(uuid)
99
100     def get_typeid(self, uuid):
101         return self._tids.get(uuid)
102
103     def singleton(self, clazzname):
104         uuid = "uuid%s"%clazzname
105
106         if not uuid in self._objects:
107             if not hasattr(self.ns3, clazzname):
108                 msg = "Type %s not supported" % (typeid)
109                 self.logger.error(msg)
110
111             clazz = getattr(self.ns3, clazzname)
112             self._objects[uuid] = clazz
113
114             typeid = "ns3::%s" % clazzname
115             self._tids[uuid] = typeid
116
117         return uuid
118
119     def create(self, clazzname, *args):
120         if not hasattr(self.ns3, clazzname):
121             msg = "Type %s not supported" % (clazzname) 
122             self.logger.error(msg)
123
124         realargs = [self.get_object(arg) if \
125                 str(arg).startswith("uuid") else arg for arg in args]
126       
127         clazz = getattr(self.ns3, clazzname)
128         obj = clazz(*realargs)
129         
130         uuid = self.make_uuid()
131         self._objects[uuid] = obj
132
133         #typeid = clazz.GetInstanceTypeId().GetName()
134         typeid = "ns3::%s" % clazzname
135         self._tids[uuid] = typeid
136
137         return uuid
138
139     def invoke(self, uuid, operation, *args):
140         obj = self.get_object(uuid)
141         
142         method = getattr(obj, operation)
143
144         # arguments starting with 'uuid' identifie stored
145         # objects and must be replaced by the actual object
146         realargs = [self.get_object(arg) if \
147                 str(arg).startswith("uuid") else arg for arg in args]
148
149         result = method(*realargs)
150
151         if not result:
152             return None
153         
154         newuuid = self.make_uuid()
155         self._objects[newuuid] = result
156
157         return newuuid
158
159     def set(self, uuid, name, value):
160         obj = self.get_object(uuid)
161         ns3_value = self._to_ns3_value(uuid, name, value)
162
163         def set_attr(obj, name, ns3_value):
164             obj.SetAttribute(name, ns3_value)
165
166         # If the Simulation thread is not running,
167         # then there will be no thread-safety problems
168         # in changing the value of an attribute directly.
169         # However, if the simulation is running we need
170         # to set the value by scheduling an event, else
171         # we risk to corrupt the state of the
172         # simulation.
173         if self._is_running:
174             # schedule the event in the Simulator
175             self._schedule_event(self._condition, set_attr, obj,
176                     name, ns3_value)
177         else:
178             set_attr(obj, name, ns3_value)
179
180     def get(self, uuid, name):
181         obj = self.get_object(uuid)
182         ns3_value = self._create_ns3_value(uuid, name)
183
184         def get_attr(obj, name, ns3_value):
185             obj.GetAttribute(name, ns3_value)
186
187         if self._is_running:
188             # schedule the event in the Simulator
189             self._schedule_event(self._condition, get_attr, obj,
190                     name, ns3_value)
191         else:
192             get_attr(obj, name, ns3_value)
193
194         return self._from_ns3_value(uuid, name, ns3_value)
195
196     def start(self):
197         # Launch the simulator thread and Start the
198         # simulator in that thread
199         self._condition = threading.Condition()
200         self._simulator_thread = threading.Thread(
201                 target = self._simulator_run,
202                 args = [self._condition])
203         self._simulator_thread.setDaemon(True)
204         self._simulator_thread.start()
205         self._started = True
206
207     def stop(self, time = None):
208         if not self.ns3:
209             return
210
211         if time is None:
212             self.ns3.Simulator.Stop()
213         else:
214             self.ns3.Simulator.Stop(self.ns3.Time(time))
215         self._stopped = True
216
217     def shutdown(self):
218         if self.ns3:
219             if not self.ns3.Simulator.IsFinished():
220                 self.stop()
221             
222             # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
223             if self._simulator_thread:
224                 self._simulator_thread.join()
225             
226             self.ns3.Simulator.Destroy()
227         
228         # Remove all references to ns-3 objects
229         self._objects.clear()
230         
231         self._ns3 = None
232         sys.stdout.flush()
233         sys.stderr.flush()
234
235     def _simulator_run(self, condition):
236         # Run simulation
237         self.ns3.Simulator.Run()
238         # Signal condition to indicate simulation ended and
239         # notify waiting threads
240         condition.acquire()
241         condition.notifyAll()
242         condition.release()
243
244     def _schedule_event(self, condition, func, *args):
245         """ Schedules event on running simulation, and wait until
246             event is executed"""
247
248         def execute_event(contextId, condition, has_event_occurred, func, *args):
249             try:
250                 func(*args)
251             finally:
252                 # flag event occured
253                 has_event_occurred[0] = True
254                 # notify condition indicating attribute was set
255                 condition.acquire()
256                 condition.notifyAll()
257                 condition.release()
258
259         # contextId is defined as general context
260         contextId = long(0xffffffff)
261
262         # delay 0 means that the event is expected to execute inmediately
263         delay = self.ns3.Seconds(0)
264
265         # flag to indicate that the event occured
266         # because bool is an inmutable object in python, in order to create a
267         # bool flag, a list is used as wrapper
268         has_event_occurred = [False]
269         condition.acquire()
270         try:
271             if not self.ns3.Simulator.IsFinished():
272                 self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event,
273                      condition, has_event_occurred, func, *args)
274                 while not has_event_occurred[0] and not self.ns3.Simulator.IsFinished():
275                     condition.wait()
276         finally:
277             condition.release()
278
279     def _create_ns3_value(self, uuid, name):
280         typeid = self.get_typeid(uuid)
281         TypeId = self.ns3.TypeId()
282         tid = TypeId.LookupByName(typeid)
283         info = TypeId.AttributeInformation()
284         if not tid.LookupAttributeByName(name, info):
285             msg = "TypeId %s has no attribute %s" % (typeid, name) 
286             self.logger.error(msg)
287
288         checker = info.checker
289         ns3_value = checker.Create() 
290         return ns3_value
291
292     def _from_ns3_value(self, uuid, name, ns3_value):
293         typeid = self.get_typeid(uuid)
294         TypeId = self.ns3.TypeId()
295         tid = TypeId.LookupByName(typeid)
296         info = TypeId.AttributeInformation()
297         if not tid.LookupAttributeByName(name, info):
298             msg = "TypeId %s has no attribute %s" % (typeid, name) 
299             self.logger.error(msg)
300
301         checker = info.checker
302         value = ns3_value.SerializeToString(checker)
303
304         type_name = checker.GetValueTypeName()
305         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
306             return int(value)
307         if type_name == "ns3::DoubleValue":
308             return float(value)
309         if type_name == "ns3::BooleanValue":
310             return value == "true"
311
312         return value
313
314     def _to_ns3_value(self, uuid, name, value):
315         typeid = self.get_typeid(uuid)
316         TypeId = self.ns3.TypeId()
317         tid = TypeId.LookupByName(typeid)
318         info = TypeId.AttributeInformation()
319         if not tid.LookupAttributeByName(name, info):
320             msg = "TypeId %s has no attribute %s" % (typeid, name) 
321             self.logger.error(msg)
322
323         str_value = str(value)
324         if isinstance(value, bool):
325             str_value = str_value.lower()
326
327         checker = info.checker
328         ns3_value = checker.Create()
329         ns3_value.DeserializeFromString(str_value, checker)
330         return ns3_value
331
332     def _load_ns3_module(self):
333         if self.ns3:
334             return 
335
336         import ctypes
337         import imp
338         import re
339         import pkgutil
340
341         bindings = os.environ.get("NS3BINDINGS")
342         libdir = os.environ.get("NS3LIBRARIES")
343
344         # Load the ns-3 modules shared libraries
345         if libdir:
346             files = os.listdir(libdir)
347             regex = re.compile("(.*\.so)$")
348             libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
349
350             libscp = list(libs)
351             while len(libs) > 0:
352                 for lib in libs:
353                     libfile = os.path.join(libdir, lib)
354                     try:
355                         ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
356                         libs.remove(lib)
357                     except:
358                         pass
359
360                 # if did not load any libraries in the last iteration break
361                 # to prevent infinit loop
362                 if len(libscp) == len(libs):
363                     raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
364                 libscp = list(libs)
365
366         # import the python bindings for the ns-3 modules
367         if bindings:
368             sys.path.append(bindings)
369
370         # create a module to add all ns3 classes
371         ns3mod = imp.new_module("ns3")
372         sys.modules["ns3"] = ns3mod
373
374         # retrieve all ns3 classes and add them to the ns3 module
375         import ns
376
377         for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
378             fullmodname = "ns.%s" % modname
379             module = __import__(fullmodname, globals(), locals(), ['*'])
380
381             for sattr in dir(module):
382                 if sattr.startswith("_"):
383                     continue
384
385                 attr = getattr(module, sattr)
386
387                 # netanim.Config and lte.Config singleton overrides ns3::Config
388                 if sattr == "Config" and modname in ['netanim', 'lte']:
389                     sattr = "%s.%s" % (modname, sattr)
390
391                 setattr(ns3mod, sattr, attr)
392
393         self._ns3 = ns3mod
394