Merging ns-3 into nepi-3-dev
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2 #
3 #    NEPI, a framework to manage network experiments
4 #    Copyright (C) 2013 INRIA
5 #
6 #    This program is free software: you can redistribute it and/or modify
7 #    it under the terms of the GNU General Public License as published by
8 #    the Free Software Foundation, either version 3 of the License, or
9 #    (at your option) any later version.
10 #
11 #    This program is distributed in the hope that it will be useful,
12 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 #    GNU General Public License for more details.
15 #
16 #    You should have received a copy of the GNU General Public License
17 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
20
21
22 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
23         ProcStatus
24
25 import getpass
26 import unittest
27 import os
28 import subprocess
29 import re
30 import signal
31 import shutil
32 import socket
33 import subprocess
34 import tempfile
35 import time
36
37 def find_bin(name, extra_path = None):
38     search = []
39     if "PATH" in os.environ:
40         search += os.environ["PATH"].split(":")
41     for pref in ("/", "/usr/", "/usr/local/"):
42         for d in ("bin", "sbin"):
43             search.append(pref + d)
44     if extra_path:
45         search += extra_path
46
47     for d in search:
48             try:
49                 os.stat(d + "/" + name)
50                 return d + "/" + name
51             except OSError, e:
52                 if e.errno != os.errno.ENOENT:
53                     raise
54     return None
55
56 def find_bin_or_die(name, extra_path = None):
57     r = find_bin(name)
58     if not r:
59         raise RuntimeError(("Cannot find `%s' command, impossible to " +
60                 "continue.") % name)
61     return r
62
63 def gen_ssh_keypair(filename):
64     ssh_keygen = find_bin_or_die("ssh-keygen")
65     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
66     assert subprocess.Popen(args).wait() == 0
67     return filename, "%s.pub" % filename
68
69 def add_key_to_agent(filename):
70     ssh_add = find_bin_or_die("ssh-add")
71     args = [ssh_add, filename]
72     null = file("/dev/null", "w")
73     assert subprocess.Popen(args, stderr = null).wait() == 0
74     null.close()
75
76 def get_free_port():
77     s = socket.socket()
78     s.bind(("127.0.0.1", 0))
79     port = s.getsockname()[1]
80     return port
81
82 _SSH_CONF = """ListenAddress 127.0.0.1:%d
83 Protocol 2
84 HostKey %s
85 UsePrivilegeSeparation no
86 PubkeyAuthentication yes
87 PasswordAuthentication no
88 AuthorizedKeysFile %s
89 UsePAM no
90 AllowAgentForwarding yes
91 PermitRootLogin yes
92 StrictModes no
93 PermitUserEnvironment yes
94 """
95
96 def gen_sshd_config(filename, port, server_key, auth_keys):
97     conf = open(filename, "w")
98     text = _SSH_CONF % (port, server_key, auth_keys)
99     conf.write(text)
100     conf.close()
101     return filename
102
103 def gen_auth_keys(pubkey, output, environ):
104     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
105     opts = []
106     for k, v in environ.items():
107         opts.append('environment="%s=%s"' % (k, v))
108
109     lines = file(pubkey).readlines()
110     pubkey = lines[0].split()[0:2]
111     out = file(output, "w")
112     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
113     out.close()
114     return output
115
116 def start_ssh_agent():
117     ssh_agent = find_bin_or_die("ssh-agent")
118     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
119     (out, foo) = proc.communicate()
120     assert proc.returncode == 0
121     d = {}
122     for l in out.split("\n"):
123         match = re.search("^(\w+)=([^ ;]+);.*", l)
124         if not match:
125             continue
126         k, v = match.groups()
127         os.environ[k] = v
128         d[k] = v
129     return d
130
131 def stop_ssh_agent(data):
132     # No need to gather the pid, ssh-agent knows how to kill itself; after we
133     # had set up the environment
134     ssh_agent = find_bin_or_die("ssh-agent")
135     null = file("/dev/null", "w")
136     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
137     null.close()
138     assert proc.wait() == 0
139     for k in data:
140         del os.environ[k]
141
142 class test_environment(object):
143     def __init__(self):
144         sshd = find_bin_or_die("sshd")
145         environ = {}
146         self.dir = tempfile.mkdtemp()
147         self.server_keypair = gen_ssh_keypair(
148                 os.path.join(self.dir, "server_key"))
149         self.client_keypair = gen_ssh_keypair(
150                 os.path.join(self.dir, "client_key"))
151         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
152                 os.path.join(self.dir, "authorized_keys"), environ)
153         self.port = get_free_port()
154         self.sshd_conf = gen_sshd_config(
155                 os.path.join(self.dir, "sshd_config"),
156                 self.port, self.server_keypair[0], self.authorized_keys)
157
158         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
159         self.ssh_agent_vars = start_ssh_agent()
160         add_key_to_agent(self.client_keypair[0])
161
162     def __del__(self):
163         if self.sshd:
164             os.kill(self.sshd.pid, signal.SIGTERM)
165             self.sshd.wait()
166         if self.ssh_agent_vars:
167             stop_ssh_agent(self.ssh_agent_vars)
168         shutil.rmtree(self.dir)
169
170 class SSHfuncsTestCase(unittest.TestCase):
171     def test_rexec(self):
172         env = test_environment()
173         user = getpass.getuser()
174         host = "localhost" 
175
176         command = "hostname"
177
178         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
179                 stdin=subprocess.PIPE)
180         outlocal, errlocal = plocal.communicate()
181
182         (outremote, errrmote), premote = rexec(command, host, user, 
183                 port = env.port, agent = True)
184
185         self.assertEquals(outlocal, outremote)
186
187     def test_rcopy_list(self):
188         env = test_environment()
189         user = getpass.getuser()
190         host = "localhost"
191
192         # create some temp files and directories to copy
193         dirpath = tempfile.mkdtemp()
194         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
195         f.close()
196       
197         f1 = tempfile.NamedTemporaryFile(delete=False)
198         f1.close()
199         f1.name
200
201         source = [dirpath, f1.name]
202         destdir = tempfile.mkdtemp()
203         dest = "%s@%s:%s" % (user, host, destdir)
204         rcopy(source, dest, port = env.port, agent = True, recursive = True)
205
206         files = []
207         def recls(files, dirname, names):
208             files.extend(names)
209         os.path.walk(destdir, recls, files)
210         
211         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
212
213         self.assertEquals(sorted(origfiles), sorted(files))
214
215         os.remove(f1.name)
216         shutil.rmtree(dirpath)
217
218     def test_rcopy_slist(self):
219         env = test_environment()
220         user = getpass.getuser()
221         host = "localhost"
222
223         # create some temp files and directories to copy
224         dirpath = tempfile.mkdtemp()
225         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
226         f.close()
227       
228         f1 = tempfile.NamedTemporaryFile(delete=False)
229         f1.close()
230         f1.name
231
232         source = "%s;%s" % (dirpath, f1.name)
233         destdir = tempfile.mkdtemp()
234         dest = "%s@%s:%s" % (user, host, destdir)
235         rcopy(source, dest, port = env.port, agent = True, recursive = True)
236
237         files = []
238         def recls(files, dirname, names):
239             files.extend(names)
240         os.path.walk(destdir, recls, files)
241         
242         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
243
244         self.assertEquals(sorted(origfiles), sorted(files))
245
246     def test_rproc_manage(self):
247         env = test_environment()
248         user = getpass.getuser()
249         host = "localhost" 
250         command = "ping localhost"
251         
252         f = tempfile.NamedTemporaryFile(delete=False)
253         pidfile = f.name 
254
255         (out,err), proc = rspawn(
256                 command, 
257                 pidfile,
258                 host = host,
259                 user = user,
260                 port = env.port,
261                 agent = True)
262
263         time.sleep(2)
264
265         (pid, ppid) = rgetpid(pidfile,
266                 host = host,
267                 user = user,
268                 port = env.port,
269                 agent = True)
270
271         status = rstatus(pid, ppid,
272                 host = host,
273                 user = user, 
274                 port = env.port, 
275                 agent = True)
276
277         self.assertEquals(status, ProcStatus.RUNNING)
278
279         rkill(pid, ppid,
280                 host = host,
281                 user = user, 
282                 port = env.port, 
283                 agent = True)
284
285         status = rstatus(pid, ppid,
286                 host = host,
287                 user = user, 
288                 port = env.port, 
289                 agent = True)
290         
291         self.assertEquals(status, ProcStatus.FINISHED)
292
293
294 if __name__ == '__main__':
295     unittest.main()
296