Adding tests for ssh_api
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Mon, 1 Apr 2013 20:48:25 +0000 (22:48 +0200)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Mon, 1 Apr 2013 20:48:25 +0000 (22:48 +0200)
src/neco/resources/linux/ssh_api.py
src/neco/util/sshfuncs.py
test/resources/linux/ssh_api.py [new file with mode: 0644]
test/util/sshfuncs.py

index 93e895e..7f5009e 100644 (file)
@@ -1,9 +1,15 @@
+
+from neco.util.sshfuncs import eintr_retry, rexec, rcopy, rspawn, \
+        rcheckpid, rstatus, rkill, RUNNING, FINISHED 
+
 import hashlib
+import logging
 import os
 import re
+import tempfile
 
-class SSHAPI(object):
-    def __init__(self, host, user, identity, port, agent, forward_x11):
+class SSHApi(object):
+    def __init__(self, host, user, port, identity, agent, forward_x11):
         self.host = host
         self.user = user
         # ssh identity file
@@ -14,6 +20,10 @@ class SSHAPI(object):
         # forward X11 
         self.forward_x11 = forward_x11
 
+        self._pm = None
+        
+        self._logger = logging.getLogger("neco.linux.SSHApi")
+
     # TODO: Investigate using http://nixos.org/nix/
     @property
     def pm(self):
@@ -25,7 +35,7 @@ class SSHAPI(object):
             self._logger.error(msg)
             raise RuntimeError(msg)
 
-        out = self.execute("cat /etc/issue")
+        out, err = self.execute("cat /etc/issue")
 
         if out.find("Fedora") == 0:
             self._pm = "yum"
@@ -40,7 +50,7 @@ class SSHAPI(object):
 
     @property
     def is_localhost(self):
-        return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1']
+        return self.host in ['localhost', '127.0.0.7', '::1']
 
     # TODO: Investigate using http://nixos.org/nix/
     def install(self, packages):
@@ -67,20 +77,27 @@ class SSHAPI(object):
 
            dst  destination path on the remote host (remote is always self.host)
         """
-        # If src is a string input 
-        if not os.path.isfile(src) and not isdir:
-            # src is text input that should be uploaded as file           
-            src = cStringIO.StringIO(src)
+        # If source is a string input 
+        if not os.path.isfile(src):
+            # src is text input that should be uploaded as file
+            # create a temporal file with the content to upload
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(src)
+            f.close()
+            src = f.name
 
         if not self.is_localhost:
             # Build destination as <user>@<server>:<path>
             dst = "%s@%s:%s" % (self.user, self.host, dst)
-        return self.copy(src, dst)
+
+        ret = self.copy(src, dst)
+
+        return ret
 
     def download(self, src, dst):
         if not self.is_localhost:
             # Build destination as <user>@<server>:<path>
-            src = "%s@%s:%s" % (self.user, self.host or self.ip, src)
+            src = "%s@%s:%s" % (self.user, self.host, src)
         return self.copy(src, dst)
         
     def is_alive(self, verbose = False):
@@ -88,7 +105,7 @@ class SSHAPI(object):
             return True
 
         try:
-            out = self.execute("echo 'ALIVE'",
+            (out, err) = self.execute("echo 'ALIVE'",
                 timeout = 60,
                 err_on_timeout = False,
                 persistent = False)
@@ -132,11 +149,11 @@ class SSHAPI(object):
                 src, dst, 
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.identity_file)
+                identity = self.identity)
 
             if proc.wait():
                 msg = "Error uploading to %s got:\n%s%s" %\
-                        (self.host or self.ip, out, err)
+                        (self.host, out, err)
                 self._logger.error(msg)
                 raise RuntimeError(msg)
 
@@ -172,15 +189,15 @@ class SSHAPI(object):
         else:
             (out, err), proc = eintr_retry(rexec)(
                     command, 
-                    self.host or self.ip
+                    self.host, 
                     self.user,
                     port = self.port, 
-                    agent = self.forward_agent,
+                    agent = self.agent,
                     sudo = sudo,
                     stdin = stdin, 
-                    identity_file = self.identity_file,
+                    identity = self.identity,
                     tty = tty,
-                    x11 = self.enable_x11,
+                    x11 = self.forward_x11,
                     env = env,
                     timeout = timeout,
                     retry = retry,
@@ -190,10 +207,9 @@ class SSHAPI(object):
 
             if proc.wait():
                 msg = "Failed to execute command %s at node %s: %s %s" % \
-                        (command, self.host or self.ip, out, err,)
+                        (command, self.host, out, err,)
                 self._logger.warn(msg)
                 raise RuntimeError(msg)
-
         return (out, err)
 
     def run(self, command, home, 
@@ -203,7 +219,7 @@ class SSHAPI(object):
             sudo = False):
         self._logger.info("Running %s", command)
         
-        pidfile = './pid',
+        pidfile = './pid'
 
         if self.is_localhost:
             if stderr == stdout:
@@ -247,7 +263,7 @@ class SSHAPI(object):
                 user = self.user,
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.file
+                identity = self.identity
                 )
             
             if proc.wait():
@@ -258,13 +274,13 @@ class SSHAPI(object):
     def checkpid(self, path):            
         # Get PID/PPID
         # NOTE: wait a bit for the pidfile to be created
-        pidtuple = rcheck_pid(
+        pidtuple = rcheckpid(
             os.path.join(path, 'pid'),
             host = self.host,
             user = self.user,
             port = self.port,
             agent = self.agent,
-            identity_file = self.identity
+            identity = self.identity
             )
         
         return pidtuple
@@ -276,7 +292,7 @@ class SSHAPI(object):
                 user = self.user,
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.identity
+                identity = self.identity
                 )
            
         return status
@@ -292,27 +308,22 @@ class SSHAPI(object):
                 port = self.port,
                 agent = self.agent,
                 sudo = sudo,
-                identity_file = self.identity
+                identity = self.identity
                 )
 
-class SSHAPIFactory(object):
+class SSHApiFactory(object):
     _apis = dict()
 
     @classmethod 
-    def get_api(cls, attributes):
-        host = attributes.get("hostname")
-        user = attributes.get("username")
-        identity = attributes.get("identity", "%s/.ssh/id_rsa" % os.environ['HOME'])
-        port = attributes.get("port", 22)
-        agent = attributes.get("agent", True)
-        forward_X11 = attributes.get("forwardX11", False)
-
-        key = cls.make_key(host, user, identity, port, agent, forward_X11)
-        api = self._apis.get(key)
-
-        if no api:
-            api = SSHAPI(host, user, identity, port, agent, forward_X11)
-            self._apis[key] = api
+    def get_api(cls, host, user, port = 22, identity = None, 
+            agent = True, forward_X11 = False):
+        key = cls.make_key(host, user, port, agent, forward_X11)
+        api = cls._apis.get(key)
+
+        if not api:
+            api = SSHApi(host, user, port, identity, agent, forward_X11)
+            cls._apis[key] = api
 
         return api
 
index f8d1cfc..dd04ba6 100644 (file)
@@ -105,14 +105,23 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
-def make_connkey(user, host, port):
-    connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
+def make_connkey(user, host, port, x11, agent):
+    # It is important to consider the x11 and agent forwarding
+    # parameters when creating the connection key since the parameters
+    # used for the first ssh connection will determine the
+    # parameters of all subsequent connections using the same key
+    x11 = 1 if x11 else 0
+    agent = 1 if agent else 0
+
+    connkey = repr((user, host, port, x11, agent)
+            ).encode("base64").strip().replace('/','.')
+
     if len(connkey) > 60:
         connkey = hashlib.sha1(connkey).hexdigest()
     return connkey
 
-def make_control_path(user, host, port):
-    connkey = make_connkey(user, host, port)
+def make_control_path(user, host, port, x11, agent):
+    connkey = make_connkey(user, host, port, x11, agent)
     return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
 
 def rexec(command, host, user, 
@@ -120,7 +129,7 @@ def rexec(command, host, user,
         agent = True,
         sudo = False,
         stdin = None,
-        identity_file = None,
+        identity = None,
         env = None,
         tty = False,
         x11 = False,
@@ -145,7 +154,7 @@ def rexec(command, host, user,
             '-l', user, host]
 
     if persistent and openssh_has_persist():
-        control_path = make_control_path(user, host, port)
+        control_path = make_control_path(user, host, port, x11, agent)
         args.extend([
             '-o', 'ControlMaster=auto',
             '-o', 'ControlPath=%s' % control_path,
@@ -154,8 +163,8 @@ def rexec(command, host, user,
         args.append('-A')
     if port:
         args.append('-p%d' % port)
-    if identity_file:
-        args.extend(('-i', identity_file))
+    if identity:
+        args.extend(('-i', identity))
     if tty:
         args.append('-t')
         if sudo:
@@ -201,7 +210,7 @@ def rexec(command, host, user,
 def rcopy(source, dest,
         port = None, 
         agent = True, 
-        identity_file = None):
+        identity = None):
     """
     Copies file from/to remote sites.
     
@@ -228,15 +237,18 @@ def rcopy(source, dest,
     raw_string += r''' -o ConnectionAttempts=3 '''
  
     if openssh_has_persist():
-        control_path = make_control_path(user, host, port)
+        control_path = make_control_path(user, host, port, False, agent)
         raw_string += r''' -o ControlMaster=auto '''
         raw_string += r''' -o ControlPath=%s ''' % control_path
-  
+    if agent:
+        raw_string += r''' -A '''
+
     if port:
         raw_string += r''' -p %d ''' % port
     
-    if identity_file:
-        raw_string += r''' -i "%s" ''' % identity_file
+    if identity:
+        raw_string += r''' -i "%s" ''' % identity
     
     # closing -e 'ssh...'
     raw_string += r''' ' '''
@@ -271,7 +283,7 @@ def rspawn(command, pidfile,
         user = None, 
         agent = None, 
         sudo = False,
-        identity_file = None, 
+        identity = None, 
         tty = False):
     """
     Spawn a remote command such that it will continue working asynchronously.
@@ -293,7 +305,7 @@ def rspawn(command, pidfile,
         
         sudo: whether the command needs to be executed as root
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         (stdout, stderr), process
@@ -334,7 +346,7 @@ def rspawn(command, pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file,
+        identity = identity,
         tty = tty
         )
     
@@ -344,19 +356,19 @@ def rspawn(command, pidfile,
     return (out,err),proc
 
 @eintr_retry
-def rcheck_pid(pidfile,
+def rcheckpid(pidfile,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None):
     """
     Check the pidfile of a process spawned with remote_spawn.
     
     Parameters:
         pidfile: the pidfile passed to remote_span
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -372,7 +384,7 @@ def rcheck_pid(pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity
         )
         
     if proc.wait():
@@ -391,14 +403,14 @@ def rstatus(pid, ppid,
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None):
     """
     Check the status of a process spawned with remote_spawn.
     
     Parameters:
         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -415,7 +427,7 @@ def rstatus(pid, ppid,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity
         )
     
     if proc.wait():
@@ -440,7 +452,7 @@ def rkill(pid, ppid,
         user = None, 
         agent = None, 
         sudo = False,
-        identity_file = None, 
+        identity = None, 
         nowait = False):
     """
     Kill a process spawned with remote_spawn.
@@ -453,7 +465,7 @@ def rkill(pid, ppid,
         
         sudo: whether the command was run with sudo - careful killing like this.
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -494,7 +506,7 @@ fi
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity
         )
     
     # wait, don't leave zombies around
diff --git a/test/resources/linux/ssh_api.py b/test/resources/linux/ssh_api.py
new file mode 100644 (file)
index 0000000..88b0365
--- /dev/null
@@ -0,0 +1,174 @@
+#!/usr/bin/env python
+from neco.resources.linux.ssh_api import SSHApiFactory
+from neco.util.sshfuncs import RUNNING, FINISHED
+
+import os
+import time
+import tempfile
+import unittest
+
+def skipIfNotAlive(func):
+    name = func.__name__
+    def wrapped(*args, **kwargs):
+        host = args[1]
+        user = args[2]
+
+        api = SSHApiFactory.get_api(host, user)
+        if not api.is_alive():
+            print "*** WARNING: Skipping test %s: Node %s is not alive\n" % (name, host)
+            return
+
+        return func(*args, **kwargs)
+    
+    return wrapped
+
+def skipInteractive(func):
+    name = func.__name__
+    def wrapped(*args, **kwargs):
+        mode = os.environ.get("NEPI_INTERACTIVE", False) in ['True', 'true', 'yes', 'YES']
+        if not mode:
+            print "*** WARNING: Skipping test %s: Interactive mode off \n" % name
+            return
+
+        return func(*args, **kwargs)
+    
+    return wrapped
+
+class SSHApiTestCase(unittest.TestCase):
+    def setUp(self):
+        self.host_fedora = 'nepi2.pl.sophia.inria.fr'
+        self.user_fedora = 'inria_nepi'
+
+        self.host_ubuntu = 'roseval.pl.sophia.inria.fr'
+        self.user_ubuntu = 'alina'
+        
+        self.target = 'nepi5.pl.sophia.inria.fr'
+        self.home = '${HOME}/test-app'
+
+    @skipIfNotAlive
+    def t_xterm(self, host, user):
+        api = SSHApiFactory.get_api(host, user)
+
+        api.enable_x11 = True
+
+        api.install('xterm')
+
+        out = api.execute('xterm')
+
+        api.uninstall('xterm')
+
+        self.assertEquals(out, "")
+
+    @skipIfNotAlive
+    def t_execute(self, host, user):
+        api = SSHApiFactory.get_api(host, user)
+        
+        command = "ping -qc3 %s" % self.target
+        out, err = api.execute(command)
+
+        expected = """3 packets transmitted, 3 received, 0% packet loss"""
+
+        self.assertTrue(out.find(expected) > 0)
+
+    @skipIfNotAlive
+    def t_run(self, host, user):
+        api = SSHApiFactory.get_api(host, user)
+        
+        api.mkdir(self.home, clean = True)
+        
+        command = "ping %s" % self.target
+        dst = os.path.join(self.home, "app.sh")
+        api.upload(command, dst)
+        
+        cmd = "bash ./app.sh"
+        api.run(cmd, self.home)
+        pid, ppid = api.checkpid(self.home)
+
+        status = api.status(pid, ppid)
+        self.assertTrue(status, RUNNING)
+
+        api.kill(pid, ppid)
+        status = api.status(pid, ppid)
+        self.assertTrue(status, FINISHED)
+
+        api.rmdir(self.home)
+
+    @skipIfNotAlive
+    def t_install(self, host, user):
+        api = SSHApiFactory.get_api(host, user)
+        
+        api.mkdir(self.home, clean = True)
+
+        prog = """#include <stdio.h>
+
+int
+main (void)
+{
+    printf ("Hello, world!\\n");
+    return 0;
+}
+"""
+        # upload the test program
+        dst = os.path.join(self.home, "hello.c")
+        api.upload(prog, dst)
+
+        # install gcc
+        api.install('gcc')
+
+        # compile the program using gcc
+        command = "cd %s; gcc -Wall hello.c -o hello" % self.home
+        out = api.execute(command)
+
+        # execute the program and get the output from stout
+        command = "%s/hello" % self.home
+        out, err = api.execute(command)
+
+        # execute the program and get the output from a file
+        command = "%s/hello > %s/hello.out" % (self.home, self.home)
+        api.execute(command)
+
+        # retrieve the output file 
+        src = os.path.join(self.home, "hello.out")
+        f = tempfile.NamedTemporaryFile(delete=False)
+        dst = f.name
+        api.download(src, dst)
+        f.close()
+
+        api.uninstall('gcc')
+        api.rmdir(self.home)
+
+        self.assertEquals(out, "Hello, world!\n")
+
+        f = open(dst, "r")
+        out = f.read()
+        f.close()
+        
+        self.assertEquals(out, "Hello, world!\n")
+
+    def test_execute_fedora(self):
+        self.t_execute(self.host_fedora, self.user_fedora)
+
+    def test_execute_ubuntu(self):
+        self.t_execute(self.host_ubuntu, self.user_ubuntu)
+
+    def test_run_fedora(self):
+        self.t_run(self.host_fedora, self.user_fedora)
+
+    def test_run_ubuntu(self):
+        self.t_run(self.host_ubuntu, self.user_ubuntu)
+
+    def test_intall_fedora(self):
+        self.t_install(self.host_fedora, self.user_fedora)
+
+    def test_install_ubuntu(self):
+        self.t_install(self.host_ubuntu, self.user_ubuntu)
+    
+    @skipInteractive
+    def test_xterm_ubuntu(self):
+        """ Interactive test. Should not run automatically """
+        self.t_xterm(self.host_ubuntu, self.user_ubuntu)
+
+
+if __name__ == '__main__':
+    unittest.main()
+
index 756c79a..9b282ce 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 
-from neco.util.sshfuncs import *
+from neco.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\
+        RUNNING, FINISHED 
 
 import getpass
 import unittest
@@ -211,7 +212,7 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         time.sleep(2)
 
-        (pid, ppid) = rcheck_pid(pidfile,
+        (pid, ppid) = rcheckpid(pidfile,
                 host = host,
                 user = user,
                 port = env.port,