Added TCP-handshake for TunChannel and tun_connect.py
[nepi.git] / src / nepi / testbeds / netns / execute.py
index d1904b0..020b97f 100644 (file)
@@ -1,16 +1,68 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from constants import TESTBED_ID
+from constants import TESTBED_ID, TESTBED_VERSION
 from nepi.core import testbed_impl
+from nepi.util.constants import TIME_NOW
 import os
-
-class TestbedInstance(testbed_impl.TestbedInstance):
-    def __init__(self, testbed_version):
-        super(TestbedInstance, self).__init__(TESTBED_ID, testbed_version)
+import fcntl
+import threading
+
+class TestbedController(testbed_impl.TestbedController):
+    from nepi.util.tunchannel_impl import TunChannel
+    
+    LOCAL_FACTORIES = {
+        'TunChannel' : TunChannel,
+    }
+    
+    LOCAL_TYPES = tuple(LOCAL_FACTORIES.values())
+
+    class HostLock(object):
+        # This class is used as a lock to prevent concurrency issues with more
+        # than one instance of netns running in the same machine. Both in 
+        # different processes or different threads.
+        taken = False
+        processcond = threading.Condition()
+        
+        def __init__(self, lockfile):
+            processcond = self.__class__.processcond
+            
+            processcond.acquire()
+            try:
+                # It's not reentrant
+                while self.__class__.taken:
+                    processcond.wait()
+                self.__class__.taken = True
+            finally:
+                processcond.release()
+            
+            self.lockfile = lockfile
+            fcntl.flock(self.lockfile, fcntl.LOCK_EX)
+        
+        def __del__(self):
+            processcond = self.__class__.processcond
+            
+            processcond.acquire()
+            try:
+                assert self.__class__.taken, "HostLock unlocked without being locked!"
+
+                fcntl.flock(self.lockfile, fcntl.LOCK_UN)
+                
+                # It's not reentrant
+                self.__class__.taken = False
+                processcond.notify()
+            finally:
+                processcond.release()
+    
+    def __init__(self):
+        super(TestbedController, self).__init__(TESTBED_ID, TESTBED_VERSION)
         self._netns = None
         self._home_directory = None
         self._traces = dict()
+        self._netns_lock = open("/tmp/nepi-netns-lock","a")
+    
+    def _lock(self):
+        return self.HostLock(self._netns_lock)
 
     @property
     def home_directory(self):
@@ -23,60 +75,75 @@ class TestbedInstance(testbed_impl.TestbedInstance):
     def do_setup(self):
         self._home_directory = self._attributes.\
             get_attribute_value("homeDirectory")
-        self._netns = self._load_netns_module()
+        # create home...
+        home = os.path.normpath(self.home_directory)
+        if not os.path.exists(home):
+            os.makedirs(home, 0755)
 
-    def set(self, time, guid, name, value):
-        super(TestbedInstance, self).set(time, guid, name, value)
-        
+        self._netns = self._load_netns_module()
+        super(TestbedController, self).do_setup()
+    
+    def do_create(self):
+        lock = self._lock()
+        super(TestbedController, self).do_create()    
+
+    def set(self, guid, name, value, time = TIME_NOW):
+        super(TestbedController, self).set(guid, name, value, time)
         # TODO: take on account schedule time for the task 
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if factory_id not in self.LOCAL_FACTORIES and \
+                factory.box_attributes.is_attribute_metadata(name):
+            return
         element = self._elements.get(guid)
         if element:
             setattr(element, name, value)
 
-    def get(self, time, guid, name):
+    def get(self, guid, name, time = TIME_NOW):
+        value = super(TestbedController, self).get(guid, name, time)
         # TODO: take on account schedule time for the task
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if factory_id not in self.LOCAL_FACTORIES and \
+                factory.box_attributes.is_attribute_metadata(name):
+            return value
         element = self._elements.get(guid)
-        if element:
-            try:
-                if hasattr(element, name):
-                    # Runtime attribute
-                    return getattr(element, name)
-                else:
-                    # Try design-time attributes
-                    return self.box_get(time, guid, name)
-            except KeyError, AttributeError:
-                return None
-
-    def get_route(self, guid, index, attribute):
-        # TODO: fetch real data from netns
         try:
-            return self.box_get_route(guid, int(index), attribute)
-        except KeyError, AttributeError:
-            return None
-
-    def get_address(self, guid, index, attribute='Address'):
-        # TODO: fetch real data from netns
-        try:
-            return self.box_get_address(guid, int(index), attribute)
-        except KeyError, AttributeError:
-            return None
-
+            return getattr(element, name)
+        except (KeyError, AttributeError):
+            return value
 
     def action(self, time, guid, action):
         raise NotImplementedError
 
     def shutdown(self):
-        for trace in self._traces.values():
-            trace.close()
-        for element in self._elements.values():
-            element.destroy()
+        for guid, traces in self._traces.iteritems():
+            for trace_id, (trace, filename) in traces.iteritems():
+                if hasattr(trace, "close"):
+                    trace.close()
+        for guid, element in self._elements.iteritems():
+            if isinstance(element, self.TunChannel):
+                element.cleanup()
+            else:
+                factory_id = self._create[guid]
+                if factory_id == "Node":
+                    element.destroy()
+        self._elements.clear()
+
+    def trace_filepath(self, guid, trace_id, filename = None):
+        if not filename:
+            (trace, filename) = self._traces[guid][trace_id]
+        return os.path.join(self.home_directory, filename)
 
     def trace_filename(self, guid, trace_id):
-        # TODO: Need to be defined inside a home!!!! with and experiment id_code
-        return os.path.join(self.home_directory, "%d_%s" % (guid, trace_id))
+        (trace, filename) = self._traces[guid][trace_id]
+        return filename
+
 
-    def follow_trace(self, trace_id, trace):
-        self._traces[trace_id] = trace
+    def follow_trace(self, guid, trace_id, trace, filename):
+        if not guid in self._traces:
+            self._traces[guid] = dict()
+        self._traces[guid][trace_id] = (trace, filename)
 
     def _load_netns_module(self):
         # TODO: Do something with the configuration!!!