Adding tests for ssh_api
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2
3 from neco.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\
4         RUNNING, FINISHED 
5
6 import getpass
7 import unittest
8 import os
9 import subprocess
10 import re
11 import signal
12 import shutil
13 import socket
14 import subprocess
15 import tempfile
16 import time
17
18 def find_bin(name, extra_path = None):
19     search = []
20     if "PATH" in os.environ:
21         search += os.environ["PATH"].split(":")
22     for pref in ("/", "/usr/", "/usr/local/"):
23         for d in ("bin", "sbin"):
24             search.append(pref + d)
25     if extra_path:
26         search += extra_path
27
28     for d in search:
29             try:
30                 os.stat(d + "/" + name)
31                 return d + "/" + name
32             except OSError, e:
33                 if e.errno != os.errno.ENOENT:
34                     raise
35     return None
36
37 def find_bin_or_die(name, extra_path = None):
38     r = find_bin(name)
39     if not r:
40         raise RuntimeError(("Cannot find `%s' command, impossible to " +
41                 "continue.") % name)
42     return r
43
44 def gen_ssh_keypair(filename):
45     ssh_keygen = find_bin_or_die("ssh-keygen")
46     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
47     assert subprocess.Popen(args).wait() == 0
48     return filename, "%s.pub" % filename
49
50 def add_key_to_agent(filename):
51     ssh_add = find_bin_or_die("ssh-add")
52     args = [ssh_add, filename]
53     null = file("/dev/null", "w")
54     assert subprocess.Popen(args, stderr = null).wait() == 0
55     null.close()
56
57 def get_free_port():
58     s = socket.socket()
59     s.bind(("127.0.0.1", 0))
60     port = s.getsockname()[1]
61     return port
62
63 _SSH_CONF = """ListenAddress 127.0.0.1:%d
64 Protocol 2
65 HostKey %s
66 UsePrivilegeSeparation no
67 PubkeyAuthentication yes
68 PasswordAuthentication no
69 AuthorizedKeysFile %s
70 UsePAM no
71 AllowAgentForwarding yes
72 PermitRootLogin yes
73 StrictModes no
74 PermitUserEnvironment yes
75 """
76
77 def gen_sshd_config(filename, port, server_key, auth_keys):
78     conf = open(filename, "w")
79     text = _SSH_CONF % (port, server_key, auth_keys)
80     conf.write(text)
81     conf.close()
82     return filename
83
84 def gen_auth_keys(pubkey, output, environ):
85     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
86     opts = []
87     for k, v in environ.items():
88         opts.append('environment="%s=%s"' % (k, v))
89
90     lines = file(pubkey).readlines()
91     pubkey = lines[0].split()[0:2]
92     out = file(output, "w")
93     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
94     out.close()
95     return output
96
97 def start_ssh_agent():
98     ssh_agent = find_bin_or_die("ssh-agent")
99     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
100     (out, foo) = proc.communicate()
101     assert proc.returncode == 0
102     d = {}
103     for l in out.split("\n"):
104         match = re.search("^(\w+)=([^ ;]+);.*", l)
105         if not match:
106             continue
107         k, v = match.groups()
108         os.environ[k] = v
109         d[k] = v
110     return d
111
112 def stop_ssh_agent(data):
113     # No need to gather the pid, ssh-agent knows how to kill itself; after we
114     # had set up the environment
115     ssh_agent = find_bin_or_die("ssh-agent")
116     null = file("/dev/null", "w")
117     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
118     null.close()
119     assert proc.wait() == 0
120     for k in data:
121         del os.environ[k]
122
123 class test_environment(object):
124     def __init__(self):
125         sshd = find_bin_or_die("sshd")
126         environ = {}
127         self.dir = tempfile.mkdtemp()
128         self.server_keypair = gen_ssh_keypair(
129                 os.path.join(self.dir, "server_key"))
130         self.client_keypair = gen_ssh_keypair(
131                 os.path.join(self.dir, "client_key"))
132         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
133                 os.path.join(self.dir, "authorized_keys"), environ)
134         self.port = get_free_port()
135         self.sshd_conf = gen_sshd_config(
136                 os.path.join(self.dir, "sshd_config"),
137                 self.port, self.server_keypair[0], self.authorized_keys)
138
139         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
140         self.ssh_agent_vars = start_ssh_agent()
141         add_key_to_agent(self.client_keypair[0])
142
143     def __del__(self):
144         if self.sshd:
145             os.kill(self.sshd.pid, signal.SIGTERM)
146             self.sshd.wait()
147         if self.ssh_agent_vars:
148             stop_ssh_agent(self.ssh_agent_vars)
149         shutil.rmtree(self.dir)
150
151 class SSHfuncsTestCase(unittest.TestCase):
152     def test_rexec(self):
153         env = test_environment()
154         user = getpass.getuser()
155         host = "localhost" 
156
157         command = "hostname"
158
159         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
160                 stdin=subprocess.PIPE)
161         outlocal, errlocal = plocal.communicate()
162
163         (outremote, errrmote), premote = rexec(command, host, user, 
164                 port = env.port, agent = True)
165
166         self.assertEquals(outlocal, outremote)
167
168     def test_rcopy(self):
169         env = test_environment()
170         user = getpass.getuser()
171         host = "localhost"
172
173         # create some temp files and directories to copy
174         dirpath = tempfile.mkdtemp()
175         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
176         f.close()
177       
178         f1 = tempfile.NamedTemporaryFile(delete=False)
179         f1.close()
180         f1.name
181
182         source = [dirpath, f1.name]
183         destdir = tempfile.mkdtemp()
184         dest = "%s@%s:%s" % (user, host, destdir)
185         rcopy(source, dest, port = env.port, agent = True)
186
187         files = []
188         def recls(files, dirname, names):
189             files.extend(names)
190         os.path.walk(destdir, recls, files)
191         
192         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
193
194         self.assertEquals(sorted(origfiles), sorted(files))
195
196     def test_rproc_manage(self):
197         env = test_environment()
198         user = getpass.getuser()
199         host = "localhost" 
200         command = "ping localhost"
201         
202         f = tempfile.NamedTemporaryFile(delete=False)
203         pidfile = f.name 
204
205         (out,err), proc = rspawn(
206                 command, 
207                 pidfile,
208                 host = host,
209                 user = user,
210                 port = env.port,
211                 agent = True)
212
213         time.sleep(2)
214
215         (pid, ppid) = rcheckpid(pidfile,
216                 host = host,
217                 user = user,
218                 port = env.port,
219                 agent = True)
220
221         status = rstatus(pid, ppid,
222                 host = host,
223                 user = user, 
224                 port = env.port, 
225                 agent = True)
226
227         self.assertEquals(status, RUNNING)
228
229         rkill(pid, ppid,
230                 host = host,
231                 user = user, 
232                 port = env.port, 
233                 agent = True)
234
235         status = rstatus(pid, ppid,
236                 host = host,
237                 user = user, 
238                 port = env.port, 
239                 agent = True)
240         
241         self.assertEquals(status, FINISHED)
242
243
244 if __name__ == '__main__':
245     unittest.main()
246