server daemon launched over ssh connection.
[nepi.git] / test / lib / test_util.py
1 #!/usr/bin/env python
2 # vim:ts=4:sw=4:et:ai:sts=4
3
4 import sys
5 import nepi.util.environ
6
7 # Unittest from Python 2.6 doesn't have these decorators
8 def _bannerwrap(f, text):
9     name = f.__name__
10     def banner(*args, **kwargs):
11         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
12                 (name, text))
13         return None
14     return banner
15 def skip(text):
16     return lambda f: _bannerwrap(f, text)
17 def skipUnless(cond, text):
18     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
19 def skipIf(cond, text):
20     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
21
22 def find_bin(name, extra_path = None):
23     search = []
24     if "PATH" in os.environ:
25         search += os.environ["PATH"].split(":")
26     for pref in ("/", "/usr/", "/usr/local/"):
27         for d in ("bin", "sbin"):
28             search.append(pref + d)
29     if extra_path:
30         search += extra_path
31
32     for d in search:
33             try:
34                 os.stat(d + "/" + name)
35                 return d + "/" + name
36             except OSError, e:
37                 if e.errno != os.errno.ENOENT:
38                     raise
39     return None
40
41 def find_bin_or_die(name, extra_path = None):
42     r = find_bin(name)
43     if not r:
44         raise RuntimeError(("Cannot find `%s' command, impossible to " +
45                 "continue.") % name)
46     return r
47
48 # SSH stuff
49
50 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
51 def gen_ssh_keypair(filename):
52     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
53     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
54     assert subprocess.Popen(args).wait() == 0
55     return filename, "%s.pub" % filename
56
57 def add_key_to_agent(filename):
58     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
59     args = [ssh_add, filename]
60     null = file("/dev/null", "w")
61     assert subprocess.Popen(args, stderr = null).wait() == 0
62     null.close()
63
64 def get_free_port():
65     s = socket.socket()
66     s.bind(("127.0.0.1", 0))
67     port = s.getsockname()[1]
68     return port
69
70 _SSH_CONF = """ListenAddress 127.0.0.1:%d
71 Protocol 2
72 HostKey %s
73 UsePrivilegeSeparation no
74 PubkeyAuthentication yes
75 PasswordAuthentication no
76 AuthorizedKeysFile %s
77 UsePAM no
78 AllowAgentForwarding yes
79 PermitRootLogin yes
80 StrictModes no
81 PermitUserEnvironment yes
82 """
83
84 def gen_sshd_config(filename, port, server_key, auth_keys):
85     conf = open(filename, "w")
86     text = _SSH_CONF % (port, server_key, auth_keys)
87     conf.write(text)
88     conf.close()
89     return filename
90
91 def gen_auth_keys(pubkey, output, environ):
92     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
93     opts = []
94     for k, v in environ.items():
95         opts.append('environment="%s=%s"' % (k, v))
96
97     lines = file(pubkey).readlines()
98     pubkey = lines[0].split()[0:2]
99     out = file(output, "w")
100     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
101     out.close()
102     return output
103
104 def start_ssh_agent():
105     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
106     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
107     (out, foo) = proc.communicate()
108     assert proc.returncode == 0
109     d = {}
110     for l in out.split("\n"):
111         match = re.search("^(\w+)=([^ ;]+);.*", l)
112         if not match:
113             continue
114         k, v = match.groups()
115         os.environ[k] = v
116         d[k] = v
117     return d
118
119 def stop_ssh_agent(data):
120     # No need to gather the pid, ssh-agent knows how to kill itself; after we
121     # had set up the environment
122     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
123     null = file("/dev/null", "w")
124     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
125     null.close()
126     assert proc.wait() == 0
127     for k in data:
128         del os.environ[k]
129
130 class test_environment(object):
131     def __init__(self):
132         sshd = find_bin_or_die("sshd")
133         environ = {}
134         if 'PYTHONPATH' in os.environ:
135             environ['PYTHONPATH'] = ":".join(map(os.path.realpath, 
136                 os.environ['PYTHONPATH'].split(":")))
137
138         self.dir = tempfile.mkdtemp()
139         self.server_keypair = gen_ssh_keypair(
140                 os.path.join(self.dir, "server_key"))
141         self.client_keypair = gen_ssh_keypair(
142                 os.path.join(self.dir, "client_key"))
143         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
144                 os.path.join(self.dir, "authorized_keys"), environ)
145         self.port = get_free_port()
146         self.sshd_conf = gen_sshd_config(
147                 os.path.join(self.dir, "sshd_config"),
148                 self.port, self.server_keypair[0], self.authorized_keys)
149
150         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
151         self.ssh_agent_vars = start_ssh_agent()
152         add_key_to_agent(self.client_keypair[0])
153
154     def __del__(self):
155         if self.sshd:
156             os.kill(self.sshd.pid, signal.SIGTERM)
157             self.sshd.wait()
158         if self.ssh_agent_vars:
159             stop_ssh_agent(self.ssh_agent_vars)
160         shutil.rmtree(self.dir)
161