SSHApi functionality migrated to LinuxNode
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Wed, 24 Apr 2013 20:13:51 +0000 (22:13 +0200)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Wed, 24 Apr 2013 20:13:51 +0000 (22:13 +0200)
src/neco/execution/attribute.py
src/neco/resources/linux/application.py
src/neco/resources/linux/node.py
src/neco/resources/linux/ssh_api.py [deleted file]
src/neco/util/sshfuncs.py
test/execution/resource.py
test/resources/linux/node.py [changed mode: 0755->0644]
test/resources/linux/ssh_api.py [deleted file]
test/util/sshfuncs.py

index a24e76d..48d7848 100644 (file)
@@ -18,12 +18,12 @@ class Flags:
 
 class Attribute(object):
     def __init__(self, name, help, type = Types.String,
-            flags = Flags.NoFlags, default_value = None):
+            flags = Flags.NoFlags, default = None):
         self._name = name
         self._help = help
         self._type = type
         self._flags = flags
-        self._default = self._value = default_value
+        self._default = self._value = default
 
     @property
     def name(self):
@@ -31,7 +31,7 @@ class Attribute(object):
 
     @property
     def default(self):
-        return self._default_value
+        return self._default
 
     @property
     def type(self):
index 2005617..0c2f0ec 100644 (file)
@@ -123,6 +123,7 @@ class LinuxApplication(ResourceManager):
         self._pid, self._ppid = self.api.checkpid(self.app_home)
 
     def stop(self):
+        # Kill
         self._state = ResourceState.STOPPED
 
     def release(self):
index 0424f1e..56eb6e2 100644 (file)
@@ -1,8 +1,18 @@
 from neco.execution.attribute import Attribute, Flags
 from neco.execution.resource import ResourceManager, clsinit, ResourceState
-from neco.resources.linux.ssh_api import SSHApiFactory
+from neco.resources.linux import rpmfuncs, debfuncs 
+from neco.util import sshfuncs, execfuncs 
 
+import collections
 import logging
+import os
+import random
+import re
+import tempfile
+import time
+import threading
+
+# TODO: Verify files and dirs exists already
 
 @clsinit
 class LinuxNode(ResourceManager):
@@ -11,95 +21,562 @@ class LinuxNode(ResourceManager):
     @classmethod
     def _register_attributes(cls):
         hostname = Attribute("hostname", "Hostname of the machine")
+
         username = Attribute("username", "Local account username", 
                 flags = Flags.Credential)
+
+        port = Attribute("port", "SSH port", flags = Flags.Credential)
+        
+        home = Attribute("home", 
+                "Experiment home directory to store all experiment related files")
+        
         identity = Attribute("identity", "SSH identity file",
                 flags = Flags.Credential)
-        clean_home = Attribute("cleanHome", "Remove all files and directories 
-                from home folder before starting experiment", 
+        
+        server_key = Attribute("serverKey", "Server public key", 
+                flags = Flags.Credential)
+        
+        clean_home = Attribute("cleanHome", "Remove all files and directories " + \
+                " from home folder before starting experiment", 
                 flags = Flags.ReadOnly)
+        
         clean_processes = Attribute("cleanProcesses", 
                 "Kill all running processes before starting experiment", 
                 flags = Flags.ReadOnly)
-        tear_down = Attribute("tearDown", "Bash script to be executed before
-                releasing the resource", flags = Flags.ReadOnly)
+        
+        tear_down = Attribute("tearDown", "Bash script to be executed before " + \
+                "releasing the resource", flags = Flags.ReadOnly)
 
         cls._register_attribute(hostname)
         cls._register_attribute(username)
+        cls._register_attribute(port)
+        cls._register_attribute(home)
         cls._register_attribute(identity)
+        cls._register_attribute(server_key)
         cls._register_attribute(clean_home)
         cls._register_attribute(clean_processes)
         cls._register_attribute(tear_down)
 
     def __init__(self, ec, guid):
         super(LinuxNode, self).__init__(ec, guid)
+        self._os = None
+        self._home = "nepi-exp-%s" % os.urandom(8).encode('hex')
+        
+        # lock to avoid concurrency issues on methods used by applications 
+        self._lock = threading.Lock()
 
-        self._logger = logging.getLogger("neco.linux.Node.%d" % guid)
+        self._logger = logging.getLogger("neco.linux.Node.%d " % self.guid)
+
+    @property
+    def home(self):
+        home = self.get("home")
+        if home and not home.startswith("nepi-"):
+            home = "nepi-" + home
+        return home or self._home
+
+    @property
+    def os(self):
+        if self._os:
+            return self._os
+
+        if (not self.get("hostname") or not self.get("username")):
+            msg = "Can't resolve OS for guid %d. Insufficient data." % self.guid
+            self.logger.error(msg)
+            raise RuntimeError, msg
+
+        (out, err), proc = self.execute("cat /etc/issue")
+
+        if err and proc.poll():
+            msg = "Error detecting OS for host %s. err: %s " % (self.get("hostname"), err)
+            self.logger.error(msg)
+            raise RuntimeError, msg
+
+        if out.find("Fedora release 12") == 0:
+            self._os = "f12"
+        elif out.find("Fedora release 14") == 0:
+            self._os = "f14"
+        elif out.find("Debian") == 0: 
+            self._os = "debian"
+        elif out.find("Ubuntu") ==0:
+            self._os = "ubuntu"
+        else:
+            msg = "Unsupported OS %s for host %s" % (out, self.get("hostname"))
+            self.logger.error(msg)
+            raise RuntimeError, msg
+
+        return self._os
+
+    @property
+    def localhost(self):
+        return self.get("hostname") in ['localhost', '127.0.0.7', '::1']
 
     def provision(self, filters = None):
-        if not self.api.is_alive():
+        if not self.is_alive():
             self._state = ResourceState.FAILED
             self.logger.error("Deploy failed. Unresponsive node")
             return
-        
+
+    def deploy(self):
+        self.provision()
+
         if self.get("cleanProcesses"):
-            self._clean_processes()
+            self.clean_processes()
 
         if self.get("cleanHome"):
-            # self._clean_home() -> this is dangerous
+            # self.clean_home() -> this is dangerous
             pass
 
-    def deploy(self):
-        self.provision()
+        self.mkdir(self.home)
+
         super(LinuxNode, self).deploy()
 
     def release(self):
         tear_down = self.get("tearDown")
         if tear_down:
-            self.api.execute(tear_down)
+            self.execute(tear_down)
 
         super(LinuxNode, self).release()
 
-    def _validate_connection(self, guid):
+    def validate_connection(self, guid):
         # TODO: Validate!
         return True
 
-    @property
-    def api(self):
-        host = self.get("host")
-        user = self.get("user")
-        identity = self.get("identity")
-        return SSHApiFactory.get_api(host, user, identity)
-
-    def _clean_processes(self):
-        hostname = self.get("hostname")
-        self.logger.info("Cleaning up processes on %s", hostname)
+    def clean_processes(self):
+        self.logger.info("Cleaning up processes")
         
-        cmds = [
-            "sudo -S killall python tcpdump || /bin/true ; "
-            "sudo -S killall python tcpdump || /bin/true ; "
-            "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ",
-            "sudo -S killall -u root || /bin/true ",
-            "sudo -S killall -u root || /bin/true ",
-        ]
-
-        api = self.api
-        for cmd in cmds:
-            out, err = api.execute(cmd)
-            if err:
-                self.logger.error(err)
+        cmd = ("sudo -S killall python tcpdump || /bin/true ; " +
+            "sudo -S killall python tcpdump || /bin/true ; " +
+            "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ; " +
+            "sudo -S killall -u root || /bin/true ; " +
+            "sudo -S killall -u root || /bin/true ; ")
+
+        out = err = ""
+        with self._lock:
+           (out, err), proc = self.run_and_wait(cmd, self.home, 
+                pidfile = "cppid",
+                stdout = "cplog", 
+                stderr = "cperr", 
+                raise_on_error = True)
+
+        return (out, err)   
             
-    def _clean_home(self):
-        hostname = self.get("hostname")
-        self.logger.info("Cleaning up home on %s", hostname)
+    def clean_home(self):
+        self.logger.info("Cleaning up home")
+
+        cmd = "find . -maxdepth 1  \( -name '.cache' -o -name '.local' -o -name '.config' -o -name 'nepi-*' \) -execdir rm -rf {} + "
+
+        out = err = ""
+        with self._lock:
+            (out, err), proc = self.run_and_wait(cmd, self.home,
+                pidfile = "chpid",
+                stdout = "chlog", 
+                stderr = "cherr", 
+                raise_on_error = True)
+        
+        return (out, err)   
+
+    def upload(self, src, dst):
+        """ Copy content to destination
+
+           src  content to copy. Can be a local file, directory or text input
+
+           dst  destination path on the remote host (remote is always self.host)
+        """
+        # If source is a string input 
+        if not os.path.isfile(src):
+            # src is text input that should be uploaded as file
+            # create a temporal file with the content to upload
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(src)
+            f.close()
+            src = f.name
+
+        if not self.localhost:
+            # Build destination as <user>@<server>:<path>
+            dst = "%s@%s:%s" % (self.get("username"), self.get("hostname"), dst)
+
+        return self.copy(src, dst)
+
+    def download(self, src, dst):
+        if not self.localhost:
+            # Build destination as <user>@<server>:<path>
+            src = "%s@%s:%s" % (self.get("username"), self.get("hostname"), src)
+        return self.copy(src, dst)
+
+    def install_packages(self, packages):
+        cmd = ""
+        if self.os in ["f12", "f14"]:
+            cmd = rpmfuncs.install_packages_command(self.os, packages)
+        elif self.os in ["debian", "ubuntu"]:
+            cmd = debfuncs.install_packages_command(self.os, packages)
+        else:
+            msg = "Error installing packages. OS not known for host %s " % (
+                    self.get("hostname"))
+            self.logger.error(msg)
+            raise RuntimeError, msg
 
-         cmds = [
-            "find . -maxdepth 1 ! -name '.bash*' ! -name '.' -execdir rm -rf {} + "
-        ]
+        out = err = ""
+        with self._lock:
+            (out, err), proc = self.run_and_wait(cmd, self.home, 
+                pidfile = "instpkgpid",
+                stdout = "instpkglog", 
+                stderr = "instpkgerr", 
+                raise_on_error = True)
 
-        api = self.api
-        for cmd in cmds:
-            out, err = api.execute(cmd)
+        return (out, err), proc 
+
+    def remove_packages(self, packages):
+        cmd = ""
+        if self.os in ["f12", "f14"]:
+            cmd = rpmfuncs.remove_packages_command(self.os, packages)
+        elif self.os in ["debian", "ubuntu"]:
+            cmd = debfuncs.remove_packages_command(self.os, packages)
+        else:
+            msg = "Error removing packages. OS not known for host %s " % (
+                    self.get("hostname"))
+            self.logger.error(msg)
+            raise RuntimeError, msg
+
+        out = err = ""
+        with self._lock:
+            (out, err), proc = self.run_and_wait(cmd, self.home, 
+                pidfile = "rmpkgpid",
+                stdout = "rmpkglog", 
+                stderr = "rmpkgerr", 
+                raise_on_error = True)
+         
+        return (out, err), proc 
+
+    def mkdir(self, path, clean = False):
+        if clean:
+            self.rmdir(path)
+
+        return self.execute("mkdir -p %s" % path)
+
+    def rmdir(self, path):
+        return self.execute("rm -rf %s" % path)
+
+    def run_and_wait(self, command, 
+            home = ".", 
+            pidfile = "pid", 
+            stdin = None, 
+            stdout = 'stdout', 
+            stderr = 'stderr', 
+            sudo = False,
+            raise_on_error = False):
+
+        (out, err), proc = self.run(command, home, 
+                pidfile = pidfile,
+                stdin = stdin, 
+                stdout = stdout, 
+                stderr = stderr, 
+                sudo = sudo)
+
+        if proc.poll() and err:
+            msg = " Failed to run command %s on host %s" % (
+                    command, self.get("hostname"))
+            self.logger.error(msg)
+            if raise_on_error:
+                raise RuntimeError, msg
+        
+        pid, ppid = self.wait_pid(
+                home = home, 
+                pidfile = pidfile, 
+                raise_on_error = raise_on_error)
+
+        self.wait_run(pid, ppid)
+        
+        (out, err), proc = self.check_run_error(home, stderr)
+
+        if err or out:
+            msg = "Error while running command %s on host %s. error output: %s" % (
+                    command, self.get("hostname"), out)
             if err:
-                self.logger.error(err)
+                msg += " . err: %s" % err
+
+            self.logger.error(msg)
+            if raise_on_error:
+                raise RuntimeError, msg
+        
+        return (out, err), proc
+    def wait_pid(self, home = ".", pidfile = "pid", raise_on_error = False):
+        pid = ppid = None
+        delay = 1.0
+        for i in xrange(5):
+            pidtuple = self.checkpid(home = home, pidfile = pidfile)
+            
+            if pidtuple:
+                pid, ppid = pidtuple
+                break
+            else:
+                time.sleep(delay)
+                delay = min(30,delay*1.2)
+        else:
+            msg = " Failed to get pid for pidfile %s/%s on host %s" % (
+                    home, pidfile, self.get("hostname"))
+            self.logger.error(msg)
+            if raise_on_error:
+                raise RuntimeError, msg
+
+        return pid, ppid
+
+    def wait_run(self, pid, ppid, trial = 0):
+        delay = 1.0
+        first = True
+        bustspin = 0
+
+        while True:
+            status = self.status(pid, ppid)
+            
+            if status is sshfuncs.FINISHED:
+                break
+            elif status is not sshfuncs.RUNNING:
+                bustspin += 1
+                time.sleep(delay*(5.5+random.random()))
+                if bustspin > 12:
+                    break
+            else:
+                if first:
+                    first = False
+
+                time.sleep(delay*(0.5+random.random()))
+                delay = min(30,delay*1.2)
+                bustspin = 0
+
+    def check_run_error(self, home, stderr = 'stderr'):
+        (out, err), proc = self.execute("cat %s" % 
+                os.path.join(home, stderr))
+        return (out, err), proc
+
+    def check_run_output(self, home, stdout = 'stdout'):
+        (out, err), proc = self.execute("cat %s" % 
+                os.path.join(home, stdout))
+        return (out, err), proc
+
+    def is_alive(self):
+        if self.localhost:
+            return True
+
+        out = err = ""
+        try:
+            (out, err), proc = self.execute("echo 'ALIVE'")
+        except:
+            import traceback
+            trace = traceback.format_exc()
+            self.logger.warn("Unresponsive host %s. got:\n out: %s err: %s\n traceback: %s", 
+                    self.get("hostname"), out, err, trace)
+            return False
+
+        if out.strip().startswith('ALIVE'):
+            return True
+        else:
+            self.logger.warn("Unresponsive host %s. got:\n%s%s", 
+                    self.get("hostname"), out, err)
+            return False
+
+            # TODO!
+            #if self.check_bad_host(out,err):
+            #    self.blacklist()
+
+    def copy(self, src, dst):
+        if self.localhost:
+            (out, err), proc =  execfuncs.lcopy(source, dest, 
+                    recursive = True)
+        else:
+            (out, err), proc = self.safe_retry(sshfuncs.rcopy)(
+                src, dst, 
+                port = self.get("port"),
+                identity = self.get("identity"),
+                server_key = self.get("serverKey"),
+                recursive = True)
+
+        return (out, err), proc
+
+    def execute(self, command,
+            sudo = False,
+            stdin = None, 
+            env = None,
+            tty = False,
+            forward_x11 = False,
+            timeout = None,
+            retry = 0,
+            err_on_timeout = True,
+            connect_timeout = 30,
+            persistent = True
+            ):
+        """ Notice that this invocation will block until the
+        execution finishes. If this is not the desired behavior,
+        use 'run' instead."""
+
+        if self.localhost:
+            (out, err), proc = execfuncs.lexec(command, 
+                    user = user,
+                    sudo = sudo,
+                    stdin = stdin,
+                    env = env)
+        else:
+            (out, err), proc = self.safe_retry(sshfuncs.rexec)(
+                    command, 
+                    host = self.get("hostname"),
+                    user = self.get("username"),
+                    port = self.get("port"),
+                    agent = True,
+                    sudo = sudo,
+                    stdin = stdin,
+                    identity = self.get("identity"),
+                    server_key = self.get("serverKey"),
+                    env = env,
+                    tty = tty,
+                    forward_x11 = forward_x11,
+                    timeout = timeout,
+                    retry = retry,
+                    err_on_timeout = err_on_timeout,
+                    connect_timeout = connect_timeout,
+                    persistent = persistent
+                    )
+
+        return (out, err), proc
+
+    def run(self, command, 
+            home = None,
+            create_home = True,
+            pidfile = "pid",
+            stdin = None, 
+            stdout = 'stdout', 
+            stderr = 'stderr', 
+            sudo = False):
+
+        self.logger.info("Running %s", command)
+        
+        if self.localhost:
+            (out, err), proc = execfuncs.lspawn(command, pidfile, 
+                    stdout = stdout, 
+                    stderr = stderr, 
+                    stdin = stdin, 
+                    home = home, 
+                    create_home = create_home, 
+                    sudo = sudo,
+                    user = user) 
+        else:
+            # Start process in a "daemonized" way, using nohup and heavy
+            # stdin/out redirection to avoid connection issues
+            (out,err), proc = self.safe_retry(sshfuncs.rspawn)(
+                command,
+                pidfile = pidfile,
+                home = home,
+                create_home = create_home,
+                stdin = stdin if stdin is not None else '/dev/null',
+                stdout = stdout if stdout else '/dev/null',
+                stderr = stderr if stderr else '/dev/null',
+                sudo = sudo,
+                host = self.get("hostname"),
+                user = self.get("username"),
+                port = self.get("port"),
+                agent = True,
+                identity = self.get("identity"),
+                server_key = self.get("serverKey")
+                )
+
+        return (out, err), proc
+
+    def checkpid(self, home = ".", pidfile = "pid"):
+        if self.localhost:
+            pidtuple =  execfuncs.lcheckpid(os.path.join(home, pidfile))
+        else:
+            pidtuple = sshfuncs.rcheckpid(
+                os.path.join(home, pidfile),
+                host = self.get("hostname"),
+                user = self.get("username"),
+                port = self.get("port"),
+                agent = True,
+                identity = self.get("identity"),
+                server_key = self.get("serverKey")
+                )
+        
+        return pidtuple
+    
+    def status(self, pid, ppid):
+        if self.localhost:
+            status = execfuncs.lstatus(pid, ppid)
+        else:
+            status = sshfuncs.rstatus(
+                    pid, ppid,
+                    host = self.get("hostname"),
+                    user = self.get("username"),
+                    port = self.get("port"),
+                    agent = True,
+                    identity = self.get("identity"),
+                    server_key = self.get("serverKey")
+                    )
+           
+        return status
+    
+    def kill(self, pid, ppid, sudo = False):
+        out = err = ""
+        proc = None
+        status = self.status(pid, ppid)
+
+        if status == sshfuncs.RUNNING:
+            if self.localhost:
+                (out, err), proc = execfuncs.lkill(pid, ppid, sudo)
+            else:
+                (out, err), proc = self.safe_retry(sshfuncs.rkill)(
+                    pid, ppid,
+                    host = self.get("hostname"),
+                    user = self.get("username"),
+                    port = self.get("port"),
+                    agent = True,
+                    sudo = sudo,
+                    identity = self.get("identity"),
+                    server_key = self.get("serverKey")
+                    )
+        return (out, err), proc
+
+    def check_bad_host(self, out, err):
+        badre = re.compile(r'(?:'
+                           r'|Error: disk I/O error'
+                           r')', 
+                           re.I)
+        return badre.search(out) or badre.search(err)
+
+    def blacklist(self):
+        # TODO!!!!
+        self.logger.warn("Blacklisting malfunctioning node %s", self.hostname)
+        #import util
+        #util.appendBlacklist(self.hostname)
+
+    def safe_retry(self, func):
+        """Retries a function invocation using a lock"""
+        import functools
+        @functools.wraps(func)
+        def rv(*p, **kw):
+            fail_msg = " Failed to execute function %s(%s, %s) at host %s" % (
+                func.__name__, p, kw, self.get("hostname"))
+            retry = kw.pop("_retry", False)
+            wlock = kw.pop("_with_lock", False)
+
+            out = err = ""
+            proc = None
+            for i in xrange(0 if retry else 4):
+                try:
+                    if wlock:
+                        with self._lock:
+                            (out, err), proc = func(*p, **kw)
+                    else:
+                        (out, err), proc = func(*p, **kw)
+                        
+                    if proc.poll():
+                        if retry:
+                            time.sleep(i*15)
+                            continue
+                        else:
+                            self.logger.error("%s. out: %s error: %s", fail_msg, out, err)
+                    break
+                except RuntimeError, e:
+                    if x >= 3:
+                        self.logger.error("%s. error: %s", fail_msg, e.args)
+            return (out, err), proc
+
+        return rv
 
diff --git a/src/neco/resources/linux/ssh_api.py b/src/neco/resources/linux/ssh_api.py
deleted file mode 100644 (file)
index 7f5009e..0000000
+++ /dev/null
@@ -1,334 +0,0 @@
-
-from neco.util.sshfuncs import eintr_retry, rexec, rcopy, rspawn, \
-        rcheckpid, rstatus, rkill, RUNNING, FINISHED 
-
-import hashlib
-import logging
-import os
-import re
-import tempfile
-
-class SSHApi(object):
-    def __init__(self, host, user, port, identity, agent, forward_x11):
-        self.host = host
-        self.user = user
-        # ssh identity file
-        self.identity = identity
-        self.port = port
-        # use ssh agent
-        self.agent = agent
-        # forward X11 
-        self.forward_x11 = forward_x11
-
-        self._pm = None
-        
-        self._logger = logging.getLogger("neco.linux.SSHApi")
-
-    # TODO: Investigate using http://nixos.org/nix/
-    @property
-    def pm(self):
-        if self._pm:
-            return self._pm
-
-        if (not self.host or not self.user):
-            msg = "Can't resolve package management system. Insufficient data."
-            self._logger.error(msg)
-            raise RuntimeError(msg)
-
-        out, err = self.execute("cat /etc/issue")
-
-        if out.find("Fedora") == 0:
-            self._pm = "yum"
-        elif out.find("Debian") == 0 or out.find("Ubuntu") ==0:
-            self._pm = "apt-get"
-        else:
-            msg = "Can't resolve package management system. Unknown OS."
-            self._logger.error(msg)
-            raise RuntimeError(msg)
-
-        return self._pm
-
-    @property
-    def is_localhost(self):
-        return self.host in ['localhost', '127.0.0.7', '::1']
-
-    # TODO: Investigate using http://nixos.org/nix/
-    def install(self, packages):
-        if not isinstance(packages, list):
-            packages = [packages]
-
-        for p in packages:
-            self.execute("%s -y install %s" % (self.pm, p), sudo = True, 
-                    tty = True)
-
-    # TODO: Investigate using http://nixos.org/nix/
-    def uninstall(self, packages):
-        if not isinstance(packages, list):
-            packages = [packages]
-
-        for p in packages:
-            self.execute("%s -y remove %s" % (self.pm, p), sudo = True, 
-                    tty = True)
-
-    def upload(self, src, dst):
-        """ Copy content to destination
-
-           src  content to copy. Can be a local file, directory or text input
-
-           dst  destination path on the remote host (remote is always self.host)
-        """
-        # If source is a string input 
-        if not os.path.isfile(src):
-            # src is text input that should be uploaded as file
-            # create a temporal file with the content to upload
-            f = tempfile.NamedTemporaryFile(delete=False)
-            f.write(src)
-            f.close()
-            src = f.name
-
-        if not self.is_localhost:
-            # Build destination as <user>@<server>:<path>
-            dst = "%s@%s:%s" % (self.user, self.host, dst)
-
-        ret = self.copy(src, dst)
-
-        return ret
-
-    def download(self, src, dst):
-        if not self.is_localhost:
-            # Build destination as <user>@<server>:<path>
-            src = "%s@%s:%s" % (self.user, self.host, src)
-        return self.copy(src, dst)
-        
-    def is_alive(self, verbose = False):
-        if self.is_localhost:
-            return True
-
-        try:
-            (out, err) = self.execute("echo 'ALIVE'",
-                timeout = 60,
-                err_on_timeout = False,
-                persistent = False)
-        except:
-            if verbose:
-                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
-            return False
-
-        if out.strip().startswith('ALIVE'):
-            return True
-        else:
-            if verbose:
-                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
-            return False
-
-    def mkdir(self, path, clean = True):
-        if clean:
-            self.rmdir(path)
-
-        return self.execute(
-            "mkdir -p %s" % path,
-            timeout = 120,
-            retry = 3
-            )
-
-    def rmdir(self, path):
-        return self.execute(
-            "rm -rf %s" % path,
-            timeout = 120,
-            retry = 3
-            )
-
-    def copy(self, src, dst):
-        if self.is_localhost:
-            command = ["cp", "-R", src, dst]
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            (out, err), proc = eintr_retry(rcopy)(
-                src, dst, 
-                port = self.port,
-                agent = self.agent,
-                identity = self.identity)
-
-            if proc.wait():
-                msg = "Error uploading to %s got:\n%s%s" %\
-                        (self.host, out, err)
-                self._logger.error(msg)
-                raise RuntimeError(msg)
-
-        return (out, err)
-
-    def execute(self, command,
-            sudo = False,
-            stdin = None, 
-            tty = False,
-            env = None,
-            timeout = None,
-            retry = 0,
-            err_on_timeout = True,
-            connect_timeout = 30,
-            persistent = True):
-        """ Notice that this invocation will block until the
-        execution finishes. If this is not the desired behavior,
-        use 'run' instead."""
-
-        if self.is_localhost:
-            if env:
-                export = ''
-                for envkey, envval in env.iteritems():
-                    export += '%s=%s ' % (envkey, envval)
-                command = export + command
-
-            if sudo:
-                command = "sudo " + command
-
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            (out, err), proc = eintr_retry(rexec)(
-                    command, 
-                    self.host, 
-                    self.user,
-                    port = self.port, 
-                    agent = self.agent,
-                    sudo = sudo,
-                    stdin = stdin, 
-                    identity = self.identity,
-                    tty = tty,
-                    x11 = self.forward_x11,
-                    env = env,
-                    timeout = timeout,
-                    retry = retry,
-                    err_on_timeout = err_on_timeout,
-                    connect_timeout = connect_timeout,
-                    persistent = persistent)
-
-            if proc.wait():
-                msg = "Failed to execute command %s at node %s: %s %s" % \
-                        (command, self.host, out, err,)
-                self._logger.warn(msg)
-                raise RuntimeError(msg)
-        return (out, err)
-
-    def run(self, command, home, 
-            stdin = None, 
-            stdout = 'stdout', 
-            stderr = 'stderr', 
-            sudo = False):
-        self._logger.info("Running %s", command)
-        
-        pidfile = './pid'
-
-        if self.is_localhost:
-            if stderr == stdout:
-                stderr = '&1'
-            else:
-                stderr = ' ' + stderr
-            
-            daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
-                'command' : command,
-                'pidfile' : pidfile,
-                
-                'stdout' : stdout,
-                'stderr' : stderr,
-                'stdin' : stdin,
-            }
-            
-            cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-                    'command' : daemon_command,
-                    
-                    'sudo' : 'sudo -S' if sudo else '',
-                    
-                    'pidfile' : pidfile,
-                    'gohome' : 'cd %s ; ' % home if home else '',
-                    'create' : 'mkdir -p %s ; ' % home if create_home else '',
-                }
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            # Start process in a "daemonized" way, using nohup and heavy
-            # stdin/out redirection to avoid connection issues
-            (out,err), proc = rspawn(
-                command,
-                pidfile = pidfile,
-                home = home,
-                stdin = stdin if stdin is not None else '/dev/null',
-                stdout = stdout if stdout else '/dev/null',
-                stderr = stderr if stderr else '/dev/null',
-                sudo = sudo,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.agent,
-                identity = self.identity
-                )
-            
-            if proc.wait():
-                raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
-
-        return (out, err)
-    
-    def checkpid(self, path):            
-        # Get PID/PPID
-        # NOTE: wait a bit for the pidfile to be created
-        pidtuple = rcheckpid(
-            os.path.join(path, 'pid'),
-            host = self.host,
-            user = self.user,
-            port = self.port,
-            agent = self.agent,
-            identity = self.identity
-            )
-        
-        return pidtuple
-    
-    def status(self, pid, ppid):
-        status = rstatus(
-                pid, ppid,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.agent,
-                identity = self.identity
-                )
-           
-        return status
-    
-    def kill(self, pid, ppid, sudo = False):
-        status = self.status(pid, ppid)
-        if status == RUNNING:
-            # kill by ppid+pid - SIGTERM first, then try SIGKILL
-            rkill(
-                pid, ppid,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.agent,
-                sudo = sudo,
-                identity = self.identity
-                )
-
-class SSHApiFactory(object):
-    _apis = dict()
-
-    @classmethod 
-    def get_api(cls, host, user, port = 22, identity = None, 
-            agent = True, forward_X11 = False):
-        key = cls.make_key(host, user, port, agent, forward_X11)
-        api = cls._apis.get(key)
-
-        if not api:
-            api = SSHApi(host, user, port, identity, agent, forward_X11)
-            cls._apis[key] = api
-
-        return api
-
-    @classmethod 
-    def make_key(cls, *args):
-        skey = "".join(map(str, args))
-        return hashlib.md5(skey).hexdigest()
-
index dd04ba6..e81c558 100644 (file)
@@ -12,8 +12,7 @@ import re
 import tempfile
 import hashlib
 
-OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyy_ssh_ctrl_path"
+TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -22,8 +21,6 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
-hostbyname_cache = dict()
-
 class STDOUT: 
     """
     Special value that when given to rspawn in stderr causes stderr to 
@@ -45,6 +42,17 @@ class NOT_STARTED:
     Process hasn't started running yet (this should be very rare)
     """
 
+hostbyname_cache = dict()
+
+def gethostbyname(host):
+    hostbyname = hostbyname_cache.get(host)
+    if not hostbyname:
+        hostbyname = socket.gethostbyname(host)
+        hostbyname_cache[host] = hostbyname
+    return hostbyname
+
+OPENSSH_HAS_PERSIST = None
+
 def openssh_has_persist():
     """ The ssh_config options ControlMaster and ControlPersist allow to
     reuse a same network connection for multiple ssh sessions. In this 
@@ -66,6 +74,59 @@ def openssh_has_persist():
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
     return OPENSSH_HAS_PERSIST
 
+def make_server_key_args(server_key, host, port):
+    """ Returns a reference to a temporary known_hosts file, to which 
+    the server key has been added. 
+    
+    Make sure to hold onto the temp file reference until the process is 
+    done with it
+
+    :param server_key: the server public key
+    :type server_key: str
+
+    :param host: the hostname
+    :type host: str
+
+    :param port: the ssh port
+    :type port: str
+
+    """
+    if port is not None:
+        host = '%s:%s' % (host, str(port))
+
+    # Create a temporary server key file
+    tmp_known_hosts = tempfile.NamedTemporaryFile()
+   
+    hostbyname = gethostbyname(host) 
+
+    # Add the intended host key
+    tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
+    
+    # If we're not in strict mode, add user-configured keys
+    if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
+        user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
+        if os.access(user_hosts_path, os.R_OK):
+            f = open(user_hosts_path, "r")
+            tmp_known_hosts.write(f.read())
+            f.close()
+        
+    tmp_known_hosts.flush()
+    
+    return tmp_known_hosts
+
+def make_control_path(agent, forward_x11):
+    ctrl_path = "/tmp/nepi_ssh"
+
+    if agent:
+        ctrl_path +="_a"
+
+    if forward_x11:
+        ctrl_path +="_x"
+
+    ctrl_path += "-%r@%h:%p"
+
+    return ctrl_path
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line 
     arguments """
@@ -105,81 +166,63 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
-def make_connkey(user, host, port, x11, agent):
-    # It is important to consider the x11 and agent forwarding
-    # parameters when creating the connection key since the parameters
-    # used for the first ssh connection will determine the
-    # parameters of all subsequent connections using the same key
-    x11 = 1 if x11 else 0
-    agent = 1 if agent else 0
-
-    connkey = repr((user, host, port, x11, agent)
-            ).encode("base64").strip().replace('/','.')
-
-    if len(connkey) > 60:
-        connkey = hashlib.sha1(connkey).hexdigest()
-    return connkey
-
-def make_control_path(user, host, port, x11, agent):
-    connkey = make_connkey(user, host, port, x11, agent)
-    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
-
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
         stdin = None,
         identity = None,
+        server_key = None,
         env = None,
         tty = False,
-        x11 = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
         connect_timeout = 30,
-        persistent = True):
+        persistent = True,
+        forward_x11 = False):
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
+    
+    tmp_known_hosts = None
+    hostip = gethostbyname(host)
+
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
-            # XXX: Possible security issue
-            # Avoid interactive requests to accept new host keys
-            '-o', 'StrictHostKeyChecking=no',
             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
-            '-l', user, host]
+            '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
-        control_path = make_control_path(user, host, port, x11, agent)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=%s' % control_path,
+            '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
             '-o', 'ControlPersist=60' ])
+
     if agent:
         args.append('-A')
+
     if port:
         args.append('-p%d' % port)
+
     if identity:
         args.extend(('-i', identity))
+
     if tty:
         args.append('-t')
-        if sudo:
-            args.append('-t')
-    if x11:
-        args.append('-X')
+        args.append('-t')
 
-    if env:
-        export = ''
-        for envkey, envval in env.iteritems():
-            export += '%s=%s ' % (envkey, envval)
-        command = export + command
+    if forward_x11:
+        args.append('-X')
 
-    if sudo:
-        command = "sudo " + command
+    if server_key:
+        # Create a temporary server key file
+        tmp_known_hosts = make_server_key_args(server_key, host, port)
+        args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
     args.append(command)
 
@@ -190,8 +233,15 @@ def rexec(command, host, user,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
         
+        # attach tempfile object to the process, to make sure the file stays
+        # alive until the process is finished with it
+        proc._known_hosts = tmp_known_hosts
+    
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
+            if TRACE:
+                print "COMMAND host %s, command %s, out %s, error %s" % (host, " ".join(args), out, err)
+
             if proc.poll():
                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
                     # SSH error, can safely retry
@@ -200,7 +250,11 @@ def rexec(command, host, user,
                     # Probably timed out or plain failed but can retry
                     continue
             break
-        except RuntimeError,e:
+        except RuntimeError, e:
+            if TRACE:
+                print "EXCEPTION host %s, command %s, out %s, error %s, exception TIMEOUT ->  %s" % (
+                        host, " ".join(args), out, err, e.args)
+
             if retry <= 0:
                 raise
             retry -= 1
@@ -210,67 +264,242 @@ def rexec(command, host, user,
 def rcopy(source, dest,
         port = None, 
         agent = True, 
-        identity = None):
+        recursive = False,
+        identity = None,
+        server_key = None):
     """
-    Copies file from/to remote sites.
+    Copies from/to remote sites.
     
     Source and destination should have the user and host encoded
     as per scp specs.
     
+    If source is a file object, a special mode will be used to
+    create the remote file with the same contents.
+    
+    If dest is a file object, the remote file (source) will be
+    read and written into dest.
+    
+    In these modes, recursive cannot be True.
+    
     Source can be a list of files to copy to a single destination,
     in which case it is advised that the destination be a folder.
     """
     
-    # Parse destination as <user>@<server>:<path>
-    if isinstance(dest, basestring) and ':' in dest:
-        remspec, path = dest.split(':',1)
-    elif isinstance(source, basestring) and ':' in source:
-        remspec, path = source.split(':',1)
-    else:
-        raise ValueError, "Both endpoints cannot be local"
-    user, host = remspec.rsplit('@',1)
-
-    raw_string = r'''rsync -rlpcSz --timeout=900 '''
-    raw_string += r''' -e 'ssh -o BatchMode=yes '''
-    raw_string += r''' -o NoHostAuthenticationForLocalhost=yes '''
-    raw_string += r''' -o StrictHostKeyChecking=no '''
-    raw_string += r''' -o ConnectionAttempts=3 '''
-    if openssh_has_persist():
-        control_path = make_control_path(user, host, port, False, agent)
-        raw_string += r''' -o ControlMaster=auto '''
-        raw_string += r''' -o ControlPath=%s ''' % control_path
-    if agent:
-        raw_string += r''' -A '''
-
-    if port:
-        raw_string += r''' -p %d ''' % port
+    if TRACE:
+        print "scp", source, dest
     
-    if identity:
-        raw_string += r''' -i "%s" ''' % identity
+    if isinstance(source, file) and source.tell() == 0:
+        source = source.name
+    elif hasattr(source, 'read'):
+        tmp = tempfile.NamedTemporaryFile()
+        while True:
+            buf = source.read(65536)
+            if buf:
+                tmp.write(buf)
+            else:
+                break
+        tmp.seek(0)
+        source = tmp.name
     
-    # closing -e 'ssh...'
-    raw_string += r''' ' '''
-
-    if isinstance(source,list):
-        source = ' '.join(source)
+    if isinstance(source, file) or isinstance(dest, file) \
+            or hasattr(source, 'read')  or hasattr(dest, 'write'):
+        assert not recursive
+        
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
+        tmp_known_hosts = None
+        hostip = gethostbyname(host)
+        
+        args = ['ssh', '-l', user, '-C',
+                # Don't bother with localhost. Makes test easier
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
+                hostip or host ]
+
+        if openssh_has_persist():
+            args.extend([
+                '-o', 'ControlMaster=auto',
+                '-o', 'ControlPath=%s' % (make_control_path(agent, False),),
+                '-o', 'ControlPersist=60' ])
+
+        if port:
+            args.append('-P%d' % port)
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+        
+        if isinstance(source, file) or hasattr(source, 'read'):
+            args.append('cat > %s' % (shell_escape(path),))
+        elif isinstance(dest, file) or hasattr(dest, 'write'):
+            args.append('cat %s' % (shell_escape(path),))
+        else:
+            raise AssertionError, "Unreachable code reached! :-Q"
+        
+        # connects to the remote host and starts a remote connection
+        if isinstance(source, file):
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = source)
+            err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,err), proc)
+        elif isinstance(dest, file):
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = source)
+            err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,err), proc)
+        elif hasattr(source, 'read'):
+            # file-like (but not file) source
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = subprocess.PIPE)
+            
+            buf = None
+            err = []
+            while True:
+                if not buf:
+                    buf = source.read(4096)
+                if not buf:
+                    #EOF
+                    break
+                
+                rdrdy, wrdy, broken = select.select(
+                    [proc.stderr],
+                    [proc.stdin],
+                    [proc.stderr,proc.stdin])
+                
+                if proc.stderr in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    err.append(os.read(proc.stderr.fileno(), 4096))
+                
+                if proc.stdin in wrdy:
+                    proc.stdin.write(buf)
+                    buf = None
+                
+                if broken:
+                    break
+            proc.stdin.close()
+            err.append(proc.stderr.read())
+                
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,''.join(err)), proc)
+        elif hasattr(dest, 'write'):
+            # file-like (but not file) dest
+            proc = subprocess.Popen(args, 
+                    stdout = subprocess.PIPE,
+                    stderr = subprocess.PIPE,
+                    stdin = open('/dev/null','w'))
+            
+            buf = None
+            err = []
+            while True:
+                rdrdy, wrdy, broken = select.select(
+                    [proc.stderr, proc.stdout],
+                    [],
+                    [proc.stderr, proc.stdout])
+                
+                if proc.stderr in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    err.append(os.read(proc.stderr.fileno(), 4096))
+                
+                if proc.stdout in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    buf = os.read(proc.stdout.fileno(), 4096)
+                    dest.write(buf)
+                    
+                    if not buf:
+                        #EOF
+                        break
+                
+                if broken:
+                    break
+            err.append(proc.stderr.read())
+                
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,''.join(err)), proc)
+        else:
+            raise AssertionError, "Unreachable code reached! :-Q"
     else:
-        source = '"%s"' % source
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
+        # plain scp
+        tmp_known_hosts = None
+
+        args = ['scp', '-q', '-p', '-C',
+                # Don't bother with localhost. Makes test easier
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes' ]
+                
+        if port:
+            args.append('-P%d' % port)
+
+        if recursive:
+            args.append('-r')
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+
+        if isinstance(source,list):
+            args.extend(source)
+        else:
+            if openssh_has_persist():
+                args.extend([
+                    '-o', 'ControlMaster=auto',
+                    '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                    ])
+            args.append(source)
 
-    raw_string += r''' %s ''' % source
-    raw_string += r''' %s ''' % dest
+        args.append(dest)
 
-    # connects to the remote host and starts a remote connection
-    proc = subprocess.Popen(raw_string,
-            shell=True,
-            stdout = subprocess.PIPE,
-            stdin = subprocess.PIPE, 
-            stderr = subprocess.PIPE)
-  
-    comm = proc.communicate()
-    eintr_retry(proc.wait)()
-    return (comm, proc)
+        # connects to the remote host and starts a remote connection
+        proc = subprocess.Popen(args, 
+                stdout = subprocess.PIPE,
+                stdin = subprocess.PIPE, 
+                stderr = subprocess.PIPE)
+        proc._known_hosts = tmp_known_hosts
+        
+        (out, err) = proc.communicate()
+        eintr_retry(proc.wait)()
+        return ((out, err), proc)
 
 def rspawn(command, pidfile, 
         stdout = '/dev/null', 
@@ -278,12 +507,13 @@ def rspawn(command, pidfile,
         stdin = '/dev/null', 
         home = None, 
         create_home = False, 
+        sudo = False,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        sudo = False,
         identity = None, 
+        server_key = None,
         tty = False):
     """
     Spawn a remote command such that it will continue working asynchronously.
@@ -319,41 +549,38 @@ def rspawn(command, pidfile,
         stderr = '&1'
     else:
         stderr = ' ' + stderr
-   
-    #XXX: ppid is always 1!!!
+    
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
-        'pidfile' : pidfile,
-        
+        'pidfile' : shell_escape(pidfile),
         'stdout' : stdout,
         'stderr' : stderr,
         'stdin' : stdin,
     }
     
-    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-            'command' : daemon_command,
-            
+    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
+            'command' : shell_escape(daemon_command),
             'sudo' : 'sudo -S' if sudo else '',
-            
-            'pidfile' : pidfile,
-            'gohome' : 'cd %s ; ' % home if home else '',
-            'create' : 'mkdir -p %s ; ' % home if create_home else '',
+            'pidfile' : shell_escape(pidfile),
+            'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
+            'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
         }
 
-    (out,err), proc = rexec(
+    (out,err),proc = rexec(
         cmd,
         host = host,
         port = port,
         user = user,
         agent = agent,
         identity = identity,
-        tty = tty
+        server_key = server_key,
+        tty = tty ,
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
 
-    return (out,err),proc
+    return ((out, err), proc)
 
 @eintr_retry
 def rcheckpid(pidfile,
@@ -361,7 +588,8 @@ def rcheckpid(pidfile,
         port = None, 
         user = None, 
         agent = None, 
-        identity = None):
+        identity = None,
+        server_key = None):
     """
     Check the pidfile of a process spawned with remote_spawn.
     
@@ -384,7 +612,8 @@ def rcheckpid(pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
         
     if proc.wait():
@@ -403,7 +632,8 @@ def rstatus(pid, ppid,
         port = None, 
         user = None, 
         agent = None, 
-        identity = None):
+        identity = None,
+        server_key = None):
     """
     Check the status of a process spawned with remote_spawn.
     
@@ -417,9 +647,9 @@ def rstatus(pid, ppid,
         One of NOT_STARTED, RUNNING, FINISHED
     """
 
-    # XXX: ppid unused
     (out,err),proc = rexec(
-        "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
+        # Check only by pid. pid+ppid does not always work (especially with sudo) 
+        " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
             'ppid' : ppid,
             'pid' : pid,
         },
@@ -427,23 +657,22 @@ def rstatus(pid, ppid,
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
     
     if proc.wait():
         return NOT_STARTED
     
     status = False
-    if out:
-        try:
-            status = bool(int(out.strip()))
-        except:
-            if out or err:
-                logging.warn("Error checking remote status:\n%s%s\n", out, err)
-            # Ignore, many ways to fail that don't matter that much
-            return NOT_STARTED
+    if err:
+        if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
+            status = True
+    elif out:
+        status = (out.strip() == 'wait')
+    else:
+        return NOT_STARTED
     return RUNNING if status else FINISHED
-    
 
 @eintr_retry
 def rkill(pid, ppid,
@@ -453,6 +682,7 @@ def rkill(pid, ppid,
         agent = None, 
         sudo = False,
         identity = None, 
+        server_key = None, 
         nowait = False):
     """
     Kill a process spawned with remote_spawn.
@@ -506,12 +736,15 @@ fi
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
     
     # wait, don't leave zombies around
     proc.wait()
 
+    return (out, err), proc
+
 # POSIX
 def _communicate(self, input, timeout=None, err_on_timeout=True):
     read_set = []
index 15b6e22..87e0338 100755 (executable)
@@ -34,8 +34,8 @@ class ResourceFactoryTestCase(unittest.TestCase):
         self.assertEquals(MyResource.rtype(), "MyResource")
         self.assertEquals(len(MyResource._attributes), 1)
 
-        self.assertEquals(Resource.rtype(), "Resource")
-        self.assertEquals(len(Resource._attributes), 0)
+        self.assertEquals(ResourceManager.rtype(), "Resource")
+        self.assertEquals(len(ResourceManager._attributes), 0)
 
         self.assertEquals(AnotherResource.rtype(), "AnotherResource")
         self.assertEquals(len(AnotherResource._attributes), 0)
old mode 100755 (executable)
new mode 100644 (file)
index c3563e3..7133e1c
@@ -2,11 +2,36 @@
 from neco.resources.linux.node import LinuxNode
 from neco.util.sshfuncs import RUNNING, FINISHED
 
-import os.path
+import os
 import time
 import tempfile
 import unittest
 
+def skipIfNotAlive(func):
+    name = func.__name__
+    def wrapped(*args, **kwargs):
+        node = args[1]
+
+        if not node.is_alive():
+            print "*** WARNING: Skipping test %s: Node %s is not alive\n" % (
+                name, node.get("hostname"))
+            return
+
+        return func(*args, **kwargs)
+    
+    return wrapped
+
+def skipInteractive(func):
+    name = func.__name__
+    def wrapped(*args, **kwargs):
+        mode = os.environ.get("NEPI_INTERACTIVE", False) in ['True', 'true', 'yes', 'YES']
+        if not mode:
+            print "*** WARNING: Skipping test %s: Interactive mode off \n" % name
+            return
+
+        return func(*args, **kwargs)
+    
+    return wrapped
 
 class DummyEC(object):
     pass
@@ -22,52 +47,44 @@ class LinuxNodeTestCase(unittest.TestCase):
         self.node_ubuntu = self.create_node(host, user)
         
         self.target = 'nepi5.pl.sophia.inria.fr'
-        self.home = '${HOME}/test-app'
+        self.home = '/tmp/nepi-home/test-app'
 
     def create_node(self, host, user):
         ec = DummyEC()
 
         node = LinuxNode(ec, 1)
-        node.host = host
-        node.user = user
+        node.set("hostname", host)
+        node.set("username", user)
 
         return node
 
+    @skipIfNotAlive
     def t_xterm(self, node):
-        if not node.is_alive():
-            print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
-            return 
-
-        node.enable_x11 = True
-
-        node.install('xterm')
-
-        out = node.execute('xterm')
-
-        node.uninstall('xterm')
+        node.install_packages('xterm')
 
+        (out, err), proc = node.execute('xterm', forward_x11 = True)
+        
         self.assertEquals(out, "")
 
-    def t_execute(self, node, target):
-        if not node.is_alive():
-            print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
-            return 
+        (out, err), proc = node.remove_packages('xterm')
+        
+        self.assertEquals(out, "")
 
-        command = "ping -qc3 %s" % target
-        out = node.execute(command)
+    @skipIfNotAlive
+    def t_execute(self, node):
+        command = "ping -qc3 %s" % self.target
+        
+        (out, err), proc = node.execute(command)
 
         expected = """3 packets transmitted, 3 received, 0% packet loss"""
 
         self.assertTrue(out.find(expected) > 0)
 
-    def t_run(self, node, target):
-        if not node.is_alive():
-            print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
-            return
-
+    @skipIfNotAlive
+    def t_run(self, node):
         node.mkdir(self.home, clean = True)
         
-        command = "ping %s" % target
+        command = "ping %s" % self.target
         dst = os.path.join(self.home, "app.sh")
         node.upload(command, dst)
         
@@ -81,14 +98,17 @@ class LinuxNodeTestCase(unittest.TestCase):
         node.kill(pid, ppid)
         status = node.status(pid, ppid)
         self.assertTrue(status, FINISHED)
+        
+        (out, err), proc = node.check_run_output(self.home)
+
+        expected = """64 bytes from"""
+
+        self.assertTrue(out.find(expected) > 0)
 
         node.rmdir(self.home)
 
+    @skipIfNotAlive
     def t_install(self, node):
-        if not node.is_alive():
-            print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
-            return
-
         node.mkdir(self.home, clean = True)
 
         prog = """#include <stdio.h>
@@ -105,19 +125,21 @@ main (void)
         node.upload(prog, dst)
 
         # install gcc
-        node.install('gcc')
+        node.install_packages('gcc')
 
         # compile the program using gcc
         command = "cd %s; gcc -Wall hello.c -o hello" % self.home
-        out = node.execute(command)
+        (out, err), proc = node.execute(command)
 
-        # execute the program and get the output from stout
+        # execute the program and get the output from stdout
         command = "%s/hello" % self.home
-        out = node.execute(command)
+        (out, err), proc = node.execute(command)
+
+        self.assertEquals(out, "Hello, world!\n")
 
         # execute the program and get the output from a file
-        command = "%s/hello > %s/hello.out" % (self.home, self.home)
-        node.execute(command)
+        command = "%(home)s/hello > %(home)s/hello.out" % {'home':self.home}
+        (out, err), proc = node.execute(command)
 
         # retrieve the output file 
         src = os.path.join(self.home, "hello.out")
@@ -126,11 +148,9 @@ main (void)
         node.download(src, dst)
         f.close()
 
-        node.uninstall('gcc')
+        node.remove_packages('gcc')
         node.rmdir(self.home)
 
-        self.assertEquals(out, "Hello, world!\n")
-
         f = open(dst, "r")
         out = f.read()
         f.close()
@@ -138,29 +158,25 @@ main (void)
         self.assertEquals(out, "Hello, world!\n")
 
     def test_execute_fedora(self):
-        self.t_execute(self.node_fedora, self.target)
+        self.t_execute(self.node_fedora)
 
     def test_execute_ubuntu(self):
-        self.t_execute(self.node_ubuntu, self.target)
+        self.t_execute(self.node_ubuntu)
 
     def test_run_fedora(self):
-        self.t_run(self.node_fedora, self.target)
+        self.t_run(self.node_fedora)
 
     def test_run_ubuntu(self):
-        self.t_run(self.node_ubuntu, self.target)
+        self.t_run(self.node_ubuntu)
 
     def test_intall_fedora(self):
         self.t_install(self.node_fedora)
 
     def test_install_ubuntu(self):
         self.t_install(self.node_ubuntu)
-
-    def xtest_xterm_fedora(self):
-        """ PlanetLab doesn't currently support X11 forwarding.
-        Interactive test. Should not run automatically """
-        self.t_xterm(self.node_fedora)
-
-    def xtest_xterm_ubuntu(self):
+    
+    @skipInteractive
+    def test_xterm_ubuntu(self):
         """ Interactive test. Should not run automatically """
         self.t_xterm(self.node_ubuntu)
 
diff --git a/test/resources/linux/ssh_api.py b/test/resources/linux/ssh_api.py
deleted file mode 100644 (file)
index 88b0365..0000000
+++ /dev/null
@@ -1,174 +0,0 @@
-#!/usr/bin/env python
-from neco.resources.linux.ssh_api import SSHApiFactory
-from neco.util.sshfuncs import RUNNING, FINISHED
-
-import os
-import time
-import tempfile
-import unittest
-
-def skipIfNotAlive(func):
-    name = func.__name__
-    def wrapped(*args, **kwargs):
-        host = args[1]
-        user = args[2]
-
-        api = SSHApiFactory.get_api(host, user)
-        if not api.is_alive():
-            print "*** WARNING: Skipping test %s: Node %s is not alive\n" % (name, host)
-            return
-
-        return func(*args, **kwargs)
-    
-    return wrapped
-
-def skipInteractive(func):
-    name = func.__name__
-    def wrapped(*args, **kwargs):
-        mode = os.environ.get("NEPI_INTERACTIVE", False) in ['True', 'true', 'yes', 'YES']
-        if not mode:
-            print "*** WARNING: Skipping test %s: Interactive mode off \n" % name
-            return
-
-        return func(*args, **kwargs)
-    
-    return wrapped
-
-class SSHApiTestCase(unittest.TestCase):
-    def setUp(self):
-        self.host_fedora = 'nepi2.pl.sophia.inria.fr'
-        self.user_fedora = 'inria_nepi'
-
-        self.host_ubuntu = 'roseval.pl.sophia.inria.fr'
-        self.user_ubuntu = 'alina'
-        
-        self.target = 'nepi5.pl.sophia.inria.fr'
-        self.home = '${HOME}/test-app'
-
-    @skipIfNotAlive
-    def t_xterm(self, host, user):
-        api = SSHApiFactory.get_api(host, user)
-
-        api.enable_x11 = True
-
-        api.install('xterm')
-
-        out = api.execute('xterm')
-
-        api.uninstall('xterm')
-
-        self.assertEquals(out, "")
-
-    @skipIfNotAlive
-    def t_execute(self, host, user):
-        api = SSHApiFactory.get_api(host, user)
-        
-        command = "ping -qc3 %s" % self.target
-        out, err = api.execute(command)
-
-        expected = """3 packets transmitted, 3 received, 0% packet loss"""
-
-        self.assertTrue(out.find(expected) > 0)
-
-    @skipIfNotAlive
-    def t_run(self, host, user):
-        api = SSHApiFactory.get_api(host, user)
-        
-        api.mkdir(self.home, clean = True)
-        
-        command = "ping %s" % self.target
-        dst = os.path.join(self.home, "app.sh")
-        api.upload(command, dst)
-        
-        cmd = "bash ./app.sh"
-        api.run(cmd, self.home)
-        pid, ppid = api.checkpid(self.home)
-
-        status = api.status(pid, ppid)
-        self.assertTrue(status, RUNNING)
-
-        api.kill(pid, ppid)
-        status = api.status(pid, ppid)
-        self.assertTrue(status, FINISHED)
-
-        api.rmdir(self.home)
-
-    @skipIfNotAlive
-    def t_install(self, host, user):
-        api = SSHApiFactory.get_api(host, user)
-        
-        api.mkdir(self.home, clean = True)
-
-        prog = """#include <stdio.h>
-
-int
-main (void)
-{
-    printf ("Hello, world!\\n");
-    return 0;
-}
-"""
-        # upload the test program
-        dst = os.path.join(self.home, "hello.c")
-        api.upload(prog, dst)
-
-        # install gcc
-        api.install('gcc')
-
-        # compile the program using gcc
-        command = "cd %s; gcc -Wall hello.c -o hello" % self.home
-        out = api.execute(command)
-
-        # execute the program and get the output from stout
-        command = "%s/hello" % self.home
-        out, err = api.execute(command)
-
-        # execute the program and get the output from a file
-        command = "%s/hello > %s/hello.out" % (self.home, self.home)
-        api.execute(command)
-
-        # retrieve the output file 
-        src = os.path.join(self.home, "hello.out")
-        f = tempfile.NamedTemporaryFile(delete=False)
-        dst = f.name
-        api.download(src, dst)
-        f.close()
-
-        api.uninstall('gcc')
-        api.rmdir(self.home)
-
-        self.assertEquals(out, "Hello, world!\n")
-
-        f = open(dst, "r")
-        out = f.read()
-        f.close()
-        
-        self.assertEquals(out, "Hello, world!\n")
-
-    def test_execute_fedora(self):
-        self.t_execute(self.host_fedora, self.user_fedora)
-
-    def test_execute_ubuntu(self):
-        self.t_execute(self.host_ubuntu, self.user_ubuntu)
-
-    def test_run_fedora(self):
-        self.t_run(self.host_fedora, self.user_fedora)
-
-    def test_run_ubuntu(self):
-        self.t_run(self.host_ubuntu, self.user_ubuntu)
-
-    def test_intall_fedora(self):
-        self.t_install(self.host_fedora, self.user_fedora)
-
-    def test_install_ubuntu(self):
-        self.t_install(self.host_ubuntu, self.user_ubuntu)
-    
-    @skipInteractive
-    def test_xterm_ubuntu(self):
-        """ Interactive test. Should not run automatically """
-        self.t_xterm(self.host_ubuntu, self.user_ubuntu)
-
-
-if __name__ == '__main__':
-    unittest.main()
-
index 9b282ce..c9afb32 100644 (file)
@@ -182,7 +182,7 @@ class SSHfuncsTestCase(unittest.TestCase):
         source = [dirpath, f1.name]
         destdir = tempfile.mkdtemp()
         dest = "%s@%s:%s" % (user, host, destdir)
-        rcopy(source, dest, port = env.port, agent = True)
+        rcopy(source, dest, port = env.port, agent = True, recursive = True)
 
         files = []
         def recls(files, dirname, names):