still making both branches closer
[nepi.git] / src / nepi / util / sshfuncs.py
index 176afd3..0d0c17d 100644 (file)
@@ -38,13 +38,11 @@ _re_inet = re.compile("\d+:\s+(?P<name>[a-z0-9_-]+)\s+inet6?\s+(?P<inet>[a-f0-9.
 
 logger = logging.getLogger("sshfuncs")
 
 
 logger = logging.getLogger("sshfuncs")
 
-def log(msg, level, out = None, err = None):
+def log(msg, level = logging.DEBUG, out = None, err = None):
     if out:
         msg += " - OUT: %s " % out
     if out:
         msg += " - OUT: %s " % out
-
     if err:
         msg += " - ERROR: %s " % err
     if err:
         msg += " - ERROR: %s " % err
-
     logger.log(level, msg)
 
 if hasattr(os, "devnull"):
     logger.log(level, msg)
 
 if hasattr(os, "devnull"):
@@ -59,6 +57,7 @@ class STDOUT:
     Special value that when given to rspawn in stderr causes stderr to 
     redirect to whatever stdout was redirected to.
     """
     Special value that when given to rspawn in stderr causes stderr to 
     redirect to whatever stdout was redirected to.
     """
+    pass
 
 class ProcStatus:
     """
 
 class ProcStatus:
     """
@@ -80,8 +79,12 @@ def resolve_hostname(host):
     ip = None
 
     if host in ["localhost", "127.0.0.1", "::1"]:
     ip = None
 
     if host in ["localhost", "127.0.0.1", "::1"]:
-        p = subprocess.Popen("ip -o addr list", shell=True,
-                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        p = subprocess.Popen(
+            "ip -o addr list",
+            shell=True,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
         stdout, stderr = p.communicate()
         m = _re_inet.findall(stdout)
         ip = m[0][1].split("/")[0]
         stdout, stderr = p.communicate()
         m = _re_inet.findall(stdout)
         ip = m[0][1].split("/")[0]
@@ -117,12 +120,15 @@ def openssh_has_persist():
     """
     global OPENSSH_HAS_PERSIST
     if OPENSSH_HAS_PERSIST is None:
     """
     global OPENSSH_HAS_PERSIST
     if OPENSSH_HAS_PERSIST is None:
-        proc = subprocess.Popen(["ssh","-v"],
-            stdout = subprocess.PIPE,
-            stderr = subprocess.STDOUT,
-            stdin = open("/dev/null","r") )
-        out,err = proc.communicate()
-        proc.wait()
+        with open("/dev/null") as null:
+            proc = subprocess.Popen(
+                ["ssh", "-v"],
+                stdout = subprocess.PIPE,
+                stderr = subprocess.STDOUT,
+                stdin = null,
+            )
+            out,err = proc.communicate()
+            proc.wait()
         
         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
         
         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
@@ -160,9 +166,8 @@ def make_server_key_args(server_key, host, port):
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
         if os.access(user_hosts_path, os.R_OK):
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
         if os.access(user_hosts_path, os.R_OK):
-            f = open(user_hosts_path, "r")
-            tmp_known_hosts.write(f.read())
-            f.close()
+            with open(user_hosts_path, "r") as f:
+                tmp_known_hosts.write(f.read())
         
     tmp_known_hosts.flush()
     
         
     tmp_known_hosts.flush()
     
@@ -189,12 +194,12 @@ def shell_escape(s):
         return s
     else:
         # unsafe string - escape
         return s
     else:
         # unsafe string - escape
-        def escp(c):
+        def escape(c):
             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
-        s = ''.join(map(escp,s))
+        s = ''.join(map(escape, s))
         return "'%s'" % (s,)
 
 def eintr_retry(func):
         return "'%s'" % (s,)
 
 def eintr_retry(func):
@@ -203,15 +208,15 @@ def eintr_retry(func):
     @functools.wraps(func)
     def rv(*p, **kw):
         retry = kw.pop("_retry", False)
     @functools.wraps(func)
     def rv(*p, **kw):
         retry = kw.pop("_retry", False)
-        for i in xrange(0 if retry else 4):
+        for i in range(0 if retry else 4):
             try:
                 return func(*p, **kw)
             try:
                 return func(*p, **kw)
-            except (select.error, socket.error), args:
+            except (select.error, socket.error) as args:
                 if args[0] == errno.EINTR:
                     continue
                 else:
                     raise 
                 if args[0] == errno.EINTR:
                     continue
                 else:
                     raise 
-            except OSError, e:
+            except OSError as e:
                 if e.errno == errno.EINTR:
                     continue
                 else:
                 if e.errno == errno.EINTR:
                     continue
                 else:
@@ -303,13 +308,13 @@ def rexec(command, host, user,
         stdout = stderr = stdin = None
 
     return _retry_rexec(args, log_msg, 
         stdout = stderr = stdin = None
 
     return _retry_rexec(args, log_msg, 
-            stderr = stderr,
-            stdin = stdin,
-            stdout = stdout,
-            env = env, 
-            retry = retry, 
-            tmp_known_hosts = tmp_known_hosts,
-            blocking = blocking)
+                        stderr = stderr,
+                        stdin = stdin,
+                        stdout = stdout,
+                        env = env, 
+                        retry = retry, 
+                        tmp_known_hosts = tmp_known_hosts,
+                        blocking = blocking)
 
 def rcopy(source, dest,
         port = None,
 
 def rcopy(source, dest,
         port = None,
@@ -337,7 +342,7 @@ def rcopy(source, dest,
     elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
     elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
-        raise ValueError, "Both endpoints cannot be local"
+        raise ValueError("Both endpoints cannot be local")
     user,host = remspec.rsplit('@',1)
     
     # plain scp
     user,host = remspec.rsplit('@',1)
     
     # plain scp
@@ -403,22 +408,22 @@ def rcopy(source, dest,
             blocking = True)
 
 def rspawn(command, pidfile, 
             blocking = True)
 
 def rspawn(command, pidfile, 
-        stdout = '/dev/null', 
-        stderr = STDOUT, 
-        stdin = '/dev/null',
-        home = None, 
-        create_home = False, 
-        sudo = False,
-        host = None, 
-        port = None, 
-        user = None, 
-        gwuser = None,
-        gw = None,
-        agent = None, 
-        identity = None, 
-        server_key = None,
-        tty = False,
-        strict_host_checking = True):
+           stdout = '/dev/null', 
+           stderr = STDOUT, 
+           stdin = '/dev/null',
+           home = None, 
+           create_home = False, 
+           sudo = False,
+           host = None, 
+           port = None, 
+           user = None, 
+           gwuser = None,
+           gw = None,
+           agent = None, 
+           identity = None, 
+           server_key = None,
+           tty = False,
+           strict_host_checking = True):
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
@@ -498,21 +503,21 @@ def rspawn(command, pidfile,
         )
     
     if proc.wait():
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
+        raise RuntimeError("Failed to set up application on host %s: %s %s" % (host, out,err,))
 
     return ((out, err), proc)
 
 @eintr_retry
 def rgetpid(pidfile,
 
     return ((out, err), proc)
 
 @eintr_retry
 def rgetpid(pidfile,
-        host = None, 
-        port = None, 
-        user = None, 
-        gwuser = None,
-        gw = None,
-        agent = None, 
-        identity = None,
-        server_key = None,
-        strict_host_checking = True):
+            host = None, 
+            port = None, 
+            user = None, 
+            gwuser = None,
+            gw = None,
+            agent = None, 
+            identity = None,
+            server_key = None,
+            strict_host_checking = True):
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
@@ -549,7 +554,7 @@ def rgetpid(pidfile,
     
     if out:
         try:
     
     if out:
         try:
-            return map(int,out.strip().split(' ',1))
+            return [ int(x) for x in out.strip().split(' ',1)) ]
         except:
             # Ignore, many ways to fail that don't matter that much
             return None
         except:
             # Ignore, many ways to fail that don't matter that much
             return None
@@ -683,26 +688,27 @@ fi
     return (out, err), proc
 
 def _retry_rexec(args,
     return (out, err), proc
 
 def _retry_rexec(args,
-        log_msg,
-        stdout = subprocess.PIPE,
-        stdin = subprocess.PIPE, 
-        stderr = subprocess.PIPE,
-        env = None,
-        retry = 3,
-        tmp_known_hosts = None,
-        blocking = True):
-
-    for x in xrange(retry):
+                 log_msg,
+                 stdout = subprocess.PIPE,
+                 stdin = subprocess.PIPE, 
+                 stderr = subprocess.PIPE,
+                 env = None,
+                 retry = 3,
+                 tmp_known_hosts = None,
+                 blocking = True):
+
+    for x in range(retry):
         # display command actually invoked when debug is turned on
         message = " ".join( [ "'{}'".format(arg) for arg in args ] )
         log("sshfuncs: invoking {}".format(message), logging.DEBUG)
         # connects to the remote host and starts a remote connection
         # display command actually invoked when debug is turned on
         message = " ".join( [ "'{}'".format(arg) for arg in args ] )
         log("sshfuncs: invoking {}".format(message), logging.DEBUG)
         # connects to the remote host and starts a remote connection
-        proc = subprocess.Popen(args,
-                env = env,
-                stdout = stdout,
-                stdin = stdin, 
-                stderr = stderr)
-        
+        proc = subprocess.Popen(
+            args,
+            env = env,
+            stdout = stdout,
+            stdin = stdin, 
+            stderr = stderr,
+        )        
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
@@ -716,7 +722,9 @@ def _retry_rexec(args,
                 # The method communicate was re implemented for performance issues
                 # when using python subprocess communicate method the ssh commands 
                 # last one minute each
                 # The method communicate was re implemented for performance issues
                 # when using python subprocess communicate method the ssh commands 
                 # last one minute each
+                #log("BEFORE communicate", level=logging.INFO); import time; beg=time.time()
                 out, err = _communicate(proc, input=None)
                 out, err = _communicate(proc, input=None)
+                #log("AFTER communicate - {}s".format(time.time()-beg), level=logging.INFO)
 
             elif stdout:
                 out = proc.stdout.read()
 
             elif stdout:
                 out = proc.stdout.read()
@@ -744,7 +752,7 @@ def _retry_rexec(args,
                     time.sleep(t)
                     continue
             break
                     time.sleep(t)
                     continue
             break
-        except RuntimeError, e:
+        except RuntimeError as e:
             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
             log(msg, logging.DEBUG, out, err)
 
             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
             log(msg, logging.DEBUG, out, err)
 
@@ -810,7 +818,7 @@ def _communicate(proc, input, timeout=None, err_on_timeout=True):
 
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
 
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
-        except select.error,e:
+        except select.error as e:
             if e[0] != 4:
                 raise
             else:
             if e[0] != 4:
                 raise
             else:
@@ -864,7 +872,7 @@ def _communicate(proc, input, timeout=None, err_on_timeout=True):
 
     if killed and err_on_timeout:
         errcode = proc.poll()
 
     if killed and err_on_timeout:
         errcode = proc.poll()
-        raise RuntimeError("Operation timed out", errcode, stdout, stderr)
+        raise RuntimeError("Operation timed out", errcode, stdout, stderr)
     else:
         if killed:
             proc.poll()
     else:
         if killed:
             proc.poll()