Creating ssh api (unfinished)
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Mon, 1 Apr 2013 11:40:30 +0000 (13:40 +0200)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Mon, 1 Apr 2013 11:40:30 +0000 (13:40 +0200)
src/neco/execution/ec.py
src/neco/execution/resource.py
src/neco/resources/linux/node.py
src/neco/resources/linux/ssh_api.py [new file with mode: 0644]
src/neco/resources/omf/omf_api.py
src/neco/resources/omf/omf_node.py
src/neco/util/sshfuncs.py
test/util/sshfuncs.py [new file with mode: 0644]

index 0b15c3a..076e4dd 100644 (file)
@@ -178,7 +178,8 @@ class ExperimentController(object):
            self._thread.join()
 
     def schedule(self, date, callback, track = False):
-        """
+        """ Schedule a callback to be executed at time date.
+
             date    string containing execution time for the task.
                     It can be expressed as an absolute time, using
                     timestamp format, or as a relative time matching
index ebc537c..62bf821 100644 (file)
@@ -110,7 +110,7 @@ class ResourceManager(object):
     def set_after(self, name, value, time, after_status, guid):
         pass
 
-    def next_step(self):
+    def start(self):
         pass
 
     def stop(self):
@@ -138,7 +138,7 @@ class ResourceFactory(object):
         cls._resource_types[rclass.rtype()] = rclass
 
     @classmethod
-    def create(cls, rtype, ec, guid, creds):
+    def create(cls, rtype, ec, guid):
         rclass = cls._resource_types[rtype]
-        return rclass(ec, guid, creds)
+        return rclass(ec, guid)
 
index feaad46..80bb103 100644 (file)
-from neco.execution.resource import ResourceManager
-from neco.util.sshfuncs import eintr_retry, rexec, rcopy, \
-        rspawn, rcheck_pid, rstatus, rkill, make_control_path, RUNNING 
-
-import cStringIO
-import logging
-import os.path
-import subprocess
+from neco.execution.resource import ResourceManager, clsinit
+from neco.execution.attribute import Attribute, Flags
 
+@clsinit
 class LinuxNode(ResourceManager):
-    def __init__(self, ec, guid):
-        super(LinuxNode, self).__init__(ec, guid)
-        self.ip = None
-        self.host = None
-        self.user = None
-        self.port = None
-        self.identity_file = None
-        self.enable_x11 = False
-        self.forward_agent = True
-
-        # packet management system - either yum or apt for now...
-        self._pm = None
-       
-        # Logging
-        loglevel = "debug"
-        self._logger = logging.getLogger("neco.resources.base.LinuxNode.%s" %\
-                self.guid)
-        self._logger.setLevel(getattr(logging, loglevel.upper()))
-
-        # For ssh connections we use the ControlMaster option which 
-        # allows us to decrease the number of open ssh network connections.
-        # Subsequent ssh connections will reuse a same master connection.
-        # This might pose a problem when using X11 and ssh-agent, since
-        # display and agent forwarded will be those of the first connection,
-        # which created the master. 
-        # To avoid reusing a master created by a previous LinuxNode instance,
-        # we explicitly erase the ControlPath socket.
-        control_path = make_control_path(self.user, self.host, self.port)
-        try:
-            os.remove(control_path)
-        except:
-            pass
-
-    @property
-    def pm(self):
-        if self._pm:
-            return self._pm
-
-        if (not (self.host or self.ip) or not self.user):
-            msg = "Can't resolve package management system. Insufficient data."
-            self._logger.error(msg)
-            raise RuntimeError(msg)
-
-        out = self.execute("cat /etc/issue")
-
-        if out.find("Fedora") == 0:
-            self._pm = "yum"
-        elif out.find("Debian") == 0 or out.find("Ubuntu") ==0:
-            self._pm = "apt-get"
-        else:
-            msg = "Can't resolve package management system. Unknown OS."
-            self._logger.error(msg)
-            raise RuntimeError(msg)
-
-        return self._pm
-
-    @property
-    def is_localhost(self):
-        return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1']
-
-    def install(self, packages):
-        if not isinstance(packages, list):
-            packages = [packages]
+    _rtype = "LinuxNode"
 
-        for p in packages:
-            self.execute("%s -y install %s" % (self.pm, p), sudo = True, 
-                    tty = True)
+    @classmethod
+    def _register_attributes(cls):
+        hostname = Attribute("hostname", "Hostname of the machine")
+        username = Attribute("username", "Local account username", 
+                flags = Flags.Credential)
+        password = Attribute("pasword", "Local account password",
+                flags = Flags.Credential)
 
-    def uninstall(self, packages):
-        if not isinstance(packages, list):
-            packages = [packages]
+        cls._register_attribute(hostname)
+        cls._register_attribute(username)
+        cls._register_attribute(password)
 
-        for p in packages:
-            self.execute("%s -y remove %s" % (self.pm, p), sudo = True, 
-                    tty = True)
-
-    def upload(self, src, dst):
-        if not os.path.isfile(src):
-            src = cStringIO.StringIO(src)
-
-        if not self.is_localhost:
-            # Build destination as <user>@<server>:<path>
-            dst = "%s@%s:%s" % (self.user, self.host or self.ip, dst)
-        return self.copy(src, dst)
-
-    def download(self, src, dst):
-        if not self.is_localhost:
-            # Build destination as <user>@<server>:<path>
-            src = "%s@%s:%s" % (self.user, self.host or self.ip, src)
-        return self.copy(src, dst)
-        
-    def is_alive(self, verbose = False):
-        if self.is_localhost:
-            return True
-
-        try:
-            out = self.execute("echo 'ALIVE'",
-                timeout = 60,
-                err_on_timeout = False,
-                persistent = False)
-        except:
-            if verbose:
-                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
-            return False
-
-        if out.strip().startswith('ALIVE'):
-            return True
-        else:
-            if verbose:
-                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
-            return False
-
-    def mkdir(self, path, clean = True):
-        if clean:
-            self.rmdir(path)
-
-        return self.execute(
-            "mkdir -p %s" % path,
-            timeout = 120,
-            retry = 3
-            )
-
-    def rmdir(self, path):
-        return self.execute(
-            "rm -rf %s" % path,
-            timeout = 120,
-            retry = 3
-            )
-
-    def copy(self, src, dst):
-        if self.is_localhost:
-            command = ["cp", "-R", src, dst]
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            (out, err), proc = eintr_retry(rcopy)(
-                src, dst, 
-                port = self.port,
-                agent = self.agent,
-                identity_file = self.identity_file)
-
-            if proc.wait():
-                msg = "Error uploading to %s got:\n%s%s" %\
-                        (self.host or self.ip, out, err)
-                self._logger.error(msg)
-                raise RuntimeError(msg)
+    def __init__(self, ec, guid):
+        super(LinuxNode, self).__init__(ec, guid)
 
-        return (out, err)
+        self._logger = logging.getLogger("neco.linux.Node.%d" % guid)
+        #elf._logger.setLevel(neco.LOGLEVEL)
 
-    def execute(self, command,
-            sudo = False,
-            stdin = None, 
-            tty = False,
-            env = None,
-            timeout = None,
-            retry = 0,
-            err_on_timeout = True,
-            connect_timeout = 30,
-            persistent = True):
-        """ Notice that this invocation will block until the
-        execution finishes. If this is not the desired behavior,
-        use 'run' instead."""
+    def deploy(self):
+        pass
 
-        if self.is_localhost:
-            if env:
-                export = ''
-                for envkey, envval in env.iteritems():
-                    export += '%s=%s ' % (envkey, envval)
-                command = export + command
+    def discover(self, filters):
+        pass
 
-            if sudo:
-                command = "sudo " + command
+    def provision(self, filters):
+        pass
 
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            (out, err), proc = eintr_retry(rexec)(
-                    command, 
-                    self.host or self.ip, 
-                    self.user,
-                    port = self.port, 
-                    agent = self.forward_agent,
-                    sudo = sudo,
-                    stdin = stdin, 
-                    identity_file = self.identity_file,
-                    tty = tty,
-                    x11 = self.enable_x11,
-                    env = env,
-                    timeout = timeout,
-                    retry = retry,
-                    err_on_timeout = err_on_timeout,
-                    connect_timeout = connect_timeout,
-                    persistent = persistent)
+    def start(self):
+        pass
 
-            if proc.wait():
-                msg = "Failed to execute command %s at node %s: %s %s" % \
-                        (command, self.host or self.ip, out, err,)
-                self._logger.warn(msg)
-                raise RuntimeError(msg)
+    def stop(self):
+        pass
 
-        return (out, err)
+    def deploy(self, group = None):
+        pass
 
-    def run(self, command, home, 
-            stdin = None, 
-            stdout = 'stdout', 
-            stderr = 'stderr', 
-            sudo = False):
-        self._logger.info("Running %s", command)
-        
-        pidfile = './pid',
+    def release(self):
+        pass
 
-        if self.is_localhost:
-            if stderr == stdout:
-                stderr = '&1'
-            else:
-                stderr = ' ' + stderr
-            
-            daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
-                'command' : command,
-                'pidfile' : pidfile,
-                
-                'stdout' : stdout,
-                'stderr' : stderr,
-                'stdin' : stdin,
-            }
-            
-            cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-                    'command' : daemon_command,
-                    
-                    'sudo' : 'sudo -S' if sudo else '',
-                    
-                    'pidfile' : pidfile,
-                    'gohome' : 'cd %s ; ' % home if home else '',
-                    'create' : 'mkdir -p %s ; ' % home if create_home else '',
-                }
-            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
-                    stderr=subprocess.PIPE)
-            out, err = p.communicate()
-        else:
-            # Start process in a "daemonized" way, using nohup and heavy
-            # stdin/out redirection to avoid connection issues
-            (out,err), proc = rspawn(
-                command,
-                pidfile = pidfile,
-                home = home,
-                stdin = stdin if stdin is not None else '/dev/null',
-                stdout = stdout if stdout else '/dev/null',
-                stderr = stderr if stderr else '/dev/null',
-                sudo = sudo,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.forward_agent,
-                identity_file = self.identity_file
-                )
-            
-            if proc.wait():
-                raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+    def _validate_connection(self, guid):
+        # TODO: Validate!
+        return True
 
-        return (out, err)
-    
-    def checkpid(self, path):            
-        # Get PID/PPID
-        # NOTE: wait a bit for the pidfile to be created
-        pidtuple = rcheck_pid(
-            os.path.join(path, 'pid'),
-            host = self.host,
-            user = self.user,
-            port = self.port,
-            agent = self.forward_agent,
-            identity_file = self.identity_file
-            )
-        
-        return pidtuple
-    
-    def status(self, pid, ppid):
-        status = rstatus(
-                pid, ppid,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.forward_agent,
-                identity_file = self.identity_file
-                )
-           
-        return status
-    
-    def kill(self, pid, ppid, sudo = False):
-        status = self.status(pid, ppid)
-        if status == RUNNING:
-            # kill by ppid+pid - SIGTERM first, then try SIGKILL
-            rkill(
-                pid, ppid,
-                host = self.host,
-                user = self.user,
-                port = self.port,
-                agent = self.forward_agent,
-                sudo = sudo,
-                identity_file = self.identity_file
-                )
 
diff --git a/src/neco/resources/linux/ssh_api.py b/src/neco/resources/linux/ssh_api.py
new file mode 100644 (file)
index 0000000..93e895e
--- /dev/null
@@ -0,0 +1,323 @@
+import hashlib
+import os
+import re
+
+class SSHAPI(object):
+    def __init__(self, host, user, identity, port, agent, forward_x11):
+        self.host = host
+        self.user = user
+        # ssh identity file
+        self.identity = identity
+        self.port = port
+        # use ssh agent
+        self.agent = agent
+        # forward X11 
+        self.forward_x11 = forward_x11
+
+    # TODO: Investigate using http://nixos.org/nix/
+    @property
+    def pm(self):
+        if self._pm:
+            return self._pm
+
+        if (not self.host or not self.user):
+            msg = "Can't resolve package management system. Insufficient data."
+            self._logger.error(msg)
+            raise RuntimeError(msg)
+
+        out = self.execute("cat /etc/issue")
+
+        if out.find("Fedora") == 0:
+            self._pm = "yum"
+        elif out.find("Debian") == 0 or out.find("Ubuntu") ==0:
+            self._pm = "apt-get"
+        else:
+            msg = "Can't resolve package management system. Unknown OS."
+            self._logger.error(msg)
+            raise RuntimeError(msg)
+
+        return self._pm
+
+    @property
+    def is_localhost(self):
+        return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1']
+
+    # TODO: Investigate using http://nixos.org/nix/
+    def install(self, packages):
+        if not isinstance(packages, list):
+            packages = [packages]
+
+        for p in packages:
+            self.execute("%s -y install %s" % (self.pm, p), sudo = True, 
+                    tty = True)
+
+    # TODO: Investigate using http://nixos.org/nix/
+    def uninstall(self, packages):
+        if not isinstance(packages, list):
+            packages = [packages]
+
+        for p in packages:
+            self.execute("%s -y remove %s" % (self.pm, p), sudo = True, 
+                    tty = True)
+
+    def upload(self, src, dst):
+        """ Copy content to destination
+
+           src  content to copy. Can be a local file, directory or text input
+
+           dst  destination path on the remote host (remote is always self.host)
+        """
+        # If src is a string input 
+        if not os.path.isfile(src) and not isdir:
+            # src is text input that should be uploaded as file           
+            src = cStringIO.StringIO(src)
+
+        if not self.is_localhost:
+            # Build destination as <user>@<server>:<path>
+            dst = "%s@%s:%s" % (self.user, self.host, dst)
+        return self.copy(src, dst)
+
+    def download(self, src, dst):
+        if not self.is_localhost:
+            # Build destination as <user>@<server>:<path>
+            src = "%s@%s:%s" % (self.user, self.host or self.ip, src)
+        return self.copy(src, dst)
+        
+    def is_alive(self, verbose = False):
+        if self.is_localhost:
+            return True
+
+        try:
+            out = self.execute("echo 'ALIVE'",
+                timeout = 60,
+                err_on_timeout = False,
+                persistent = False)
+        except:
+            if verbose:
+                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
+            return False
+
+        if out.strip().startswith('ALIVE'):
+            return True
+        else:
+            if verbose:
+                self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err)
+            return False
+
+    def mkdir(self, path, clean = True):
+        if clean:
+            self.rmdir(path)
+
+        return self.execute(
+            "mkdir -p %s" % path,
+            timeout = 120,
+            retry = 3
+            )
+
+    def rmdir(self, path):
+        return self.execute(
+            "rm -rf %s" % path,
+            timeout = 120,
+            retry = 3
+            )
+
+    def copy(self, src, dst):
+        if self.is_localhost:
+            command = ["cp", "-R", src, dst]
+            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
+                    stderr=subprocess.PIPE)
+            out, err = p.communicate()
+        else:
+            (out, err), proc = eintr_retry(rcopy)(
+                src, dst, 
+                port = self.port,
+                agent = self.agent,
+                identity_file = self.identity_file)
+
+            if proc.wait():
+                msg = "Error uploading to %s got:\n%s%s" %\
+                        (self.host or self.ip, out, err)
+                self._logger.error(msg)
+                raise RuntimeError(msg)
+
+        return (out, err)
+
+    def execute(self, command,
+            sudo = False,
+            stdin = None, 
+            tty = False,
+            env = None,
+            timeout = None,
+            retry = 0,
+            err_on_timeout = True,
+            connect_timeout = 30,
+            persistent = True):
+        """ Notice that this invocation will block until the
+        execution finishes. If this is not the desired behavior,
+        use 'run' instead."""
+
+        if self.is_localhost:
+            if env:
+                export = ''
+                for envkey, envval in env.iteritems():
+                    export += '%s=%s ' % (envkey, envval)
+                command = export + command
+
+            if sudo:
+                command = "sudo " + command
+
+            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
+                    stderr=subprocess.PIPE)
+            out, err = p.communicate()
+        else:
+            (out, err), proc = eintr_retry(rexec)(
+                    command, 
+                    self.host or self.ip, 
+                    self.user,
+                    port = self.port, 
+                    agent = self.forward_agent,
+                    sudo = sudo,
+                    stdin = stdin, 
+                    identity_file = self.identity_file,
+                    tty = tty,
+                    x11 = self.enable_x11,
+                    env = env,
+                    timeout = timeout,
+                    retry = retry,
+                    err_on_timeout = err_on_timeout,
+                    connect_timeout = connect_timeout,
+                    persistent = persistent)
+
+            if proc.wait():
+                msg = "Failed to execute command %s at node %s: %s %s" % \
+                        (command, self.host or self.ip, out, err,)
+                self._logger.warn(msg)
+                raise RuntimeError(msg)
+
+        return (out, err)
+
+    def run(self, command, home, 
+            stdin = None, 
+            stdout = 'stdout', 
+            stderr = 'stderr', 
+            sudo = False):
+        self._logger.info("Running %s", command)
+        
+        pidfile = './pid',
+
+        if self.is_localhost:
+            if stderr == stdout:
+                stderr = '&1'
+            else:
+                stderr = ' ' + stderr
+            
+            daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
+                'command' : command,
+                'pidfile' : pidfile,
+                
+                'stdout' : stdout,
+                'stderr' : stderr,
+                'stdin' : stdin,
+            }
+            
+            cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
+                    'command' : daemon_command,
+                    
+                    'sudo' : 'sudo -S' if sudo else '',
+                    
+                    'pidfile' : pidfile,
+                    'gohome' : 'cd %s ; ' % home if home else '',
+                    'create' : 'mkdir -p %s ; ' % home if create_home else '',
+                }
+            p = subprocess.Popen(command, stdout=subprocess.PIPE, 
+                    stderr=subprocess.PIPE)
+            out, err = p.communicate()
+        else:
+            # Start process in a "daemonized" way, using nohup and heavy
+            # stdin/out redirection to avoid connection issues
+            (out,err), proc = rspawn(
+                command,
+                pidfile = pidfile,
+                home = home,
+                stdin = stdin if stdin is not None else '/dev/null',
+                stdout = stdout if stdout else '/dev/null',
+                stderr = stderr if stderr else '/dev/null',
+                sudo = sudo,
+                host = self.host,
+                user = self.user,
+                port = self.port,
+                agent = self.agent,
+                identity_file = self.file
+                )
+            
+            if proc.wait():
+                raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+
+        return (out, err)
+    
+    def checkpid(self, path):            
+        # Get PID/PPID
+        # NOTE: wait a bit for the pidfile to be created
+        pidtuple = rcheck_pid(
+            os.path.join(path, 'pid'),
+            host = self.host,
+            user = self.user,
+            port = self.port,
+            agent = self.agent,
+            identity_file = self.identity
+            )
+        
+        return pidtuple
+    
+    def status(self, pid, ppid):
+        status = rstatus(
+                pid, ppid,
+                host = self.host,
+                user = self.user,
+                port = self.port,
+                agent = self.agent,
+                identity_file = self.identity
+                )
+           
+        return status
+    
+    def kill(self, pid, ppid, sudo = False):
+        status = self.status(pid, ppid)
+        if status == RUNNING:
+            # kill by ppid+pid - SIGTERM first, then try SIGKILL
+            rkill(
+                pid, ppid,
+                host = self.host,
+                user = self.user,
+                port = self.port,
+                agent = self.agent,
+                sudo = sudo,
+                identity_file = self.identity
+                )
+
+class SSHAPIFactory(object):
+    _apis = dict()
+
+    @classmethod 
+    def get_api(cls, attributes):
+        host = attributes.get("hostname")
+        user = attributes.get("username")
+        identity = attributes.get("identity", "%s/.ssh/id_rsa" % os.environ['HOME'])
+        port = attributes.get("port", 22)
+        agent = attributes.get("agent", True)
+        forward_X11 = attributes.get("forwardX11", False)
+
+        key = cls.make_key(host, user, identity, port, agent, forward_X11)
+        api = self._apis.get(key)
+
+        if no api:
+            api = SSHAPI(host, user, identity, port, agent, forward_X11)
+            self._apis[key] = api
+
+        return api
+
+    @classmethod 
+    def make_key(cls, *args):
+        skey = "".join(map(str, args))
+        return hashlib.md5(skey).hexdigest()
+
index aca53e5..412e735 100644 (file)
@@ -150,6 +150,7 @@ class OMFAPI(object):
 
 
 class OMFAPIFactory(object):
+    # XXX: put '_apis' instead of '_Api'
     _Api = dict()
 
     @classmethod 
@@ -169,6 +170,16 @@ class OMFAPIFactory(object):
         cls._Api[key] = OmfApi
         return OmfApi
 
+    # XXX: this is not a hash :)
+    # From wikipedia: "A hash function is any algorithm or subroutine that maps large data 
+    # sets of variable length to smaller data sets of a fixed length."
+    # The idea is to apply a function to get a smaller string. Use hashlib instead.
+    # e.g:
+    # import hashlib
+    # res = slice + "_" + host + "_" + port
+    # hashlib.md5(res).hexdigest()
+    #
+    # XXX: change method name for 'make_key'
     @classmethod 
     def _hash_api(cls, slice, host, port):
         res = slice + "_" + host + "_" + port
index 9bcffe9..064da47 100644 (file)
@@ -17,6 +17,10 @@ class OMFNode(ResourceManager):
         hostname = Attribute("hostname", "Hostname of the machine")
         cpu = Attribute("cpu", "CPU of the node")
         ram = Attribute("ram", "RAM of the node")
+        # XXX: flags = "0x02" is not human readable.
+        # instead:
+        # from neco.execution.attribute import Attribute, Flags 
+        # xmppSlice = Attribute("xmppSlice","Name of the slice", flags = Flags.Credential)
         xmppSlice = Attribute("xmppSlice","Name of the slice", flags = "0x02")
         xmppHost = Attribute("xmppHost", "Xmpp Server",flags = "0x02")
         xmppPort = Attribute("xmppPort", "Xmpp Port",flags = "0x02")
@@ -27,7 +31,7 @@ class OMFNode(ResourceManager):
         cls._register_attribute(xmppSlice)
         cls._register_attribute(xmppHost)
         cls._register_attribute(xmppPort)
-        cls._register_attribute(xmppPassword)
+        ls._register_attribute(xmppPassword)
 
     @classmethod
     def _register_filters(cls):
@@ -40,6 +44,10 @@ class OMFNode(ResourceManager):
         cls._register_filter(granularity)
         cls._register_filter(hardware_type)
 
+    # XXX: We don't necessary need to have the credentials at the 
+    # moment we create the RM
+    # THE OMF API SHOULD BE CREATED ON THE DEPLOY METHOD, NOT NOW
+    # THIS FORCES MORE CONSTRAINES ON THE WAY WE WILL AUTHOMATE DEPLOYMENT!
     def __init__(self, ec, guid, creds):
         super(OMFNode, self).__init__(ec, guid)
         self.set('xmppSlice', creds['xmppSlice'])
@@ -47,9 +55,12 @@ class OMFNode(ResourceManager):
         self.set('xmppPort', creds['xmppPort'])
         self.set('xmppPassword', creds['xmppPassword'])
 
+        # XXX: Lines should not be more than 80 characters!
         self._omf_api = OMFAPIFactory.get_api(self.get('xmppSlice'), self.get('xmppHost'), self.get('xmppPort'), self.get('xmppPassword'))
 
         self._logger = logging.getLogger("neco.omf.omfNode   ")
+
+        # XXX: TO DISCUSS
         self._logger.setLevel(neco.LOGLEVEL)
 
     def _validate_connection(self, guid):
index aad6039..f8d1cfc 100644 (file)
@@ -201,7 +201,6 @@ def rexec(command, host, user,
 def rcopy(source, dest,
         port = None, 
         agent = True, 
-        recursive = False,
         identity_file = None):
     """
     Copies file from/to remote sites.
@@ -209,214 +208,57 @@ def rcopy(source, dest,
     Source and destination should have the user and host encoded
     as per scp specs.
     
-    If source is a file object, a special mode will be used to
-    create the remote file with the same contents.
-    
-    If dest is a file object, the remote file (source) will be
-    read and written into dest.
-    
-    In these modes, recursive cannot be True.
-    
     Source can be a list of files to copy to a single destination,
     in which case it is advised that the destination be a folder.
     """
     
-    if isinstance(source, file) and source.tell() == 0:
-        source = source.name
-    elif hasattr(source, 'read'):
-        tmp = tempfile.NamedTemporaryFile()
-        while True:
-            buf = source.read(65536)
-            if buf:
-                tmp.write(buf)
-            else:
-                break
-        tmp.seek(0)
-        source = tmp.name
+    # Parse destination as <user>@<server>:<path>
+    if isinstance(dest, basestring) and ':' in dest:
+        remspec, path = dest.split(':',1)
+    elif isinstance(source, basestring) and ':' in source:
+        remspec, path = source.split(':',1)
+    else:
+        raise ValueError, "Both endpoints cannot be local"
+    user, host = remspec.rsplit('@',1)
+
+    raw_string = r'''rsync -rlpcSz --timeout=900 '''
+    raw_string += r''' -e 'ssh -o BatchMode=yes '''
+    raw_string += r''' -o NoHostAuthenticationForLocalhost=yes '''
+    raw_string += r''' -o StrictHostKeyChecking=no '''
+    raw_string += r''' -o ConnectionAttempts=3 '''
+    if openssh_has_persist():
+        control_path = make_control_path(user, host, port)
+        raw_string += r''' -o ControlMaster=auto '''
+        raw_string += r''' -o ControlPath=%s ''' % control_path
+  
+    if port:
+        raw_string += r''' -p %d ''' % port
     
-    if isinstance(source, file) or isinstance(dest, file) \
-            or hasattr(source, 'read')  or hasattr(dest, 'write'):
-        assert not recursive
+    if identity_file:
+        raw_string += r''' -i "%s" ''' % identity_file
     
-        # Parse source/destination as <user>@<server>:<path>
-        if isinstance(dest, basestring) and ':' in dest:
-            remspec, path = dest.split(':',1)
-        elif isinstance(source, basestring) and ':' in source:
-            remspec, path = source.split(':',1)
-        else:
-            raise ValueError, "Both endpoints cannot be local"
-        user,host = remspec.rsplit('@',1)
-        tmp_known_hosts = None
-        
-        args = ['ssh', '-l', user, '-C',
-                # Don't bother with localhost. Makes test easier
-                '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
-                '-o', 'ConnectionAttempts=3',
-                '-o', 'ServerAliveInterval=30',
-                '-o', 'TCPKeepAlive=yes',
-                host ]
-
-        if openssh_has_persist():
-            control_path = make_control_path(user, host, port)
-            args.extend([
-                '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=%s' % control_path,
-                '-o', 'ControlPersist=60' ])
-        if port:
-            args.append('-P%d' % port)
-        if identity_file:
-            args.extend(('-i', identity_file))
-        
-        if isinstance(source, file) or hasattr(source, 'read'):
-            args.append('cat > %s' % (shell_escape(path),))
-        elif isinstance(dest, file) or hasattr(dest, 'write'):
-            args.append('cat %s' % (shell_escape(path),))
-        else:
-            raise AssertionError, "Unreachable code reached! :-Q"
+    # closing -e 'ssh...'
+    raw_string += r''' ' '''
 
-        # connects to the remote host and starts a remote connection
-        if isinstance(source, file):
-            proc = subprocess.Popen(args, 
-                    stdout = open('/dev/null','w'),
-                    stderr = subprocess.PIPE,
-                    stdin = source)
-            err = proc.stderr.read()
-            eintr_retry(proc.wait)()
-            return ((None,err), proc)
-        elif isinstance(dest, file):
-            proc = subprocess.Popen(args, 
-                    stdout = open('/dev/null','w'),
-                    stderr = subprocess.PIPE,
-                    stdin = source)
-            err = proc.stderr.read()
-            eintr_retry(proc.wait)()
-            return ((None,err), proc)
-        elif hasattr(source, 'read'):
-            # file-like (but not file) source
-            proc = subprocess.Popen(args, 
-                    stdout = open('/dev/null','w'),
-                    stderr = subprocess.PIPE,
-                    stdin = subprocess.PIPE)
-            
-            buf = None
-            err = []
-            while True:
-                if not buf:
-                    buf = source.read(4096)
-                if not buf:
-                    #EOF
-                    break
-                
-                rdrdy, wrdy, broken = select.select(
-                    [proc.stderr],
-                    [proc.stdin],
-                    [proc.stderr,proc.stdin])
-                
-                if proc.stderr in rdrdy:
-                    # use os.read for fully unbuffered behavior
-                    err.append(os.read(proc.stderr.fileno(), 4096))
-                
-                if proc.stdin in wrdy:
-                    proc.stdin.write(buf)
-                    buf = None
-                
-                if broken:
-                    break
-            proc.stdin.close()
-            err.append(proc.stderr.read())
-                
-            eintr_retry(proc.wait)()
-            return ((None,''.join(err)), proc)
-        elif hasattr(dest, 'write'):
-            # file-like (but not file) dest
-            proc = subprocess.Popen(args, 
-                    stdout = subprocess.PIPE,
-                    stderr = subprocess.PIPE,
-                    stdin = open('/dev/null','w'))
-            
-            buf = None
-            err = []
-            while True:
-                rdrdy, wrdy, broken = select.select(
-                    [proc.stderr, proc.stdout],
-                    [],
-                    [proc.stderr, proc.stdout])
-                
-                if proc.stderr in rdrdy:
-                    # use os.read for fully unbuffered behavior
-                    err.append(os.read(proc.stderr.fileno(), 4096))
-                
-                if proc.stdout in rdrdy:
-                    # use os.read for fully unbuffered behavior
-                    buf = os.read(proc.stdout.fileno(), 4096)
-                    dest.write(buf)
-                    
-                    if not buf:
-                        #EOF
-                        break
-                
-                if broken:
-                    break
-            err.append(proc.stderr.read())
-                
-            eintr_retry(proc.wait)()
-            return ((None,''.join(err)), proc)
-        else:
-            raise AssertionError, "Unreachable code reached! :-Q"
+    if isinstance(source,list):
+        source = ' '.join(source)
     else:
-        # Parse destination as <user>@<server>:<path>
-        if isinstance(dest, basestring) and ':' in dest:
-            remspec, path = dest.split(':',1)
-        elif isinstance(source, basestring) and ':' in source:
-            remspec, path = source.split(':',1)
-        else:
-            raise ValueError, "Both endpoints cannot be local"
-        user,host = remspec.rsplit('@',1)
-
-        # plain scp
-        args = ['scp', '-q', '-p', '-C',
-                # Don't bother with localhost. Makes test easier
-                '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
-                '-o', 'ConnectionAttempts=3',
-                '-o', 'ServerAliveInterval=30',
-                '-o', 'TCPKeepAlive=yes' ]
-                
-        if port:
-            args.append('-P%d' % port)
-        if recursive:
-            args.append('-r')
-        if identity_file:
-            args.extend(('-i', identity_file))
-
-        if isinstance(source,list):
-            args.extend(source)
-        else:
-            if openssh_has_persist():
-                control_path = make_control_path(user, host, port)
-                args.extend([
-                    '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=%s' % control_path ])
-            args.append(source)
+        source = '"%s"' % source
 
-        args.append(dest)
+    raw_string += r''' %s ''' % source
+    raw_string += r''' %s ''' % dest
 
-        # connects to the remote host and starts a remote connection
-        proc = subprocess.Popen(args, 
-                stdout = subprocess.PIPE,
-                stdin = subprocess.PIPE, 
-                stderr = subprocess.PIPE)
-        
-        comm = proc.communicate()
-        eintr_retry(proc.wait)()
-        return (comm, proc)
+    # connects to the remote host and starts a remote connection
+    proc = subprocess.Popen(raw_string,
+            shell=True,
+            stdout = subprocess.PIPE,
+            stdin = subprocess.PIPE, 
+            stderr = subprocess.PIPE)
+  
+    comm = proc.communicate()
+    eintr_retry(proc.wait)()
+    return (comm, proc)
 
 def rspawn(command, pidfile, 
         stdout = '/dev/null', 
@@ -465,7 +307,8 @@ def rspawn(command, pidfile,
         stderr = '&1'
     else:
         stderr = ' ' + stderr
-    
+   
+    #XXX: ppid is always 1!!!
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
         'pidfile' : pidfile,
@@ -562,6 +405,7 @@ def rstatus(pid, ppid,
         One of NOT_STARTED, RUNNING, FINISHED
     """
 
+    # XXX: ppid unused
     (out,err),proc = rexec(
         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
             'ppid' : ppid,
@@ -590,7 +434,7 @@ def rstatus(pid, ppid,
     
 
 @eintr_retry
-def rkill(pid, ppid, 
+def rkill(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
diff --git a/test/util/sshfuncs.py b/test/util/sshfuncs.py
new file mode 100644 (file)
index 0000000..756c79a
--- /dev/null
@@ -0,0 +1,245 @@
+#!/usr/bin/env python
+
+from neco.util.sshfuncs import *
+
+import getpass
+import unittest
+import os
+import subprocess
+import re
+import signal
+import shutil
+import socket
+import subprocess
+import tempfile
+import time
+
+def find_bin(name, extra_path = None):
+    search = []
+    if "PATH" in os.environ:
+        search += os.environ["PATH"].split(":")
+    for pref in ("/", "/usr/", "/usr/local/"):
+        for d in ("bin", "sbin"):
+            search.append(pref + d)
+    if extra_path:
+        search += extra_path
+
+    for d in search:
+            try:
+                os.stat(d + "/" + name)
+                return d + "/" + name
+            except OSError, e:
+                if e.errno != os.errno.ENOENT:
+                    raise
+    return None
+
+def find_bin_or_die(name, extra_path = None):
+    r = find_bin(name)
+    if not r:
+        raise RuntimeError(("Cannot find `%s' command, impossible to " +
+                "continue.") % name)
+    return r
+
+def gen_ssh_keypair(filename):
+    ssh_keygen = find_bin_or_die("ssh-keygen")
+    args = [ssh_keygen, '-q', '-N', '', '-f', filename]
+    assert subprocess.Popen(args).wait() == 0
+    return filename, "%s.pub" % filename
+
+def add_key_to_agent(filename):
+    ssh_add = find_bin_or_die("ssh-add")
+    args = [ssh_add, filename]
+    null = file("/dev/null", "w")
+    assert subprocess.Popen(args, stderr = null).wait() == 0
+    null.close()
+
+def get_free_port():
+    s = socket.socket()
+    s.bind(("127.0.0.1", 0))
+    port = s.getsockname()[1]
+    return port
+
+_SSH_CONF = """ListenAddress 127.0.0.1:%d
+Protocol 2
+HostKey %s
+UsePrivilegeSeparation no
+PubkeyAuthentication yes
+PasswordAuthentication no
+AuthorizedKeysFile %s
+UsePAM no
+AllowAgentForwarding yes
+PermitRootLogin yes
+StrictModes no
+PermitUserEnvironment yes
+"""
+
+def gen_sshd_config(filename, port, server_key, auth_keys):
+    conf = open(filename, "w")
+    text = _SSH_CONF % (port, server_key, auth_keys)
+    conf.write(text)
+    conf.close()
+    return filename
+
+def gen_auth_keys(pubkey, output, environ):
+    #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
+    opts = []
+    for k, v in environ.items():
+        opts.append('environment="%s=%s"' % (k, v))
+
+    lines = file(pubkey).readlines()
+    pubkey = lines[0].split()[0:2]
+    out = file(output, "w")
+    out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
+    out.close()
+    return output
+
+def start_ssh_agent():
+    ssh_agent = find_bin_or_die("ssh-agent")
+    proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
+    (out, foo) = proc.communicate()
+    assert proc.returncode == 0
+    d = {}
+    for l in out.split("\n"):
+        match = re.search("^(\w+)=([^ ;]+);.*", l)
+        if not match:
+            continue
+        k, v = match.groups()
+        os.environ[k] = v
+        d[k] = v
+    return d
+
+def stop_ssh_agent(data):
+    # No need to gather the pid, ssh-agent knows how to kill itself; after we
+    # had set up the environment
+    ssh_agent = find_bin_or_die("ssh-agent")
+    null = file("/dev/null", "w")
+    proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
+    null.close()
+    assert proc.wait() == 0
+    for k in data:
+        del os.environ[k]
+
+class test_environment(object):
+    def __init__(self):
+        sshd = find_bin_or_die("sshd")
+        environ = {}
+        self.dir = tempfile.mkdtemp()
+        self.server_keypair = gen_ssh_keypair(
+                os.path.join(self.dir, "server_key"))
+        self.client_keypair = gen_ssh_keypair(
+                os.path.join(self.dir, "client_key"))
+        self.authorized_keys = gen_auth_keys(self.client_keypair[1],
+                os.path.join(self.dir, "authorized_keys"), environ)
+        self.port = get_free_port()
+        self.sshd_conf = gen_sshd_config(
+                os.path.join(self.dir, "sshd_config"),
+                self.port, self.server_keypair[0], self.authorized_keys)
+
+        self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
+        self.ssh_agent_vars = start_ssh_agent()
+        add_key_to_agent(self.client_keypair[0])
+
+    def __del__(self):
+        if self.sshd:
+            os.kill(self.sshd.pid, signal.SIGTERM)
+            self.sshd.wait()
+        if self.ssh_agent_vars:
+            stop_ssh_agent(self.ssh_agent_vars)
+        shutil.rmtree(self.dir)
+
+class SSHfuncsTestCase(unittest.TestCase):
+    def test_rexec(self):
+        env = test_environment()
+        user = getpass.getuser()
+        host = "localhost" 
+
+        command = "hostname"
+
+        plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
+                stdin=subprocess.PIPE)
+        outlocal, errlocal = plocal.communicate()
+
+        (outremote, errrmote), premote = rexec(command, host, user, 
+                port = env.port, agent = True)
+
+        self.assertEquals(outlocal, outremote)
+
+    def test_rcopy(self):
+        env = test_environment()
+        user = getpass.getuser()
+        host = "localhost"
+
+        # create some temp files and directories to copy
+        dirpath = tempfile.mkdtemp()
+        f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
+        f.close()
+      
+        f1 = tempfile.NamedTemporaryFile(delete=False)
+        f1.close()
+        f1.name
+
+        source = [dirpath, f1.name]
+        destdir = tempfile.mkdtemp()
+        dest = "%s@%s:%s" % (user, host, destdir)
+        rcopy(source, dest, port = env.port, agent = True)
+
+        files = []
+        def recls(files, dirname, names):
+            files.extend(names)
+        os.path.walk(destdir, recls, files)
+        
+        origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
+
+        self.assertEquals(sorted(origfiles), sorted(files))
+
+    def test_rproc_manage(self):
+        env = test_environment()
+        user = getpass.getuser()
+        host = "localhost" 
+        command = "ping localhost"
+        
+        f = tempfile.NamedTemporaryFile(delete=False)
+        pidfile = f.name 
+
+        (out,err), proc = rspawn(
+                command, 
+                pidfile,
+                host = host,
+                user = user,
+                port = env.port,
+                agent = True)
+
+        time.sleep(2)
+
+        (pid, ppid) = rcheck_pid(pidfile,
+                host = host,
+                user = user,
+                port = env.port,
+                agent = True)
+
+        status = rstatus(pid, ppid,
+                host = host,
+                user = user, 
+                port = env.port, 
+                agent = True)
+
+        self.assertEquals(status, RUNNING)
+
+        rkill(pid, ppid,
+                host = host,
+                user = user, 
+                port = env.port, 
+                agent = True)
+
+        status = rstatus(pid, ppid,
+                host = host,
+                user = user, 
+                port = env.port, 
+                agent = True)
+        
+        self.assertEquals(status, FINISHED)
+
+
+if __name__ == '__main__':
+    unittest.main()
+