ns-3 simulator synchronizing start
[nepi.git] / src / nepi / resources / ns3 / ns3wrapper.py
index 00245bb..f6075b0 100644 (file)
@@ -25,12 +25,15 @@ import time
 import uuid
 
 SINGLETON = "singleton::"
+SIMULATOR_UUID = "singleton::Simulator"
+CONFIG_UUID = "singleton::Config"
+GLOBAL_VALUE_UUID = "singleton::GlobalValue"
+IPV4_GLOBAL_ROUTING_HELPER_UUID = "singleton::Ipv4GlobalRoutingHelper"
 
-def load_ns3_module():
+def load_ns3_libraries():
     import ctypes
     import re
 
-    bindings = os.environ.get("NS3BINDINGS")
     libdir = os.environ.get("NS3LIBRARIES")
 
     # Load the ns-3 modules shared libraries
@@ -58,9 +61,13 @@ def load_ns3_module():
             # to prevent infinit loop
             if initial_size == len(libs):
                 raise RuntimeError("Imposible to load shared libraries %s" % str(libs))
-            initial_size = list(libs)
+            initial_size = len(libs)
+
+def load_ns3_module():
+    load_ns3_libraries()
 
     # import the python bindings for the ns-3 modules
+    bindings = os.environ.get("NS3BINDINGS")
     if bindings:
         sys.path.append(bindings)
 
@@ -204,9 +211,13 @@ class NS3Wrapper(object):
 
         return uuid
 
-    def invoke(self, uuid, operation, *args):
+    def invoke(self, uuid, operation, *args, **kwargs):
+        if operation == "isRunning":
+            return self._is_running()
         if operation == "isAppRunning":
             return self._is_app_running(uuid)
+        if operation == "addStaticRoute":
+            return self._add_static_route(uuid, *args)
 
         if uuid.startswith(SINGLETON):
             obj = self._singleton(uuid)
@@ -218,12 +229,15 @@ class NS3Wrapper(object):
         # arguments starting with 'uuid' identify ns-3 C++
         # objects and must be replaced by the actual object
         realargs = self.replace_args(args)
+        realkwargs = self.replace_kwargs(kwargs)
 
-        result = method(*realargs)
+        result = method(*realargs, **realkwargs)
 
-        if not result:
+        # If the result is not an object, no need to 
+        # keep a reference. Directly return value.
+        if result is None or type(result) in [bool, float, long, str, int]:
             return result
-       
+      
         newuuid = self.make_uuid()
         self._objects[newuuid] = result
 
@@ -301,7 +315,7 @@ class NS3Wrapper(object):
         
         if self._simulator_thread:
             self._simulator_thread.join()
-        
+       
         self.ns3.Simulator.Destroy()
         
         # Remove all references to ns-3 objects
@@ -423,19 +437,93 @@ class NS3Wrapper(object):
 
         return realargs
 
+    # replace uuids and singleton references for the real objects
+    def replace_kwargs(self, kwargs):
+        realkwargs = dict([(k, self.get_object(v) \
+                if str(v).startswith("uuid") else v) \
+                for k,v in kwargs.iteritems()])
+        realkwargs = dict([(k, self._singleton(v) \
+                if str(v).startswith(SINGLETON) else v )\
+                for k, v in realkwargs.iteritems()])
+
+        return realkwargs
+
+    def _is_running(self): 
+        if self.ns3.Simulator.IsFinished():
+            return False
+
+        now = self.ns3.Simulator.Now()
+        if now.IsZero():
+            return False
+        
+        return True
+
     def _is_app_running(self, uuid): 
         now = self.ns3.Simulator.Now()
         if now.IsZero():
             return False
 
-        stop_value = self.get(uuid, "StopTime")
-        stop_time = self.ns3.Time(stop_value)
-        
-        start_value = self.get(uuid, "StartTime")
-        start_time = self.ns3.Time(start_value)
+        app = self.get_object(uuid)
+        stop_time_value = self.ns3.TimeValue()
+        app.GetAttribute("StopTime", stop_time_value)
+        stop_time = stop_time_value.Get()
+
+        start_time_value = self.ns3.TimeValue()
+        app.GetAttribute("StartTime", start_time_value)
+        start_time = start_time_value.Get()
         
-        if now.Compare(start_time) >= 0 and now.Compare(stop_time) <= 0:
+        if now.Compare(start_time) >= 0 and now.Compare(stop_time) < 0:
             return True
 
         return False
 
+    def _add_static_route(self, ipv4_uuid, network, prefix, nexthop):
+        ipv4 = self.get_object(ipv4_uuid)
+
+        list_routing = ipv4.GetRoutingProtocol()
+        (static_routing, priority) = list_routing.GetRoutingProtocol(0)
+
+        ifindex = self._find_ifindex(ipv4, nexthop)
+        if ifindex == -1:
+            return False
+        
+        nexthop = self.ns3.Ipv4Address(nexthop)
+
+        if network in ["0.0.0.0", "0", None]:
+            # Default route: 0.0.0.0/0
+            static_routing.SetDefaultRoute(nexthop, ifindex)
+        else:
+            mask = self.ns3.Ipv4Mask("/%s" % prefix) 
+            network = self.ns3.Ipv4Address(network)
+
+            if prefix == 32:
+                # Host route: x.y.z.w/32
+                static_routing.AddHostRouteTo(network, nexthop, ifindex)
+            else:
+                # Network route: x.y.z.w/n
+                static_routing.AddNetworkRouteTo(network, mask, nexthop, 
+                        ifindex) 
+        return True
+
+    def _find_ifindex(self, ipv4, nexthop):
+        ifindex = -1
+
+        nexthop = self.ns3.Ipv4Address(nexthop)
+
+        # For all the interfaces registered with the ipv4 object, find
+        # the one that matches the network of the nexthop
+        nifaces = ipv4.GetNInterfaces()
+        for ifidx in xrange(nifaces):
+            iface = ipv4.GetInterface(ifidx)
+            naddress = iface.GetNAddresses()
+            for addridx in xrange(naddress):
+                ifaddr = iface.GetAddress(addridx)
+                ifmask = ifaddr.GetMask()
+                
+                ifindex = ipv4.GetInterfaceForPrefix(nexthop, ifmask)
+
+                if ifindex == ifidx:
+                    return ifindex
+        return ifindex
+