Merging ns-3 into nepi-3-dev
[nepi.git] / src / nepi / resources / linux / node.py
index 49f5342..f9ae091 100644 (file)
@@ -17,9 +17,9 @@
 #
 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
 
-from nepi.execution.attribute import Attribute, Flags
-from nepi.execution.resource import ResourceManager, clsinit, ResourceState, \
-        reschedule_delay
+from nepi.execution.attribute import Attribute, Flags, Types
+from nepi.execution.resource import ResourceManager, clsinit_copy, \
+        ResourceState, reschedule_delay
 from nepi.resources.linux import rpmfuncs, debfuncs 
 from nepi.util import sshfuncs, execfuncs
 from nepi.util.sshfuncs import ProcStatus
@@ -57,7 +57,7 @@ class OSType:
     UBUNTU = "ubuntu"
     DEBIAN = "debian"
 
-@clsinit
+@clsinit_copy
 class LinuxNode(ResourceManager):
     """
     .. class:: Class Args :
@@ -142,42 +142,57 @@ class LinuxNode(ResourceManager):
 
     """
     _rtype = "LinuxNode"
+    _help = "Controls Linux host machines ( either localhost or a host " \
+            "that can be accessed using a SSH key)"
+    _backend_type = "linux"
 
     @classmethod
     def _register_attributes(cls):
         hostname = Attribute("hostname", "Hostname of the machine",
-                flags = Flags.ExecReadOnly)
+                flags = Flags.Design)
 
         username = Attribute("username", "Local account username", 
                 flags = Flags.Credential)
 
-        port = Attribute("port", "SSH port", flags = Flags.ExecReadOnly)
+        port = Attribute("port", "SSH port", flags = Flags.Design)
         
         home = Attribute("home",
                 "Experiment home directory to store all experiment related files",
-                flags = Flags.ExecReadOnly)
+                flags = Flags.Design)
         
         identity = Attribute("identity", "SSH identity file",
                 flags = Flags.Credential)
         
         server_key = Attribute("serverKey", "Server public key", 
-                flags = Flags.ExecReadOnly)
+                flags = Flags.Design)
         
         clean_home = Attribute("cleanHome", "Remove all nepi files and directories "
                 " from node home folder before starting experiment", 
-                flags = Flags.ExecReadOnly)
+                type = Types.Bool,
+                default = False,
+                flags = Flags.Design)
 
         clean_experiment = Attribute("cleanExperiment", "Remove all files and directories " 
                 " from a previous same experiment, before the new experiment starts", 
-                flags = Flags.ExecReadOnly)
+                type = Types.Bool,
+                default = False,
+                flags = Flags.Design)
         
         clean_processes = Attribute("cleanProcesses", 
                 "Kill all running processes before starting experiment",
-                flags = Flags.ExecReadOnly)
+                type = Types.Bool,
+                default = False,
+                flags = Flags.Design)
         
         tear_down = Attribute("tearDown", "Bash script to be executed before " + \
                 "releasing the resource",
-                flags = Flags.ExecReadOnly)
+                flags = Flags.Design)
+
+        gateway_user = Attribute("gatewayUser", "Gateway account username",
+                flags = Flags.Design)
+
+        gateway = Attribute("gateway", "Hostname of the gateway machine",
+                flags = Flags.Design)
 
         cls._register_attribute(hostname)
         cls._register_attribute(username)
@@ -189,15 +204,23 @@ class LinuxNode(ResourceManager):
         cls._register_attribute(clean_experiment)
         cls._register_attribute(clean_processes)
         cls._register_attribute(tear_down)
+        cls._register_attribute(gateway_user)
+        cls._register_attribute(gateway)
 
     def __init__(self, ec, guid):
         super(LinuxNode, self).__init__(ec, guid)
         self._os = None
         # home directory at Linux host
         self._home_dir = ""
-        
-        # lock to avoid concurrency issues on methods used by applications 
-        self._lock = threading.Lock()
+
+        # lock to prevent concurrent applications on the same node,
+        # to execute commands at the same time. There are potential
+        # concurrency issues when using SSH to a same host from 
+        # multiple threads. There are also possible operational 
+        # issues, e.g. an application querying the existence 
+        # of a file or folder prior to its creation, and another 
+        # application creating the same file or folder in between.
+        self._node_lock = threading.Lock()
     
     def log_message(self, msg):
         return " guid %d - host %s - %s " % (self.guid, 
@@ -256,12 +279,7 @@ class LinuxNode(ResourceManager):
             self.error(msg)
             raise RuntimeError, msg
 
-        (out, err), proc = self.execute("cat /etc/issue", with_lock = True)
-
-        if err and proc.poll():
-            msg = "Error detecting OS "
-            self.error(msg, out, err)
-            raise RuntimeError, "%s - %s - %s" %( msg, out, err )
+        out = self.get_os()
 
         if out.find("Fedora release 8") == 0:
             self._os = OSType.FEDORA_8
@@ -269,6 +287,8 @@ class LinuxNode(ResourceManager):
             self._os = OSType.FEDORA_12
         elif out.find("Fedora release 14") == 0:
             self._os = OSType.FEDORA_14
+        elif out.find("Fedora release") == 0:
+            self._os = OSType.FEDORA
         elif out.find("Debian") == 0: 
             self._os = OSType.DEBIAN
         elif out.find("Ubuntu") ==0:
@@ -280,6 +300,23 @@ class LinuxNode(ResourceManager):
 
         return self._os
 
+    def get_os(self):
+        # The underlying SSH layer will sometimes return an empty
+        # output (even if the command was executed without errors).
+        # To work arround this, repeat the operation N times or
+        # until the result is not empty string
+        out = ""
+        try:
+            (out, err), proc = self.execute("cat /etc/issue", 
+                    with_lock = True,
+                    blocking = True)
+        except:
+            trace = traceback.format_exc()
+            msg = "Error detecting OS: %s " % trace
+            self.error(msg, out, err)
+        
+        return out
+
     @property
     def use_deb(self):
         return self.os in [OSType.DEBIAN, OSType.UBUNTU]
@@ -293,11 +330,9 @@ class LinuxNode(ResourceManager):
     def localhost(self):
         return self.get("hostname") in ['localhost', '127.0.0.7', '::1']
 
-    def provision(self):
+    def do_provision(self):
         # check if host is alive
         if not self.is_alive():
-            self.fail()
-            
             msg = "Deploy failed. Unresponsive node %s" % self.get("hostname")
             self.error(msg)
             raise RuntimeError, msg
@@ -322,61 +357,76 @@ class LinuxNode(ResourceManager):
         # Create experiment node home directory
         self.mkdir(self.node_home)
 
-        super(LinuxNode, self).provision()
+        super(LinuxNode, self).do_provision()
 
-    def deploy(self):
+    def do_deploy(self):
         if self.state == ResourceState.NEW:
-            try:
-                self.discover()
-                self.provision()
-            except:
-                self._state = ResourceState.FAILED
-                raise
+            self.info("Deploying node")
+            self.do_discover()
+            self.do_provision()
 
         # Node needs to wait until all associated interfaces are 
         # ready before it can finalize deployment
         from nepi.resources.linux.interface import LinuxInterface
-        ifaces = self.get_connected(LinuxInterface.rtype())
+        ifaces = self.get_connected(LinuxInterface.get_rtype())
         for iface in ifaces:
             if iface.state < ResourceState.READY:
                 self.ec.schedule(reschedule_delay, self.deploy)
                 return 
 
-        super(LinuxNode, self).deploy()
+        super(LinuxNode, self).do_deploy()
+
+    def do_release(self):
+        rms = self.get_connected()
+        for rm in rms:
+            # Node needs to wait until all associated RMs are released
+            # before it can be released
+            if rm.state != ResourceState.RELEASED:
+                self.ec.schedule(reschedule_delay, self.release)
+                return 
 
-    def release(self):
         tear_down = self.get("tearDown")
         if tear_down:
             self.execute(tear_down)
 
         self.clean_processes()
 
-        super(LinuxNode, self).release()
+        super(LinuxNode, self).do_release()
 
     def valid_connection(self, guid):
         # TODO: Validate!
         return True
 
-    def clean_processes(self, killer = False):
+    def clean_processes(self):
         self.info("Cleaning up processes")
-        
-        if killer:
-            # Hardcore kill
-            cmd = ("sudo -S killall python tcpdump || /bin/true ; " +
-                "sudo -S killall python tcpdump || /bin/true ; " +
-                "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ; " +
-                "sudo -S killall -u root || /bin/true ; " +
-                "sudo -S killall -u root || /bin/true ; ")
-        else:
-            # Be gentler...
+        if self.get("username") != 'root':
             cmd = ("sudo -S killall tcpdump || /bin/true ; " +
-                "sudo -S killall tcpdump || /bin/true ; " +
-                "sudo -S killall -u %s || /bin/true ; " % self.get("username") +
+                "sudo -S kill $(ps aux | grep '[n]epi' | awk '{print $2}') || /bin/true ; " +
                 "sudo -S killall -u %s || /bin/true ; " % self.get("username"))
+        else:
+            if self.state >= ResourceState.READY:
+                import pickle
+                pids = pickle.load(open("/tmp/save.proc", "rb"))
+                pids_temp = dict()
+                ps_aux = "ps aux |awk '{print $2,$11}'"
+                (out, err), proc = self.execute(ps_aux)
+                for line in out.strip().split("\n"):
+                    parts = line.strip().split(" ")
+                    pids_temp[parts[0]] = parts[1]
+                kill_pids = set(pids_temp.items()) - set(pids.items())
+                kill_pids = ' '.join(dict(kill_pids).keys())
+
+                cmd = ("killall tcpdump || /bin/true ; " +
+                    "kill $(ps aux | grep '[n]epi' | awk '{print $2}') || /bin/true ; " +
+                    "kill %s || /bin/true ; " % kill_pids)
+            else:
+                cmd = ("killall tcpdump || /bin/true ; " +
+                    "kill $(ps aux | grep '[n]epi' | awk '{print $2}') || /bin/true ; ")
 
         out = err = ""
-        (out, err), proc = self.execute(cmd, retry = 1, with_lock = True) 
-            
+        (out, err), proc = self.execute(cmd, retry = 1, with_lock = True)
+
     def clean_home(self):
         """ Cleans all NEPI related folders in the Linux host
         """
@@ -402,13 +452,10 @@ class LinuxNode(ResourceManager):
 
     def execute(self, command,
             sudo = False,
-            stdin = None, 
             env = None,
             tty = False,
             forward_x11 = False,
-            timeout = None,
             retry = 3,
-            err_on_timeout = True,
             connect_timeout = 30,
             strict_host_checking = False,
             persistent = True,
@@ -421,29 +468,31 @@ class LinuxNode(ResourceManager):
 
         if self.localhost:
             (out, err), proc = execfuncs.lexec(command, 
-                    user = user,
+                    user = self.get("username"), # still problem with localhost
                     sudo = sudo,
-                    stdin = stdin,
                     env = env)
         else:
             if with_lock:
-                with self._lock:
+                # If the execute command is blocking, we don't want to keep
+                # the node lock. This lock is used to avoid race conditions
+                # when creating the ControlMaster sockets. A more elegant
+                # solution is needed.
+                with self._node_lock:
                     (out, err), proc = sshfuncs.rexec(
                         command, 
                         host = self.get("hostname"),
                         user = self.get("username"),
                         port = self.get("port"),
+                        gwuser = self.get("gatewayUser"),
+                        gw = self.get("gateway"),
                         agent = True,
                         sudo = sudo,
-                        stdin = stdin,
                         identity = self.get("identity"),
                         server_key = self.get("serverKey"),
                         env = env,
                         tty = tty,
                         forward_x11 = forward_x11,
-                        timeout = timeout,
                         retry = retry,
-                        err_on_timeout = err_on_timeout,
                         connect_timeout = connect_timeout,
                         persistent = persistent,
                         blocking = blocking, 
@@ -455,17 +504,16 @@ class LinuxNode(ResourceManager):
                     host = self.get("hostname"),
                     user = self.get("username"),
                     port = self.get("port"),
+                    gwuser = self.get("gatewayUser"),
+                    gw = self.get("gateway"),
                     agent = True,
                     sudo = sudo,
-                    stdin = stdin,
                     identity = self.get("identity"),
                     server_key = self.get("serverKey"),
                     env = env,
                     tty = tty,
                     forward_x11 = forward_x11,
-                    timeout = timeout,
                     retry = retry,
-                    err_on_timeout = err_on_timeout,
                     connect_timeout = connect_timeout,
                     persistent = persistent,
                     blocking = blocking, 
@@ -495,7 +543,7 @@ class LinuxNode(ResourceManager):
                     sudo = sudo,
                     user = user) 
         else:
-            with self._lock:
+            with self._node_lock:
                 (out, err), proc = sshfuncs.rspawn(
                     command,
                     pidfile = pidfile,
@@ -508,6 +556,8 @@ class LinuxNode(ResourceManager):
                     host = self.get("hostname"),
                     user = self.get("username"),
                     port = self.get("port"),
+                    gwuser = self.get("gatewayUser"),
+                    gw = self.get("gateway"),
                     agent = True,
                     identity = self.get("identity"),
                     server_key = self.get("serverKey"),
@@ -520,12 +570,14 @@ class LinuxNode(ResourceManager):
         if self.localhost:
             pidtuple =  execfuncs.lgetpid(os.path.join(home, pidfile))
         else:
-            with self._lock:
+            with self._node_lock:
                 pidtuple = sshfuncs.rgetpid(
                     os.path.join(home, pidfile),
                     host = self.get("hostname"),
                     user = self.get("username"),
                     port = self.get("port"),
+                    gwuser = self.get("gatewayUser"),
+                    gw = self.get("gateway"),
                     agent = True,
                     identity = self.get("identity"),
                     server_key = self.get("serverKey")
@@ -537,12 +589,14 @@ class LinuxNode(ResourceManager):
         if self.localhost:
             status = execfuncs.lstatus(pid, ppid)
         else:
-            with self._lock:
+            with self._node_lock:
                 status = sshfuncs.rstatus(
                         pid, ppid,
                         host = self.get("hostname"),
                         user = self.get("username"),
                         port = self.get("port"),
+                        gwuser = self.get("gatewayUser"),
+                        gw = self.get("gateway"),
                         agent = True,
                         identity = self.get("identity"),
                         server_key = self.get("serverKey")
@@ -559,12 +613,14 @@ class LinuxNode(ResourceManager):
             if self.localhost:
                 (out, err), proc = execfuncs.lkill(pid, ppid, sudo)
             else:
-                with self._lock:
+                with self._node_lock:
                     (out, err), proc = sshfuncs.rkill(
                         pid, ppid,
                         host = self.get("hostname"),
                         user = self.get("username"),
                         port = self.get("port"),
+                        gwuser = self.get("gatewayUser"),
+                        gw = self.get("gateway"),
                         agent = True,
                         sudo = sudo,
                         identity = self.get("identity"),
@@ -579,10 +635,12 @@ class LinuxNode(ResourceManager):
                     recursive = True,
                     strict_host_checking = False)
         else:
-            with self._lock:
+            with self._node_lock:
                 (out, err), proc = sshfuncs.rcopy(
                     src, dst, 
                     port = self.get("port"),
+                    gwuser = self.get("gatewayUser"),
+                    gw = self.get("gateway"),
                     identity = self.get("identity"),
                     server_key = self.get("serverKey"),
                     recursive = True,
@@ -590,15 +648,20 @@ class LinuxNode(ResourceManager):
 
         return (out, err), proc
 
-
     def upload(self, src, dst, text = False, overwrite = True):
         """ Copy content to destination
 
-           src  content to copy. Can be a local file, directory or a list of files
+        src  string with the content to copy. Can be:
+            - plain text
+            - a string with the path to a local file
+            - a string with a semi-colon separeted list of local files
+            - a string with a local directory
 
-           dst  destination path on the remote host (remote is always self.host)
+        dst  string with destination path on the remote host (remote is 
+            always self.host)
 
-           text src is text input, it must be stored into a temp file before uploading
+        text src is text input, it must be stored into a temp file before 
+        uploading
         """
         # If source is a string input 
         f = None
@@ -611,11 +674,14 @@ class LinuxNode(ResourceManager):
             src = f.name
 
         # If dst files should not be overwritten, check that the files do not
-        # exits already 
+        # exits already
+        if isinstance(src, str):
+            src = map(str.strip, src.split(";"))
+    
         if overwrite == False:
             src = self.filter_existing_files(src, dst)
             if not src:
-                return ("", ""), None 
+                return ("", ""), None
 
         if not self.localhost:
             # Build destination as <user>@<server>:<path>
@@ -635,12 +701,7 @@ class LinuxNode(ResourceManager):
             src = "%s@%s:%s" % (self.get("username"), self.get("hostname"), src)
         return self.copy(src, dst)
 
-    def install_packages(self, packages, home, run_home = None):
-        """ Install packages in the Linux host.
-
-        'home' is the directory to upload the package installation script.
-        'run_home' is the directory from where to execute the script.
-        """
+    def install_packages_command(self, packages):
         command = ""
         if self.use_rpm:
             command = rpmfuncs.install_packages_command(self.os, packages)
@@ -651,6 +712,16 @@ class LinuxNode(ResourceManager):
             self.error(msg, self.os)
             raise RuntimeError, msg
 
+        return command
+
+    def install_packages(self, packages, home, run_home = None):
+        """ Install packages in the Linux host.
+
+        'home' is the directory to upload the package installation script.
+        'run_home' is the directory from where to execute the script.
+        """
+        command = self.install_packages_command(packages)
+
         run_home = run_home or home
 
         (out, err), proc = self.run_and_wait(command, run_home, 
@@ -819,8 +890,8 @@ class LinuxNode(ResourceManager):
         return self.upload(command, shfile, text = True, overwrite = overwrite)
 
     def format_environment(self, env, inline = False):
-        """Format environmental variables for command to be executed either
-        as an inline command
+        """ Formats the environment variables for a command to be executed
+        either as an inline command
         (i.e. export PYTHONPATH=src/..; export LALAL= ..;python script.py) or 
         as a bash script (i.e. export PYTHONPATH=src/.. \n export LALA=.. \n)
         """
@@ -835,8 +906,7 @@ class LinuxNode(ResourceManager):
     def check_errors(self, home, 
             ecodefile = "exitcode", 
             stderr = "stderr"):
-        """
-        Checks whether errors occurred while running a command.
+        """ Checks whether errors occurred while running a command.
         It first checks the exit code for the command, and only if the
         exit code is an error one it returns the error output.
 
@@ -868,7 +938,7 @@ class LinuxNode(ResourceManager):
         pid = ppid = None
         delay = 1.0
 
-        for i in xrange(4):
+        for i in xrange(2):
             pidtuple = self.getpid(home = home, pidfile = pidfile)
             
             if pidtuple:
@@ -889,7 +959,7 @@ class LinuxNode(ResourceManager):
 
     def wait_run(self, pid, ppid, trial = 0):
         """ wait for a remote process to finish execution """
-        start_delay = 1.0
+        delay = 1.0
 
         while True:
             status = self.status(pid, ppid)
@@ -920,42 +990,56 @@ class LinuxNode(ResourceManager):
             return True
 
         out = err = ""
+        msg = "Unresponsive host. Wrong answer. "
+
+        # The underlying SSH layer will sometimes return an empty
+        # output (even if the command was executed without errors).
+        # To work arround this, repeat the operation N times or
+        # until the result is not empty string
         try:
             (out, err), proc = self.execute("echo 'ALIVE'",
-                    retry = 5, 
+                    blocking = True,
                     with_lock = True)
+    
+            if out.find("ALIVE") > -1:
+                return True
         except:
             trace = traceback.format_exc()
-            msg = "Unresponsive host  %s " % err
-            self.error(msg, out, trace)
-            return False
+            msg = "Unresponsive host. Error reaching host: %s " % trace
 
-        if out.strip() == "ALIVE":
-            return True
-        else:
-            msg = "Unresponsive host "
-            self.error(msg, out, err)
-            return False
+        self.error(msg, out, err)
+        return False
 
     def find_home(self):
         """ Retrieves host home directory
         """
-        (out, err), proc = self.execute("echo ${HOME}", retry = 5, 
+        # The underlying SSH layer will sometimes return an empty
+        # output (even if the command was executed without errors).
+        # To work arround this, repeat the operation N times or
+        # until the result is not empty string
+        msg = "Impossible to retrieve HOME directory"
+        try:
+            (out, err), proc = self.execute("echo ${HOME}",
+                    blocking = True,
                     with_lock = True)
+    
+            if out.strip() != "":
+                self._home_dir =  out.strip()
+        except:
+            trace = traceback.format_exc()
+            msg = "Impossible to retrieve HOME directory %s" % trace
 
-        if proc.poll():
-            msg = "Imposible to retrieve HOME directory"
-            self.error(msg, out, err)
+        if not self._home_dir:
+            self.error(msg)
             raise RuntimeError, msg
 
-        self._home_dir =  out.strip()
-
     def filter_existing_files(self, src, dst):
         """ Removes files that already exist in the Linux host from src list
         """
         # construct a dictionary with { dst: src }
-        dests = dict(map(lambda x: ( os.path.join(dst, os.path.basename(x) ),  x ), 
-            src.strip().split(" ") ) ) if src.strip().find(" ") != -1 else dict({dst: src})
+        dests = dict(map(
+            lambda s: (os.path.join(dst, os.path.basename(s)), s ), s)) \
+                    if len(src) > 1 else dict({dst: src[0]})
 
         command = []
         for d in dests.keys():
@@ -970,7 +1054,7 @@ class LinuxNode(ResourceManager):
                 del dests[d]
 
         if not dests:
-            return ""
+            return []
 
-        return " ".join(dests.values())
+        return dests.values()