applied the except and raise fixers to the master branch to close the gap with py3
[nepi.git] / src / nepi / util / sshfuncs.py
index b06abdb..1899de2 100644 (file)
@@ -3,9 +3,8 @@
 #    Copyright (C) 2013 INRIA
 #
 #    This program is free software: you can redistribute it and/or modify
 #    Copyright (C) 2013 INRIA
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -35,15 +34,15 @@ import threading
 import time
 import tempfile
 
 import time
 import tempfile
 
+_re_inet = re.compile("\d+:\s+(?P<name>[a-z0-9_-]+)\s+inet6?\s+(?P<inet>[a-f0-9.:/]+)\s+(brd\s+[0-9.]+)?.*scope\s+global.*") 
+
 logger = logging.getLogger("sshfuncs")
 
 logger = logging.getLogger("sshfuncs")
 
-def log(msg, level, out = None, err = None):
+def log(msg, level = logging.DEBUG, out = None, err = None):
     if out:
         msg += " - OUT: %s " % out
     if out:
         msg += " - OUT: %s " % out
-
     if err:
         msg += " - ERROR: %s " % err
     if err:
         msg += " - ERROR: %s " % err
-
     logger.log(level, msg)
 
 if hasattr(os, "devnull"):
     logger.log(level, msg)
 
 if hasattr(os, "devnull"):
@@ -58,6 +57,7 @@ class STDOUT:
     Special value that when given to rspawn in stderr causes stderr to 
     redirect to whatever stdout was redirected to.
     """
     Special value that when given to rspawn in stderr causes stderr to 
     redirect to whatever stdout was redirected to.
     """
+    pass
 
 class ProcStatus:
     """
 
 class ProcStatus:
     """
@@ -75,6 +75,24 @@ class ProcStatus:
 hostbyname_cache = dict()
 hostbyname_cache_lock = threading.Lock()
 
 hostbyname_cache = dict()
 hostbyname_cache_lock = threading.Lock()
 
+def resolve_hostname(host):
+    ip = None
+
+    if host in ["localhost", "127.0.0.1", "::1"]:
+        p = subprocess.Popen(
+            "ip -o addr list",
+            shell=True,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
+        stdout, stderr = p.communicate()
+        m = _re_inet.findall(stdout)
+        ip = m[0][1].split("/")[0]
+    else:
+        ip = socket.gethostbyname(host)
+
+    return ip
+
 def gethostbyname(host):
     global hostbyname_cache
     global hostbyname_cache_lock
 def gethostbyname(host):
     global hostbyname_cache
     global hostbyname_cache_lock
@@ -82,7 +100,7 @@ def gethostbyname(host):
     hostbyname = hostbyname_cache.get(host)
     if not hostbyname:
         with hostbyname_cache_lock:
     hostbyname = hostbyname_cache.get(host)
     if not hostbyname:
         with hostbyname_cache_lock:
-            hostbyname = socket.gethostbyname(host)
+            hostbyname = resolve_hostname(host)
             hostbyname_cache[host] = hostbyname
 
             msg = " Added hostbyname %s - %s " % (host, hostbyname)
             hostbyname_cache[host] = hostbyname
 
             msg = " Added hostbyname %s - %s " % (host, hostbyname)
@@ -102,12 +120,15 @@ def openssh_has_persist():
     """
     global OPENSSH_HAS_PERSIST
     if OPENSSH_HAS_PERSIST is None:
     """
     global OPENSSH_HAS_PERSIST
     if OPENSSH_HAS_PERSIST is None:
-        proc = subprocess.Popen(["ssh","-v"],
-            stdout = subprocess.PIPE,
-            stderr = subprocess.STDOUT,
-            stdin = open("/dev/null","r") )
-        out,err = proc.communicate()
-        proc.wait()
+        with open("/dev/null") as null:
+            proc = subprocess.Popen(
+                ["ssh", "-v"],
+                stdout = subprocess.PIPE,
+                stderr = subprocess.STDOUT,
+                stdin = null,
+            )
+            out,err = proc.communicate()
+            proc.wait()
         
         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
         
         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
@@ -145,9 +166,8 @@ def make_server_key_args(server_key, host, port):
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
         if os.access(user_hosts_path, os.R_OK):
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
         if os.access(user_hosts_path, os.R_OK):
-            f = open(user_hosts_path, "r")
-            tmp_known_hosts.write(f.read())
-            f.close()
+            with open(user_hosts_path, "r") as f:
+                tmp_known_hosts.write(f.read())
         
     tmp_known_hosts.flush()
     
         
     tmp_known_hosts.flush()
     
@@ -174,12 +194,12 @@ def shell_escape(s):
         return s
     else:
         # unsafe string - escape
         return s
     else:
         # unsafe string - escape
-        def escp(c):
+        def escape(c):
             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
-        s = ''.join(map(escp,s))
+        s = ''.join(map(escape, s))
         return "'%s'" % (s,)
 
 def eintr_retry(func):
         return "'%s'" % (s,)
 
 def eintr_retry(func):
@@ -191,12 +211,12 @@ def eintr_retry(func):
         for i in xrange(0 if retry else 4):
             try:
                 return func(*p, **kw)
         for i in xrange(0 if retry else 4):
             try:
                 return func(*p, **kw)
-            except (select.error, socket.error), args:
+            except (select.error, socket.error) as args:
                 if args[0] == errno.EINTR:
                     continue
                 else:
                     raise 
                 if args[0] == errno.EINTR:
                     continue
                 else:
                     raise 
-            except OSError, e:
+            except OSError as e:
                 if e.errno == errno.EINTR:
                     continue
                 else:
                 if e.errno == errno.EINTR:
                     continue
                 else:
@@ -251,10 +271,7 @@ def rexec(command, host, user,
         args.extend(['-o', 'StrictHostKeyChecking=no'])
 
     if gw:
         args.extend(['-o', 'StrictHostKeyChecking=no'])
 
     if gw:
-        if gwuser:
-            proxycommand = 'ProxyCommand=ssh -q %s@%s -W %%h:%%p' % (gwuser, gw)
-        else:
-            proxycommand = 'ProxyCommand=ssh -q %%r@%s -W %%h:%%p' % gw
+        proxycommand = _proxy_command(gw, gwuser, identity)
         args.extend(['-o', proxycommand])
 
     if agent:
         args.extend(['-o', proxycommand])
 
     if agent:
@@ -291,13 +308,13 @@ def rexec(command, host, user,
         stdout = stderr = stdin = None
 
     return _retry_rexec(args, log_msg, 
         stdout = stderr = stdin = None
 
     return _retry_rexec(args, log_msg, 
-            stderr = stderr,
-            stdin = stdin,
-            stdout = stdout,
-            env = env, 
-            retry = retry, 
-            tmp_known_hosts = tmp_known_hosts,
-            blocking = blocking)
+                        stderr = stderr,
+                        stdin = stdin,
+                        stdout = stdout,
+                        env = env, 
+                        retry = retry, 
+                        tmp_known_hosts = tmp_known_hosts,
+                        blocking = blocking)
 
 def rcopy(source, dest,
         port = None,
 
 def rcopy(source, dest,
         port = None,
@@ -325,16 +342,20 @@ def rcopy(source, dest,
     elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
     elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
-        raise ValueError, "Both endpoints cannot be local"
+        raise ValueError("Both endpoints cannot be local")
     user,host = remspec.rsplit('@',1)
     
     # plain scp
     tmp_known_hosts = None
 
     args = ['scp', '-q', '-p', '-C',
     user,host = remspec.rsplit('@',1)
     
     # plain scp
     tmp_known_hosts = None
 
     args = ['scp', '-q', '-p', '-C',
+            # 2015-06-01 Thierry: I am commenting off blowfish
+            # as this is not available on a plain ubuntu 15.04 install
+            # this IMHO is too fragile, shoud be something the user
+            # decides explicitly (so he is at least aware of that dependency)
             # Speed up transfer using blowfish cypher specification which is 
             # faster than the default one (3des)
             # Speed up transfer using blowfish cypher specification which is 
             # faster than the default one (3des)
-            '-c', 'blowfish',
+            '-c', 'blowfish',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
             '-o', 'ConnectTimeout=60',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
             '-o', 'ConnectTimeout=60',
@@ -346,10 +367,7 @@ def rcopy(source, dest,
         args.append('-P%d' % port)
 
     if gw:
         args.append('-P%d' % port)
 
     if gw:
-        if gwuser:
-            proxycommand = 'ProxyCommand=ssh -q %s@%s -W %%h:%%p' % (gwuser, gw)
-        else:
-            proxycommand = 'ProxyCommand=ssh -q %%r@%s -W %%h:%%p' % gw
+        proxycommand = _proxy_command(gw, gwuser, identity)
         args.extend(['-o', proxycommand])
 
     if recursive:
         args.extend(['-o', proxycommand])
 
     if recursive:
@@ -390,21 +408,22 @@ def rcopy(source, dest,
             blocking = True)
 
 def rspawn(command, pidfile, 
             blocking = True)
 
 def rspawn(command, pidfile, 
-        stdout = '/dev/null', 
-        stderr = STDOUT, 
-        stdin = '/dev/null',
-        home = None, 
-        create_home = False, 
-        sudo = False,
-        host = None, 
-        port = None, 
-        user = None, 
-        gwuser = None,
-        gw = None,
-        agent = None, 
-        identity = None, 
-        server_key = None,
-        tty = False):
+           stdout = '/dev/null', 
+           stderr = STDOUT, 
+           stdin = '/dev/null',
+           home = None, 
+           create_home = False, 
+           sudo = False,
+           host = None, 
+           port = None, 
+           user = None, 
+           gwuser = None,
+           gw = None,
+           agent = None, 
+           identity = None, 
+           server_key = None,
+           tty = False,
+           strict_host_checking = True):
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
@@ -439,7 +458,7 @@ def rspawn(command, pidfile,
         :param sudo: Flag forcing execution with sudo user
         :type sudo: bool
         
         :param sudo: Flag forcing execution with sudo user
         :type sudo: bool
         
-        :rtype: touple
+        :rtype: tuple
 
         (stdout, stderr), process
         
 
         (stdout, stderr), process
         
@@ -479,24 +498,26 @@ def rspawn(command, pidfile,
         agent = agent,
         identity = identity,
         server_key = server_key,
         agent = agent,
         identity = identity,
         server_key = server_key,
-        tty = tty ,
+        tty = tty,
+        strict_host_checking = strict_host_checking ,
         )
     
     if proc.wait():
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
+        raise RuntimeError("Failed to set up application on host %s: %s %s" % (host, out,err,))
 
     return ((out, err), proc)
 
 @eintr_retry
 def rgetpid(pidfile,
 
     return ((out, err), proc)
 
 @eintr_retry
 def rgetpid(pidfile,
-        host = None, 
-        port = None, 
-        user = None, 
-        gwuser = None,
-        gw = None,
-        agent = None, 
-        identity = None,
-        server_key = None):
+            host = None, 
+            port = None, 
+            user = None, 
+            gwuser = None,
+            gw = None,
+            agent = None, 
+            identity = None,
+            server_key = None,
+            strict_host_checking = True):
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
@@ -524,7 +545,8 @@ def rgetpid(pidfile,
         gw = gw,
         agent = agent,
         identity = identity,
         gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
         
     if proc.wait():
         )
         
     if proc.wait():
@@ -546,7 +568,8 @@ def rstatus(pid, ppid,
         gw = None,
         agent = None, 
         identity = None,
         gw = None,
         agent = None, 
         identity = None,
-        server_key = None):
+        server_key = None,
+        strict_host_checking = True):
     """
     Returns a code representing the the status of a remote process
 
     """
     Returns a code representing the the status of a remote process
 
@@ -572,7 +595,8 @@ def rstatus(pid, ppid,
         gw = gw,
         agent = agent,
         identity = identity,
         gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     if proc.wait():
         )
     
     if proc.wait():
@@ -599,7 +623,8 @@ def rkill(pid, ppid,
         sudo = False,
         identity = None, 
         server_key = None, 
         sudo = False,
         identity = None, 
         server_key = None, 
-        nowait = False):
+        nowait = False,
+        strict_host_checking = True):
     """
     Sends a kill signal to a remote process.
 
     """
     Sends a kill signal to a remote process.
 
@@ -653,7 +678,8 @@ fi
         gw = gw,
         agent = agent,
         identity = identity,
         gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     # wait, don't leave zombies around
         )
     
     # wait, don't leave zombies around
@@ -662,23 +688,27 @@ fi
     return (out, err), proc
 
 def _retry_rexec(args,
     return (out, err), proc
 
 def _retry_rexec(args,
-        log_msg,
-        stdout = subprocess.PIPE,
-        stdin = subprocess.PIPE, 
-        stderr = subprocess.PIPE,
-        env = None,
-        retry = 3,
-        tmp_known_hosts = None,
-        blocking = True):
+                 log_msg,
+                 stdout = subprocess.PIPE,
+                 stdin = subprocess.PIPE, 
+                 stderr = subprocess.PIPE,
+                 env = None,
+                 retry = 3,
+                 tmp_known_hosts = None,
+                 blocking = True):
 
     for x in xrange(retry):
 
     for x in xrange(retry):
+        # display command actually invoked when debug is turned on
+        message = " ".join( [ "'{}'".format(arg) for arg in args ] )
+        log("sshfuncs: invoking {}".format(message), logging.DEBUG)
         # connects to the remote host and starts a remote connection
         # connects to the remote host and starts a remote connection
-        proc = subprocess.Popen(args,
-                env = env,
-                stdout = stdout,
-                stdin = stdin, 
-                stderr = stderr)
-        
+        proc = subprocess.Popen(
+            args,
+            env = env,
+            stdout = stdout,
+            stdin = stdin, 
+            stderr = stderr,
+        )        
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
@@ -720,7 +750,7 @@ def _retry_rexec(args,
                     time.sleep(t)
                     continue
             break
                     time.sleep(t)
                     continue
             break
-        except RuntimeError, e:
+        except RuntimeError as e:
             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
             log(msg, logging.DEBUG, out, err)
 
             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
             log(msg, logging.DEBUG, out, err)
 
@@ -786,7 +816,7 @@ def _communicate(proc, input, timeout=None, err_on_timeout=True):
 
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
 
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
-        except select.error,e:
+        except select.error as e:
             if e[0] != 4:
                 raise
             else:
             if e[0] != 4:
                 raise
             else:
@@ -840,7 +870,7 @@ def _communicate(proc, input, timeout=None, err_on_timeout=True):
 
     if killed and err_on_timeout:
         errcode = proc.poll()
 
     if killed and err_on_timeout:
         errcode = proc.poll()
-        raise RuntimeError("Operation timed out", errcode, stdout, stderr)
+        raise RuntimeError("Operation timed out", errcode, stdout, stderr)
     else:
         if killed:
             proc.poll()
     else:
         if killed:
             proc.poll()
@@ -848,4 +878,33 @@ def _communicate(proc, input, timeout=None, err_on_timeout=True):
             proc.wait()
         return (stdout, stderr)
 
             proc.wait()
         return (stdout, stderr)
 
+def _proxy_command(gw, gwuser, gwidentity):
+    """
+    Constructs the SSH ProxyCommand option to add to the SSH command to connect
+    via a proxy
+        :param gw: SSH proxy hostname
+        :type gw:  str 
+       
+        :param gwuser: SSH proxy username
+        :type gwuser:  str
+
+        :param gwidentity: SSH proxy identity file 
+        :type gwidentity:  str
+
+  
+        :rtype: str 
+        
+        returns the SSH ProxyCommand option.
+    """
+
+    proxycommand = 'ProxyCommand=ssh -q '
+    if gwidentity:
+        proxycommand += '-i %s ' % os.path.expanduser(gwidentity)
+    if gwuser:
+        proxycommand += '%s' % gwuser
+    else:
+        proxycommand += '%r'
+    proxycommand += '@%s -W %%h:%%p' % gw
+
+    return proxycommand