From f5e3ad5b86d64c712ed0a49faf11d46be2209979 Mon Sep 17 00:00:00 2001 From: Alina Quereilhac Date: Mon, 1 Apr 2013 22:48:25 +0200 Subject: [PATCH] Adding tests for ssh_api --- src/neco/resources/linux/ssh_api.py | 91 ++++++++------- src/neco/util/sshfuncs.py | 64 +++++----- test/resources/linux/ssh_api.py | 174 ++++++++++++++++++++++++++++ test/util/sshfuncs.py | 5 +- 4 files changed, 266 insertions(+), 68 deletions(-) create mode 100644 test/resources/linux/ssh_api.py diff --git a/src/neco/resources/linux/ssh_api.py b/src/neco/resources/linux/ssh_api.py index 93e895ec..7f5009e3 100644 --- a/src/neco/resources/linux/ssh_api.py +++ b/src/neco/resources/linux/ssh_api.py @@ -1,9 +1,15 @@ + +from neco.util.sshfuncs import eintr_retry, rexec, rcopy, rspawn, \ + rcheckpid, rstatus, rkill, RUNNING, FINISHED + import hashlib +import logging import os import re +import tempfile -class SSHAPI(object): - def __init__(self, host, user, identity, port, agent, forward_x11): +class SSHApi(object): + def __init__(self, host, user, port, identity, agent, forward_x11): self.host = host self.user = user # ssh identity file @@ -14,6 +20,10 @@ class SSHAPI(object): # forward X11 self.forward_x11 = forward_x11 + self._pm = None + + self._logger = logging.getLogger("neco.linux.SSHApi") + # TODO: Investigate using http://nixos.org/nix/ @property def pm(self): @@ -25,7 +35,7 @@ class SSHAPI(object): self._logger.error(msg) raise RuntimeError(msg) - out = self.execute("cat /etc/issue") + out, err = self.execute("cat /etc/issue") if out.find("Fedora") == 0: self._pm = "yum" @@ -40,7 +50,7 @@ class SSHAPI(object): @property def is_localhost(self): - return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1'] + return self.host in ['localhost', '127.0.0.7', '::1'] # TODO: Investigate using http://nixos.org/nix/ def install(self, packages): @@ -67,20 +77,27 @@ class SSHAPI(object): dst destination path on the remote host (remote is always self.host) """ - # If src is a string input - if not os.path.isfile(src) and not isdir: - # src is text input that should be uploaded as file - src = cStringIO.StringIO(src) + # If source is a string input + if not os.path.isfile(src): + # src is text input that should be uploaded as file + # create a temporal file with the content to upload + f = tempfile.NamedTemporaryFile(delete=False) + f.write(src) + f.close() + src = f.name if not self.is_localhost: # Build destination as @: dst = "%s@%s:%s" % (self.user, self.host, dst) - return self.copy(src, dst) + + ret = self.copy(src, dst) + + return ret def download(self, src, dst): if not self.is_localhost: # Build destination as @: - src = "%s@%s:%s" % (self.user, self.host or self.ip, src) + src = "%s@%s:%s" % (self.user, self.host, src) return self.copy(src, dst) def is_alive(self, verbose = False): @@ -88,7 +105,7 @@ class SSHAPI(object): return True try: - out = self.execute("echo 'ALIVE'", + (out, err) = self.execute("echo 'ALIVE'", timeout = 60, err_on_timeout = False, persistent = False) @@ -132,11 +149,11 @@ class SSHAPI(object): src, dst, port = self.port, agent = self.agent, - identity_file = self.identity_file) + identity = self.identity) if proc.wait(): msg = "Error uploading to %s got:\n%s%s" %\ - (self.host or self.ip, out, err) + (self.host, out, err) self._logger.error(msg) raise RuntimeError(msg) @@ -172,15 +189,15 @@ class SSHAPI(object): else: (out, err), proc = eintr_retry(rexec)( command, - self.host or self.ip, + self.host, self.user, port = self.port, - agent = self.forward_agent, + agent = self.agent, sudo = sudo, stdin = stdin, - identity_file = self.identity_file, + identity = self.identity, tty = tty, - x11 = self.enable_x11, + x11 = self.forward_x11, env = env, timeout = timeout, retry = retry, @@ -190,10 +207,9 @@ class SSHAPI(object): if proc.wait(): msg = "Failed to execute command %s at node %s: %s %s" % \ - (command, self.host or self.ip, out, err,) + (command, self.host, out, err,) self._logger.warn(msg) raise RuntimeError(msg) - return (out, err) def run(self, command, home, @@ -203,7 +219,7 @@ class SSHAPI(object): sudo = False): self._logger.info("Running %s", command) - pidfile = './pid', + pidfile = './pid' if self.is_localhost: if stderr == stdout: @@ -247,7 +263,7 @@ class SSHAPI(object): user = self.user, port = self.port, agent = self.agent, - identity_file = self.file + identity = self.identity ) if proc.wait(): @@ -258,13 +274,13 @@ class SSHAPI(object): def checkpid(self, path): # Get PID/PPID # NOTE: wait a bit for the pidfile to be created - pidtuple = rcheck_pid( + pidtuple = rcheckpid( os.path.join(path, 'pid'), host = self.host, user = self.user, port = self.port, agent = self.agent, - identity_file = self.identity + identity = self.identity ) return pidtuple @@ -276,7 +292,7 @@ class SSHAPI(object): user = self.user, port = self.port, agent = self.agent, - identity_file = self.identity + identity = self.identity ) return status @@ -292,27 +308,22 @@ class SSHAPI(object): port = self.port, agent = self.agent, sudo = sudo, - identity_file = self.identity + identity = self.identity ) -class SSHAPIFactory(object): +class SSHApiFactory(object): _apis = dict() @classmethod - def get_api(cls, attributes): - host = attributes.get("hostname") - user = attributes.get("username") - identity = attributes.get("identity", "%s/.ssh/id_rsa" % os.environ['HOME']) - port = attributes.get("port", 22) - agent = attributes.get("agent", True) - forward_X11 = attributes.get("forwardX11", False) - - key = cls.make_key(host, user, identity, port, agent, forward_X11) - api = self._apis.get(key) - - if no api: - api = SSHAPI(host, user, identity, port, agent, forward_X11) - self._apis[key] = api + def get_api(cls, host, user, port = 22, identity = None, + agent = True, forward_X11 = False): + + key = cls.make_key(host, user, port, agent, forward_X11) + api = cls._apis.get(key) + + if not api: + api = SSHApi(host, user, port, identity, agent, forward_X11) + cls._apis[key] = api return api diff --git a/src/neco/util/sshfuncs.py b/src/neco/util/sshfuncs.py index f8d1cfc9..dd04ba62 100644 --- a/src/neco/util/sshfuncs.py +++ b/src/neco/util/sshfuncs.py @@ -105,14 +105,23 @@ def eintr_retry(func): return func(*p, **kw) return rv -def make_connkey(user, host, port): - connkey = repr((user,host,port)).encode("base64").strip().replace('/','.') +def make_connkey(user, host, port, x11, agent): + # It is important to consider the x11 and agent forwarding + # parameters when creating the connection key since the parameters + # used for the first ssh connection will determine the + # parameters of all subsequent connections using the same key + x11 = 1 if x11 else 0 + agent = 1 if agent else 0 + + connkey = repr((user, host, port, x11, agent) + ).encode("base64").strip().replace('/','.') + if len(connkey) > 60: connkey = hashlib.sha1(connkey).hexdigest() return connkey -def make_control_path(user, host, port): - connkey = make_connkey(user, host, port) +def make_control_path(user, host, port, x11, agent): + connkey = make_connkey(user, host, port, x11, agent) return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, ) def rexec(command, host, user, @@ -120,7 +129,7 @@ def rexec(command, host, user, agent = True, sudo = False, stdin = None, - identity_file = None, + identity = None, env = None, tty = False, x11 = False, @@ -145,7 +154,7 @@ def rexec(command, host, user, '-l', user, host] if persistent and openssh_has_persist(): - control_path = make_control_path(user, host, port) + control_path = make_control_path(user, host, port, x11, agent) args.extend([ '-o', 'ControlMaster=auto', '-o', 'ControlPath=%s' % control_path, @@ -154,8 +163,8 @@ def rexec(command, host, user, args.append('-A') if port: args.append('-p%d' % port) - if identity_file: - args.extend(('-i', identity_file)) + if identity: + args.extend(('-i', identity)) if tty: args.append('-t') if sudo: @@ -201,7 +210,7 @@ def rexec(command, host, user, def rcopy(source, dest, port = None, agent = True, - identity_file = None): + identity = None): """ Copies file from/to remote sites. @@ -228,15 +237,18 @@ def rcopy(source, dest, raw_string += r''' -o ConnectionAttempts=3 ''' if openssh_has_persist(): - control_path = make_control_path(user, host, port) + control_path = make_control_path(user, host, port, False, agent) raw_string += r''' -o ControlMaster=auto ''' raw_string += r''' -o ControlPath=%s ''' % control_path - + + if agent: + raw_string += r''' -A ''' + if port: raw_string += r''' -p %d ''' % port - if identity_file: - raw_string += r''' -i "%s" ''' % identity_file + if identity: + raw_string += r''' -i "%s" ''' % identity # closing -e 'ssh...' raw_string += r''' ' ''' @@ -271,7 +283,7 @@ def rspawn(command, pidfile, user = None, agent = None, sudo = False, - identity_file = None, + identity = None, tty = False): """ Spawn a remote command such that it will continue working asynchronously. @@ -293,7 +305,7 @@ def rspawn(command, pidfile, sudo: whether the command needs to be executed as root - host/port/user/agent/identity_file: see rexec + host/port/user/agent/identity: see rexec Returns: (stdout, stderr), process @@ -334,7 +346,7 @@ def rspawn(command, pidfile, port = port, user = user, agent = agent, - identity_file = identity_file, + identity = identity, tty = tty ) @@ -344,19 +356,19 @@ def rspawn(command, pidfile, return (out,err),proc @eintr_retry -def rcheck_pid(pidfile, +def rcheckpid(pidfile, host = None, port = None, user = None, agent = None, - identity_file = None): + identity = None): """ Check the pidfile of a process spawned with remote_spawn. Parameters: pidfile: the pidfile passed to remote_span - host/port/user/agent/identity_file: see rexec + host/port/user/agent/identity: see rexec Returns: @@ -372,7 +384,7 @@ def rcheck_pid(pidfile, port = port, user = user, agent = agent, - identity_file = identity_file + identity = identity ) if proc.wait(): @@ -391,14 +403,14 @@ def rstatus(pid, ppid, port = None, user = None, agent = None, - identity_file = None): + identity = None): """ Check the status of a process spawned with remote_spawn. Parameters: pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid - host/port/user/agent/identity_file: see rexec + host/port/user/agent/identity: see rexec Returns: @@ -415,7 +427,7 @@ def rstatus(pid, ppid, port = port, user = user, agent = agent, - identity_file = identity_file + identity = identity ) if proc.wait(): @@ -440,7 +452,7 @@ def rkill(pid, ppid, user = None, agent = None, sudo = False, - identity_file = None, + identity = None, nowait = False): """ Kill a process spawned with remote_spawn. @@ -453,7 +465,7 @@ def rkill(pid, ppid, sudo: whether the command was run with sudo - careful killing like this. - host/port/user/agent/identity_file: see rexec + host/port/user/agent/identity: see rexec Returns: @@ -494,7 +506,7 @@ fi port = port, user = user, agent = agent, - identity_file = identity_file + identity = identity ) # wait, don't leave zombies around diff --git a/test/resources/linux/ssh_api.py b/test/resources/linux/ssh_api.py new file mode 100644 index 00000000..88b03650 --- /dev/null +++ b/test/resources/linux/ssh_api.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +from neco.resources.linux.ssh_api import SSHApiFactory +from neco.util.sshfuncs import RUNNING, FINISHED + +import os +import time +import tempfile +import unittest + +def skipIfNotAlive(func): + name = func.__name__ + def wrapped(*args, **kwargs): + host = args[1] + user = args[2] + + api = SSHApiFactory.get_api(host, user) + if not api.is_alive(): + print "*** WARNING: Skipping test %s: Node %s is not alive\n" % (name, host) + return + + return func(*args, **kwargs) + + return wrapped + +def skipInteractive(func): + name = func.__name__ + def wrapped(*args, **kwargs): + mode = os.environ.get("NEPI_INTERACTIVE", False) in ['True', 'true', 'yes', 'YES'] + if not mode: + print "*** WARNING: Skipping test %s: Interactive mode off \n" % name + return + + return func(*args, **kwargs) + + return wrapped + +class SSHApiTestCase(unittest.TestCase): + def setUp(self): + self.host_fedora = 'nepi2.pl.sophia.inria.fr' + self.user_fedora = 'inria_nepi' + + self.host_ubuntu = 'roseval.pl.sophia.inria.fr' + self.user_ubuntu = 'alina' + + self.target = 'nepi5.pl.sophia.inria.fr' + self.home = '${HOME}/test-app' + + @skipIfNotAlive + def t_xterm(self, host, user): + api = SSHApiFactory.get_api(host, user) + + api.enable_x11 = True + + api.install('xterm') + + out = api.execute('xterm') + + api.uninstall('xterm') + + self.assertEquals(out, "") + + @skipIfNotAlive + def t_execute(self, host, user): + api = SSHApiFactory.get_api(host, user) + + command = "ping -qc3 %s" % self.target + out, err = api.execute(command) + + expected = """3 packets transmitted, 3 received, 0% packet loss""" + + self.assertTrue(out.find(expected) > 0) + + @skipIfNotAlive + def t_run(self, host, user): + api = SSHApiFactory.get_api(host, user) + + api.mkdir(self.home, clean = True) + + command = "ping %s" % self.target + dst = os.path.join(self.home, "app.sh") + api.upload(command, dst) + + cmd = "bash ./app.sh" + api.run(cmd, self.home) + pid, ppid = api.checkpid(self.home) + + status = api.status(pid, ppid) + self.assertTrue(status, RUNNING) + + api.kill(pid, ppid) + status = api.status(pid, ppid) + self.assertTrue(status, FINISHED) + + api.rmdir(self.home) + + @skipIfNotAlive + def t_install(self, host, user): + api = SSHApiFactory.get_api(host, user) + + api.mkdir(self.home, clean = True) + + prog = """#include + +int +main (void) +{ + printf ("Hello, world!\\n"); + return 0; +} +""" + # upload the test program + dst = os.path.join(self.home, "hello.c") + api.upload(prog, dst) + + # install gcc + api.install('gcc') + + # compile the program using gcc + command = "cd %s; gcc -Wall hello.c -o hello" % self.home + out = api.execute(command) + + # execute the program and get the output from stout + command = "%s/hello" % self.home + out, err = api.execute(command) + + # execute the program and get the output from a file + command = "%s/hello > %s/hello.out" % (self.home, self.home) + api.execute(command) + + # retrieve the output file + src = os.path.join(self.home, "hello.out") + f = tempfile.NamedTemporaryFile(delete=False) + dst = f.name + api.download(src, dst) + f.close() + + api.uninstall('gcc') + api.rmdir(self.home) + + self.assertEquals(out, "Hello, world!\n") + + f = open(dst, "r") + out = f.read() + f.close() + + self.assertEquals(out, "Hello, world!\n") + + def test_execute_fedora(self): + self.t_execute(self.host_fedora, self.user_fedora) + + def test_execute_ubuntu(self): + self.t_execute(self.host_ubuntu, self.user_ubuntu) + + def test_run_fedora(self): + self.t_run(self.host_fedora, self.user_fedora) + + def test_run_ubuntu(self): + self.t_run(self.host_ubuntu, self.user_ubuntu) + + def test_intall_fedora(self): + self.t_install(self.host_fedora, self.user_fedora) + + def test_install_ubuntu(self): + self.t_install(self.host_ubuntu, self.user_ubuntu) + + @skipInteractive + def test_xterm_ubuntu(self): + """ Interactive test. Should not run automatically """ + self.t_xterm(self.host_ubuntu, self.user_ubuntu) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/util/sshfuncs.py b/test/util/sshfuncs.py index 756c79a1..9b282ce1 100644 --- a/test/util/sshfuncs.py +++ b/test/util/sshfuncs.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -from neco.util.sshfuncs import * +from neco.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\ + RUNNING, FINISHED import getpass import unittest @@ -211,7 +212,7 @@ class SSHfuncsTestCase(unittest.TestCase): time.sleep(2) - (pid, ppid) = rcheck_pid(pidfile, + (pid, ppid) = rcheckpid(pidfile, host = host, user = user, port = env.port, -- 2.43.0