Fix invalid message handlign in testbed proxies
[nepi.git] / src / nepi / util / proxy.py
index dd2d40b..db301ad 100644 (file)
@@ -8,6 +8,8 @@ from nepi.util.constants import TIME_NOW
 import getpass
 import sys
 import time
+import tempfile
+import shutil
 
 # PROTOCOL REPLIES
 OK = 0
@@ -40,6 +42,9 @@ SET = 24
 ACTION  = 25
 STATUS  = 26
 GUIDS  = 27
+GET_ROUTE = 28
+GET_ADDRESS = 29
+RECOVER = 30
 
 # PARAMETER TYPE
 STRING  =  100
@@ -51,16 +56,17 @@ FLOAT   = 103
 controller_messages = dict({
     XML:    "%d" % XML,
     ACCESS: "%d|%s" % (ACCESS, "%d|%s|%s|%s|%s|%d|%s|%r|%s"),
-    TRACE:  "%d|%s" % (TRACE, "%d|%d|%s"),
+    TRACE:  "%d|%s" % (TRACE, "%d|%d|%s|%s"),
     FINISHED:   "%d|%s" % (FINISHED, "%d"),
     START:  "%d" % START,
     STOP:   "%d" % STOP,
+    RECOVER : "%d" % RECOVER,
     SHUTDOWN:   "%d" % SHUTDOWN,
     })
 
 # TESTBED INSTANCE PROTOCOL MESSAGES
 testbed_messages = dict({
-    TRACE:  "%d|%s" % (TRACE, "%d|%s"),
+    TRACE:  "%d|%s" % (TRACE, "%d|%s|%s"),
     START:  "%d" % START,
     STOP:   "%d" % STOP,
     SHUTDOWN:   "%d" % SHUTDOWN,
@@ -80,6 +86,8 @@ testbed_messages = dict({
     DO_CROSS_CONNECT:   "%d" % DO_CROSS_CONNECT,
     GET:    "%d|%s" % (GET, "%s|%d|%s"),
     SET:    "%d|%s" % (SET, "%s|%d|%s|%s|%d"),
+    GET_ROUTE: "%d|%s" % (GET, "%d|%d|%s"),
+    GET_ADDRESS: "%d|%s" % (GET, "%d|%d|%s"),
     ACTION: "%d|%s" % (ACTION, "%s|%d|%s"),
     STATUS: "%d|%s" % (STATUS, "%d"),
     GUIDS:  "%d" % GUIDS,
@@ -94,6 +102,7 @@ instruction_text = dict({
     FINISHED:   "FINISHED",
     START:  "START",
     STOP:   "STOP",
+    RECOVER: "RECOVER",
     SHUTDOWN:   "SHUTDOWN",
     CONFIGURE:  "CONFIGURE",
     CREATE: "CREATE",
@@ -111,6 +120,8 @@ instruction_text = dict({
     DO_CROSS_CONNECT:   "DO_CROSS_CONNECT",
     GET:    "GET",
     SET:    "SET",
+    GET_ROUTE: "GET_ROUTE",
+    GET_ADDRESS: "GET_ADDRESS",
     ACTION: "ACTION",
     STATUS: "STATUS",
     GUIDS:  "GUIDS",
@@ -242,32 +253,69 @@ class AccessConfiguration(AttributesMap):
                 allowed = [AccessConfiguration.ERROR_LEVEL,
                     AccessConfiguration.DEBUG_LEVEL],
                 validation_function = validation.is_enum)
+        self.add_attribute(name = "recover",
+                help = "Do not intantiate testbeds, rather, reconnect to already-running instances. Used to recover from a dead controller.", 
+                type = Attribute.BOOL,
+                value = False,
+                validation_function = validation.is_bool)
+
+class TempDir(object):
+    def __init__(self):
+        self.path = tempfile.mkdtemp()
+    
+    def __del__(self):
+        shutil.rmtree(self.path)
+
+class PermDir(object):
+    def __init__(self, path):
+        self.path = path
 
 def create_controller(xml, access_config = None):
-    mode = None if not access_config else \
-            access_config.get_attribute_value("mode")
+    mode = None if not access_config \
+            else access_config.get_attribute_value("mode")
+    launch = True if not access_config \
+            else not access_config.get_attribute_value("recover")
     if not mode or mode == AccessConfiguration.MODE_SINGLE_PROCESS:
+        if not launch:
+            raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
+        
         from nepi.core.execute import ExperimentController
-        return ExperimentController(xml)
+        
+        if not access_config or not access_config.has_attribute("rootDirectory"):
+            root_dir = TempDir()
+        else:
+            root_dir = PermDir(access_config.get_attribute_value("rootDirectory"))
+        controller = ExperimentController(xml, root_dir.path)
+        
+        # inject reference to temporary dir, so that it gets cleaned
+        # up at destruction time.
+        controller._tempdir = root_dir
+        
+        return controller
     elif mode == AccessConfiguration.MODE_DAEMON:
         (root_dir, log_level, user, host, port, agent) = \
                 get_access_config_params(access_config)
         return ExperimentControllerProxy(root_dir, log_level,
                 experiment_xml = xml, host = host, port = port, user = user, 
-                agent = agent)
-    raise RuntimeError("Unsupported access configuration 'mode'" % mode)
+                agent = agent, launch = launch)
+    raise RuntimeError("Unsupported access configuration '%s'" % mode)
 
 def create_testbed_instance(testbed_id, testbed_version, access_config):
-    mode = None if not access_config else access_config.get_attribute_value("mode")
+    mode = None if not access_config \
+            else access_config.get_attribute_value("mode")
+    launch = True if not access_config \
+            else not access_config.get_attribute_value("recover")
     if not mode or mode == AccessConfiguration.MODE_SINGLE_PROCESS:
+        if not launch:
+            raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
         return  _build_testbed_instance(testbed_id, testbed_version)
     elif mode == AccessConfiguration.MODE_DAEMON:
         (root_dir, log_level, user, host, port, agent) = \
                 get_access_config_params(access_config)
         return TestbedInstanceProxy(root_dir, log_level, testbed_id = testbed_id, 
                 testbed_version = testbed_version, host = host, port = port,
-                user = user, agent = agent)
-    raise RuntimeError("Unsupported access configuration 'mode'" % mode)
+                user = user, agent = agent, launch = launch)
+    raise RuntimeError("Unsupported access configuration '%s'" % mode)
 
 def _build_testbed_instance(testbed_id, testbed_version):
     mod_name = "nepi.testbeds.%s" % (testbed_id.lower())
@@ -288,65 +336,73 @@ class TestbedInstanceServer(server.Server):
                 self._testbed_version)
 
     def reply_action(self, msg):
-        params = msg.split("|")
-        instruction = int(params[0])
-        log_msg(self, params)
-        try:
-            if instruction == TRACE:
-                reply = self.trace(params)
-            elif instruction == START:
-                reply = self.start(params)
-            elif instruction == STOP:
-                reply = self.stop(params)
-            elif instruction == SHUTDOWN:
-                reply = self.shutdown(params)
-            elif instruction == CONFIGURE:
-                reply = self.defer_configure(params)
-            elif instruction == CREATE:
-                reply = self.defer_create(params)
-            elif instruction == CREATE_SET:
-                reply = self.defer_create_set(params)
-            elif instruction == FACTORY_SET:
-                reply = self.defer_factory_set(params)
-            elif instruction == CONNECT:
-                reply = self.defer_connect(params)
-            elif instruction == CROSS_CONNECT:
-                reply = self.defer_cross_connect(params)
-            elif instruction == ADD_TRACE:
-                reply = self.defer_add_trace(params)
-            elif instruction == ADD_ADDRESS:
-                reply = self.defer_add_address(params)
-            elif instruction == ADD_ROUTE:
-                reply = self.defer_add_route(params)
-            elif instruction == DO_SETUP:
-                reply = self.do_setup(params)
-            elif instruction == DO_CREATE:
-                reply = self.do_create(params)
-            elif instruction == DO_CONNECT:
-                reply = self.do_connect(params)
-            elif instruction == DO_CONFIGURE:
-                reply = self.do_configure(params)
-            elif instruction == DO_CROSS_CONNECT:
-                reply = self.do_cross_connect(params)
-            elif instruction == GET:
-                reply = self.get(params)
-            elif instruction == SET:
-                reply = self.set(params)
-            elif instruction == ACTION:
-                reply = self.action(params)
-            elif instruction == STATUS:
-                reply = self.status(params)
-            elif instruction == GUIDS:
-                reply = self.guids(params)
-            else:
-                error = "Invalid instruction %s" % instruction
-                self.log_error(error)
+        if not msg:
+            result = base64.b64encode("Invalid command line")
+            reply = "%d|%s" % (ERROR, result)
+        else:
+            params = msg.split("|")
+            instruction = int(params[0])
+            log_msg(self, params)
+            try:
+                if instruction == TRACE:
+                    reply = self.trace(params)
+                elif instruction == START:
+                    reply = self.start(params)
+                elif instruction == STOP:
+                    reply = self.stop(params)
+                elif instruction == SHUTDOWN:
+                    reply = self.shutdown(params)
+                elif instruction == CONFIGURE:
+                    reply = self.defer_configure(params)
+                elif instruction == CREATE:
+                    reply = self.defer_create(params)
+                elif instruction == CREATE_SET:
+                    reply = self.defer_create_set(params)
+                elif instruction == FACTORY_SET:
+                    reply = self.defer_factory_set(params)
+                elif instruction == CONNECT:
+                    reply = self.defer_connect(params)
+                elif instruction == CROSS_CONNECT:
+                    reply = self.defer_cross_connect(params)
+                elif instruction == ADD_TRACE:
+                    reply = self.defer_add_trace(params)
+                elif instruction == ADD_ADDRESS:
+                    reply = self.defer_add_address(params)
+                elif instruction == ADD_ROUTE:
+                    reply = self.defer_add_route(params)
+                elif instruction == DO_SETUP:
+                    reply = self.do_setup(params)
+                elif instruction == DO_CREATE:
+                    reply = self.do_create(params)
+                elif instruction == DO_CONNECT:
+                    reply = self.do_connect(params)
+                elif instruction == DO_CONFIGURE:
+                    reply = self.do_configure(params)
+                elif instruction == DO_CROSS_CONNECT:
+                    reply = self.do_cross_connect(params)
+                elif instruction == GET:
+                    reply = self.get(params)
+                elif instruction == SET:
+                    reply = self.set(params)
+                elif instruction == GET_ADDRESS:
+                    reply = self.get_address(params)
+                elif instruction == GET_ROUTE:
+                    reply = self.get_route(params)
+                elif instruction == ACTION:
+                    reply = self.action(params)
+                elif instruction == STATUS:
+                    reply = self.status(params)
+                elif instruction == GUIDS:
+                    reply = self.guids(params)
+                else:
+                    error = "Invalid instruction %s" % instruction
+                    self.log_error(error)
+                    result = base64.b64encode(error)
+                    reply = "%d|%s" % (ERROR, result)
+            except:
+                error = self.log_error()
                 result = base64.b64encode(error)
                 reply = "%d|%s" % (ERROR, result)
-        except:
-            error = self.log_error()
-            result = base64.b64encode(error)
-            reply = "%d|%s" % (ERROR, result)
         log_reply(self, reply)
         return reply
 
@@ -365,7 +421,8 @@ class TestbedInstanceServer(server.Server):
     def trace(self, params):
         guid = int(params[1])
         trace_id = params[2]
-        trace = self._testbed.trace(guid, trace_id)
+        attribute = base64.b64decode(params[3])
+        trace = self._testbed.trace(guid, trace_id, attribute)
         result = base64.b64encode(trace)
         return "%d|%s" % (OK, result)
 
@@ -489,6 +546,22 @@ class TestbedInstanceServer(server.Server):
         self._testbed.set(time, guid, name, value)
         return "%d|%s" % (OK, "")
 
+    def get_address(self, params):
+        guid = int(param[1])
+        index = int(param[2])
+        attribute = base64.b64decode(param[3])
+        value = self._testbed.get_address(guid, index, attribute)
+        result = base64.b64encode(str(value))
+        return "%d|%s" % (OK, result)
+
+    def get_route(self, params):
+        guid = int(param[1])
+        index = int(param[2])
+        attribute = base64.b64decode(param[3])
+        value = self._testbed.get_route(guid, index, attribute)
+        result = base64.b64encode(str(value))
+        return "%d|%s" % (OK, result)
+
     def action(self, params):
         time = params[1]
         guid = int(params[2])
@@ -510,36 +583,43 @@ class ExperimentControllerServer(server.Server):
 
     def post_daemonize(self):
         from nepi.core.execute import ExperimentController
-        self._controller = ExperimentController(self._experiment_xml)
+        self._controller = ExperimentController(self._experiment_xml, 
+            root_dir = self._root_dir)
 
     def reply_action(self, msg):
-        params = msg.split("|")
-        instruction = int(params[0])
-        log_msg(self, params)
-        try:
-            if instruction == XML:
-                reply = self.experiment_xml(params)
-            elif instruction == ACCESS:
-                reply = self.set_access_configuration(params)
-            elif instruction == TRACE:
-                reply = self.trace(params)
-            elif instruction == FINISHED:
-                reply = self.is_finished(params)
-            elif instruction == START:
-                reply = self.start(params)
-            elif instruction == STOP:
-                reply = self.stop(params)
-            elif instruction == SHUTDOWN:
-                reply = self.shutdown(params)
-            else:
-                error = "Invalid instruction %s" % instruction
-                self.log_error(error)
+        if not msg:
+            result = base64.b64encode("Invalid command line")
+            reply = "%d|%s" % (ERROR, result)
+        else:
+            params = msg.split("|")
+            instruction = int(params[0])
+            log_msg(self, params)
+            try:
+                if instruction == XML:
+                    reply = self.experiment_xml(params)
+                elif instruction == ACCESS:
+                    reply = self.set_access_configuration(params)
+                elif instruction == TRACE:
+                    reply = self.trace(params)
+                elif instruction == FINISHED:
+                    reply = self.is_finished(params)
+                elif instruction == START:
+                    reply = self.start(params)
+                elif instruction == STOP:
+                    reply = self.stop(params)
+                elif instruction == RECOVER:
+                    reply = self.recover(params)
+                elif instruction == SHUTDOWN:
+                    reply = self.shutdown(params)
+                else:
+                    error = "Invalid instruction %s" % instruction
+                    self.log_error(error)
+                    result = base64.b64encode(error)
+                    reply = "%d|%s" % (ERROR, result)
+            except:
+                error = self.log_error()
                 result = base64.b64encode(error)
                 reply = "%d|%s" % (ERROR, result)
-        except:
-            error = self.log_error()
-            result = base64.b64encode(error)
-            reply = "%d|%s" % (ERROR, result)
         log_reply(self, reply)
         return reply
 
@@ -575,7 +655,8 @@ class ExperimentControllerServer(server.Server):
         testbed_guid = int(params[1])
         guid = int(params[2])
         trace_id = params[3]
-        trace = self._controller.trace(testbed_guid, guid, trace_id)
+        attribute = base64.b64decode(params[4])
+        trace = self._controller.trace(testbed_guid, guid, trace_id, attribute)
         result = base64.b64encode(trace)
         return "%d|%s" % (OK, result)
 
@@ -593,6 +674,10 @@ class ExperimentControllerServer(server.Server):
         self._controller.stop()
         return "%d|%s" % (OK, "")
 
+    def recover(self, params):
+        self._controller.recover()
+        return "%d|%s" % (OK, "")
+
     def shutdown(self, params):
         self._controller.shutdown()
         return "%d|%s" % (OK, "")
@@ -855,6 +940,34 @@ class TestbedInstanceProxy(object):
             raise RuntimeError(text)
         return text
 
+    def get_address(self, guid, index, attribute):
+        msg = testbed_messages[GET_ADDRESS]
+        # avoid having "|" in this parameters
+        attribute = base64.b64encode(attribute)
+        msg = msg % (guid, index, attribute)
+        self._client.send_msg(msg)
+        reply = self._client.read_reply()
+        result = reply.split("|")
+        code = int(result[0])
+        text = base64.b64decode(result[1])
+        if code == ERROR:
+            raise RuntimeError(text)
+        return text
+
+    def get_route(self, guid, index, attribute):
+        msg = testbed_messages[GET_ROUTE]
+        # avoid having "|" in this parameters
+        attribute = base64.b64encode(attribute)
+        msg = msg % (guid, index, attribute)
+        self._client.send_msg(msg)
+        reply = self._client.read_reply()
+        result = reply.split("|")
+        code = int(result[0])
+        text = base64.b64decode(result[1])
+        if code == ERROR:
+            raise RuntimeError(text)
+        return text
+
     def action(self, time, guid, action):
         msg = testbed_messages[ACTION]
         msg = msg % (time, guid, action)
@@ -878,9 +991,10 @@ class TestbedInstanceProxy(object):
             raise RuntimeError(text)
         return int(text)
 
-    def trace(self, guid, trace_id):
+    def trace(self, guid, trace_id, attribute='value'):
         msg = testbed_messages[TRACE]
-        msg = msg % (guid, trace_id)
+        attribute = base64.b64encode(attribute)
+        msg = msg % (guid, trace_id, attribute)
         self._client.send_msg(msg)
         reply = self._client.read_reply()
         result = reply.split("|")
@@ -900,6 +1014,7 @@ class TestbedInstanceProxy(object):
         if code == ERROR:
             raise RuntimeError(text)
         self._client.send_stop()
+        self._client.read_reply() # wait for it
 
 class ExperimentControllerProxy(object):
     def __init__(self, root_dir, log_level, experiment_xml = None, 
@@ -967,9 +1082,10 @@ class ExperimentControllerProxy(object):
         if code == ERROR:
             raise RuntimeError(text)
 
-    def trace(self, testbed_guid, guid, trace_id):
+    def trace(self, testbed_guid, guid, trace_id, attribute='value'):
         msg = controller_messages[TRACE]
-        msg = msg % (testbed_guid, guid, trace_id)
+        attribute = base64.b64encode(attribute)
+        msg = msg % (testbed_guid, guid, trace_id, attribute)
         self._client.send_msg(msg)
         reply = self._client.read_reply()
         result = reply.split("|")
@@ -999,6 +1115,16 @@ class ExperimentControllerProxy(object):
         if code == ERROR:
             raise RuntimeError(text)
 
+    def recover(self):
+        msg = controller_messages[RECOVER]
+        self._client.send_msg(msg)
+        reply = self._client.read_reply()
+        result = reply.split("|")
+        code = int(result[0])
+        text =  base64.b64decode(result[1])
+        if code == ERROR:
+            raise RuntimeError(text)
+
     def is_finished(self, guid):
         msg = controller_messages[FINISHED]
         msg = msg % guid
@@ -1021,4 +1147,5 @@ class ExperimentControllerProxy(object):
         if code == ERROR:
             raise RuntimeError(text)
         self._client.send_stop()
+        self._client.read_reply() # wait for it