Ticket #29: Phasing out AccessConfiguration.
[nepi.git] / src / nepi / util / proxy.py
index 1be195a..7af0543 100644 (file)
@@ -4,7 +4,7 @@
 import base64
 from nepi.core.attributes import AttributesMap, Attribute
 from nepi.util import server, validation
-from nepi.util.constants import TIME_NOW, ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP
+from nepi.util.constants import TIME_NOW, ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP, DeploymentConfiguration as DC
 import getpass
 import cPickle
 import sys
@@ -18,7 +18,6 @@ ERROR = 1
 
 # PROTOCOL INSTRUCTION MESSAGES
 XML = 2 
-ACCESS  = 3
 TRACE   = 4
 FINISHED    = 5
 START   = 6
@@ -64,7 +63,6 @@ FLOAT   = 103
 # EXPERIMENT CONTROLER PROTOCOL MESSAGES
 controller_messages = dict({
     XML:    "%d" % XML,
-    ACCESS: "%d|%s" % (ACCESS, "%d|%s|%s|%s|%s|%d|%s|%r|%s"),
     TRACE:  "%d|%s" % (TRACE, "%d|%d|%s|%s"),
     FINISHED:   "%d|%s" % (FINISHED, "%d"),
     START:  "%d" % START,
@@ -114,7 +112,6 @@ instruction_text = dict({
     OK:     "OK",
     ERROR:  "ERROR",
     XML:    "XML",
-    ACCESS: "ACCESS",
     TRACE:  "TRACE",
     FINISHED:   "FINISHED",
     START:  "START",
@@ -192,84 +189,44 @@ def log_reply(server, reply):
             code_txt, txt))
 
 def to_server_log_level(log_level):
-    return server.DEBUG_LEVEL \
-            if log_level == AccessConfiguration.DEBUG_LEVEL \
-                else server.ERROR_LEVEL
+    return (
+        server.DEBUG_LEVEL
+            if log_level == DC.DEBUG_LEVEL 
+        else server.ERROR_LEVEL
+    )
 
 def get_access_config_params(access_config):
-    root_dir = access_config.get_attribute_value("rootDirectory")
-    log_level = access_config.get_attribute_value("logLevel")
+    root_dir = access_config.get_attribute_value(DC.ROOT_DIRECTORY)
+    log_level = access_config.get_attribute_value(DC.LOG_LEVEL)
     log_level = to_server_log_level(log_level)
     user = host = port = agent = None
-    communication = access_config.get_attribute_value("communication")
-    if communication == AccessConfiguration.ACCESS_SSH:
-        user = access_config.get_attribute_value("user")
-        host = access_config.get_attribute_value("host")
-        port = access_config.get_attribute_value("port")
-        agent = access_config.get_attribute_value("useAgent")
-    return (root_dir, log_level, user, host, port, agent)
+    communication = access_config.get_attribute_value(DC.DEPLOYMENT_COMMUNICATION)
+    environment_setup = (
+        access_config.get_attribute_value(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
+        if access_config.has_attribute(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
+        else None
+    )
+    if communication == DC.ACCESS_SSH:
+        user = access_config.get_attribute_value(DC.DEPLOYMENT_USER)
+        host = access_config.get_attribute_value(DC.DEPLOYMENT_HOST)
+        port = access_config.get_attribute_value(DC.DEPLOYMENT_PORT)
+        agent = access_config.get_attribute_value(DC.USE_AGENT)
+    return (root_dir, log_level, user, host, port, agent, environment_setup)
 
 class AccessConfiguration(AttributesMap):
-    MODE_SINGLE_PROCESS = "SINGLE"
-    MODE_DAEMON = "DAEMON"
-    ACCESS_SSH = "SSH"
-    ACCESS_LOCAL = "LOCAL"
-    ERROR_LEVEL = "Error"
-    DEBUG_LEVEL = "Debug"
-
-    def __init__(self):
+    def __init__(self, params = None):
         super(AccessConfiguration, self).__init__()
-        self.add_attribute(name = "mode",
-                help = "Instance execution mode",
-                type = Attribute.ENUM,
-                value = AccessConfiguration.MODE_SINGLE_PROCESS,
-                allowed = [AccessConfiguration.MODE_DAEMON,
-                    AccessConfiguration.MODE_SINGLE_PROCESS],
-                validation_function = validation.is_enum)
-        self.add_attribute(name = "communication",
-                help = "Instance communication mode",
-                type = Attribute.ENUM,
-                value = AccessConfiguration.ACCESS_LOCAL,
-                allowed = [AccessConfiguration.ACCESS_LOCAL,
-                    AccessConfiguration.ACCESS_SSH],
-                validation_function = validation.is_enum)
-        self.add_attribute(name = "host",
-                help = "Host where the testbed will be executed",
-                type = Attribute.STRING,
-                value = "localhost",
-                validation_function = validation.is_string)
-        self.add_attribute(name = "user",
-                help = "User on the Host to execute the testbed",
-                type = Attribute.STRING,
-                value = getpass.getuser(),
-                validation_function = validation.is_string)
-        self.add_attribute(name = "port",
-                help = "Port on the Host",
-                type = Attribute.INTEGER,
-                value = 22,
-                validation_function = validation.is_integer)
-        self.add_attribute(name = "rootDirectory",
-                help = "Root directory for storing process files",
-                type = Attribute.STRING,
-                value = ".",
-                validation_function = validation.is_string) # TODO: validation.is_path
-        self.add_attribute(name = "useAgent",
-                help = "Use -A option for forwarding of the authentication agent, if ssh access is used", 
-                type = Attribute.BOOL,
-                value = False,
-                validation_function = validation.is_bool)
-        self.add_attribute(name = "logLevel",
-                help = "Log level for instance",
-                type = Attribute.ENUM,
-                value = AccessConfiguration.ERROR_LEVEL,
-                allowed = [AccessConfiguration.ERROR_LEVEL,
-                    AccessConfiguration.DEBUG_LEVEL],
-                validation_function = validation.is_enum)
-        self.add_attribute(name = "recover",
-                help = "Do not intantiate testbeds, rather, reconnect to already-running instances. Used to recover from a dead controller.", 
-                type = Attribute.BOOL,
-                value = False,
-                validation_function = validation.is_bool)
+        
+        from nepi.core.metadata import Metadata
+        
+        for _,attr_info in Metadata.DEPLOYMENT_ATTRIBUTES:
+            self.add_attribute(**attr_info)
+        
+        if params:
+            for attr_name, attr_value in params.iteritems():
+                parser = Attribute.type_parsers[self.get_attribute_type(attr_name)]
+                attr_value = parser(attr_value)
+                self.set_attribute_value(attr_name, attr_value)
 
 class TempDir(object):
     def __init__(self):
@@ -284,19 +241,19 @@ class PermDir(object):
 
 def create_controller(xml, access_config = None):
     mode = None if not access_config \
-            else access_config.get_attribute_value("mode")
+            else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
     launch = True if not access_config \
-            else not access_config.get_attribute_value("recover")
-    if not mode or mode == AccessConfiguration.MODE_SINGLE_PROCESS:
+            else not access_config.get_attribute_value(DC.RECOVER)
+    if not mode or mode == DC.MODE_SINGLE_PROCESS:
         if not launch:
             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
         
         from nepi.core.execute import ExperimentController
         
-        if not access_config or not access_config.has_attribute("rootDirectory"):
+        if not access_config or not access_config.has_attribute(DC.ROOT_DIRECTORY):
             root_dir = TempDir()
         else:
-            root_dir = PermDir(access_config.get_attribute_value("rootDirectory"))
+            root_dir = PermDir(access_config.get_attribute_value(DC.ROOT_DIRECTORY))
         controller = ExperimentController(xml, root_dir.path)
         
         # inject reference to temporary dir, so that it gets cleaned
@@ -304,29 +261,26 @@ def create_controller(xml, access_config = None):
         controller._tempdir = root_dir
         
         return controller
-    elif mode == AccessConfiguration.MODE_DAEMON:
-        (root_dir, log_level, user, host, port, agent) = \
+    elif mode == DC.MODE_DAEMON:
+        (root_dir, log_level, user, host, port, agent, environment_setup) = \
                 get_access_config_params(access_config)
         return ExperimentControllerProxy(root_dir, log_level,
                 experiment_xml = xml, host = host, port = port, user = user, 
-                agent = agent, launch = launch)
+                agent = agent, launch = launch,
+                environment_setup = environment_setup)
     raise RuntimeError("Unsupported access configuration '%s'" % mode)
 
 def create_testbed_controller(testbed_id, testbed_version, access_config):
     mode = None if not access_config \
-            else access_config.get_attribute_value("mode")
+            else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
     launch = True if not access_config \
-            else not access_config.get_attribute_value("recover")
-    environment_setup = access_config \
-            and access_config.has_attribute(ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP) \
-            and access_config.get_attribute_value(ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP) \
-            or ""
-    if not mode or mode == AccessConfiguration.MODE_SINGLE_PROCESS:
+            else not access_config.get_attribute_value(DC.RECOVER)
+    if not mode or mode == DC.MODE_SINGLE_PROCESS:
         if not launch:
             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
         return  _build_testbed_controller(testbed_id, testbed_version)
-    elif mode == AccessConfiguration.MODE_DAEMON:
-        (root_dir, log_level, user, host, port, agent) = \
+    elif mode == DC.MODE_DAEMON:
+        (root_dir, log_level, user, host, port, agent, environment_setup) = \
                 get_access_config_params(access_config)
         return TestbedControllerProxy(root_dir, log_level, testbed_id = testbed_id, 
                 testbed_version = testbed_version, host = host, port = port,
@@ -663,8 +617,6 @@ class ExperimentControllerServer(server.Server):
             try:
                 if instruction == XML:
                     reply = self.experiment_xml(params)
-                elif instruction == ACCESS:
-                    reply = self.set_access_configuration(params)
                 elif instruction == TRACE:
                     reply = self.trace(params)
                 elif instruction == FINISHED:
@@ -697,30 +649,7 @@ class ExperimentControllerServer(server.Server):
         xml = self._controller.experiment_xml
         result = base64.b64encode(xml)
         return "%d|%s" % (OK, result)
-
-    def set_access_configuration(self, params):
-        testbed_guid = int(params[1])
-        mode = params[2]
-        communication = params[3]
-        host = params[4]
-        user = params[5]
-        port = int(params[6])
-        root_dir = params[7]
-        use_agent = params[8] == "True"
-        log_level = params[9]
-        access_config = AccessConfiguration()
-        access_config.set_attribute_value("mode", mode)
-        access_config.set_attribute_value("communication", communication)
-        access_config.set_attribute_value("host", host)
-        access_config.set_attribute_value("user", user)
-        access_config.set_attribute_value("port", port)
-        access_config.set_attribute_value("rootDirectory", root_dir)
-        access_config.set_attribute_value("useAgent", use_agent)
-        access_config.set_attribute_value("logLevel", log_level)
-        self._controller.set_access_configuration(testbed_guid, 
-                access_config)
-        return "%d|%s" % (OK, "")
-
+        
     def trace(self, params):
         testbed_guid = int(params[1])
         guid = int(params[2])
@@ -1225,26 +1154,6 @@ class ExperimentControllerProxy(object):
             raise RuntimeError(text)
         return text
 
-    def set_access_configuration(self, testbed_guid, access_config):
-        mode = access_config.get_attribute_value("mode")
-        communication = access_config.get_attribute_value("communication")
-        host = access_config.get_attribute_value("host")
-        user = access_config.get_attribute_value("user")
-        port = access_config.get_attribute_value("port")
-        root_dir = access_config.get_attribute_value("rootDirectory")
-        use_agent = access_config.get_attribute_value("useAgent")
-        log_level = access_config.get_attribute_value("logLevel")
-        msg = controller_messages[ACCESS]
-        msg = msg % (testbed_guid, mode, communication, host, user, port, 
-                root_dir, use_agent, log_level)
-        self._client.send_msg(msg)
-        reply = self._client.read_reply()
-        result = reply.split("|")
-        code = int(result[0])
-        text =  base64.b64decode(result[1])
-        if code == ERROR:
-            raise RuntimeError(text)
-
     def trace(self, testbed_guid, guid, trace_id, attribute='value'):
         msg = controller_messages[TRACE]
         attribute = base64.b64encode(attribute)