88f2eb258110710e1d6a5c54b61fc86805faba23
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2 #
3 #    NEPI, a framework to manage network experiments
4 #    Copyright (C) 2013 INRIA
5 #
6 #    This program is free software: you can redistribute it and/or modify
7 #    it under the terms of the GNU General Public License as published by
8 #    the Free Software Foundation, either version 3 of the License, or
9 #    (at your option) any later version.
10 #
11 #    This program is distributed in the hope that it will be useful,
12 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 #    GNU General Public License for more details.
15 #
16 #    You should have received a copy of the GNU General Public License
17 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
20
21
22 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
23         ProcStatus
24
25 import getpass
26 import unittest
27 import os
28 import subprocess
29 import re
30 import signal
31 import shutil
32 import socket
33 import subprocess
34 import tempfile
35 import time
36
37 def find_bin(name, extra_path = None):
38     search = []
39     if "PATH" in os.environ:
40         search += os.environ["PATH"].split(":")
41     for pref in ("/", "/usr/", "/usr/local/"):
42         for d in ("bin", "sbin"):
43             search.append(pref + d)
44     if extra_path:
45         search += extra_path
46
47     for d in search:
48             try:
49                 os.stat(d + "/" + name)
50                 return d + "/" + name
51             except OSError, e:
52                 if e.errno != os.errno.ENOENT:
53                     raise
54     return None
55
56 def find_bin_or_die(name, extra_path = None):
57     r = find_bin(name)
58     if not r:
59         raise RuntimeError(("Cannot find `%s' command, impossible to " +
60                 "continue.") % name)
61     return r
62
63 def gen_ssh_keypair(filename):
64     ssh_keygen = find_bin_or_die("ssh-keygen")
65     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
66     assert subprocess.Popen(args).wait() == 0
67     return filename, "%s.pub" % filename
68
69 def add_key_to_agent(filename):
70     ssh_add = find_bin_or_die("ssh-add")
71     args = [ssh_add, filename]
72     null = file("/dev/null", "w")
73     assert subprocess.Popen(args, stderr = null).wait() == 0
74     null.close()
75
76 def get_free_port():
77     s = socket.socket()
78     s.bind(("127.0.0.1", 0))
79     port = s.getsockname()[1]
80     return port
81
82 _SSH_CONF = """ListenAddress 127.0.0.1:%d
83 Protocol 2
84 HostKey %s
85 UsePrivilegeSeparation no
86 PubkeyAuthentication yes
87 PasswordAuthentication no
88 AuthorizedKeysFile %s
89 UsePAM no
90 AllowAgentForwarding yes
91 PermitRootLogin yes
92 StrictModes no
93 PermitUserEnvironment yes
94 """
95
96 def gen_sshd_config(filename, port, server_key, auth_keys):
97     conf = open(filename, "w")
98     text = _SSH_CONF % (port, server_key, auth_keys)
99     conf.write(text)
100     conf.close()
101     return filename
102
103 def gen_auth_keys(pubkey, output, environ):
104     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
105     opts = []
106     for k, v in environ.items():
107         opts.append('environment="%s=%s"' % (k, v))
108
109     lines = file(pubkey).readlines()
110     pubkey = lines[0].split()[0:2]
111     out = file(output, "w")
112     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
113     out.close()
114     return output
115
116 def start_ssh_agent():
117     ssh_agent = find_bin_or_die("ssh-agent")
118     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
119     (out, foo) = proc.communicate()
120     assert proc.returncode == 0
121     d = {}
122     for l in out.split("\n"):
123         match = re.search("^(\w+)=([^ ;]+);.*", l)
124         if not match:
125             continue
126         k, v = match.groups()
127         os.environ[k] = v
128         d[k] = v
129     return d
130
131 def stop_ssh_agent(data):
132     # No need to gather the pid, ssh-agent knows how to kill itself; after we
133     # had set up the environment
134     ssh_agent = find_bin_or_die("ssh-agent")
135     null = file("/dev/null", "w")
136     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
137     null.close()
138     assert proc.wait() == 0
139     for k in data:
140         del os.environ[k]
141
142 class test_environment(object):
143     def __init__(self):
144         sshd = find_bin_or_die("sshd")
145         environ = {}
146         self.dir = tempfile.mkdtemp()
147         self.server_keypair = gen_ssh_keypair(
148                 os.path.join(self.dir, "server_key"))
149         self.client_keypair = gen_ssh_keypair(
150                 os.path.join(self.dir, "client_key"))
151         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
152                 os.path.join(self.dir, "authorized_keys"), environ)
153         self.port = get_free_port()
154         self.sshd_conf = gen_sshd_config(
155                 os.path.join(self.dir, "sshd_config"),
156                 self.port, self.server_keypair[0], self.authorized_keys)
157
158         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
159         self.ssh_agent_vars = start_ssh_agent()
160         add_key_to_agent(self.client_keypair[0])
161
162     def __del__(self):
163         if self.sshd:
164             os.kill(self.sshd.pid, signal.SIGTERM)
165             self.sshd.wait()
166         if self.ssh_agent_vars:
167             stop_ssh_agent(self.ssh_agent_vars)
168         shutil.rmtree(self.dir)
169
170 class SSHfuncsTestCase(unittest.TestCase):
171     def test_rexec(self):
172         env = test_environment()
173         user = getpass.getuser()
174         host = "localhost" 
175
176         command = "hostname"
177
178         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
179                 stdin=subprocess.PIPE)
180         outlocal, errlocal = plocal.communicate()
181
182         (outremote, errrmote), premote = rexec(command, host, user, 
183                 port = env.port, agent = True)
184
185         self.assertEquals(outlocal, outremote)
186
187     def test_rcopy_list(self):
188         env = test_environment()
189         user = getpass.getuser()
190         host = "localhost"
191
192         # create some temp files and directories to copy
193         dirpath = tempfile.mkdtemp()
194         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
195         f.close()
196       
197         f1 = tempfile.NamedTemporaryFile(delete=False)
198         f1.close()
199         f1.name
200
201         source = [dirpath, f1.name]
202         destdir = tempfile.mkdtemp()
203         dest = "%s@%s:%s" % (user, host, destdir)
204         rcopy(source, dest, port = env.port, agent = True, recursive = True)
205
206         files = []
207         def recls(files, dirname, names):
208             files.extend(names)
209         os.path.walk(destdir, recls, files)
210         
211         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
212
213         self.assertEquals(sorted(origfiles), sorted(files))
214
215         os.remove(f1.name)
216         shutil.rmtree(dirpath)
217
218     def test_rcopy_list(self):
219         env = test_environment()
220         user = getpass.getuser()
221         host = "localhost"
222
223         # create some temp files and directories to copy
224         dirpath = tempfile.mkdtemp()
225         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
226         f.close()
227       
228         f1 = tempfile.NamedTemporaryFile(delete=False)
229         f1.close()
230         f1.name
231
232         # Copy a list of files
233         source = [dirpath, f1.name]
234         destdir = tempfile.mkdtemp()
235         dest = "%s@%s:%s" % (user, host, destdir)
236         ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
237
238         files = []
239         def recls(files, dirname, names):
240             files.extend(names)
241         os.path.walk(destdir, recls, files)
242        
243         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
244
245         self.assertEquals(sorted(origfiles), sorted(files))
246
247     def test_rproc_manage(self):
248         env = test_environment()
249         user = getpass.getuser()
250         host = "localhost" 
251         command = "ping localhost"
252         
253         f = tempfile.NamedTemporaryFile(delete=False)
254         pidfile = f.name 
255
256         (out,err), proc = rspawn(
257                 command, 
258                 pidfile,
259                 host = host,
260                 user = user,
261                 port = env.port,
262                 agent = True)
263
264         time.sleep(2)
265
266         (pid, ppid) = rgetpid(pidfile,
267                 host = host,
268                 user = user,
269                 port = env.port,
270                 agent = True)
271
272         status = rstatus(pid, ppid,
273                 host = host,
274                 user = user, 
275                 port = env.port, 
276                 agent = True)
277
278         self.assertEquals(status, ProcStatus.RUNNING)
279
280         rkill(pid, ppid,
281                 host = host,
282                 user = user, 
283                 port = env.port, 
284                 agent = True)
285
286         status = rstatus(pid, ppid,
287                 host = host,
288                 user = user, 
289                 port = env.port, 
290                 agent = True)
291         
292         self.assertEquals(status, ProcStatus.FINISHED)
293
294
295 if __name__ == '__main__':
296     unittest.main()
297