#!/usr/bin/env python from nepi.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\ RUNNING, FINISHED import getpass import unittest import os import subprocess import re import signal import shutil import socket import subprocess import tempfile import time def find_bin(name, extra_path = None): search = [] if "PATH" in os.environ: search += os.environ["PATH"].split(":") for pref in ("/", "/usr/", "/usr/local/"): for d in ("bin", "sbin"): search.append(pref + d) if extra_path: search += extra_path for d in search: try: os.stat(d + "/" + name) return d + "/" + name except OSError, e: if e.errno != os.errno.ENOENT: raise return None def find_bin_or_die(name, extra_path = None): r = find_bin(name) if not r: raise RuntimeError(("Cannot find `%s' command, impossible to " + "continue.") % name) return r def gen_ssh_keypair(filename): ssh_keygen = find_bin_or_die("ssh-keygen") args = [ssh_keygen, '-q', '-N', '', '-f', filename] assert subprocess.Popen(args).wait() == 0 return filename, "%s.pub" % filename def add_key_to_agent(filename): ssh_add = find_bin_or_die("ssh-add") args = [ssh_add, filename] null = file("/dev/null", "w") assert subprocess.Popen(args, stderr = null).wait() == 0 null.close() def get_free_port(): s = socket.socket() s.bind(("127.0.0.1", 0)) port = s.getsockname()[1] return port _SSH_CONF = """ListenAddress 127.0.0.1:%d Protocol 2 HostKey %s UsePrivilegeSeparation no PubkeyAuthentication yes PasswordAuthentication no AuthorizedKeysFile %s UsePAM no AllowAgentForwarding yes PermitRootLogin yes StrictModes no PermitUserEnvironment yes """ def gen_sshd_config(filename, port, server_key, auth_keys): conf = open(filename, "w") text = _SSH_CONF % (port, server_key, auth_keys) conf.write(text) conf.close() return filename def gen_auth_keys(pubkey, output, environ): #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup opts = [] for k, v in environ.items(): opts.append('environment="%s=%s"' % (k, v)) lines = file(pubkey).readlines() pubkey = lines[0].split()[0:2] out = file(output, "w") out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1])) out.close() return output def start_ssh_agent(): ssh_agent = find_bin_or_die("ssh-agent") proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE) (out, foo) = proc.communicate() assert proc.returncode == 0 d = {} for l in out.split("\n"): match = re.search("^(\w+)=([^ ;]+);.*", l) if not match: continue k, v = match.groups() os.environ[k] = v d[k] = v return d def stop_ssh_agent(data): # No need to gather the pid, ssh-agent knows how to kill itself; after we # had set up the environment ssh_agent = find_bin_or_die("ssh-agent") null = file("/dev/null", "w") proc = subprocess.Popen([ssh_agent, "-k"], stdout = null) null.close() assert proc.wait() == 0 for k in data: del os.environ[k] class test_environment(object): def __init__(self): sshd = find_bin_or_die("sshd") environ = {} self.dir = tempfile.mkdtemp() self.server_keypair = gen_ssh_keypair( os.path.join(self.dir, "server_key")) self.client_keypair = gen_ssh_keypair( os.path.join(self.dir, "client_key")) self.authorized_keys = gen_auth_keys(self.client_keypair[1], os.path.join(self.dir, "authorized_keys"), environ) self.port = get_free_port() self.sshd_conf = gen_sshd_config( os.path.join(self.dir, "sshd_config"), self.port, self.server_keypair[0], self.authorized_keys) self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf]) self.ssh_agent_vars = start_ssh_agent() add_key_to_agent(self.client_keypair[0]) def __del__(self): if self.sshd: os.kill(self.sshd.pid, signal.SIGTERM) self.sshd.wait() if self.ssh_agent_vars: stop_ssh_agent(self.ssh_agent_vars) shutil.rmtree(self.dir) class SSHfuncsTestCase(unittest.TestCase): def test_rexec(self): env = test_environment() user = getpass.getuser() host = "localhost" command = "hostname" plocal = subprocess.Popen(command, stdout=subprocess.PIPE, stdin=subprocess.PIPE) outlocal, errlocal = plocal.communicate() (outremote, errrmote), premote = rexec(command, host, user, port = env.port, agent = True) self.assertEquals(outlocal, outremote) def test_rcopy(self): env = test_environment() user = getpass.getuser() host = "localhost" # create some temp files and directories to copy dirpath = tempfile.mkdtemp() f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False) f.close() f1 = tempfile.NamedTemporaryFile(delete=False) f1.close() f1.name source = [dirpath, f1.name] destdir = tempfile.mkdtemp() dest = "%s@%s:%s" % (user, host, destdir) rcopy(source, dest, port = env.port, agent = True, recursive = True) files = [] def recls(files, dirname, names): files.extend(names) os.path.walk(destdir, recls, files) origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name]) self.assertEquals(sorted(origfiles), sorted(files)) def test_rproc_manage(self): env = test_environment() user = getpass.getuser() host = "localhost" command = "ping localhost" f = tempfile.NamedTemporaryFile(delete=False) pidfile = f.name (out,err), proc = rspawn( command, pidfile, host = host, user = user, port = env.port, agent = True) time.sleep(2) (pid, ppid) = rcheckpid(pidfile, host = host, user = user, port = env.port, agent = True) status = rstatus(pid, ppid, host = host, user = user, port = env.port, agent = True) self.assertEquals(status, RUNNING) rkill(pid, ppid, host = host, user = user, port = env.port, agent = True) status = rstatus(pid, ppid, host = host, user = user, port = env.port, agent = True) self.assertEquals(status, FINISHED) if __name__ == '__main__': unittest.main()