Make servers able to launch when a stale ctrl.sock from a pervious server remains.
[nepi.git] / test / lib / test_util.py
1 #!/usr/bin/env python
2
3 import nepi.util.environ
4 import ctypes
5 import imp
6 import sys
7
8 # Unittest from Python 2.6 doesn't have these decorators
9 def _bannerwrap(f, text):
10     name = f.__name__
11     def banner(*args, **kwargs):
12         sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
13                 (name, text))
14         return None
15     return banner
16 def skip(text):
17     return lambda f: _bannerwrap(f, text)
18 def skipUnless(cond, text):
19     return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
20 def skipIf(cond, text):
21     return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
22
23 def ns3_bindings_path():
24     if "NEPI_NS3BINDINGS" in os.environ:
25         return os.environ["NEPI_NS3BINDINGS"]
26     return None
27
28 def ns3_library_path():
29     if "NEPI_NS3LIBRARY" in os.environ:
30         return os.environ["NEPI_NS3LIBRARY"]
31     return None
32
33 def ns3_usable():
34     try:
35         from nepi.testbeds.ns3 import execute
36         execute.load_ns3_module()
37     except:
38         import traceback
39         import sys
40         traceback.print_exc(file = sys.stderr)
41         return False
42     return True
43
44 def pl_auth():
45     user = os.environ.get('PL_USER')
46     pwd = os.environ.get('PL_PASS')
47      
48     if user and pwd:
49         return (user,pwd)
50     else:
51         return None
52
53 def find_bin(name, extra_path = None):
54     search = []
55     if "PATH" in os.environ:
56         search += os.environ["PATH"].split(":")
57     for pref in ("/", "/usr/", "/usr/local/"):
58         for d in ("bin", "sbin"):
59             search.append(pref + d)
60     if extra_path:
61         search += extra_path
62
63     for d in search:
64             try:
65                 os.stat(d + "/" + name)
66                 return d + "/" + name
67             except OSError, e:
68                 if e.errno != os.errno.ENOENT:
69                     raise
70     return None
71
72 def find_bin_or_die(name, extra_path = None):
73     r = find_bin(name)
74     if not r:
75         raise RuntimeError(("Cannot find `%s' command, impossible to " +
76                 "continue.") % name)
77     return r
78
79 # SSH stuff
80
81 import os, os.path, re, signal, shutil, socket, subprocess, tempfile
82 def gen_ssh_keypair(filename):
83     ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
84     args = [ssh_keygen, '-q', '-N', '', '-f', filename]
85     assert subprocess.Popen(args).wait() == 0
86     return filename, "%s.pub" % filename
87
88 def add_key_to_agent(filename):
89     ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
90     args = [ssh_add, filename]
91     null = file("/dev/null", "w")
92     assert subprocess.Popen(args, stderr = null).wait() == 0
93     null.close()
94
95 def get_free_port():
96     s = socket.socket()
97     s.bind(("127.0.0.1", 0))
98     port = s.getsockname()[1]
99     return port
100
101 _SSH_CONF = """ListenAddress 127.0.0.1:%d
102 Protocol 2
103 HostKey %s
104 UsePrivilegeSeparation no
105 PubkeyAuthentication yes
106 PasswordAuthentication no
107 AuthorizedKeysFile %s
108 UsePAM no
109 AllowAgentForwarding yes
110 PermitRootLogin yes
111 StrictModes no
112 PermitUserEnvironment yes
113 """
114
115 def gen_sshd_config(filename, port, server_key, auth_keys):
116     conf = open(filename, "w")
117     text = _SSH_CONF % (port, server_key, auth_keys)
118     conf.write(text)
119     conf.close()
120     return filename
121
122 def gen_auth_keys(pubkey, output, environ):
123     #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
124     opts = []
125     for k, v in environ.items():
126         opts.append('environment="%s=%s"' % (k, v))
127
128     lines = file(pubkey).readlines()
129     pubkey = lines[0].split()[0:2]
130     out = file(output, "w")
131     out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
132     out.close()
133     return output
134
135 def start_ssh_agent():
136     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
137     proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
138     (out, foo) = proc.communicate()
139     assert proc.returncode == 0
140     d = {}
141     for l in out.split("\n"):
142         match = re.search("^(\w+)=([^ ;]+);.*", l)
143         if not match:
144             continue
145         k, v = match.groups()
146         os.environ[k] = v
147         d[k] = v
148     return d
149
150 def stop_ssh_agent(data):
151     # No need to gather the pid, ssh-agent knows how to kill itself; after we
152     # had set up the environment
153     ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
154     null = file("/dev/null", "w")
155     proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
156     null.close()
157     assert proc.wait() == 0
158     for k in data:
159         del os.environ[k]
160
161 class test_environment(object):
162     def __init__(self):
163         sshd = find_bin_or_die("sshd")
164         environ = {}
165         if 'PYTHONPATH' in os.environ:
166             environ['PYTHONPATH'] = ":".join(map(os.path.realpath, 
167                 os.environ['PYTHONPATH'].split(":")))
168         if 'NEPI_NS3BINDINGS' in os.environ:
169             environ['NEPI_NS3BINDINGS'] = \
170                     os.path.realpath(os.environ['NEPI_NS3BINDINGS'])
171         if 'NEPI_NS3LIBRARY' in os.environ:
172             environ['NEPI_NS3LIBRARY'] = \
173                     os.path.realpath(os.environ['NEPI_NS3LIBRARY'])
174
175         self.dir = tempfile.mkdtemp()
176         self.server_keypair = gen_ssh_keypair(
177                 os.path.join(self.dir, "server_key"))
178         self.client_keypair = gen_ssh_keypair(
179                 os.path.join(self.dir, "client_key"))
180         self.authorized_keys = gen_auth_keys(self.client_keypair[1],
181                 os.path.join(self.dir, "authorized_keys"), environ)
182         self.port = get_free_port()
183         self.sshd_conf = gen_sshd_config(
184                 os.path.join(self.dir, "sshd_config"),
185                 self.port, self.server_keypair[0], self.authorized_keys)
186
187         self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf])
188         self.ssh_agent_vars = start_ssh_agent()
189         add_key_to_agent(self.client_keypair[0])
190
191     def __del__(self):
192         if self.sshd:
193             os.kill(self.sshd.pid, signal.SIGTERM)
194             self.sshd.wait()
195         if self.ssh_agent_vars:
196             stop_ssh_agent(self.ssh_agent_vars)
197         shutil.rmtree(self.dir)
198