applied the except and raise fixers to the master branch to close the gap with py3
[nepi.git] / src / nepi / util / sshfuncs.py
index 81bf5f9..1899de2 100644 (file)
@@ -3,9 +3,8 @@
 #    Copyright (C) 2013 INRIA
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -35,18 +34,17 @@ import threading
 import time
 import tempfile
 
+_re_inet = re.compile("\d+:\s+(?P<name>[a-z0-9_-]+)\s+inet6?\s+(?P<inet>[a-f0-9.:/]+)\s+(brd\s+[0-9.]+)?.*scope\s+global.*") 
+
 logger = logging.getLogger("sshfuncs")
 
-def log(msg, level, out = None, err = None):
+def log(msg, level = logging.DEBUG, out = None, err = None):
     if out:
         msg += " - OUT: %s " % out
-
     if err:
         msg += " - ERROR: %s " % err
-
     logger.log(level, msg)
 
-
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
 else:
@@ -59,6 +57,7 @@ class STDOUT:
     Special value that when given to rspawn in stderr causes stderr to 
     redirect to whatever stdout was redirected to.
     """
+    pass
 
 class ProcStatus:
     """
@@ -76,6 +75,24 @@ class ProcStatus:
 hostbyname_cache = dict()
 hostbyname_cache_lock = threading.Lock()
 
+def resolve_hostname(host):
+    ip = None
+
+    if host in ["localhost", "127.0.0.1", "::1"]:
+        p = subprocess.Popen(
+            "ip -o addr list",
+            shell=True,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
+        stdout, stderr = p.communicate()
+        m = _re_inet.findall(stdout)
+        ip = m[0][1].split("/")[0]
+    else:
+        ip = socket.gethostbyname(host)
+
+    return ip
+
 def gethostbyname(host):
     global hostbyname_cache
     global hostbyname_cache_lock
@@ -83,7 +100,7 @@ def gethostbyname(host):
     hostbyname = hostbyname_cache.get(host)
     if not hostbyname:
         with hostbyname_cache_lock:
-            hostbyname = socket.gethostbyname(host)
+            hostbyname = resolve_hostname(host)
             hostbyname_cache[host] = hostbyname
 
             msg = " Added hostbyname %s - %s " % (host, hostbyname)
@@ -103,12 +120,15 @@ def openssh_has_persist():
     """
     global OPENSSH_HAS_PERSIST
     if OPENSSH_HAS_PERSIST is None:
-        proc = subprocess.Popen(["ssh","-v"],
-            stdout = subprocess.PIPE,
-            stderr = subprocess.STDOUT,
-            stdin = open("/dev/null","r") )
-        out,err = proc.communicate()
-        proc.wait()
+        with open("/dev/null") as null:
+            proc = subprocess.Popen(
+                ["ssh", "-v"],
+                stdout = subprocess.PIPE,
+                stderr = subprocess.STDOUT,
+                stdin = null,
+            )
+            out,err = proc.communicate()
+            proc.wait()
         
         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
@@ -146,9 +166,8 @@ def make_server_key_args(server_key, host, port):
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
         if os.access(user_hosts_path, os.R_OK):
-            f = open(user_hosts_path, "r")
-            tmp_known_hosts.write(f.read())
-            f.close()
+            with open(user_hosts_path, "r") as f:
+                tmp_known_hosts.write(f.read())
         
     tmp_known_hosts.flush()
     
@@ -175,12 +194,12 @@ def shell_escape(s):
         return s
     else:
         # unsafe string - escape
-        def escp(c):
+        def escape(c):
             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
-        s = ''.join(map(escp,s))
+        s = ''.join(map(escape, s))
         return "'%s'" % (s,)
 
 def eintr_retry(func):
@@ -192,12 +211,12 @@ def eintr_retry(func):
         for i in xrange(0 if retry else 4):
             try:
                 return func(*p, **kw)
-            except (select.error, socket.error), args:
+            except (select.error, socket.error) as args:
                 if args[0] == errno.EINTR:
                     continue
                 else:
                     raise 
-            except OSError, e:
+            except OSError as e:
                 if e.errno == errno.EINTR:
                     continue
                 else:
@@ -207,7 +226,9 @@ def eintr_retry(func):
     return rv
 
 def rexec(command, host, user, 
-        port = None, 
+        port = None,
+        gwuser = None,
+        gw = None, 
         agent = True,
         sudo = False,
         identity = None,
@@ -223,9 +244,11 @@ def rexec(command, host, user,
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    
+
     tmp_known_hosts = None
-    hostip = gethostbyname(host)
+    if not gw:
+        hostip = gethostbyname(host)
+    else: hostip = None
 
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
@@ -234,6 +257,7 @@ def rexec(command, host, user,
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
+            '-o', 'Batchmode=yes',
             '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
@@ -246,6 +270,10 @@ def rexec(command, host, user,
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
 
+    if gw:
+        proxycommand = _proxy_command(gw, gwuser, identity)
+        args.extend(['-o', proxycommand])
+
     if agent:
         args.append('-A')
 
@@ -253,6 +281,7 @@ def rexec(command, host, user,
         args.append('-p%d' % port)
 
     if identity:
+        identity = os.path.expanduser(identity)
         args.extend(('-i', identity))
 
     if tty:
@@ -271,25 +300,26 @@ def rexec(command, host, user,
         command = "sudo " + command
 
     args.append(command)
-    
-    log_msg = " rexec - host %s - command %s " % (host, " ".join(args))
+
+    log_msg = " rexec - host %s - command %s " % (str(host), " ".join(map(str, args))) 
 
     stdout = stderr = stdin = subprocess.PIPE
     if forward_x11:
         stdout = stderr = stdin = None
 
     return _retry_rexec(args, log_msg, 
-            stderr = stderr,
-            stdin = stdin,
-            stdout = stdout,
-            env = env, 
-            retry = retry, 
-            tmp_known_hosts = tmp_known_hosts,
-            blocking = blocking)
+                        stderr = stderr,
+                        stdin = stdin,
+                        stdout = stdout,
+                        env = env, 
+                        retry = retry, 
+                        tmp_known_hosts = tmp_known_hosts,
+                        blocking = blocking)
 
 def rcopy(source, dest,
-        port = None, 
-        agent = True, 
+        port = None,
+        gwuser = None,
+        gw = None,
         recursive = False,
         identity = None,
         server_key = None,
@@ -301,26 +331,31 @@ def rcopy(source, dest,
     Source and destination should have the user and host encoded
     as per scp specs.
     
-    Source can be a list of files to copy to a single destination,
-    in which case it is advised that the destination be a folder.
+    Source can be a list of files to copy to a single destination, 
+    (in which case it is advised that the destination be a folder),
+    or a single file in a string.
     """
-    
+
     # Parse destination as <user>@<server>:<path>
-    if isinstance(dest, basestring) and ':' in dest:
+    if isinstance(dest, str) and ':' in dest:
         remspec, path = dest.split(':',1)
-    elif isinstance(source, basestring) and ':' in source:
+    elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
-        raise ValueError, "Both endpoints cannot be local"
+        raise ValueError("Both endpoints cannot be local")
     user,host = remspec.rsplit('@',1)
     
     # plain scp
     tmp_known_hosts = None
 
     args = ['scp', '-q', '-p', '-C',
+            # 2015-06-01 Thierry: I am commenting off blowfish
+            # as this is not available on a plain ubuntu 15.04 install
+            # this IMHO is too fragile, shoud be something the user
+            # decides explicitly (so he is at least aware of that dependency)
             # Speed up transfer using blowfish cypher specification which is 
             # faster than the default one (3des)
-            '-c', 'blowfish',
+            '-c', 'blowfish',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
             '-o', 'ConnectTimeout=60',
@@ -331,10 +366,15 @@ def rcopy(source, dest,
     if port:
         args.append('-P%d' % port)
 
+    if gw:
+        proxycommand = _proxy_command(gw, gwuser, identity)
+        args.extend(['-o', proxycommand])
+
     if recursive:
         args.append('-r')
 
     if identity:
+        identity = os.path.expanduser(identity)
         args.extend(('-i', identity))
 
     if server_key:
@@ -345,39 +385,45 @@ def rcopy(source, dest,
     if not strict_host_checking:
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
-
+    
     if isinstance(source, list):
         args.extend(source)
     else:
         if openssh_has_persist():
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                '-o', 'ControlPath=%s' % (make_control_path(False, False),)
                 ])
         args.append(source)
 
-    args.append(dest)
+    if isinstance(dest, list):
+        args.extend(dest)
+    else:
+        args.append(dest)
 
-    log_msg = " rcopy - host %s - command %s " % (host, " ".join(args))
+    log_msg = " rcopy - host %s - command %s " % (str(host), " ".join(map(str, args)))
     
     return _retry_rexec(args, log_msg, env = None, retry = retry, 
             tmp_known_hosts = tmp_known_hosts,
             blocking = True)
 
 def rspawn(command, pidfile, 
-        stdout = '/dev/null', 
-        stderr = STDOUT, 
-        stdin = '/dev/null',
-        home = None, 
-        create_home = False, 
-        sudo = False,
-        host = None, 
-        port = None, 
-        user = None, 
-        agent = None, 
-        identity = None, 
-        server_key = None,
-        tty = False):
+           stdout = '/dev/null', 
+           stderr = STDOUT, 
+           stdin = '/dev/null',
+           home = None, 
+           create_home = False, 
+           sudo = False,
+           host = None, 
+           port = None, 
+           user = None, 
+           gwuser = None,
+           gw = None,
+           agent = None, 
+           identity = None, 
+           server_key = None,
+           tty = False,
+           strict_host_checking = True):
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
@@ -412,7 +458,7 @@ def rspawn(command, pidfile,
         :param sudo: Flag forcing execution with sudo user
         :type sudo: bool
         
-        :rtype: touple
+        :rtype: tuple
 
         (stdout, stderr), process
         
@@ -447,25 +493,31 @@ def rspawn(command, pidfile,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
         server_key = server_key,
-        tty = tty ,
+        tty = tty,
+        strict_host_checking = strict_host_checking ,
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
+        raise RuntimeError("Failed to set up application on host %s: %s %s" % (host, out,err,))
 
     return ((out, err), proc)
 
 @eintr_retry
 def rgetpid(pidfile,
-        host = None, 
-        port = None, 
-        user = None, 
-        agent = None, 
-        identity = None,
-        server_key = None):
+            host = None, 
+            port = None, 
+            user = None, 
+            gwuser = None,
+            gw = None,
+            agent = None, 
+            identity = None,
+            server_key = None,
+            strict_host_checking = True):
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
@@ -489,9 +541,12 @@ def rgetpid(pidfile,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
         
     if proc.wait():
@@ -509,9 +564,12 @@ def rstatus(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         identity = None,
-        server_key = None):
+        server_key = None,
+        strict_host_checking = True):
     """
     Returns a code representing the the status of a remote process
 
@@ -533,9 +591,12 @@ def rstatus(pid, ppid,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     if proc.wait():
@@ -556,11 +617,14 @@ def rkill(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         sudo = False,
         identity = None, 
         server_key = None, 
-        nowait = False):
+        nowait = False,
+        strict_host_checking = True):
     """
     Sends a kill signal to a remote process.
 
@@ -610,9 +674,12 @@ fi
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     # wait, don't leave zombies around
@@ -621,23 +688,27 @@ fi
     return (out, err), proc
 
 def _retry_rexec(args,
-        log_msg,
-        stdout = subprocess.PIPE,
-        stdin = subprocess.PIPE, 
-        stderr = subprocess.PIPE,
-        env = None,
-        retry = 3,
-        tmp_known_hosts = None,
-        blocking = True):
+                 log_msg,
+                 stdout = subprocess.PIPE,
+                 stdin = subprocess.PIPE, 
+                 stderr = subprocess.PIPE,
+                 env = None,
+                 retry = 3,
+                 tmp_known_hosts = None,
+                 blocking = True):
 
     for x in xrange(retry):
+        # display command actually invoked when debug is turned on
+        message = " ".join( [ "'{}'".format(arg) for arg in args ] )
+        log("sshfuncs: invoking {}".format(message), logging.DEBUG)
         # connects to the remote host and starts a remote connection
-        proc = subprocess.Popen(args,
-                env = env,
-                stdout = stdout,
-                stdin = stdin, 
-                stderr = stderr)
-        
+        proc = subprocess.Popen(
+            args,
+            env = env,
+            stdout = stdout,
+            stdin = stdin, 
+            stderr = stderr,
+        )        
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
@@ -647,7 +718,12 @@ def _retry_rexec(args,
         try:
             err = out = " "
             if blocking:
-                (out, err) = proc.communicate()
+                #(out, err) = proc.communicate()
+                # The method communicate was re implemented for performance issues
+                # when using python subprocess communicate method the ssh commands 
+                # last one minute each
+                out, err = _communicate(proc, input=None)
+
             elif stdout:
                 out = proc.stdout.read()
                 if proc.poll() and stderr:
@@ -674,14 +750,161 @@ def _retry_rexec(args,
                     time.sleep(t)
                     continue
             break
-        except RuntimeError, e:
+        except RuntimeError as e:
             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
             log(msg, logging.DEBUG, out, err)
 
             if retry <= 0:
                 raise
             retry -= 1
-        
+
     return ((out, err), proc)
 
+# POSIX
+# Don't remove. The method communicate was re implemented for performance issues
+def _communicate(proc, input, timeout=None, err_on_timeout=True):
+    read_set = []
+    write_set = []
+    stdout = None # Return
+    stderr = None # Return
+
+    killed = False
+
+    if timeout is not None:
+        timelimit = time.time() + timeout
+        killtime = timelimit + 4
+        bailtime = timelimit + 4
+
+    if proc.stdin:
+        # Flush stdio buffer.  This might block, if the user has
+        # been writing to .stdin in an uncontrolled fashion.
+        proc.stdin.flush()
+        if input:
+            write_set.append(proc.stdin)
+        else:
+            proc.stdin.close()
+
+    if proc.stdout:
+        read_set.append(proc.stdout)
+        stdout = []
+
+    if proc.stderr:
+        read_set.append(proc.stderr)
+        stderr = []
+
+    input_offset = 0
+    while read_set or write_set:
+        if timeout is not None:
+            curtime = time.time()
+            if timeout is None or curtime > timelimit:
+                if curtime > bailtime:
+                    break
+                elif curtime > killtime:
+                    signum = signal.SIGKILL
+                else:
+                    signum = signal.SIGTERM
+                # Lets kill it
+                os.kill(proc.pid, signum)
+                select_timeout = 0.5
+            else:
+                select_timeout = timelimit - curtime + 0.1
+        else:
+            select_timeout = 1.0
+
+        if select_timeout > 1.0:
+            select_timeout = 1.0
+
+        try:
+            rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
+        except select.error as e:
+            if e[0] != 4:
+                raise
+            else:
+                continue
+
+        if not rlist and not wlist and not xlist and proc.poll() is not None:
+            # timeout and process exited, say bye
+            break
+
+        if proc.stdin in wlist:
+            # When select has indicated that the file is writable,
+            # we can write up to PIPE_BUF bytes without risk
+            # blocking.  POSIX defines PIPE_BUF >= 512
+            bytes_written = os.write(proc.stdin.fileno(),
+                    buffer(input, input_offset, 512))
+            input_offset += bytes_written
+
+            if input_offset >= len(input):
+                proc.stdin.close()
+                write_set.remove(proc.stdin)
+
+        if proc.stdout in rlist:
+            data = os.read(proc.stdout.fileno(), 1024)
+            if data == "":
+                proc.stdout.close()
+                read_set.remove(proc.stdout)
+            stdout.append(data)
+
+        if proc.stderr in rlist:
+            data = os.read(proc.stderr.fileno(), 1024)
+            if data == "":
+                proc.stderr.close()
+                read_set.remove(proc.stderr)
+            stderr.append(data)
+
+    # All data exchanged.  Translate lists into strings.
+    if stdout is not None:
+        stdout = ''.join(stdout)
+    if stderr is not None:
+        stderr = ''.join(stderr)
+
+    # Translate newlines, if requested.  We cannot let the file
+    # object do the translation: It is based on stdio, which is
+    # impossible to combine with select (unless forcing no
+    # buffering).
+    if proc.universal_newlines and hasattr(file, 'newlines'):
+        if stdout:
+            stdout = proc._translate_newlines(stdout)
+        if stderr:
+            stderr = proc._translate_newlines(stderr)
+
+    if killed and err_on_timeout:
+        errcode = proc.poll()
+        raise RuntimeError("Operation timed out", errcode, stdout, stderr)
+    else:
+        if killed:
+            proc.poll()
+        else:
+            proc.wait()
+        return (stdout, stderr)
+
+def _proxy_command(gw, gwuser, gwidentity):
+    """
+    Constructs the SSH ProxyCommand option to add to the SSH command to connect
+    via a proxy
+        :param gw: SSH proxy hostname
+        :type gw:  str 
+       
+        :param gwuser: SSH proxy username
+        :type gwuser:  str
+
+        :param gwidentity: SSH proxy identity file 
+        :type gwidentity:  str
+
+  
+        :rtype: str 
+        
+        returns the SSH ProxyCommand option.
+    """
+
+    proxycommand = 'ProxyCommand=ssh -q '
+    if gwidentity:
+        proxycommand += '-i %s ' % os.path.expanduser(gwidentity)
+    if gwuser:
+        proxycommand += '%s' % gwuser
+    else:
+        proxycommand += '%r'
+    proxycommand += '@%s -W %%h:%%p' % gw
+
+    return proxycommand