Creating ssh api (unfinished)
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2
3 from neco.util.sshfuncs import *
4
5 import getpass
6 import unittest
7 import os
8 import subprocess
9 import re
10 import signal
11 import shutil
12 import socket
13 import subprocess
14 import tempfile
15 import time
16
17 def find_bin(name, extra_path = None):
18     search = []
19     if "PATH" in os.environ:
20         search += os.environ["PATH"].split(":")
21     for pref in ("/", "/usr/", "/usr/local/"):
22         for d in ("bin", "sbin"):
23             search.append(pref + d)
24     if extra_path:
25         search += extra_path
26
27     for d in search:
28             try:
29                 os.stat(d + "/" + name)
30                 return d + "/" + name
31             except OSError, e:
32                 if e.errno != os.errno.ENOENT:
33                     raise
34     return None
35
36 def find_bin_or_die(name, extra_path = None):
37     r = find_bin(name)
38     if not r:
39         raise RuntimeError(("Cannot find `%s' command, impossible to " +
40                 "continue.") % name)
41     return r
42
43 def gen_ssh_keypair(filename):
44     ssh_keygen = find_bin_or_die("ssh-keygen")
45     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
46     assert subprocess.Popen(args).wait() == 0
47     return filename, "%s.pub" % filename
48
49 def add_key_to_agent(filename):
50     ssh_add = find_bin_or_die("ssh-add")
51     args = [ssh_add, filename]
52     null = file("/dev/null", "w")
53     assert subprocess.Popen(args, stderr = null).wait() == 0
54     null.close()
55
56 def get_free_port():
57     s = socket.socket()
58     s.bind(("127.0.0.1", 0))
59     port = s.getsockname()[1]
60     return port
61
62 _SSH_CONF = """ListenAddress 127.0.0.1:%d
63 Protocol 2
64 HostKey %s
65 UsePrivilegeSeparation no
66 PubkeyAuthentication yes
67 PasswordAuthentication no
68 AuthorizedKeysFile %s
69 UsePAM no
70 AllowAgentForwarding yes
71 PermitRootLogin yes
72 StrictModes no
73 PermitUserEnvironment yes
74 """
75
76 def gen_sshd_config(filename, port, server_key, auth_keys):
77     conf = open(filename, "w")
78     text = _SSH_CONF % (port, server_key, auth_keys)
79     conf.write(text)
80     conf.close()
81     return filename
82
83 def gen_auth_keys(pubkey, output, environ):
84     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
85     opts = []
86     for k, v in environ.items():
87         opts.append('environment="%s=%s"' % (k, v))
88
89     lines = file(pubkey).readlines()
90     pubkey = lines[0].split()[0:2]
91     out = file(output, "w")
92     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
93     out.close()
94     return output
95
96 def start_ssh_agent():
97     ssh_agent = find_bin_or_die("ssh-agent")
98     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
99     (out, foo) = proc.communicate()
100     assert proc.returncode == 0
101     d = {}
102     for l in out.split("\n"):
103         match = re.search("^(\w+)=([^ ;]+);.*", l)
104         if not match:
105             continue
106         k, v = match.groups()
107         os.environ[k] = v
108         d[k] = v
109     return d
110
111 def stop_ssh_agent(data):
112     # No need to gather the pid, ssh-agent knows how to kill itself; after we
113     # had set up the environment
114     ssh_agent = find_bin_or_die("ssh-agent")
115     null = file("/dev/null", "w")
116     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
117     null.close()
118     assert proc.wait() == 0
119     for k in data:
120         del os.environ[k]
121
122 class test_environment(object):
123     def __init__(self):
124         sshd = find_bin_or_die("sshd")
125         environ = {}
126         self.dir = tempfile.mkdtemp()
127         self.server_keypair = gen_ssh_keypair(
128                 os.path.join(self.dir, "server_key"))
129         self.client_keypair = gen_ssh_keypair(
130                 os.path.join(self.dir, "client_key"))
131         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
132                 os.path.join(self.dir, "authorized_keys"), environ)
133         self.port = get_free_port()
134         self.sshd_conf = gen_sshd_config(
135                 os.path.join(self.dir, "sshd_config"),
136                 self.port, self.server_keypair[0], self.authorized_keys)
137
138         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
139         self.ssh_agent_vars = start_ssh_agent()
140         add_key_to_agent(self.client_keypair[0])
141
142     def __del__(self):
143         if self.sshd:
144             os.kill(self.sshd.pid, signal.SIGTERM)
145             self.sshd.wait()
146         if self.ssh_agent_vars:
147             stop_ssh_agent(self.ssh_agent_vars)
148         shutil.rmtree(self.dir)
149
150 class SSHfuncsTestCase(unittest.TestCase):
151     def test_rexec(self):
152         env = test_environment()
153         user = getpass.getuser()
154         host = "localhost" 
155
156         command = "hostname"
157
158         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
159                 stdin=subprocess.PIPE)
160         outlocal, errlocal = plocal.communicate()
161
162         (outremote, errrmote), premote = rexec(command, host, user, 
163                 port = env.port, agent = True)
164
165         self.assertEquals(outlocal, outremote)
166
167     def test_rcopy(self):
168         env = test_environment()
169         user = getpass.getuser()
170         host = "localhost"
171
172         # create some temp files and directories to copy
173         dirpath = tempfile.mkdtemp()
174         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
175         f.close()
176       
177         f1 = tempfile.NamedTemporaryFile(delete=False)
178         f1.close()
179         f1.name
180
181         source = [dirpath, f1.name]
182         destdir = tempfile.mkdtemp()
183         dest = "%s@%s:%s" % (user, host, destdir)
184         rcopy(source, dest, port = env.port, agent = True)
185
186         files = []
187         def recls(files, dirname, names):
188             files.extend(names)
189         os.path.walk(destdir, recls, files)
190         
191         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
192
193         self.assertEquals(sorted(origfiles), sorted(files))
194
195     def test_rproc_manage(self):
196         env = test_environment()
197         user = getpass.getuser()
198         host = "localhost" 
199         command = "ping localhost"
200         
201         f = tempfile.NamedTemporaryFile(delete=False)
202         pidfile = f.name 
203
204         (out,err), proc = rspawn(
205                 command, 
206                 pidfile,
207                 host = host,
208                 user = user,
209                 port = env.port,
210                 agent = True)
211
212         time.sleep(2)
213
214         (pid, ppid) = rcheck_pid(pidfile,
215                 host = host,
216                 user = user,
217                 port = env.port,
218                 agent = True)
219
220         status = rstatus(pid, ppid,
221                 host = host,
222                 user = user, 
223                 port = env.port, 
224                 agent = True)
225
226         self.assertEquals(status, RUNNING)
227
228         rkill(pid, ppid,
229                 host = host,
230                 user = user, 
231                 port = env.port, 
232                 agent = True)
233
234         status = rstatus(pid, ppid,
235                 host = host,
236                 user = user, 
237                 port = env.port, 
238                 agent = True)
239         
240         self.assertEquals(status, FINISHED)
241
242
243 if __name__ == '__main__':
244     unittest.main()
245