Added factroy method to the ns-3 wrapper
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Fri, 31 Jan 2014 19:07:49 +0000 (20:07 +0100)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Fri, 31 Jan 2014 19:07:49 +0000 (20:07 +0100)
setup.py
src/nepi/resources/linux/ns3/__init__.py [new file with mode: 0644]
src/nepi/resources/linux/ns3/ns3simulator.py
src/nepi/resources/ns3/ns3base.py
src/nepi/resources/ns3/ns3wrapper.py
src/nepi/resources/ns3/resource_manager_generator.py
test/resources/ns3/ns3wrapper.py

index aa9bfc6..27b0f9b 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -20,6 +20,7 @@ setup(
             "nepi.resources.all",
             "nepi.resources.linux",
             "nepi.resources.linux.ccn",
+            "nepi.resources.linux.ns3",
             "nepi.resources.netns",
             "nepi.resources.ns3",
             "nepi.resources.omf",
diff --git a/src/nepi/resources/linux/ns3/__init__.py b/src/nepi/resources/linux/ns3/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index 4c605f7..e4fda7e 100644 (file)
@@ -39,7 +39,7 @@ class LinuxNS3Simulator(LinuxApplication, NS3Simulator):
             "Sets the CCND_MAX_RTE_MICROSEC environmental variable. ",
             flags = Flags.ExecReadOnly)
 
-        cls._register_attribute(debug)
+        cls._register_attribute(max_rte)
 
     def __init__(self, ec, guid):
         super(LinuxApplication, self).__init__(ec, guid)
index c829a4f..3ccd138 100644 (file)
@@ -20,7 +20,7 @@
 from nepi.execution.resource import ResourceManager, clsinit_copy, \
         ResourceState, reschedule_delay
 
-from nepi.resources.ns3.simulator import NS3Simulator
+from nepi.resources.ns3.ns3simulator import NS3Simulator
 
 @clsinit_copy
 class NS3Base(ResourceManager):
index f6c063b..0c58df0 100644 (file)
@@ -24,17 +24,11 @@ import threading
 import time
 import uuid
 
-# TODO: 
-#       1. ns-3 classes should be identified as ns3::clazzname?
-# 
-
 SINGLETON = "singleton::"
 
 def load_ns3_module():
     import ctypes
-    import imp
     import re
-    import pkgutil
 
     bindings = os.environ.get("NS3BINDINGS")
     libdir = os.environ.get("NS3LIBRARIES")
@@ -65,13 +59,14 @@ def load_ns3_module():
     if bindings:
         sys.path.append(bindings)
 
+    import pkgutil
+    import imp
+    import ns
+
     # create a module to add all ns3 classes
     ns3mod = imp.new_module("ns3")
     sys.modules["ns3"] = ns3mod
 
-    # retrieve all ns3 classes and add them to the ns3 module
-    import ns
-
     for importer, modname, ispkg in pkgutil.iter_modules(ns.__path__):
         fullmodname = "ns.%s" % modname
         module = __import__(fullmodname, globals(), locals(), ['*'])
@@ -97,17 +92,13 @@ class NS3Wrapper(object):
         self._simulation_thread = None
         self._condition = None
 
+        # XXX: Started should be global. There is no support for more than
+        # one simulator per process
         self._started = False
 
         # holds reference to all C++ objects and variables in the simulation
         self._objects = dict()
 
-        # holds the class identifiers of uuid to be able to retrieve
-        # the corresponding ns3 TypeId to set/get attributes.
-        # This is necessary because the method GetInstanceTypeId is not
-        # exposed through the Python bindings
-        self._tids = dict()
-
         # create home dir (where all simulation related files will end up)
         self._homedir = homedir or os.path.join("/", "tmp", "ns3_wrapper" )
         
@@ -124,18 +115,52 @@ class NS3Wrapper(object):
         formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
         hdlr.setFormatter(formatter)
         
-        self._logger.addHandler(hdlr) 
-
-        # Python module to refernce all ns-3 classes and types
+        self._logger.addHandler(hdlr)
+
+        ## NOTE that the reason to create a handler to the ns3 module,
+        # that is re-loaded each time a ns-3 wrapper is instantiated,
+        # is that else each unit test for the ns3wrapper class would need
+        # a separate file. Several ns3wrappers would be created in the 
+        # same unit test (single process), leading to inchorences in the 
+        # state of ns-3 global objects
+        #
+        # Handler to ns3 classes
         self._ns3 = None
-        
+
+        # Collection of allowed ns3 classes
+        self._allowed_types = None
+
     @property
     def ns3(self):
         if not self._ns3:
+            # load ns-3 libraries and bindings
             self._ns3 = load_ns3_module()
 
         return self._ns3
 
+    @property
+    def allowed_types(self):
+        if not self._allowed_types:
+            self._allowed_types = set()
+            type_id = self.ns3.TypeId()
+            
+            tid_count = type_id.GetRegisteredN()
+            base = type_id.LookupByName("ns3::Object")
+
+            # Create a .py file using the ns-3 RM template for each ns-3 TypeId
+            for i in xrange(tid_count):
+                tid = type_id.GetRegistered(i)
+                
+                if tid.MustHideFromDocumentation() or \
+                        not tid.HasConstructor() or \
+                        not tid.IsChildOf(base): 
+                    continue
+
+                type_name = tid.GetName()
+                self._allowed_types.add(type_name)
+        
+        return self._allowed_types
+
     @property
     def homedir(self):
         return self._homedir
@@ -146,7 +171,7 @@ class NS3Wrapper(object):
 
     @property
     def is_running(self):
-        return self._started and self._ns3 and not self.ns3.Simulator.IsFinished()
+        return self._started and self.ns3.Simulator.IsFinished()
 
     def make_uuid(self):
         return "uuid%s" % uuid.uuid4()
@@ -154,8 +179,24 @@ class NS3Wrapper(object):
     def get_object(self, uuid):
         return self._objects.get(uuid)
 
-    def get_typeid(self, uuid):
-        return self._tids.get(uuid)
+    def factory(self, type_name, *kwargs):
+        if type_name not in allowed_types:
+            msg = "Type %s not supported" % (type_name) 
+            self.logger.error(msg)
+        factory = self.ns3.ObjectFactory()
+        factory.SetTypeId(type_name)
+
+        for name, value in kwargs.iteritems():
+            ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
+            factory.Set(name, ns3_value)
+
+        obj = factory.Create()
+
+        uuid = self.make_uuid()
+        self._objects[uuid] = obj
+
+        return uuid
 
     def create(self, clazzname, *args):
         if not hasattr(self.ns3, clazzname):
@@ -173,10 +214,6 @@ class NS3Wrapper(object):
         uuid = self.make_uuid()
         self._objects[uuid] = obj
 
-        #typeid = clazz.GetInstanceTypeId().GetName()
-        typeid = "ns3::%s" % clazzname
-        self._tids[uuid] = typeid
-
         return uuid
 
     def invoke(self, uuid, operation, *args):
@@ -201,12 +238,13 @@ class NS3Wrapper(object):
 
         return newuuid
 
+    def _set_attr(self, obj, name, ns3_value):
+        obj.SetAttribute(name, ns3_value)
+
     def set(self, uuid, name, value):
         obj = self.get_object(uuid)
-        ns3_value = self._to_ns3_value(uuid, name, value)
-
-        def set_attr(obj, name, ns3_value):
-            obj.SetAttribute(name, ns3_value)
+        type_name = obj.GetInstanceTypeId().GetName()
+        ns3_value = self._attr_from_string_to_ns3_value(type_name, name, value)
 
         # If the Simulation thread is not running,
         # then there will be no thread-safety problems
@@ -217,28 +255,29 @@ class NS3Wrapper(object):
         # simulation.
         if self.is_running:
             # schedule the event in the Simulator
-            self._schedule_event(self._condition, set_attr, obj,
-                    name, ns3_value)
+            self._schedule_event(self._condition, self._set_attr, 
+                    obj, name, ns3_value)
         else:
-            set_attr(obj, name, ns3_value)
+            self._set_attr(obj, name, ns3_value)
 
         return value
 
+    def _get_attr(self, obj, name, ns3_value):
+        obj.GetAttribute(name, ns3_value)
+
     def get(self, uuid, name):
         obj = self.get_object(uuid)
-        ns3_value = self._create_ns3_value(uuid, name)
-
-        def get_attr(obj, name, ns3_value):
-            obj.GetAttribute(name, ns3_value)
+        type_name = obj.GetInstanceTypeId().GetName()
+        ns3_value = self._create_attr_ns3_value(type_name, name)
 
         if self.is_running:
             # schedule the event in the Simulator
-            self._schedule_event(self._condition, get_attr, obj,
+            self._schedule_event(self._condition, self._get_attr, obj,
                     name, ns3_value)
         else:
             get_attr(obj, name, ns3_value)
 
-        return self._from_ns3_value(uuid, name, ns3_value)
+        return self._attr_from_ns3_value_to_string(type_name, name, ns3_value)
 
     def start(self):
         # Launch the simulator thread and Start the
@@ -252,30 +291,25 @@ class NS3Wrapper(object):
         self._started = True
 
     def stop(self, time = None):
-        if not self.ns3:
-            return
-
         if time is None:
             self.ns3.Simulator.Stop()
         else:
             self.ns3.Simulator.Stop(self.ns3.Time(time))
 
     def shutdown(self):
-        if self.ns3:
-            while not self.ns3.Simulator.IsFinished():
-                #self.logger.debug("Waiting for simulation to finish")
-                time.sleep(0.5)
-            
-            # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
-            if self._simulator_thread:
-                self._simulator_thread.join()
-            
-            self.ns3.Simulator.Destroy()
+        while not self.ns3.Simulator.IsFinished():
+            #self.logger.debug("Waiting for simulation to finish")
+            time.sleep(0.5)
+        
+        # TODO!!!! SHOULD WAIT UNTIL THE THREAD FINISHES
+        if self._simulator_thread:
+            self._simulator_thread.join()
+        
+        self.ns3.Simulator.Destroy()
         
         # Remove all references to ns-3 objects
         self._objects.clear()
         
-        self._ns3 = None
         sys.stdout.flush()
         sys.stderr.flush()
 
@@ -314,35 +348,36 @@ class NS3Wrapper(object):
         # bool flag, a list is used as wrapper
         has_event_occurred = [False]
         condition.acquire()
+
+        simu = self.ns3.Simulator
+
         try:
-            if not self.ns3.Simulator.IsFinished():
-                self.ns3.Simulator.ScheduleWithContext(contextId, delay, execute_event,
+            if not simu.IsFinished():
+                simu.ScheduleWithContext(contextId, delay, execute_event,
                      condition, has_event_occurred, func, *args)
-                while not has_event_occurred[0] and not self.ns3.Simulator.IsFinished():
+                while not has_event_occurred[0] and not simu.IsFinished():
                     condition.wait()
         finally:
             condition.release()
 
-    def _create_ns3_value(self, uuid, name):
-        typeid = self.get_typeid(uuid)
+    def _create_attr_ns3_value(self, type_name, name):
         TypeId = self.ns3.TypeId()
-        tid = TypeId.LookupByName(typeid)
+        tid = TypeId.LookupByName(type_name)
         info = TypeId.AttributeInformation()
         if not tid.LookupAttributeByName(name, info):
-            msg = "TypeId %s has no attribute %s" % (typeid, name) 
+            msg = "TypeId %s has no attribute %s" % (type_name, name) 
             self.logger.error(msg)
 
         checker = info.checker
         ns3_value = checker.Create() 
         return ns3_value
 
-    def _from_ns3_value(self, uuid, name, ns3_value):
-        typeid = self.get_typeid(uuid)
+    def _attr_from_ns3_value_to_string(self, type_name, name, ns3_value):
         TypeId = self.ns3.TypeId()
-        tid = TypeId.LookupByName(typeid)
+        tid = TypeId.LookupByName(type_name)
         info = TypeId.AttributeInformation()
         if not tid.LookupAttributeByName(name, info):
-            msg = "TypeId %s has no attribute %s" % (typeid, name) 
+            msg = "TypeId %s has no attribute %s" % (type_name, name) 
             self.logger.error(msg)
 
         checker = info.checker
@@ -358,13 +393,12 @@ class NS3Wrapper(object):
 
         return value
 
-    def _to_ns3_value(self, uuid, name, value):
-        typeid = self.get_typeid(uuid)
+    def _attr_from_string_to_ns3_value(self, type_name, name, value):
         TypeId = self.ns3.TypeId()
-        tid = TypeId.LookupByName(typeid)
+        tid = TypeId.LookupByName(type_name)
         info = TypeId.AttributeInformation()
         if not tid.LookupAttributeByName(name, info):
-            msg = "TypeId %s has no attribute %s" % (typeid, name) 
+            msg = "TypeId %s has no attribute %s" % (type_name, name) 
             self.logger.error(msg)
 
         str_value = str(value)
index c01180e..f5d4acc 100644 (file)
@@ -17,6 +17,7 @@
 #
 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
 
+# Force the load of ns3 libraries
 from nepi.resources.ns3.ns3wrapper import load_ns3_module
 
 import os
index 70bf970..3843864 100755 (executable)
@@ -360,7 +360,7 @@ class NS3WrapperTest(unittest.TestCase):
         wrapper.invoke(ipv41, "SetUp", ifindex1)
 
         # Enable collection of Ascii format to a specific file
-        filepath1 = "trace-p2p-1.tr"
+        filepath1 = "/tmp/trace-p2p-1.tr"
         stream1 = wrapper.invoke(asciiHelper, "CreateFileStream", filepath1)
         wrapper.invoke(p2pHelper, "EnableAscii", stream1, p1)
        
@@ -380,7 +380,7 @@ class NS3WrapperTest(unittest.TestCase):
         wrapper.invoke(ipv42, "SetUp", ifindex2)
 
         # Enable collection of Ascii format to a specific file
-        filepath2 = "trace-p2p-2.tr"
+        filepath2 = "/tmp/trace-p2p-2.tr"
         stream2 = wrapper.invoke(asciiHelper, "CreateFileStream", filepath2)
         wrapper.invoke(p2pHelper, "EnableAscii", stream2, p2)
 
@@ -414,6 +414,9 @@ class NS3WrapperTest(unittest.TestCase):
         # wait until simulation is over
         wrapper.shutdown()
 
+        p = subprocess.Popen("rm /tmp/trace-p2p-*",  shell = True)
+        p.communicate()
+
 if __name__ == '__main__':
     unittest.main()