Added LICENSE
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
1 """
2     NEPI, a framework to manage network experiments
3     Copyright (C) 2013 INRIA
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 """
19
20 import logging
21 import os
22 import sys
23 import threading
24 import uuid
25
26 class NS3Wrapper(object):
27     def __init__(self, homedir = None):
28         super(NS3Wrapper, self).__init__()
29         self._ns3 = None
30         self._uuid = self.make_uuid()
31         self._homedir = homedir or os.path.join("/tmp", self._uuid)
32         self._simulation_thread = None
33         self._condition = None
34
35         self._started = False
36         self._stopped = False
37
38         # holds reference to all ns-3 objects in the simulation
39         self._resources = dict()
40
41         # create home dir (where all simulation related files will end up)
42         home = os.path.normpath(self.homedir)
43         if not os.path.exists(home):
44             os.makedirs(home, 0755)
45
46         # Logging
47         loglevel = os.environ.get("NS3LOGLEVEL", "debug")
48         self._logger = logging.getLogger("ns3wrapper.%s" % self.uuid)
49         self._logger.setLevel(getattr(logging, loglevel.upper()))
50         hdlr = logging.FileHandler(os.path.join(self.homedir, "ns3wrapper.log"))
51         formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
52         hdlr.setFormatter(formatter)
53         self._logger.addHandler(hdlr) 
54
55         # Load ns-3 shared libraries and import modules
56         self._load_ns3_module()
57         
58     @property
59     def ns3(self):
60         return self._ns3
61
62     @property
63     def homedir(self):
64         return self._homedir
65
66     @property
67     def uuid(self):
68         return self._uuid
69
70     @property
71     def logger(self):
72         return self._logger
73
74     def make_uuid(self):
75         return "uuid%s" % uuid.uuid4()
76
77     def singleton(self, clazzname):
78         uuid = "uuid%s"%clazzname
79
80         if not uuid in self._resources:
81             if not hasattr(self.ns3, clazzname):
82                 msg = "Type %s not supported" % (typeid) 
83                 self.logger.error(msg)
84
85             clazz = getattr(self.ns3, clazzname)
86             typeid = "ns3::%s" % clazzname
87             self._resources[uuid] = (clazz, typeid)
88
89         return uuid
90
91     def get_trace(self, trace, offset = None, nbytes = None ):
92         pass
93
94     def is_running(self):
95         return self._started and not self._stopped
96
97     def get_resource(self, uuid):
98         (resource, typeid) =  self._resources.get(uuid)
99         return resource
100     
101     def get_typeid(self, uuid):
102         (resource, typeid) =  self._resources.get(uuid)
103         return typeid
104
105     def create(self, clazzname, *args):
106         if not hasattr(self.ns3, clazzname):
107             msg = "Type %s not supported" % (clazzname) 
108             self.logger.error(msg)
109
110         clazz = getattr(self.ns3, clazzname)
111         #typeid = clazz.GetInstanceTypeId().GetName()
112         typeid = "ns3::%s" % clazzname
113
114         realargs = [self.get_resource(arg) if \
115                 str(arg).startswith("uuid") else arg for arg in args]
116       
117         resource = clazz(*realargs)
118         
119         uuid = self.make_uuid()
120         self._resources[uuid] = (resource, typeid)
121         return uuid
122
123     def set(self, uuid, name, value):
124         resource = self.get_resource(uuid)
125
126         if hasattr(resource, name):
127             setattr(resource, name, value)
128         else:
129             self._set_ns3_attr(uuid, name, value)
130
131     def get(self, name, uuid = None):
132         resource = self.get_resource(uuid)
133
134         value = None
135         if hasattr(resource, name):
136             value = getattr(resource, name)
137         else:
138             value = self._get_ns3_attr(uuid, name)
139
140         return value
141
142     def invoke(self, uuid, operation, *args):
143         resource = self.get_resource(uuid)
144         typeid = self.get_typeid(uuid)
145         method = getattr(resource, operation)
146
147         realargs = [self.get_resource(arg) if \
148                 str(arg).startswith("uuid") else arg for arg in args]
149
150         result = method(*realargs)
151
152         if not result:
153             return None
154         
155         uuid = self.make_uuid()
156         self._resources[uuid] = (result, typeid)
157
158         return uuid
159
160     def start(self):
161         self._condition = threading.Condition()
162         self._simulator_thread = threading.Thread(
163                 target = self._simulator_run,
164                 args = [self._condition])
165         self._simulator_thread.setDaemon(True)
166         self._simulator_thread.start()
167         self._started = True
168
169     def stop(self, time = None):
170         if not self.ns3:
171             return
172
173         if time is None:
174             self.ns3.Simulator.Stop()
175         else:
176             self.ns3.Simulator.Stop(self.ns3.Time(time))
177         self._stopped = True
178
179     def shutdown(self):
180         if self.ns3:
181             if not self.ns3.Simulator.IsFinished():
182                 self.stop()
183             
184             # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
185             if self._simulator_thread:
186                 self._simulator_thread.join()
187             
188             self.ns3.Simulator.Destroy()
189         
190         self._resources.clear()
191         
192         self._ns3 = None
193         sys.stdout.flush()
194         sys.stderr.flush()
195
196     def _simulator_run(self, condition):
197         # Run simulation
198         self.ns3.Simulator.Run()
199         # Signal condition on simulation end to notify waiting threads
200         condition.acquire()
201         condition.notifyAll()
202         condition.release()
203
204     def _schedule_event(self, condition, func, *args):
205         """ Schedules event on running simulation, and wait until
206             event is executed"""
207
208         def execute_event(contextId, condition, has_event_occurred, func, *args):
209             try:
210                 func(*args)
211             finally:
212                 # flag event occured
213                 has_event_occurred[0] = True
214                 # notify condition indicating attribute was set
215                 condition.acquire()
216                 condition.notifyAll()
217                 condition.release()
218
219         # contextId is defined as general context
220         contextId = long(0xffffffff)
221
222         # delay 0 means that the event is expected to execute inmediately
223         delay = self.ns3.Seconds(0)
224
225         # flag to indicate that the event occured
226         # because bool is an inmutable object in python, in order to create a
227         # bool flag, a list is used as wrapper
228         has_event_occurred = [False]
229         condition.acquire()
230         try:
231             if not self.ns3.Simulator.IsFinished():
232                 self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event,
233                      condition, has_event_occurred, func, *args)
234                 while not has_event_occurred[0] and not self.ns3.Simulator.IsFinished():
235                     condition.wait()
236         finally:
237             condition.release()
238
239     def _set_ns3_attr(self, uuid, name, value):
240         resource = self.get_resource(uuid)
241         ns3_value = self._to_ns3_value(uuid, name, value)
242
243         def set_attr(resource, name, ns3_value):
244             resource.SetAttribute(name, ns3_value)
245
246         if self._is_running:
247             # schedule the event in the Simulator
248             self._schedule_event(self._condition, set_attr, resource,
249                     name, ns3_value)
250         else:
251             set_attr(resource, name, ns3_value)
252
253     def _get_ns3_attr(self, uuid, name):
254         resource = self.get_resource(uuid)
255         ns3_value = self._create_ns3_value(uuid, name)
256
257         def get_attr(resource, name, ns3_value):
258             resource.GetAttribute(name, ns3_value)
259
260         if self._is_running:
261             # schedule the event in the Simulator
262             self._schedule_event(self._condition, get_attr, resource,
263                     name, ns3_value)
264         else:
265             get_attr(resource, name, ns3_value)
266
267         return self._from_ns3_value(uuid, name, ns3_value)
268
269     def _create_ns3_value(self, uuid, name):
270         typeid = get_typeid(uuid)
271         TypeId = self.ns3.TypeId()
272         tid = TypeId.LookupByName(typeid)
273         info = TypeId.AttributeInformation()
274         if not tid.LookupAttributeByName(name, info):
275             msg = "TypeId %s has no attribute %s" % (typeid, name) 
276             self.logger.error(msg)
277
278         checker = info.checker
279         ns3_value = checker.Create() 
280         return ns3_value
281
282     def _from_ns3_value(self, uuid, name, ns3_value):
283         typeid = get_typeid(uuid)
284         TypeId = self.ns3.TypeId()
285         tid = TypeId.LookupByName(typeid)
286         info = TypeId.AttributeInformation()
287         if not tid.LookupAttributeByName(name, info):
288             msg = "TypeId %s has no attribute %s" % (typeid, name) 
289             self.logger.error(msg)
290
291         checker = info.checker
292         value = ns3_value.SerializeToString(checker)
293
294         type_name = checker.GetValueTypeName()
295         if type_name in ["ns3::UintegerValue", "ns3::IntegerValue"]:
296             return int(value)
297         if type_name == "ns3::DoubleValue":
298             return float(value)
299         if type_name == "ns3::BooleanValue":
300             return value == "true"
301
302         return value
303
304     def _to_ns3_value(self, uuid, name, value):
305         typeid = get_typeid(uuid)
306         TypeId = self.ns3.TypeId()
307         typeid = TypeId.LookupByName(typeid)
308         info = TypeId.AttributeInformation()
309         if not tid.LookupAttributeByName(name, info):
310             msg = "TypeId %s has no attribute %s" % (typeid, name) 
311             self.logger.error(msg)
312
313         str_value = str(value)
314         if isinstance(value, bool):
315             str_value = str_value.lower()
316
317         checker = info.checker
318         ns3_value = checker.Create()
319         ns3_value.DeserializeFromString(str_value, checker)
320         return ns3_value
321
322     def _load_ns3_module(self):
323         if self.ns3:
324             return 
325
326         import ctypes
327         import imp
328         import re
329         import pkgutil
330
331         bindings = os.environ.get("NS3BINDINGS")
332         libdir = os.environ.get("NS3LIBRARIES")
333
334         # Load the ns-3 modules shared libraries
335         if libdir:
336             files = os.listdir(libdir)
337             regex = re.compile("(.*\.so)$")
338             libs = [m.group(1) for filename in files for m in [regex.search(filename)] if m]
339
340             libscp = list(libs)
341             while len(libs) > 0:
342                 for lib in libs:
343                     libfile = os.path.join(libdir, lib)
344                     try:
345                         ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
346                         libs.remove(lib)
347                     except:
348                         pass
349
350                 # if did not load any libraries in the last iteration break
351                 # to prevent infinit loop
352                 if len(libscp) == len(libs):
353                     raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
354                 libscp = list(libs)
355
356         # import the python bindings for the ns-3 modules
357         if bindings:
358             sys.path.append(bindings)
359
360         # create a module to add all ns3 classes
361         ns3mod = imp.new_module("ns3")
362         sys.modules["ns3"] = ns3mod
363
364         # retrieve all ns3 classes and add them to the ns3 module
365         import ns
366         for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
367             fullmodname = "ns.%s" % modname
368             module = __import__(fullmodname, globals(), locals(), ['*'])
369
370             # netanim.Config singleton overrides ns3::Config
371             if modname in ['netanim']:
372                 continue
373
374             for sattr in dir(module):
375                 if not sattr.startswith("_"):
376                     attr = getattr(module, sattr)
377                     setattr(ns3mod, sattr, attr)
378
379         self._ns3 = ns3mod
380