3 # NEPI, a framework to manage network experiments
4 # Copyright (C) 2013 INRIA
6 # This program is free software: you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License version 2 as
8 # published by the Free Software Foundation;
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
21 from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
36 def find_bin(name, extra_path = None):
38 if "PATH" in os.environ:
39 search += os.environ["PATH"].split(":")
40 for pref in ("/", "/usr/", "/usr/local/"):
41 for d in ("bin", "sbin"):
42 search.append(pref + d)
48 os.stat(d + "/" + name)
51 if e.errno != os.errno.ENOENT:
55 def find_bin_or_die(name, extra_path = None):
58 raise RuntimeError(("Cannot find `%s' command, impossible to " +
62 def gen_ssh_keypair(filename):
63 ssh_keygen = find_bin_or_die("ssh-keygen")
64 args = [ssh_keygen, '-q', '-N', '', '-f', filename]
65 assert subprocess.Popen(args).wait() == 0
66 return filename, "%s.pub" % filename
68 def add_key_to_agent(filename):
69 ssh_add = find_bin_or_die("ssh-add")
70 args = [ssh_add, filename]
71 with open("/dev/null", "w") as null:
72 assert subprocess.Popen(args, stderr = null).wait() == 0
76 s.bind(("127.0.0.1", 0))
77 port = s.getsockname()[1]
80 _SSH_CONF = """ListenAddress 127.0.0.1:%d
83 UsePrivilegeSeparation no
84 PubkeyAuthentication yes
85 PasswordAuthentication no
88 AllowAgentForwarding yes
91 PermitUserEnvironment yes
94 def gen_sshd_config(filename, port, server_key, auth_keys):
95 with open(filename, "w") as conf:
96 text = _SSH_CONF % (port, server_key, auth_keys)
100 def gen_auth_keys(pubkey, output, environ):
101 #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
103 for k, v in environ.items():
104 opts.append('environment="%s=%s"' % (k, v))
106 with open(pubkey) as f:
107 lines = f.readlines()
108 pubkey = lines[0].split()[0:2]
109 with open(output, "w") as out:
110 out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
113 def start_ssh_agent():
114 ssh_agent = find_bin_or_die("ssh-agent")
115 proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
116 (out, foo) = proc.communicate()
117 assert proc.returncode == 0
119 for l in out.split("\n"):
120 match = re.search("^(\w+)=([^ ;]+);.*", l)
123 k, v = match.groups()
128 def stop_ssh_agent(data):
129 # No need to gather the pid, ssh-agent knows how to kill itself; after we
130 # had set up the environment
131 ssh_agent = find_bin_or_die("ssh-agent")
132 with open("/dev/null", "w") as null:
133 proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
134 assert proc.wait() == 0
138 class test_environment(object):
140 sshd = find_bin_or_die("sshd")
142 self.dir = tempfile.mkdtemp()
143 self.server_keypair = gen_ssh_keypair(
144 os.path.join(self.dir, "server_key"))
145 self.client_keypair = gen_ssh_keypair(
146 os.path.join(self.dir, "client_key"))
147 self.authorized_keys = gen_auth_keys(self.client_keypair[1],
148 os.path.join(self.dir, "authorized_keys"), environ)
149 self.port = get_free_port()
150 self.sshd_conf = gen_sshd_config(
151 os.path.join(self.dir, "sshd_config"),
152 self.port, self.server_keypair[0], self.authorized_keys)
154 self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
155 self.ssh_agent_vars = start_ssh_agent()
156 add_key_to_agent(self.client_keypair[0])
160 os.kill(self.sshd.pid, signal.SIGTERM)
162 if self.ssh_agent_vars:
163 stop_ssh_agent(self.ssh_agent_vars)
164 shutil.rmtree(self.dir)
166 class SSHfuncsTestCase(unittest.TestCase):
167 def test_rexec(self):
168 env = test_environment()
169 user = getpass.getuser()
174 plocal = subprocess.Popen(command, stdout=subprocess.PIPE,
175 stdin=subprocess.PIPE)
176 outlocal, errlocal = plocal.communicate()
178 (outremote, errrmote), premote = rexec(command, host, user,
179 port = env.port, agent = True)
181 self.assertEqual(outlocal, outremote)
183 def test_rcopy_list(self):
184 env = test_environment()
185 user = getpass.getuser()
188 # create some temp files and directories to copy
189 dirpath = tempfile.mkdtemp()
190 f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
193 f1 = tempfile.NamedTemporaryFile(delete=False)
197 source = [dirpath, f1.name]
198 destdir = tempfile.mkdtemp()
199 dest = "%s@%s:%s" % (user, host, destdir)
200 rcopy(source, dest, port = env.port, agent = True, recursive = True)
203 def recls(files, dirname, names):
205 os.path.walk(destdir, recls, files)
207 origfiles = [os.path.basename(s) for s in [dirpath, f.name, f1.name]]
209 self.assertEqual(sorted(origfiles), sorted(files))
212 shutil.rmtree(dirpath)
214 def test_rcopy_list(self):
215 env = test_environment()
216 user = getpass.getuser()
219 # create some temp files and directories to copy
220 dirpath = tempfile.mkdtemp()
221 f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
224 f1 = tempfile.NamedTemporaryFile(delete=False)
228 # Copy a list of files
229 source = [dirpath, f1.name]
230 destdir = tempfile.mkdtemp()
231 dest = "%s@%s:%s" % (user, host, destdir)
232 ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
235 def recls(files, dirname, names):
237 os.path.walk(destdir, recls, files)
239 origfiles = [os.path.basename(s) for s in [dirpath, f.name, f1.name]]
241 self.assertEqual(sorted(origfiles), sorted(files))
243 def test_rproc_manage(self):
244 env = test_environment()
245 user = getpass.getuser()
247 command = "ping localhost"
249 f = tempfile.NamedTemporaryFile(delete=False)
252 (out,err), proc = rspawn(
262 (pid, ppid) = rgetpid(pidfile,
268 status = rstatus(pid, ppid,
274 self.assertEqual(status, ProcStatus.RUNNING)
282 status = rstatus(pid, ppid,
288 self.assertEqual(status, ProcStatus.FINISHED)
291 if __name__ == '__main__':