Lots of cross-connection fixes, TUN synchronization, etc
[nepi.git] / src / nepi / core / execute.py
index de3e7e1..4018857 100644 (file)
@@ -3,7 +3,7 @@
 
 from nepi.core.attributes import Attribute, AttributesMap
 from nepi.core.connector import ConnectorTypeBase
-from nepi.util import proxy, validation
+from nepi.util import validation
 from nepi.util.constants import STATUS_FINISHED, TIME_NOW
 from nepi.util.parser._xml import XmlExperimentParser
 import sys
@@ -12,8 +12,9 @@ import threading
 import ConfigParser
 import os
 import collections
+import functools
 
-ATTRIBUTE_PATTERN_BASE = re.compile(r"\{#\[(?P<label>[-a-zA-Z0-9._]*)\](?P<expr>(?P<component>\.addr\[[0-9]+\]|\.route\[[0-9]+\]|\.trace\[[0-9]+\]|).\[(?P<attribute>[-a-zA-Z0-9._]*)\])#}")
+ATTRIBUTE_PATTERN_BASE = re.compile(r"\{#\[(?P<label>[-a-zA-Z0-9._]*)\](?P<expr>(?P<component>\.addr\[[0-9]+\]|\.route\[[0-9]+\]|\.trace\[[0-9]+\])?.\[(?P<attribute>[-a-zA-Z0-9._]*)\])#}")
 ATTRIBUTE_PATTERN_GUID_SUB = r"{#[%(guid)s]%(expr)s#}"
 COMPONENT_PATTERN = re.compile(r"(?P<kind>[a-z]*)\[(?P<index>.*)\]")
 
@@ -43,7 +44,7 @@ class ConnectorType(ConnectorTypeBase):
         self._to_connections[type_id] = (can_cross, init_code, compl_code)
 
     def can_connect(self, testbed_id, factory_id, name, count, 
-            must_cross = False):
+            must_cross):
         connector_type_id = self.make_connector_type_id(testbed_id, factory_id, name)
         for lookup_type_id in self._type_resolution_order(connector_type_id):
             if lookup_type_id in self._from_connections:
@@ -51,26 +52,28 @@ class ConnectorType(ConnectorTypeBase):
             elif lookup_type_id in self._to_connections:
                 (can_cross, init_code, compl_code) = self._to_connections[lookup_type_id]
             else:
-                # keey trying
+                # keep trying
                 continue
             return not must_cross or can_cross
         else:
             return False
 
-    def _connect_to_code(self, testbed_id, factory_id, name):
+    def _connect_to_code(self, testbed_id, factory_id, name,
+            must_cross):
         connector_type_id = self.make_connector_type_id(testbed_id, factory_id, name)
         for lookup_type_id in self._type_resolution_order(connector_type_id):
             if lookup_type_id in self._to_connections:
                 (can_cross, init_code, compl_code) = self._to_connections[lookup_type_id]
-                return (init_code, compl_code)
+                if not must_cross or can_cross:
+                    return (init_code, compl_code)
         else:
             return (False, False)
     
-    def connect_to_init_code(self, testbed_id, factory_id, name):
-        return self._connect_to_code(testbed_id, factory_id, name)[0]
+    def connect_to_init_code(self, testbed_id, factory_id, name, must_cross):
+        return self._connect_to_code(testbed_id, factory_id, name, must_cross)[0]
 
-    def connect_to_compl_code(self, testbed_id, factory_id, name):
-        return self._connect_to_code(testbed_id, factory_id, name)[1]
+    def connect_to_compl_code(self, testbed_id, factory_id, name, must_cross):
+        return self._connect_to_code(testbed_id, factory_id, name, must_cross)[1]
 
 class Factory(AttributesMap):
     def __init__(self, factory_id, create_function, start_function, 
@@ -198,8 +201,11 @@ class TestbedController(object):
         """Instructs creation of a connection between the given connectors"""
         raise NotImplementedError
 
-    def defer_cross_connect(self, guid, connector_type_name, cross_guid, 
-            cross_testbed_id, cross_factory_id, cross_connector_type_name):
+    def defer_cross_connect(self, 
+            guid, connector_type_name,
+            cross_guid, cross_testbed_guid,
+            cross_testbed_id, cross_factory_id,
+            cross_connector_type_name):
         """
         Instructs creation of a connection between the given connectors 
         of different testbed instances
@@ -243,6 +249,14 @@ class TestbedController(object):
         """
         raise NotImplementedError
 
+    def do_preconfigure(self):
+        """
+        Done just before resolving netrefs, after connection, before cross connections,
+        useful for early stages of configuration, for setting up stuff that might be
+        required for netref resolution.
+        """
+        raise NotImplementedError
+
     def do_configure(self):
         """After do_configure elements are configured"""
         raise NotImplementedError
@@ -318,6 +332,7 @@ class ExperimentController(object):
         self._cross_data = dict()
         self._root_dir = root_dir
         self._netreffed_testbeds = set()
+        self._guids_in_testbed_cache = dict()
 
         self.persist_experiment_xml()
 
@@ -336,11 +351,24 @@ class ExperimentController(object):
 
     @staticmethod
     def _parallel(callables):
-        threads = [ threading.Thread(target=callable) for callable in callables ]
+        excs = []
+        def wrap(callable):
+            @functools.wraps(callable)
+            def wrapped(*p, **kw):
+                try:
+                    callable(*p, **kw)
+                except Exception,e:
+                    import traceback
+                    traceback.print_exc(file=sys.stderr)
+                    excs.append(e)
+            return wrapped
+        threads = [ threading.Thread(target=wrap(callable)) for callable in callables ]
         for thread in threads:
             thread.start()
         for thread in threads:
             thread.join()
+        for exc in excs:
+            raise exc
 
     def start(self):
         parser = XmlExperimentParser()
@@ -396,11 +424,18 @@ class ExperimentController(object):
         # final netref step, fail if anything's left unresolved
         self.do_netrefs(data, fail_if_undefined=True)
         
+        self._program_testbed_cross_connections(data)
+        
         # perform do_configure in parallel for al testbeds
         # (it's internal configuration for each)
         self._parallel([testbed.do_configure
                         for testbed in self._testbeds.itervalues()])
 
+        
+        #print >>sys.stderr, "DO IT"
+        #import time
+        #time.sleep(60)
+        
         # cross-connect (cannot be done in parallel)
         for guid, testbed in self._testbeds.iteritems():
             cross_data = self._get_cross_data(guid)
@@ -439,6 +474,10 @@ class ExperimentController(object):
             BOOLEAN : 'getboolean',
         }
         
+        # deferred import because proxy needs
+        # our class definitions to define proxies
+        import nepi.util.proxy as proxy
+        
         conf = ConfigParser.RawConfigParser()
         conf.read(os.path.join(self._root_dir, 'deployment_config.ini'))
         for testbed_guid in conf.sections():
@@ -482,9 +521,8 @@ class ExperimentController(object):
 
     def is_finished(self, guid):
         for testbed in self._testbeds.values():
-            for guid_ in testbed.guids:
-                if guid_ == guid:
-                    return testbed.status(guid) == STATUS_FINISHED
+            if guid in testbed.guids:
+                return testbed.status(guid) == STATUS_FINISHED
         raise RuntimeError("No element exists with guid %d" % guid)    
 
     def set(self, testbed_guid, guid, name, value, time = TIME_NOW):
@@ -496,8 +534,16 @@ class ExperimentController(object):
         return testbed.get(guid, name, time)
 
     def shutdown(self):
-       for testbed in self._testbeds.values():
-           testbed.shutdown()
+        for testbed in self._testbeds.values():
+            testbed.shutdown()
+    
+    def _guids_in_testbed(self, testbed_guid):
+        if testbed_guid not in self._testbeds:
+            return set()
+        if testbed_guid not in self._guids_in_testbed_cache:
+            self._guids_in_testbed_cache[testbed_guid] = \
+                set(self._testbeds[testbed_guid].guids)
+        return self._guids_in_testbed_cache[testbed_guid]
 
     @staticmethod
     def _netref_component_split(component):
@@ -510,10 +556,10 @@ class ExperimentController(object):
     _NETREF_COMPONENT_GETTERS = {
         'addr':
             lambda testbed, guid, index, name: 
-                testbed.get_address(guid, index, name),
+                testbed.get_address(guid, int(index), name),
         'route' :
             lambda testbed, guid, index, name: 
-                testbed.get_route(guid, index, name),
+                testbed.get_route(guid, int(index), name),
         'trace' :
             lambda testbed, guid, index, name: 
                 testbed.trace(guid, index, name),
@@ -522,7 +568,7 @@ class ExperimentController(object):
                 testbed.get(guid, name),
     }
     
-    def resolve_netref_value(self, value):
+    def resolve_netref_value(self, value, failval = None):
         match = ATTRIBUTE_PATTERN_BASE.search(value)
         if match:
             label = match.group("label")
@@ -530,53 +576,60 @@ class ExperimentController(object):
                 ref_guid = int(label[5:])
                 if ref_guid:
                     expr = match.group("expr")
-                    component = match.group("component")[1:] # skip the dot
+                    component = (match.group("component") or "")[1:] # skip the dot
                     attribute = match.group("attribute")
                     
                     # split compound components into component kind and index
                     # eg: 'addr[0]' -> ('addr', '0')
                     component, component_index = self._netref_component_split(component)
-                    
+
                     # find object and resolve expression
-                    for ref_testbed in self._testbeds.itervalues():
+                    for ref_testbed_guid, ref_testbed in self._testbeds.iteritems():
                         if component not in self._NETREF_COMPONENT_GETTERS:
                             raise ValueError, "Malformed netref: %r - unknown component" % (expr,)
+                        elif ref_guid not in self._guids_in_testbed(ref_testbed_guid):
+                            pass
                         else:
                             ref_value = self._NETREF_COMPONENT_GETTERS[component](
                                 ref_testbed, ref_guid, component_index, attribute)
                             if ref_value:
                                 return value.replace(match.group(), ref_value)
         # couldn't find value
-        return None
+        return failval
     
     def do_netrefs(self, data, fail_if_undefined = False):
         # element netrefs
-        for (testbed_guid, guid), attrs in self._netrefs.iteritems():
-            testbed = self._testbeds[testbed_guid]
-            for name in attrs:
-                value = testbed.get(guid, name)
-                if isinstance(value, basestring):
-                    ref_value = self.resolve_netref_value(value)
-                    if ref_value is not None:
-                        testbed.set(guid, name, ref_value)
-                    elif fail_if_undefined:
-                        raise ValueError, "Unresolvable netref in: %r" % (value,)
+        for (testbed_guid, guid), attrs in self._netrefs.items():
+            testbed = self._testbeds.get(testbed_guid)
+            if testbed is not None:
+                for name in set(attrs):
+                    value = testbed.get(guid, name)
+                    if isinstance(value, basestring):
+                        ref_value = self.resolve_netref_value(value)
+                        if ref_value is not None:
+                            testbed.set(guid, name, ref_value)
+                            attrs.remove(name)
+                        elif fail_if_undefined:
+                            raise ValueError, "Unresolvable netref in: %r=%r" % (name,value,)
+                if not attrs:
+                    del self._netrefs[(testbed_guid, guid)]
         
         # testbed netrefs
-        for testbed_guid, attrs in self._testbed_netrefs.iteritems():
+        for testbed_guid, attrs in self._testbed_netrefs.items():
             tb_data = dict(data.get_attribute_data(testbed_guid))
             if data:
-                for name in attrs:
+                for name in set(attrs):
                     value = tb_data.get(name)
                     if isinstance(value, basestring):
                         ref_value = self.resolve_netref_value(value)
                         if ref_value is not None:
                             data.set_attribute_data(testbed_guid, name, ref_value)
+                            attrs.remove(name)
                         elif fail_if_undefined:
                             raise ValueError, "Unresolvable netref in: %r" % (value,)
+                if not attrs:
+                    del self._testbed_netrefs[testbed_guid]
         
-        self._netrefs.clear()
-        self._testbed_netrefs.clear()
 
     def _init_testbed_controllers(self, data, recover = False):
         blacklist_testbeds = set(self._testbeds)
@@ -641,6 +694,10 @@ class ExperimentController(object):
         (testbed_id, testbed_version) = data.get_testbed_data(guid)
         deployment_config = self._deployment_config.get(guid)
         
+        # deferred import because proxy needs
+        # our class definitions to define proxies
+        import nepi.util.proxy as proxy
+        
         if deployment_config is None:
             # need to create one
             deployment_config = proxy.AccessConfiguration()
@@ -675,40 +732,69 @@ class ExperimentController(object):
     def _program_testbed_controllers(self, element_guids, data):
         for guid in element_guids:
             (testbed_guid, factory_id) = data.get_box_data(guid)
-            testbed = self._testbeds[testbed_guid]
-            testbed.defer_create(guid, factory_id)
-            for (name, value) in data.get_attribute_data(guid):
-                testbed.defer_create_set(guid, name, value)
+            testbed = self._testbeds.get(testbed_guid)
+            if testbed:
+                testbed.defer_create(guid, factory_id)
+                for (name, value) in data.get_attribute_data(guid):
+                    # Try to resolve create-time netrefs, if possible
+                    if isinstance(value, basestring) and ATTRIBUTE_PATTERN_BASE.search(value):
+                        try:
+                            nuvalue = self.resolve_netref_value(value)
+                        except:
+                            # Any trouble means we're not in shape to resolve the netref yet
+                            nuvalue = None
+                        if nuvalue is not None:
+                            # Only if we succeed we remove the netref deferral entry
+                            value = nuvalue
+                            data.set_attribute_data(guid, name, value)
+                            if (testbed_guid, guid) in self._netrefs:
+                                self._netrefs[(testbed_guid, guid)].discard(name)
+                    testbed.defer_create_set(guid, name, value)
 
         for guid in element_guids: 
             (testbed_guid, factory_id) = data.get_box_data(guid)
-            testbed = self._testbeds[testbed_guid]
-            for (connector_type_name, cross_guid, cross_connector_type_name) \
-                    in data.get_connection_data(guid):
-                (testbed_guid, factory_id) = data.get_box_data(guid)
-                (cross_testbed_guid, cross_factory_id) = data.get_box_data(
-                        cross_guid)
-                if testbed_guid == cross_testbed_guid:
-                    testbed.defer_connect(guid, connector_type_name, 
-                            cross_guid, cross_connector_type_name)
-                else: 
-                    cross_testbed = self._testbeds[cross_testbed_guid]
-                    cross_testbed_id = cross_testbed.testbed_id
-                    testbed.defer_cross_connect(guid, connector_type_name, cross_guid, 
-                            cross_testbed_guid, cross_testbed_id, cross_factory_id, 
-                            cross_connector_type_name)
-                    # save cross data for later
-                    self._add_crossdata(testbed_guid, guid, cross_testbed_guid,
+            testbed = self._testbeds.get(testbed_guid)
+            if testbed:
+                for (connector_type_name, cross_guid, cross_connector_type_name) \
+                        in data.get_connection_data(guid):
+                    (testbed_guid, factory_id) = data.get_box_data(guid)
+                    (cross_testbed_guid, cross_factory_id) = data.get_box_data(
                             cross_guid)
-            for trace_id in data.get_trace_data(guid):
-                testbed.defer_add_trace(guid, trace_id)
-            for (autoconf, address, netprefix, broadcast) in \
-                    data.get_address_data(guid):
-                if address != None:
-                    testbed.defer_add_address(guid, address, netprefix, 
-                            broadcast)
-            for (destination, netprefix, nexthop) in data.get_route_data(guid):
-                testbed.defer_add_route(guid, destination, netprefix, nexthop)
+                    if testbed_guid == cross_testbed_guid:
+                        testbed.defer_connect(guid, connector_type_name, 
+                                cross_guid, cross_connector_type_name)
+                for trace_id in data.get_trace_data(guid):
+                    testbed.defer_add_trace(guid, trace_id)
+                for (autoconf, address, netprefix, broadcast) in \
+                        data.get_address_data(guid):
+                    if address != None:
+                        testbed.defer_add_address(guid, address, netprefix, 
+                                broadcast)
+                for (destination, netprefix, nexthop) in data.get_route_data(guid):
+                    testbed.defer_add_route(guid, destination, netprefix, nexthop)
+    
+    def _program_testbed_cross_connections(self, data):
+        data_guids = data.guids
+
+        for guid in data_guids: 
+            if not data.is_testbed_data(guid):
+                (testbed_guid, factory_id) = data.get_box_data(guid)
+                testbed = self._testbeds.get(testbed_guid)
+                if testbed:
+                    for (connector_type_name, cross_guid, cross_connector_type_name) \
+                            in data.get_connection_data(guid):
+                        (testbed_guid, factory_id) = data.get_box_data(guid)
+                        (cross_testbed_guid, cross_factory_id) = data.get_box_data(
+                                cross_guid)
+                        if testbed_guid != cross_testbed_guid:
+                            cross_testbed = self._testbeds[cross_testbed_guid]
+                            cross_testbed_id = cross_testbed.testbed_id
+                            testbed.defer_cross_connect(guid, connector_type_name, cross_guid, 
+                                    cross_testbed_guid, cross_testbed_id, cross_factory_id, 
+                                    cross_connector_type_name)
+                            # save cross data for later
+                            self._add_crossdata(testbed_guid, guid, cross_testbed_guid,
+                                    cross_guid)
                 
     def _add_crossdata(self, testbed_guid, guid, cross_testbed_guid, cross_guid):
         if testbed_guid not in self._cross_data:
@@ -726,7 +812,11 @@ class ExperimentController(object):
             cross_data[cross_testbed_guid] = dict()
             cross_testbed = self._testbeds[cross_testbed_guid]
             for cross_guid in guid_list:
-                elem_cross_data = dict()
+                elem_cross_data = dict(
+                    _guid = cross_guid,
+                    _testbed_guid = cross_testbed_guid,
+                    _testbed_id = cross_testbed.testbed_id,
+                    _testbed_version = cross_testbed.testbed_version)
                 cross_data[cross_testbed_guid][cross_guid] = elem_cross_data
                 attributes_list = cross_testbed.get_attribute_list(cross_guid)
                 for attr_name in attributes_list: