b40354689ba608f4469b66df3ed120c6e17d2774
[nepi.git] / src / nepi / util / environ.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19 #         Martin Ferrari <martin.ferrari@inria.fr>
20
21
22
23 import ctypes
24 import imp
25 import sys
26
27 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
28
29 __all__ =  ["python", "ssh_path"]
30 __all__ += ["rsh", "tcpdump_path", "sshd_path"]
31 __all__ += ["execute", "backticks"]
32
33
34 # Unittest from Python 2.6 doesn't have these decorators
35 def _bannerwrap(f, text):
36     name = f.__name__
37     def banner(*args, **kwargs):
38         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
39                 (name, text))
40         return None
41     return banner
42
43 def skip(text):
44     return lambda f: _bannerwrap(f, text)
45
46 def skipUnless(cond, text):
47     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
48
49 def skipIf(cond, text):
50     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
51
52 def find_bin(name, extra_path = None):
53     search = []
54     if "PATH" in os.environ:
55         search += os.environ["PATH"].split(":")
56     for pref in ("/", "/usr/", "/usr/local/"):
57         for d in ("bin", "sbin"):
58             search.append(pref + d)
59     if extra_path:
60         search += extra_path
61
62     for d in search:
63             try:
64                 os.stat(d + "/" + name)
65                 return d + "/" + name
66             except OSError, e:
67                 if e.errno != os.errno.ENOENT:
68                     raise
69     return None
70
71 def find_bin_or_die(name, extra_path = None):
72     r = find_bin(name)
73     if not r:
74         raise RuntimeError(("Cannot find `%s' command, impossible to " +
75                 "continue.") % name)
76     return r
77
78 def find_bin(name, extra_path = None):
79     search = []
80     if "PATH" in os.environ:
81         search += os.environ["PATH"].split(":")
82     for pref in ("/", "/usr/", "/usr/local/"):
83         for d in ("bin", "sbin"):
84             search.append(pref + d)
85     if extra_path:
86         search += extra_path
87
88     for d in search:
89             try:
90                 os.stat(d + "/" + name)
91                 return d + "/" + name
92             except OSError, e:
93                 if e.errno != os.errno.ENOENT:
94                     raise
95     return None
96
97 ssh_path = find_bin_or_die("ssh")
98 python_path = find_bin_or_die("python")
99
100 # Optional tools
101 rsh_path = find_bin("rsh")
102 tcpdump_path = find_bin("tcpdump")
103 sshd_path = find_bin("sshd")
104
105 def execute(cmd):
106     # FIXME: create a global debug variable
107     #print "[pid %d]" % os.getpid(), " ".join(cmd)
108     null = open("/dev/null", "r+")
109     p = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
110     out, err = p.communicate()
111     if p.returncode != 0:
112         raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
113
114 def backticks(cmd):
115     p = subprocess.Popen(cmd, stdout = subprocess.PIPE,
116             stderr = subprocess.PIPE)
117     out, err = p.communicate()
118     if p.returncode != 0:
119         raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
120     return out
121
122
123 # SSH stuff
124
125 def gen_ssh_keypair(filename):
126     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
127     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
128     assert subprocess.Popen(args).wait() == 0
129     return filename, "%s.pub" % filename
130
131 def add_key_to_agent(filename):
132     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
133     args = [ssh_add, filename]
134     null = file("/dev/null", "w")
135     assert subprocess.Popen(args, stderr = null).wait() == 0
136     null.close()
137
138 def get_free_port():
139     s = socket.socket()
140     s.bind(("127.0.0.1", 0))
141     port = s.getsockname()[1]
142     return port
143
144 _SSH_CONF = """ListenAddress 127.0.0.1:%d
145 Protocol 2
146 HostKey %s
147 UsePrivilegeSeparation no
148 PubkeyAuthentication yes
149 PasswordAuthentication no
150 AuthorizedKeysFile %s
151 UsePAM no
152 AllowAgentForwarding yes
153 PermitRootLogin yes
154 StrictModes no
155 PermitUserEnvironment yes
156 """
157
158 def gen_sshd_config(filename, port, server_key, auth_keys):
159     conf = open(filename, "w")
160     text = _SSH_CONF % (port, server_key, auth_keys)
161     conf.write(text)
162     conf.close()
163     return filename
164
165 def gen_auth_keys(pubkey, output, environ):
166     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
167     opts = []
168     for k, v in environ.items():
169         opts.append('environment="%s=%s"' % (k, v))
170
171     lines = file(pubkey).readlines()
172     pubkey = lines[0].split()[0:2]
173     out = file(output, "w")
174     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
175     out.close()
176     return output
177
178 def start_ssh_agent():
179     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
180     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
181     (out, foo) = proc.communicate()
182     assert proc.returncode == 0
183     d = {}
184     for l in out.split("\n"):
185         match = re.search("^(\w+)=([^ ;]+);.*", l)
186         if not match:
187             continue
188         k, v = match.groups()
189         os.environ[k] = v
190         d[k] = v
191     return d
192
193 def stop_ssh_agent(data):
194     # No need to gather the pid, ssh-agent knows how to kill itself; after we
195     # had set up the environment
196     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
197     null = file("/dev/null", "w")
198     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
199     null.close()
200     assert proc.wait() == 0
201     for k in data:
202         del os.environ[k]
203