1f321c4f7aa66d660f3a4be438c6ae883c7a304b
[nepi.git] / test / lib / test_util.py
1 #!/usr/bin/env python
2
3 import nepi.util.environ
4 import imp
5 import sys
6
7 # Unittest from Python 2.6 doesn't have these decorators
8 def _bannerwrap(f, text):
9     name = f.__name__
10     def banner(*args, **kwargs):
11         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
12                 (name, text))
13         return None
14     return banner
15 def skip(text):
16     return lambda f: _bannerwrap(f, text)
17 def skipUnless(cond, text):
18     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
19 def skipIf(cond, text):
20     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
21
22 def ns3_bindings_path():
23     if "NEPI_NS3BINDINGS" in os.environ:
24         return os.environ["NEPI_NS3BINDINGS"]
25     return None
26
27 def ns3_library_path():
28     if "NEPI_NS3LIBRARY" in os.environ:
29         return os.environ["NEPI_NS3LIBRARY"]
30     return None
31
32 def autoconfig_ns3_backend(conf):
33     if ns3_bindings_path():
34         conf.set_attribute_value("Ns3Bindings", ns3_bindings_path())
35     if ns3_library_path():
36         conf.set_attribute_value("Ns3Library", ns3_library_path())
37
38 def ns3_usable():
39     if ns3_library_path():
40         try:
41             ctypes.CDLL(ns3_library_path(), ctypes.RTLD_GLOBAL)
42         except:
43             return False
44     if ns3_bindings_path():
45         sys.path.insert(0, ns3_bindings_path())
46
47     try:
48         found = imp.find_module('ns3')
49         module = imp.load_module('ns3', *found)
50     except ImportError:
51         return False
52     finally:
53         if ns3_bindings_path():
54             del sys.path[0]
55
56     return True
57
58
59 def find_bin(name, extra_path = None):
60     search = []
61     if "PATH" in os.environ:
62         search += os.environ["PATH"].split(":")
63     for pref in ("/", "/usr/", "/usr/local/"):
64         for d in ("bin", "sbin"):
65             search.append(pref + d)
66     if extra_path:
67         search += extra_path
68
69     for d in search:
70             try:
71                 os.stat(d + "/" + name)
72                 return d + "/" + name
73             except OSError, e:
74                 if e.errno != os.errno.ENOENT:
75                     raise
76     return None
77
78 def find_bin_or_die(name, extra_path = None):
79     r = find_bin(name)
80     if not r:
81         raise RuntimeError(("Cannot find `%s' command, impossible to " +
82                 "continue.") % name)
83     return r
84
85 # SSH stuff
86
87 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
88 def gen_ssh_keypair(filename):
89     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
90     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
91     assert subprocess.Popen(args).wait() == 0
92     return filename, "%s.pub" % filename
93
94 def add_key_to_agent(filename):
95     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
96     args = [ssh_add, filename]
97     null = file("/dev/null", "w")
98     assert subprocess.Popen(args, stderr = null).wait() == 0
99     null.close()
100
101 def get_free_port():
102     s = socket.socket()
103     s.bind(("127.0.0.1", 0))
104     port = s.getsockname()[1]
105     return port
106
107 _SSH_CONF = """ListenAddress 127.0.0.1:%d
108 Protocol 2
109 HostKey %s
110 UsePrivilegeSeparation no
111 PubkeyAuthentication yes
112 PasswordAuthentication no
113 AuthorizedKeysFile %s
114 UsePAM no
115 AllowAgentForwarding yes
116 PermitRootLogin yes
117 StrictModes no
118 PermitUserEnvironment yes
119 """
120
121 def gen_sshd_config(filename, port, server_key, auth_keys):
122     conf = open(filename, "w")
123     text = _SSH_CONF % (port, server_key, auth_keys)
124     conf.write(text)
125     conf.close()
126     return filename
127
128 def gen_auth_keys(pubkey, output, environ):
129     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
130     opts = []
131     for k, v in environ.items():
132         opts.append('environment="%s=%s"' % (k, v))
133
134     lines = file(pubkey).readlines()
135     pubkey = lines[0].split()[0:2]
136     out = file(output, "w")
137     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
138     out.close()
139     return output
140
141 def start_ssh_agent():
142     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
143     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
144     (out, foo) = proc.communicate()
145     assert proc.returncode == 0
146     d = {}
147     for l in out.split("\n"):
148         match = re.search("^(\w+)=([^ ;]+);.*", l)
149         if not match:
150             continue
151         k, v = match.groups()
152         os.environ[k] = v
153         d[k] = v
154     return d
155
156 def stop_ssh_agent(data):
157     # No need to gather the pid, ssh-agent knows how to kill itself; after we
158     # had set up the environment
159     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
160     null = file("/dev/null", "w")
161     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
162     null.close()
163     assert proc.wait() == 0
164     for k in data:
165         del os.environ[k]
166
167 class test_environment(object):
168     def __init__(self):
169         sshd = find_bin_or_die("sshd")
170         environ = {}
171         if 'PYTHONPATH' in os.environ:
172             environ['PYTHONPATH'] = ":".join(map(os.path.realpath, 
173                 os.environ['PYTHONPATH'].split(":")))
174         if 'NEPI_NS3BINDINGS' in os.environ:
175             environ['NEPI_NS3BINDINGS'] = \
176                     os.path.realpath(os.environ['NEPI_NS3BINDINGS'])
177         if 'NEPI_NS3LIBRARY' in os.environ:
178             environ['NEPI_NS3LIBRARY'] = \
179                     os.path.realpath(os.environ['NEPI_NS3LIBRARY'])
180
181         self.dir = tempfile.mkdtemp()
182         self.server_keypair = gen_ssh_keypair(
183                 os.path.join(self.dir, "server_key"))
184         self.client_keypair = gen_ssh_keypair(
185                 os.path.join(self.dir, "client_key"))
186         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
187                 os.path.join(self.dir, "authorized_keys"), environ)
188         self.port = get_free_port()
189         self.sshd_conf = gen_sshd_config(
190                 os.path.join(self.dir, "sshd_config"),
191                 self.port, self.server_keypair[0], self.authorized_keys)
192
193         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
194         self.ssh_agent_vars = start_ssh_agent()
195         add_key_to_agent(self.client_keypair[0])
196
197     def __del__(self):
198         if self.sshd:
199             os.kill(self.sshd.pid, signal.SIGTERM)
200             self.sshd.wait()
201         if self.ssh_agent_vars:
202             stop_ssh_agent(self.ssh_agent_vars)
203         shutil.rmtree(self.dir)
204