Adding authors and correcting licence information
[nepi.git] / test / util / sshfuncs.py
1 #!/usr/bin/env python
2 #
3 #    NEPI, a framework to manage network experiments
4 #    Copyright (C) 2013 INRIA
5 #
6 #    This program is free software: you can redistribute it and/or modify
7 #    it under the terms of the GNU General Public License as published by
8 #    the Free Software Foundation, either version 3 of the License, or
9 #    (at your option) any later version.
10 #
11 #    This program is distributed in the hope that it will be useful,
12 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 #    GNU General Public License for more details.
15 #
16 #    You should have received a copy of the GNU General Public License
17 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
20
21
22 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\
23         RUNNING, FINISHED 
24
25 import getpass
26 import unittest
27 import os
28 import subprocess
29 import re
30 import signal
31 import shutil
32 import socket
33 import subprocess
34 import tempfile
35 import time
36
37 def find_bin(name, extra_path = None):
38     search = []
39     if "PATH" in os.environ:
40         search += os.environ["PATH"].split(":")
41     for pref in ("/", "/usr/", "/usr/local/"):
42         for d in ("bin", "sbin"):
43             search.append(pref + d)
44     if extra_path:
45         search += extra_path
46
47     for d in search:
48             try:
49                 os.stat(d + "/" + name)
50                 return d + "/" + name
51             except OSError, e:
52                 if e.errno != os.errno.ENOENT:
53                     raise
54     return None
55
56 def find_bin_or_die(name, extra_path = None):
57     r = find_bin(name)
58     if not r:
59         raise RuntimeError(("Cannot find `%s' command, impossible to " +
60                 "continue.") % name)
61     return r
62
63 def gen_ssh_keypair(filename):
64     ssh_keygen = find_bin_or_die("ssh-keygen")
65     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
66     assert subprocess.Popen(args).wait() == 0
67     return filename, "%s.pub" % filename
68
69 def add_key_to_agent(filename):
70     ssh_add = find_bin_or_die("ssh-add")
71     args = [ssh_add, filename]
72     null = file("/dev/null", "w")
73     assert subprocess.Popen(args, stderr = null).wait() == 0
74     null.close()
75
76 def get_free_port():
77     s = socket.socket()
78     s.bind(("127.0.0.1", 0))
79     port = s.getsockname()[1]
80     return port
81
82 _SSH_CONF = """ListenAddress 127.0.0.1:%d
83 Protocol 2
84 HostKey %s
85 UsePrivilegeSeparation no
86 PubkeyAuthentication yes
87 PasswordAuthentication no
88 AuthorizedKeysFile %s
89 UsePAM no
90 AllowAgentForwarding yes
91 PermitRootLogin yes
92 StrictModes no
93 PermitUserEnvironment yes
94 """
95
96 def gen_sshd_config(filename, port, server_key, auth_keys):
97     conf = open(filename, "w")
98     text = _SSH_CONF % (port, server_key, auth_keys)
99     conf.write(text)
100     conf.close()
101     return filename
102
103 def gen_auth_keys(pubkey, output, environ):
104     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
105     opts = []
106     for k, v in environ.items():
107         opts.append('environment="%s=%s"' % (k, v))
108
109     lines = file(pubkey).readlines()
110     pubkey = lines[0].split()[0:2]
111     out = file(output, "w")
112     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
113     out.close()
114     return output
115
116 def start_ssh_agent():
117     ssh_agent = find_bin_or_die("ssh-agent")
118     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
119     (out, foo) = proc.communicate()
120     assert proc.returncode == 0
121     d = {}
122     for l in out.split("\n"):
123         match = re.search("^(\w+)=([^ ;]+);.*", l)
124         if not match:
125             continue
126         k, v = match.groups()
127         os.environ[k] = v
128         d[k] = v
129     return d
130
131 def stop_ssh_agent(data):
132     # No need to gather the pid, ssh-agent knows how to kill itself; after we
133     # had set up the environment
134     ssh_agent = find_bin_or_die("ssh-agent")
135     null = file("/dev/null", "w")
136     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
137     null.close()
138     assert proc.wait() == 0
139     for k in data:
140         del os.environ[k]
141
142 class test_environment(object):
143     def __init__(self):
144         sshd = find_bin_or_die("sshd")
145         environ = {}
146         self.dir = tempfile.mkdtemp()
147         self.server_keypair = gen_ssh_keypair(
148                 os.path.join(self.dir, "server_key"))
149         self.client_keypair = gen_ssh_keypair(
150                 os.path.join(self.dir, "client_key"))
151         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
152                 os.path.join(self.dir, "authorized_keys"), environ)
153         self.port = get_free_port()
154         self.sshd_conf = gen_sshd_config(
155                 os.path.join(self.dir, "sshd_config"),
156                 self.port, self.server_keypair[0], self.authorized_keys)
157
158         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
159         self.ssh_agent_vars = start_ssh_agent()
160         add_key_to_agent(self.client_keypair[0])
161
162     def __del__(self):
163         if self.sshd:
164             os.kill(self.sshd.pid, signal.SIGTERM)
165             self.sshd.wait()
166         if self.ssh_agent_vars:
167             stop_ssh_agent(self.ssh_agent_vars)
168         shutil.rmtree(self.dir)
169
170 class SSHfuncsTestCase(unittest.TestCase):
171     def test_rexec(self):
172         env = test_environment()
173         user = getpass.getuser()
174         host = "localhost" 
175
176         command = "hostname"
177
178         plocal = subprocess.Popen(command, stdout=subprocess.PIPE, 
179                 stdin=subprocess.PIPE)
180         outlocal, errlocal = plocal.communicate()
181
182         (outremote, errrmote), premote = rexec(command, host, user, 
183                 port = env.port, agent = True)
184
185         self.assertEquals(outlocal, outremote)
186
187     def test_rcopy(self):
188         env = test_environment()
189         user = getpass.getuser()
190         host = "localhost"
191
192         # create some temp files and directories to copy
193         dirpath = tempfile.mkdtemp()
194         f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
195         f.close()
196       
197         f1 = tempfile.NamedTemporaryFile(delete=False)
198         f1.close()
199         f1.name
200
201         source = [dirpath, f1.name]
202         destdir = tempfile.mkdtemp()
203         dest = "%s@%s:%s" % (user, host, destdir)
204         rcopy(source, dest, port = env.port, agent = True, recursive = True)
205
206         files = []
207         def recls(files, dirname, names):
208             files.extend(names)
209         os.path.walk(destdir, recls, files)
210         
211         origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
212
213         self.assertEquals(sorted(origfiles), sorted(files))
214
215     def test_rproc_manage(self):
216         env = test_environment()
217         user = getpass.getuser()
218         host = "localhost" 
219         command = "ping localhost"
220         
221         f = tempfile.NamedTemporaryFile(delete=False)
222         pidfile = f.name 
223
224         (out,err), proc = rspawn(
225                 command, 
226                 pidfile,
227                 host = host,
228                 user = user,
229                 port = env.port,
230                 agent = True)
231
232         time.sleep(2)
233
234         (pid, ppid) = rcheckpid(pidfile,
235                 host = host,
236                 user = user,
237                 port = env.port,
238                 agent = True)
239
240         status = rstatus(pid, ppid,
241                 host = host,
242                 user = user, 
243                 port = env.port, 
244                 agent = True)
245
246         self.assertEquals(status, RUNNING)
247
248         rkill(pid, ppid,
249                 host = host,
250                 user = user, 
251                 port = env.port, 
252                 agent = True)
253
254         status = rstatus(pid, ppid,
255                 host = host,
256                 user = user, 
257                 port = env.port, 
258                 agent = True)
259         
260         self.assertEquals(status, FINISHED)
261
262
263 if __name__ == '__main__':
264     unittest.main()
265