minor changes in tests
[nepi.git] / test / lib / test_util.py
1 #!/usr/bin/env python
2
3 import nepi.util.environ
4 import imp
5 import sys
6
7 # Unittest from Python 2.6 doesn't have these decorators
8 def _bannerwrap(f, text):
9     name = f.__name__
10     def banner(*args, **kwargs):
11         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
12                 (name, text))
13         return None
14     return banner
15 def skip(text):
16     return lambda f: _bannerwrap(f, text)
17 def skipUnless(cond, text):
18     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
19 def skipIf(cond, text):
20     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
21
22 def ns3_bindings_path():
23     if "NEPI_NS3BINDINGS" in os.environ:
24         return os.environ["NEPI_NS3BINDINGS"]
25     return None
26
27 def ns3_library_path():
28     if "NEPI_NS3LIBRARY" in os.environ:
29         return os.environ["NEPI_NS3LIBRARY"]
30     return None
31
32 def ns3_usable():
33     if ns3_library_path():
34         try:
35             ctypes.CDLL(ns3_library_path(), ctypes.RTLD_GLOBAL)
36         except:
37             return False
38
39     if ns3_bindings_path():
40         sys.path.insert(0, ns3_bindings_path())
41
42     try:
43         found = imp.find_module('ns3')
44         module = imp.load_module('ns3', *found)
45     except ImportError:
46         return False
47     finally:
48         if ns3_bindings_path():
49             del sys.path[0]
50
51     return True
52
53
54 def find_bin(name, extra_path = None):
55     search = []
56     if "PATH" in os.environ:
57         search += os.environ["PATH"].split(":")
58     for pref in ("/", "/usr/", "/usr/local/"):
59         for d in ("bin", "sbin"):
60             search.append(pref + d)
61     if extra_path:
62         search += extra_path
63
64     for d in search:
65             try:
66                 os.stat(d + "/" + name)
67                 return d + "/" + name
68             except OSError, e:
69                 if e.errno != os.errno.ENOENT:
70                     raise
71     return None
72
73 def find_bin_or_die(name, extra_path = None):
74     r = find_bin(name)
75     if not r:
76         raise RuntimeError(("Cannot find `%s' command, impossible to " +
77                 "continue.") % name)
78     return r
79
80 # SSH stuff
81
82 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
83 def gen_ssh_keypair(filename):
84     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
85     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
86     assert subprocess.Popen(args).wait() == 0
87     return filename, "%s.pub" % filename
88
89 def add_key_to_agent(filename):
90     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
91     args = [ssh_add, filename]
92     null = file("/dev/null", "w")
93     assert subprocess.Popen(args, stderr = null).wait() == 0
94     null.close()
95
96 def get_free_port():
97     s = socket.socket()
98     s.bind(("127.0.0.1", 0))
99     port = s.getsockname()[1]
100     return port
101
102 _SSH_CONF = """ListenAddress 127.0.0.1:%d
103 Protocol 2
104 HostKey %s
105 UsePrivilegeSeparation no
106 PubkeyAuthentication yes
107 PasswordAuthentication no
108 AuthorizedKeysFile %s
109 UsePAM no
110 AllowAgentForwarding yes
111 PermitRootLogin yes
112 StrictModes no
113 PermitUserEnvironment yes
114 """
115
116 def gen_sshd_config(filename, port, server_key, auth_keys):
117     conf = open(filename, "w")
118     text = _SSH_CONF % (port, server_key, auth_keys)
119     conf.write(text)
120     conf.close()
121     return filename
122
123 def gen_auth_keys(pubkey, output, environ):
124     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
125     opts = []
126     for k, v in environ.items():
127         opts.append('environment="%s=%s"' % (k, v))
128
129     lines = file(pubkey).readlines()
130     pubkey = lines[0].split()[0:2]
131     out = file(output, "w")
132     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
133     out.close()
134     return output
135
136 def start_ssh_agent():
137     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
138     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
139     (out, foo) = proc.communicate()
140     assert proc.returncode == 0
141     d = {}
142     for l in out.split("\n"):
143         match = re.search("^(\w+)=([^ ;]+);.*", l)
144         if not match:
145             continue
146         k, v = match.groups()
147         os.environ[k] = v
148         d[k] = v
149     return d
150
151 def stop_ssh_agent(data):
152     # No need to gather the pid, ssh-agent knows how to kill itself; after we
153     # had set up the environment
154     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
155     null = file("/dev/null", "w")
156     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
157     null.close()
158     assert proc.wait() == 0
159     for k in data:
160         del os.environ[k]
161
162 class test_environment(object):
163     def __init__(self):
164         sshd = find_bin_or_die("sshd")
165         environ = {}
166         if 'PYTHONPATH' in os.environ:
167             environ['PYTHONPATH'] = ":".join(map(os.path.realpath, 
168                 os.environ['PYTHONPATH'].split(":")))
169         if 'NEPI_NS3BINDINGS' in os.environ:
170             environ['NEPI_NS3BINDINGS'] = \
171                     os.path.realpath(os.environ['NEPI_NS3BINDINGS'])
172         if 'NEPI_NS3LIBRARY' in os.environ:
173             environ['NEPI_NS3LIBRARY'] = \
174                     os.path.realpath(os.environ['NEPI_NS3LIBRARY'])
175
176         self.dir = tempfile.mkdtemp()
177         self.server_keypair = gen_ssh_keypair(
178                 os.path.join(self.dir, "server_key"))
179         self.client_keypair = gen_ssh_keypair(
180                 os.path.join(self.dir, "client_key"))
181         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
182                 os.path.join(self.dir, "authorized_keys"), environ)
183         self.port = get_free_port()
184         self.sshd_conf = gen_sshd_config(
185                 os.path.join(self.dir, "sshd_config"),
186                 self.port, self.server_keypair[0], self.authorized_keys)
187
188         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
189         self.ssh_agent_vars = start_ssh_agent()
190         add_key_to_agent(self.client_keypair[0])
191
192     def __del__(self):
193         if self.sshd:
194             os.kill(self.sshd.pid, signal.SIGTERM)
195             self.sshd.wait()
196         if self.ssh_agent_vars:
197             stop_ssh_agent(self.ssh_agent_vars)
198         shutil.rmtree(self.dir)
199