applied the except and raise fixers to the master branch to close the gap with py3
[nepi.git] / src / nepi / util / environ.py
index 8a2a90d..e67249a 100644 (file)
@@ -1,11 +1,53 @@
-# vim:ts=4:sw=4:et:ai:sts=4
+#
+#    NEPI, a framework to manage network experiments
+#    Copyright (C) 2013 INRIA
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Author: Alina Quereilhac <alina.quereilhac@inria.fr>
+#         Martin Ferrari <martin.ferrari@inria.fr>
 
-import os, subprocess, os.path
+
+
+import ctypes
+import imp
+import sys
+
+import os, os.path, re, signal, shutil, socket, subprocess, tempfile
 
 __all__ =  ["python", "ssh_path"]
 __all__ += ["rsh", "tcpdump_path", "sshd_path"]
 __all__ += ["execute", "backticks"]
 
+
+# Unittest from Python 2.6 doesn't have these decorators
+def _bannerwrap(f, text):
+    name = f.__name__
+    def banner(*args, **kwargs):
+        sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
+                (name, text))
+        return None
+    return banner
+
+def skip(text):
+    return lambda f: _bannerwrap(f, text)
+
+def skipUnless(cond, text):
+    return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
+
+def skipIf(cond, text):
+    return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
+
 def find_bin(name, extra_path = None):
     search = []
     if "PATH" in os.environ:
@@ -20,7 +62,7 @@ def find_bin(name, extra_path = None):
             try:
                 os.stat(d + "/" + name)
                 return d + "/" + name
-            except OSError, e:
+            except OSError as e:
                 if e.errno != os.errno.ENOENT:
                     raise
     return None
@@ -32,6 +74,25 @@ def find_bin_or_die(name, extra_path = None):
                 "continue.") % name)
     return r
 
+def find_bin(name, extra_path = None):
+    search = []
+    if "PATH" in os.environ:
+        search += os.environ["PATH"].split(":")
+    for pref in ("/", "/usr/", "/usr/local/"):
+        for d in ("bin", "sbin"):
+            search.append(pref + d)
+    if extra_path:
+        search += extra_path
+
+    for d in search:
+            try:
+                os.stat(d + "/" + name)
+                return d + "/" + name
+            except OSError as e:
+                if e.errno != os.errno.ENOENT:
+                    raise
+    return None
+
 ssh_path = find_bin_or_die("ssh")
 python_path = find_bin_or_die("python")
 
@@ -43,11 +104,11 @@ sshd_path = find_bin("sshd")
 def execute(cmd):
     # FIXME: create a global debug variable
     #print "[pid %d]" % os.getpid(), " ".join(cmd)
-    null = open("/dev/null", "r+")
-    p = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
-    out, err = p.communicate()
-    if p.returncode != 0:
-        raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
+    with open("/dev/null", "r+") as null:
+        p = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
+        out, err = p.communicate()
+        if p.returncode != 0:
+            raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
 
 def backticks(cmd):
     p = subprocess.Popen(cmd, stdout = subprocess.PIPE,
@@ -57,41 +118,82 @@ def backticks(cmd):
         raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
     return out
 
-def homepath(path, app='.nepi', mode = 0500):
-    home = os.environ.get('HOME')
-    if home is None:
-        home = os.path.join(os.sep, 'home', os.getlogin())
-    
-    path = os.path.join(home, app, path)
-    dirname = os.path.dirname(path)
-    if not os.path.exists(dirname):
-        os.makedirs(dirname)
-    
-    return path
-
-def find_testbed(testbed_id):
-    mod_name = None
-    
-    # look for environment-specified testbeds
-    if 'NEPI_TESTBEDS' in os.environ:
-        try:
-            # parse testbed map
-            #   split space-separated items, filter empty items
-            testbed_map = filter(bool,os.environ['NEPI_TESTBEDS'].strip().split(' '))
-            #   split items, keep pairs only, build map
-            testbed_map = dict([map(str.strip,i.split(':',1)) for i in testbed_map if ':' in i])
-        except:
-            import traceback, sys
-            traceback.print_exc(file=sys.stderr)
-            
-            # ignore malformed environment
-            testbed_map = {}
-        
-        mod_name = testbed_map.get(testbed_id)
-    
-    if mod_name is None:
-        # no explicit map, load built-in testbeds
-        mod_name = "nepi.testbeds.%s" % (testbed_id.lower())
-
-    return mod_name
+
+# SSH stuff
+
+def gen_ssh_keypair(filename):
+    ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
+    args = [ssh_keygen, '-q', '-N', '', '-f', filename]
+    assert subprocess.Popen(args).wait() == 0
+    return filename, "%s.pub" % filename
+
+def add_key_to_agent(filename):
+    ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
+    args = [ssh_add, filename]
+    with open("/dev/null", "w") as null:
+        assert subprocess.Popen(args, stderr = null).wait() == 0
+
+def get_free_port():
+    s = socket.socket()
+    s.bind(("127.0.0.1", 0))
+    port = s.getsockname()[1]
+    return port
+
+_SSH_CONF = """ListenAddress 127.0.0.1:%d
+Protocol 2
+HostKey %s
+UsePrivilegeSeparation no
+PubkeyAuthentication yes
+PasswordAuthentication no
+AuthorizedKeysFile %s
+UsePAM no
+AllowAgentForwarding yes
+PermitRootLogin yes
+StrictModes no
+PermitUserEnvironment yes
+"""
+
+def gen_sshd_config(filename, port, server_key, auth_keys):
+    with open(filename, "w") as conf:
+        text = _SSH_CONF % (port, server_key, auth_keys)
+        conf.write(text)
+    return filename
+
+def gen_auth_keys(pubkey, output, environ):
+    #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
+    opts = []
+    for k, v in environ.items():
+        opts.append('environment="%s=%s"' % (k, v))
+
+    with open(pubkey) as f:
+        lines = f.readlines()
+    pubkey = lines[0].split()[0:2]
+    with open(output, "w") as out:
+        out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
+    return output
+
+def start_ssh_agent():
+    ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
+    proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
+    (out, foo) = proc.communicate()
+    assert proc.returncode == 0
+    d = {}
+    for l in out.split("\n"):
+        match = re.search("^(\w+)=([^ ;]+);.*", l)
+        if not match:
+            continue
+        k, v = match.groups()
+        os.environ[k] = v
+        d[k] = v
+    return d
+
+def stop_ssh_agent(data):
+    # No need to gather the pid, ssh-agent knows how to kill itself; after we
+    # had set up the environment
+    ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
+    with open("/dev/null", "w") as null:
+        proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
+        assert proc.wait() == 0
+    for k in data:
+        del os.environ[k]