44c272a2d95be682989aede56ea5503ea6958a6c
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2 #
3 #    NEPI, a framework to manage network experiments
4 #    Copyright (C) 2013 INRIA
5 #
6 #    This program is free software: you can redistribute it and/or modify
7 #    it under the terms of the GNU General Public License version 2 as
8 #    published by the Free Software Foundation;
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20
21 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
22         ProcStatus
23
24 import getpass
25 import unittest
26 import os
27 import subprocess
28 import re
29 import signal
30 import shutil
31 import socket
32 import subprocess
33 import tempfile
34 import time
35
36 def find_bin(name, extra_path = None):
37     search = []
38     if "PATH" in os.environ:
39         search += os.environ["PATH"].split(":")
40     for pref in ("/", "/usr/", "/usr/local/"):
41         for d in ("bin", "sbin"):
42             search.append(pref + d)
43     if extra_path:
44         search += extra_path
45
46     for d in search:
47             try:
48                 os.stat(d + "/" + name)
49                 return d + "/" + name
50             except OSError as e:
51                 if e.errno != os.errno.ENOENT:
52                     raise
53     return None
54
55 def find_bin_or_die(name, extra_path = None):
56     r = find_bin(name)
57     if not r:
58         raise RuntimeError(("Cannot find `%s' command, impossible to " +
59                 "continue.") % name)
60     return r
61
62 def gen_ssh_keypair(filename):
63     ssh_keygen = find_bin_or_die("ssh-keygen")
64     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
65     assert subprocess.Popen(args).wait() == 0
66     return filename, "%s.pub" % filename
67
68 def add_key_to_agent(filename):
69     ssh_add = find_bin_or_die("ssh-add")
70     args = [ssh_add, filename]
71     with open("/dev/null", "w") as null:
72         assert subprocess.Popen(args, stderr = null).wait() == 0
73
74 def get_free_port():
75     s = socket.socket()
76     s.bind(("127.0.0.1", 0))
77     port = s.getsockname()[1]
78     return port
79
80 _SSH_CONF = """ListenAddress 127.0.0.1:%d
81 Protocol 2
82 HostKey %s
83 UsePrivilegeSeparation no
84 PubkeyAuthentication yes
85 PasswordAuthentication no
86 AuthorizedKeysFile %s
87 UsePAM no
88 AllowAgentForwarding yes
89 PermitRootLogin yes
90 StrictModes no
91 PermitUserEnvironment yes
92 """
93
94 def gen_sshd_config(filename, port, server_key, auth_keys):
95     with open(filename, "w") as conf:
96         text = _SSH_CONF % (port, server_key, auth_keys)
97         conf.write(text)
98     return filename
99
100 def gen_auth_keys(pubkey, output, environ):
101     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
102     opts = []
103     for k, v in environ.items():
104         opts.append('environment="%s=%s"' % (k, v))
105
106     with open(pubkey) as f:
107         lines = f.readlines()
108     pubkey = lines[0].split()[0:2]
109     with open(output, "w") as out:
110         out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
111     return output
112
113 def start_ssh_agent():
114     ssh_agent = find_bin_or_die("ssh-agent")
115     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
116     (out, foo) = proc.communicate()
117     assert proc.returncode == 0
118     d = {}
119     for l in out.split("\n"):
120         match = re.search("^(\w+)=([^ ;]+);.*", l)
121         if not match:
122             continue
123         k, v = match.groups()
124         os.environ[k] = v
125         d[k] = v
126     return d
127
128 def stop_ssh_agent(data):
129     # No need to gather the pid, ssh-agent knows how to kill itself; after we
130     # had set up the environment
131     ssh_agent = find_bin_or_die("ssh-agent")
132     with open("/dev/null", "w") as null:
133         proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
134     assert proc.wait() == 0
135     for k in data:
136         del os.environ[k]
137
138 class test_environment(object):
139     def __init__(self):
140         sshd = find_bin_or_die("sshd")
141         environ = {}
142         self.dir = tempfile.mkdtemp()
143         self.server_keypair = gen_ssh_keypair(
144                 os.path.join(self.dir, "server_key"))
145         self.client_keypair = gen_ssh_keypair(
146                 os.path.join(self.dir, "client_key"))
147         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
148                 os.path.join(self.dir, "authorized_keys"), environ)
149         self.port = get_free_port()
150         self.sshd_conf = gen_sshd_config(
151                 os.path.join(self.dir, "sshd_config"),
152                 self.port, self.server_keypair[0], self.authorized_keys)
153
154         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
155         self.ssh_agent_vars = start_ssh_agent()
156         add_key_to_agent(self.client_keypair[0])
157
158     def __del__(self):
159         if self.sshd:
160             os.kill(self.sshd.pid, signal.SIGTERM)
161             self.sshd.wait()
162         if self.ssh_agent_vars:
163             stop_ssh_agent(self.ssh_agent_vars)
164         shutil.rmtree(self.dir)
165
166 class SSHfuncsTestCase(unittest.TestCase):
167     def test_rexec(self):
168         env = test_environment()
169         user = getpass.getuser()
170         host = "localhost" 
171
172         command = "hostname"
173
174         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
175                 stdin=subprocess.PIPE)
176         outlocal, errlocal = plocal.communicate()
177
178         (outremote, errrmote), premote = rexec(command, host, user, 
179                 port = env.port, agent = True)
180
181         self.assertEquals(outlocal, outremote)
182
183     def test_rcopy_list(self):
184         env = test_environment()
185         user = getpass.getuser()
186         host = "localhost"
187
188         # create some temp files and directories to copy
189         dirpath = tempfile.mkdtemp()
190         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
191         f.close()
192       
193         f1 = tempfile.NamedTemporaryFile(delete=False)
194         f1.close()
195         f1.name
196
197         source = [dirpath, f1.name]
198         destdir = tempfile.mkdtemp()
199         dest = "%s@%s:%s" % (user, host, destdir)
200         rcopy(source, dest, port = env.port, agent = True, recursive = True)
201
202         files = []
203         def recls(files, dirname, names):
204             files.extend(names)
205         os.path.walk(destdir, recls, files)
206         
207         origfiles = [os.path.basename(s) for s in [dirpath, f.name, f1.name]]
208
209         self.assertEqual(sorted(origfiles), sorted(files))
210
211         os.remove(f1.name)
212         shutil.rmtree(dirpath)
213
214     def test_rcopy_list(self):
215         env = test_environment()
216         user = getpass.getuser()
217         host = "localhost"
218
219         # create some temp files and directories to copy
220         dirpath = tempfile.mkdtemp()
221         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
222         f.close()
223       
224         f1 = tempfile.NamedTemporaryFile(delete=False)
225         f1.close()
226         f1.name
227
228         # Copy a list of files
229         source = [dirpath, f1.name]
230         destdir = tempfile.mkdtemp()
231         dest = "%s@%s:%s" % (user, host, destdir)
232         ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
233
234         files = []
235         def recls(files, dirname, names):
236             files.extend(names)
237         os.path.walk(destdir, recls, files)
238        
239         origfiles = [os.path.basename(s) for s in [dirpath, f.name, f1.name]]
240
241         self.assertEqual(sorted(origfiles), sorted(files))
242
243     def test_rproc_manage(self):
244         env = test_environment()
245         user = getpass.getuser()
246         host = "localhost" 
247         command = "ping localhost"
248         
249         f = tempfile.NamedTemporaryFile(delete=False)
250         pidfile = f.name 
251
252         (out,err), proc = rspawn(
253                 command, 
254                 pidfile,
255                 host = host,
256                 user = user,
257                 port = env.port,
258                 agent = True)
259
260         time.sleep(2)
261
262         (pid, ppid) = rgetpid(pidfile,
263                 host = host,
264                 user = user,
265                 port = env.port,
266                 agent = True)
267
268         status = rstatus(pid, ppid,
269                 host = host,
270                 user = user, 
271                 port = env.port, 
272                 agent = True)
273
274         self.assertEquals(status, ProcStatus.RUNNING)
275
276         rkill(pid, ppid,
277                 host = host,
278                 user = user, 
279                 port = env.port, 
280                 agent = True)
281
282         status = rstatus(pid, ppid,
283                 host = host,
284                 user = user, 
285                 port = env.port, 
286                 agent = True)
287         
288         self.assertEquals(status, ProcStatus.FINISHED)
289
290
291 if __name__ == '__main__':
292     unittest.main()
293