test support added
[nepi.git] / src / nepi / testbeds / netns / execute.py
index 11c622f..ba1f28c 100644 (file)
@@ -3,10 +3,12 @@
 
 from constants import TESTBED_ID
 from nepi.core import execute
+from nepi.core import execute_impl
 from nepi.core.attributes import Attribute
 from nepi.core.metadata import Metadata
 from nepi.util import validation
 from nepi.util.constants import AF_INET, AF_INET6
+import os
 
 class TestbedConfiguration(execute.TestbedConfiguration):
     def __init__(self):
@@ -14,274 +16,66 @@ class TestbedConfiguration(execute.TestbedConfiguration):
         self.add_attribute("EnableDebug", "Enable netns debug output", 
                 Attribute.BOOL, False, None, None, False, validation.is_bool)
 
-class TestbedInstance(execute.TestbedInstance):
+class TestbedInstance(execute_impl.TestbedInstance):
     def __init__(self, testbed_version, configuration):
-        self._configuration = configuration
-        self._testbed_id = TESTBED_ID
-        self._testbed_version = testbed_version
-        self._factories = dict()
-        self._elements = dict()
-        self._create = dict()
-        self._set = dict()
-        self._connect = dict()
-        self._cross_connect = dict()
-        self._add_trace = dict()
-        self._add_address = dict()
-        self._add_route = dict()        
-
-        self._metadata = Metadata(self._testbed_id, self._testbed_version)
-        for factory in self._metadata.build_execute_factories():
-            self._factories[factory.factory_id] = factory
-
+        super(TestbedInstance, self).__init__(TESTBED_ID, testbed_version, 
+                configuration)
         self._netns = self._load_netns_module(configuration)
-
-    @property
-    def elements(self):
-        return self._elements
+        self._traces = dict()
 
     @property
     def netns(self):
         return self._netns
 
-    def create(self, guid, factory_id):
-        if factory_id not in self._factories:
-            raise RuntimeError("Invalid element type %s for Netns version %s" %
-                    (factory_id, self._testbed_version))
-        if guid in self._create:
-            raise RuntimeError("Cannot add elements with the same guid: %d" %
-                    guid)
-        self._create[guid] = factory_id
-
-    def create_set(self, guid, name, value):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.has_attribute(name):
-            raise RuntimeError("Invalid attribute %s for element type %s" %
-                    (name, factory_id))
-        factory.set_attribute_value(name, value)
-        if guid not in self._set:
-            self._set[guid] = dict()
-        self._set[guid][name] = value
-       
-    def connect(self, guid1, connector_type_name1, guid2, 
-            connector_type_name2):
-        factory_id1 = self._create[guid1]
-        factory_id2 = self._create[guid2]
-        count = self._get_connection_count(guid1, connector_type_name1)
-        factory1 = self._factories[factory_id1]
-        connector_type = factory1.connector_type(connector_type_name1)
-        connector_type.can_connect(self._testbed_id, factory_id2, 
-                connector_type_name2, count)
-        if not guid1 in self._connect:
-            self._connect[guid1] = dict()
-        if not connector_type_name1 in self._connect[guid1]:
-             self._connect[guid1][connector_type_name1] = dict()
-        self._connect[guid1][connector_type_name1][guid2] = \
-               connector_type_name2
-        if not guid2 in self._connect:
-            self._connect[guid2] = dict()
-        if not connector_type_name2 in self._connect[guid2]:
-             self._connect[guid2][connector_type_name2] = dict()
-        self._connect[guid2][connector_type_name2][guid1] = \
-                connector_type_name1
-
-    def cross_connect(self, guid, connector_type_name, cross_guid, 
-            cross_testbed_id, cross_factory_id, cross_connector_type_name):
-        factory_id = self._create[guid]
-        count = self._get_connection_count(guid, connector_type_name)
-        factory = self._factories[factory_id]
-        connector_type = factory.connector_type(connector_type_name)
-        connector_type.can_connect(cross_testbed_id, cross_factory_id, 
-                cross_connector_type_name, count, must_cross = True)
-        if not guid in self._connect:
-            self._cross_connect[guid] = dict()
-        if not connector_type_name in self._cross_connect[guid]:
-             self._cross_connect[guid][connector_type_name] = dict()
-        self._cross_connect[guid][connector_type_name] = \
-                (cross_guid, cross_testbed_id, cross_factory_id, 
-                        cross_connector_type_name)
-
-    def add_trace(self, guid, trace_id):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not trace_id in factory_traces:
-            raise RuntimeError("Element type %s doesn't support trace %s" %
-                    (factory_id, trace_id))
-        if not guid in self._add_trace:
-            self._add_trace[guid] = list()
-        self._add_trace[guid].append(trace_id)
-
-    def add_adddress(self, guid, family, address, netprefix, broadcast):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.allow_addresses:
-            raise RuntimeError("Element type %s doesn't support addresses" %
-                    factory_id)
-        max_addresses = factory.get_attribute_value("MaxAddresses")
-        if guid in self._add_address:
-            count_addresses = len(self._add_address[guid])
-            if max_addresses == count_addresses:
-                raise RuntimeError("Element guid %d of type %s can't accept \
-                        more addresses" % (guid, family_id))
-        else:
-            self._add_address[guid] = list()
-        self._add_address[guid].append((family, address, netprefix, broadcast))
-
-    def add_route(self, guid, family, destination, netprefix, nexthop, 
-            interface):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.allow_routes:
-            raise RuntimeError("Element type %s doesn't support routes" %
-                    factory_id)
-        if not guid in self._add_route:
-            self._add_route[guid] = list()
-        self._add_route[guid].append((family, destination, netprefix, nexthop,
-            interface)) 
-
-    def do_create(self):
-        guids = dict()
-        for guid, factory_id in self._create.iteritems():
-            if not factory_id in guids:
-               guids[factory_id] = list()
-            guids[factory_id].append(guid)
-        for factory_id in self._metadata.factories_order:
-            if factory_id not in guids:
-                continue
-            factory = self._factories[factory_id]
-            for guid in guids[factory_id]:
-                parameters = dict() if guid not in self._set else \
-                        self._set[guid]
-                factory.create_function(self, guid, parameters)
-                element = self._elements[guid]
-                if element:
-                    for name, value in parameters.iteritems():
-                        setattr(element, name, value)
-
-    def do_connect(self):
-        for guid1, connections in self._connect.iteritems():
-            element1 = self._elements[guid1]
-            factory_id1 = self._create[guid1]
-            factory1 = self._factories[factory_id1]
-            for connector_type_name1, connections2 in connections.iteritems():
-                connector_type1 = factory1.connector_type(connector_type_name1)
-                for guid2, connector_type_name2 in connections2.iteritems():
-                    element2 = self._elements[guid2]
-                    factory_id2 = self._create[guid2]
-                    # Connections are executed in a "From -> To" direction only
-                    # This explicitly ignores the "To -> From" (mirror) 
-                    # connections of every connection pair. 
-                    code_to_connect = connector_type1.code_to_connect(
-                            self._testbed_id, factory_id2, 
-                            connector_type_name2)
-                    if code_to_connect:
-                        code_to_connect(element1, element2)
-
     def do_configure(self):
         # TODO: add traces!
-        # add addressess
+        # configure addressess
         for guid, addresses in self._add_address.iteritems():
             element = self._elements[guid]
             for address in addresses:
                 (family, address, netprefix, broadcast) = address
                 if family == AF_INET:
                     element.add_v4_address(address, netprefix)
-        # add routes
+        # configure routes
         for guid, routes in self._add_route.iteritems():
             element = self._elements[guid]
             for route in routes:
-                # TODO: family and interface not used!!!!!
-                (family, destination, netprefix, nexthop, interfaces) = routes
+                (destination, netprefix, nexthop) = route
                 element.add_route(prefix = destination, prefix_len = netprefix,
                         nexthop = nexthop)
 
-    def do_cross_connect(self):
-        for guid, cross_connections in self._cross_connect.iteritems():
-            element = self._elements[guid]
-            factory_id = self._create[guid]
-            factory = self._factories[factory_id]
-            for connector_type_name, cross_connection in \
-                    cross_connections.iteritems():
-                connector_type = factory.connector_type(connector_type_name)
-                (cross_testbed_id, cross_factory_id, 
-                        cross_connector_type_name) = cross_connection
-                code_to_connect = connector_type.code_to_connect(
-                    cross_guid, cross_testbed_id, cross_factory_id, 
-                    cross_conector_type_name)
-                if code_to_connect:
-                    code_to_connect(element, cross_guid)       
-
     def set(self, time, guid, name, value):
         # TODO: take on account schedule time for the task 
         element = self._elements[guid]
-        setattr(element, name, value)
+        if element:
+            setattr(element, name, value)
 
     def get(self, time, guid, name):
-        # TODO: take on account schedule time for the task 
+        # TODO: take on account schedule time for the task
         element = self._elements[guid]
         return getattr(element, name)
 
-    def start(self, time = "0s"):
-        for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            start_function = factory.start_function
-            if start_function:
-                traces = [] if guid not in self._add_trace else \
-                        self._add_trace[guid]
-                parameters = dict() if guid not in self._set else \
-                        self._set[guid]
-                start_function(self, guid, parameters, traces)
-
     def action(self, time, guid, action):
         raise NotImplementedError
 
-    def stop(self, time = "0s"):
-        for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            stop_function = factory.stop_function
-            if stop_function:
-                traces = [] if guid not in self._add_trace else \
-                        self._add_trace[guid]
-                stop_function(self, guid, traces)
-
-    def status(self, guid):
-          for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            status_function = factory.status_function
-            if status_function:
-                result = status_function(self, guid)
-
     def trace(self, guid, trace_id):
-        raise NotImplementedError
+        f = open(self.trace_filename(guid, trace_id), "r")
+        content = f.read()
+        f.close()
+        return content
 
     def shutdown(self):
+        for trace in self._traces.values():
+            trace.close()
         for element in self._elements.values():
             element.destroy()
 
-    def get_connected(self, guid, connector_type_name, 
-            other_connector_type_name):
-        """searchs the connected elements for the specific connector_type_name 
-        pair"""
-        if guid not in self._connect:
-            return []
-        # all connections for all connectors for guid
-        all_connections = self._connect[guid]
-        if connector_type_name not in all_connections:
-            return []
-        # all connections for the specific connector
-        connections = all_connections[connector_type_name]
-        specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
-                in connections.iteritems() if \
-                otr_connector_type_name == other_connector_type_name]
-        return specific_connections
+    def trace_filename(self, guid, trace_id):
+        # TODO: Need to be defined inside a home!!!! with and experiment id_code
+        return os.path.join(self.home_directory, "%d_%s" % (guid, trace_id))
+
+    def follow_trace(self, trace_id, trace):
+        self._traces[trace_id] = trace
 
     def _load_netns_module(self, configuration):
         # TODO: Do something with the configuration!!!
@@ -294,15 +88,3 @@ class TestbedInstance(execute.TestbedInstance):
             netns_mod.environ.set_log_level(netns_mod.environ.LOG_DEBUG)
         return netns_mod
 
-    def _get_connection_count(self, guid, connection_type_name):
-        count = 0
-        cross_count = 0
-        if guid in self._connect and connection_type_name in \
-                self._connect[guid]:
-            count = len(self._connect[guid][connection_type_name])
-        if guid in self._cross_connect and connection_type_name in \
-                self._cross_connect[guid]:
-            cross_count = len(self._cross_connect[guid][connection_type_name])
-        return count + cross_count
-
-