Detect alternate NS3 installations (_ns3 module instead of ns3)
[nepi.git] / test / lib / test_util.py
1 #!/usr/bin/env python
2
3 import nepi.util.environ
4 import ctypes
5 import imp
6 import sys
7
8 # Unittest from Python 2.6 doesn't have these decorators
9 def _bannerwrap(f, text):
10     name = f.__name__
11     def banner(*args, **kwargs):
12         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
13                 (name, text))
14         return None
15     return banner
16 def skip(text):
17     return lambda f: _bannerwrap(f, text)
18 def skipUnless(cond, text):
19     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
20 def skipIf(cond, text):
21     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
22
23 def ns3_bindings_path():
24     if "NEPI_NS3BINDINGS" in os.environ:
25         return os.environ["NEPI_NS3BINDINGS"]
26     return None
27
28 def ns3_library_path():
29     if "NEPI_NS3LIBRARY" in os.environ:
30         return os.environ["NEPI_NS3LIBRARY"]
31     return None
32
33 def ns3_usable():
34     if ns3_library_path():
35         try:
36             ctypes.CDLL(ns3_library_path(), ctypes.RTLD_GLOBAL)
37         except:
38             return False
39
40     if ns3_bindings_path():
41         sys.path.insert(0, ns3_bindings_path())
42
43     try:
44         found = imp.find_module('ns3')
45         module = imp.load_module('ns3', *found)
46     except ImportError:
47         try:
48             found = imp.find_module('_ns3')
49             module = imp.load_module('_ns3', *found)
50         except ImportError:
51             return False
52     finally:
53         if ns3_bindings_path():
54             del sys.path[0]
55
56     return True
57
58 def pl_auth():
59     user = os.environ.get('PL_USER')
60     pwd = os.environ.get('PL_PASS')
61      
62     if user and pwd:
63         return (user,pwd)
64     else:
65         return None
66
67 def find_bin(name, extra_path = None):
68     search = []
69     if "PATH" in os.environ:
70         search += os.environ["PATH"].split(":")
71     for pref in ("/", "/usr/", "/usr/local/"):
72         for d in ("bin", "sbin"):
73             search.append(pref + d)
74     if extra_path:
75         search += extra_path
76
77     for d in search:
78             try:
79                 os.stat(d + "/" + name)
80                 return d + "/" + name
81             except OSError, e:
82                 if e.errno != os.errno.ENOENT:
83                     raise
84     return None
85
86 def find_bin_or_die(name, extra_path = None):
87     r = find_bin(name)
88     if not r:
89         raise RuntimeError(("Cannot find `%s' command, impossible to " +
90                 "continue.") % name)
91     return r
92
93 # SSH stuff
94
95 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
96 def gen_ssh_keypair(filename):
97     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
98     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
99     assert subprocess.Popen(args).wait() == 0
100     return filename, "%s.pub" % filename
101
102 def add_key_to_agent(filename):
103     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
104     args = [ssh_add, filename]
105     null = file("/dev/null", "w")
106     assert subprocess.Popen(args, stderr = null).wait() == 0
107     null.close()
108
109 def get_free_port():
110     s = socket.socket()
111     s.bind(("127.0.0.1", 0))
112     port = s.getsockname()[1]
113     return port
114
115 _SSH_CONF = """ListenAddress 127.0.0.1:%d
116 Protocol 2
117 HostKey %s
118 UsePrivilegeSeparation no
119 PubkeyAuthentication yes
120 PasswordAuthentication no
121 AuthorizedKeysFile %s
122 UsePAM no
123 AllowAgentForwarding yes
124 PermitRootLogin yes
125 StrictModes no
126 PermitUserEnvironment yes
127 """
128
129 def gen_sshd_config(filename, port, server_key, auth_keys):
130     conf = open(filename, "w")
131     text = _SSH_CONF % (port, server_key, auth_keys)
132     conf.write(text)
133     conf.close()
134     return filename
135
136 def gen_auth_keys(pubkey, output, environ):
137     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
138     opts = []
139     for k, v in environ.items():
140         opts.append('environment="%s=%s"' % (k, v))
141
142     lines = file(pubkey).readlines()
143     pubkey = lines[0].split()[0:2]
144     out = file(output, "w")
145     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
146     out.close()
147     return output
148
149 def start_ssh_agent():
150     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
151     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
152     (out, foo) = proc.communicate()
153     assert proc.returncode == 0
154     d = {}
155     for l in out.split("\n"):
156         match = re.search("^(\w+)=([^ ;]+);.*", l)
157         if not match:
158             continue
159         k, v = match.groups()
160         os.environ[k] = v
161         d[k] = v
162     return d
163
164 def stop_ssh_agent(data):
165     # No need to gather the pid, ssh-agent knows how to kill itself; after we
166     # had set up the environment
167     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
168     null = file("/dev/null", "w")
169     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
170     null.close()
171     assert proc.wait() == 0
172     for k in data:
173         del os.environ[k]
174
175 class test_environment(object):
176     def __init__(self):
177         sshd = find_bin_or_die("sshd")
178         environ = {}
179         if 'PYTHONPATH' in os.environ:
180             environ['PYTHONPATH'] = ":".join(map(os.path.realpath, 
181                 os.environ['PYTHONPATH'].split(":")))
182         if 'NEPI_NS3BINDINGS' in os.environ:
183             environ['NEPI_NS3BINDINGS'] = \
184                     os.path.realpath(os.environ['NEPI_NS3BINDINGS'])
185         if 'NEPI_NS3LIBRARY' in os.environ:
186             environ['NEPI_NS3LIBRARY'] = \
187                     os.path.realpath(os.environ['NEPI_NS3LIBRARY'])
188
189         self.dir = tempfile.mkdtemp()
190         self.server_keypair = gen_ssh_keypair(
191                 os.path.join(self.dir, "server_key"))
192         self.client_keypair = gen_ssh_keypair(
193                 os.path.join(self.dir, "client_key"))
194         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
195                 os.path.join(self.dir, "authorized_keys"), environ)
196         self.port = get_free_port()
197         self.sshd_conf = gen_sshd_config(
198                 os.path.join(self.dir, "sshd_config"),
199                 self.port, self.server_keypair[0], self.authorized_keys)
200
201         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
202         self.ssh_agent_vars = start_ssh_agent()
203         add_key_to_agent(self.client_keypair[0])
204
205     def __del__(self):
206         if self.sshd:
207             os.kill(self.sshd.pid, signal.SIGTERM)
208             self.sshd.wait()
209         if self.ssh_agent_vars:
210             stop_ssh_agent(self.ssh_agent_vars)
211         shutil.rmtree(self.dir)
212