Adding authors and correcting licence information
[nepi.git] / src / nepi / util / environ.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import ctypes
20 import imp
21 import sys
22
23 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
24
25 __all__ =  ["python", "ssh_path"]
26 __all__ += ["rsh", "tcpdump_path", "sshd_path"]
27 __all__ += ["execute", "backticks"]
28
29
30 # Unittest from Python 2.6 doesn't have these decorators
31 def _bannerwrap(f, text):
32     name = f.__name__
33     def banner(*args, **kwargs):
34         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
35                 (name, text))
36         return None
37     return banner
38
39 def skip(text):
40     return lambda f: _bannerwrap(f, text)
41
42 def skipUnless(cond, text):
43     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
44
45 def skipIf(cond, text):
46     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
47
48 def find_bin(name, extra_path = None):
49     search = []
50     if "PATH" in os.environ:
51         search += os.environ["PATH"].split(":")
52     for pref in ("/", "/usr/", "/usr/local/"):
53         for d in ("bin", "sbin"):
54             search.append(pref + d)
55     if extra_path:
56         search += extra_path
57
58     for d in search:
59             try:
60                 os.stat(d + "/" + name)
61                 return d + "/" + name
62             except OSError, e:
63                 if e.errno != os.errno.ENOENT:
64                     raise
65     return None
66
67 def find_bin_or_die(name, extra_path = None):
68     r = find_bin(name)
69     if not r:
70         raise RuntimeError(("Cannot find `%s' command, impossible to " +
71                 "continue.") % name)
72     return r
73
74 def find_bin(name, extra_path = None):
75     search = []
76     if "PATH" in os.environ:
77         search += os.environ["PATH"].split(":")
78     for pref in ("/", "/usr/", "/usr/local/"):
79         for d in ("bin", "sbin"):
80             search.append(pref + d)
81     if extra_path:
82         search += extra_path
83
84     for d in search:
85             try:
86                 os.stat(d + "/" + name)
87                 return d + "/" + name
88             except OSError, e:
89                 if e.errno != os.errno.ENOENT:
90                     raise
91     return None
92
93 ssh_path = find_bin_or_die("ssh")
94 python_path = find_bin_or_die("python")
95
96 # Optional tools
97 rsh_path = find_bin("rsh")
98 tcpdump_path = find_bin("tcpdump")
99 sshd_path = find_bin("sshd")
100
101 def execute(cmd):
102     # FIXME: create a global debug variable
103     #print "[pid %d]" % os.getpid(), " ".join(cmd)
104     null = open("/dev/null", "r+")
105     p = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
106     out, err = p.communicate()
107     if p.returncode != 0:
108         raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
109
110 def backticks(cmd):
111     p = subprocess.Popen(cmd, stdout = subprocess.PIPE,
112             stderr = subprocess.PIPE)
113     out, err = p.communicate()
114     if p.returncode != 0:
115         raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
116     return out
117
118
119 # SSH stuff
120
121 def gen_ssh_keypair(filename):
122     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
123     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
124     assert subprocess.Popen(args).wait() == 0
125     return filename, "%s.pub" % filename
126
127 def add_key_to_agent(filename):
128     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
129     args = [ssh_add, filename]
130     null = file("/dev/null", "w")
131     assert subprocess.Popen(args, stderr = null).wait() == 0
132     null.close()
133
134 def get_free_port():
135     s = socket.socket()
136     s.bind(("127.0.0.1", 0))
137     port = s.getsockname()[1]
138     return port
139
140 _SSH_CONF = """ListenAddress 127.0.0.1:%d
141 Protocol 2
142 HostKey %s
143 UsePrivilegeSeparation no
144 PubkeyAuthentication yes
145 PasswordAuthentication no
146 AuthorizedKeysFile %s
147 UsePAM no
148 AllowAgentForwarding yes
149 PermitRootLogin yes
150 StrictModes no
151 PermitUserEnvironment yes
152 """
153
154 def gen_sshd_config(filename, port, server_key, auth_keys):
155     conf = open(filename, "w")
156     text = _SSH_CONF % (port, server_key, auth_keys)
157     conf.write(text)
158     conf.close()
159     return filename
160
161 def gen_auth_keys(pubkey, output, environ):
162     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
163     opts = []
164     for k, v in environ.items():
165         opts.append('environment="%s=%s"' % (k, v))
166
167     lines = file(pubkey).readlines()
168     pubkey = lines[0].split()[0:2]
169     out = file(output, "w")
170     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
171     out.close()
172     return output
173
174 def start_ssh_agent():
175     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
176     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
177     (out, foo) = proc.communicate()
178     assert proc.returncode == 0
179     d = {}
180     for l in out.split("\n"):
181         match = re.search("^(\w+)=([^ ;]+);.*", l)
182         if not match:
183             continue
184         k, v = match.groups()
185         os.environ[k] = v
186         d[k] = v
187     return d
188
189 def stop_ssh_agent(data):
190     # No need to gather the pid, ssh-agent knows how to kill itself; after we
191     # had set up the environment
192     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
193     null = file("/dev/null", "w")
194     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
195     null.close()
196     assert proc.wait() == 0
197     for k in data:
198         del os.environ[k]
199