99f44b0eb18ed2a11781df1428cd826c7d7ba0f8
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2 #
3 #    NEPI, a framework to manage network experiments
4 #    Copyright (C) 2013 INRIA
5 #
6 #    This program is free software: you can redistribute it and/or modify
7 #    it under the terms of the GNU General Public License version 2 as
8 #    published by the Free Software Foundation;
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20
21 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
22         ProcStatus
23
24 import getpass
25 import unittest
26 import os
27 import subprocess
28 import re
29 import signal
30 import shutil
31 import socket
32 import subprocess
33 import tempfile
34 import time
35
36 def find_bin(name, extra_path = None):
37     search = []
38     if "PATH" in os.environ:
39         search += os.environ["PATH"].split(":")
40     for pref in ("/", "/usr/", "/usr/local/"):
41         for d in ("bin", "sbin"):
42             search.append(pref + d)
43     if extra_path:
44         search += extra_path
45
46     for d in search:
47             try:
48                 os.stat(d + "/" + name)
49                 return d + "/" + name
50             except OSError, e:
51                 if e.errno != os.errno.ENOENT:
52                     raise
53     return None
54
55 def find_bin_or_die(name, extra_path = None):
56     r = find_bin(name)
57     if not r:
58         raise RuntimeError(("Cannot find `%s' command, impossible to " +
59                 "continue.") % name)
60     return r
61
62 def gen_ssh_keypair(filename):
63     ssh_keygen = find_bin_or_die("ssh-keygen")
64     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
65     assert subprocess.Popen(args).wait() == 0
66     return filename, "%s.pub" % filename
67
68 def add_key_to_agent(filename):
69     ssh_add = find_bin_or_die("ssh-add")
70     args = [ssh_add, filename]
71     null = file("/dev/null", "w")
72     assert subprocess.Popen(args, stderr = null).wait() == 0
73     null.close()
74
75 def get_free_port():
76     s = socket.socket()
77     s.bind(("127.0.0.1", 0))
78     port = s.getsockname()[1]
79     return port
80
81 _SSH_CONF = """ListenAddress 127.0.0.1:%d
82 Protocol 2
83 HostKey %s
84 UsePrivilegeSeparation no
85 PubkeyAuthentication yes
86 PasswordAuthentication no
87 AuthorizedKeysFile %s
88 UsePAM no
89 AllowAgentForwarding yes
90 PermitRootLogin yes
91 StrictModes no
92 PermitUserEnvironment yes
93 """
94
95 def gen_sshd_config(filename, port, server_key, auth_keys):
96     conf = open(filename, "w")
97     text = _SSH_CONF % (port, server_key, auth_keys)
98     conf.write(text)
99     conf.close()
100     return filename
101
102 def gen_auth_keys(pubkey, output, environ):
103     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
104     opts = []
105     for k, v in environ.items():
106         opts.append('environment="%s=%s"' % (k, v))
107
108     lines = file(pubkey).readlines()
109     pubkey = lines[0].split()[0:2]
110     out = file(output, "w")
111     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
112     out.close()
113     return output
114
115 def start_ssh_agent():
116     ssh_agent = find_bin_or_die("ssh-agent")
117     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
118     (out, foo) = proc.communicate()
119     assert proc.returncode == 0
120     d = {}
121     for l in out.split("\n"):
122         match = re.search("^(\w+)=([^ ;]+);.*", l)
123         if not match:
124             continue
125         k, v = match.groups()
126         os.environ[k] = v
127         d[k] = v
128     return d
129
130 def stop_ssh_agent(data):
131     # No need to gather the pid, ssh-agent knows how to kill itself; after we
132     # had set up the environment
133     ssh_agent = find_bin_or_die("ssh-agent")
134     null = file("/dev/null", "w")
135     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
136     null.close()
137     assert proc.wait() == 0
138     for k in data:
139         del os.environ[k]
140
141 class test_environment(object):
142     def __init__(self):
143         sshd = find_bin_or_die("sshd")
144         environ = {}
145         self.dir = tempfile.mkdtemp()
146         self.server_keypair = gen_ssh_keypair(
147                 os.path.join(self.dir, "server_key"))
148         self.client_keypair = gen_ssh_keypair(
149                 os.path.join(self.dir, "client_key"))
150         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
151                 os.path.join(self.dir, "authorized_keys"), environ)
152         self.port = get_free_port()
153         self.sshd_conf = gen_sshd_config(
154                 os.path.join(self.dir, "sshd_config"),
155                 self.port, self.server_keypair[0], self.authorized_keys)
156
157         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
158         self.ssh_agent_vars = start_ssh_agent()
159         add_key_to_agent(self.client_keypair[0])
160
161     def __del__(self):
162         if self.sshd:
163             os.kill(self.sshd.pid, signal.SIGTERM)
164             self.sshd.wait()
165         if self.ssh_agent_vars:
166             stop_ssh_agent(self.ssh_agent_vars)
167         shutil.rmtree(self.dir)
168
169 class SSHfuncsTestCase(unittest.TestCase):
170     def test_rexec(self):
171         env = test_environment()
172         user = getpass.getuser()
173         host = "localhost" 
174
175         command = "hostname"
176
177         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
178                 stdin=subprocess.PIPE)
179         outlocal, errlocal = plocal.communicate()
180
181         (outremote, errrmote), premote = rexec(command, host, user, 
182                 port = env.port, agent = True)
183
184         self.assertEquals(outlocal, outremote)
185
186     def test_rcopy_list(self):
187         env = test_environment()
188         user = getpass.getuser()
189         host = "localhost"
190
191         # create some temp files and directories to copy
192         dirpath = tempfile.mkdtemp()
193         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
194         f.close()
195       
196         f1 = tempfile.NamedTemporaryFile(delete=False)
197         f1.close()
198         f1.name
199
200         source = [dirpath, f1.name]
201         destdir = tempfile.mkdtemp()
202         dest = "%s@%s:%s" % (user, host, destdir)
203         rcopy(source, dest, port = env.port, agent = True, recursive = True)
204
205         files = []
206         def recls(files, dirname, names):
207             files.extend(names)
208         os.path.walk(destdir, recls, files)
209         
210         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
211
212         self.assertEquals(sorted(origfiles), sorted(files))
213
214         os.remove(f1.name)
215         shutil.rmtree(dirpath)
216
217     def test_rcopy_list(self):
218         env = test_environment()
219         user = getpass.getuser()
220         host = "localhost"
221
222         # create some temp files and directories to copy
223         dirpath = tempfile.mkdtemp()
224         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
225         f.close()
226       
227         f1 = tempfile.NamedTemporaryFile(delete=False)
228         f1.close()
229         f1.name
230
231         # Copy a list of files
232         source = [dirpath, f1.name]
233         destdir = tempfile.mkdtemp()
234         dest = "%s@%s:%s" % (user, host, destdir)
235         ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
236
237         files = []
238         def recls(files, dirname, names):
239             files.extend(names)
240         os.path.walk(destdir, recls, files)
241        
242         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
243
244         self.assertEquals(sorted(origfiles), sorted(files))
245
246     def test_rproc_manage(self):
247         env = test_environment()
248         user = getpass.getuser()
249         host = "localhost" 
250         command = "ping localhost"
251         
252         f = tempfile.NamedTemporaryFile(delete=False)
253         pidfile = f.name 
254
255         (out,err), proc = rspawn(
256                 command, 
257                 pidfile,
258                 host = host,
259                 user = user,
260                 port = env.port,
261                 agent = True)
262
263         time.sleep(2)
264
265         (pid, ppid) = rgetpid(pidfile,
266                 host = host,
267                 user = user,
268                 port = env.port,
269                 agent = True)
270
271         status = rstatus(pid, ppid,
272                 host = host,
273                 user = user, 
274                 port = env.port, 
275                 agent = True)
276
277         self.assertEquals(status, ProcStatus.RUNNING)
278
279         rkill(pid, ppid,
280                 host = host,
281                 user = user, 
282                 port = env.port, 
283                 agent = True)
284
285         status = rstatus(pid, ppid,
286                 host = host,
287                 user = user, 
288                 port = env.port, 
289                 agent = True)
290         
291         self.assertEquals(status, ProcStatus.FINISHED)
292
293
294 if __name__ == '__main__':
295     unittest.main()
296