Fix bug strict host checking
[nepi.git] / src / nepi / util / sshfuncs.py
index 458af49..451a7b4 100644 (file)
@@ -46,7 +46,6 @@ def log(msg, level, out = None, err = None):
 
     logger.log(level, msg)
 
 
     logger.log(level, msg)
 
-
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
 else:
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
 else:
@@ -206,80 +205,10 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
             return func(*p, **kw)
     return rv
 
-def socat(local_socket_name, remote_socket_name,
-        host, user,
-        port = None, 
-        agent = True,
-        sudo = False,
-        identity = None,
-        server_key = None,
-        env = None,
-        tty = False,
-        connect_timeout = 30,
-        retry = 3,
-        strict_host_checking = True):
-    """
-    Executes a remote command, returns ((stdout,stderr),process)
-    """
-    
-    tmp_known_hosts = None
-    hostip = gethostbyname(host)
-
-
-    args = ["socat"]
-    args.append("UNIX-LISTEN:%s,unlink-early,fork" % local_socket_name)
-
-    ssh_args = ['ssh', '-C',
-            # Don't bother with localhost. Makes test easier
-            '-o', 'NoHostAuthenticationForLocalhost=yes',
-            '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
-            '-o', 'ConnectionAttempts=3',
-            '-o', 'ServerAliveInterval=30',
-            '-o', 'TCPKeepAlive=yes',
-            '-l', user, hostip or host]
-
-    if not strict_host_checking:
-        # Do not check for Host key. Unsafe.
-        ssh_args.extend(['-o', 'StrictHostKeyChecking=no'])
-
-    if agent:
-        ssh_args.append('-A')
-
-    if port:
-        ssh_args.append('-p%d' % port)
-
-    if identity:
-        ssh_args.extend(('-i', identity))
-
-    if tty:
-        ssh_args.append('-t')
-        ssh_args.append('-t')
-
-    if server_key:
-        # Create a temporary server key file
-        tmp_known_hosts = make_server_key_args(server_key, host, port)
-        ssh_args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
-
-    ssh_cmd = " ".join(ssh_args)
-
-    exec_cmd = "EXEC:'%s socat STDIO UNIX-CONNECT\:%s'" % (ssh_cmd, 
-            remote_socket_name)
-
-    args.append(exec_cmd)
-    
-    log_msg = " socat - host %s - command %s " % (host, " ".join(args))
-
-    return _retry_rexec(args, log_msg, 
-            stdout = None,
-            stdin = None,
-            stderr = None,
-            env = env, 
-            retry = retry, 
-            tmp_known_hosts = tmp_known_hosts,
-            blocking = False)
-
 def rexec(command, host, user, 
 def rexec(command, host, user, 
-        port = None, 
+        port = None,
+        gwuser = None,
+        gw = None, 
         agent = True,
         sudo = False,
         identity = None,
         agent = True,
         sudo = False,
         identity = None,
@@ -295,9 +224,11 @@ def rexec(command, host, user,
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    
+
     tmp_known_hosts = None
     tmp_known_hosts = None
-    hostip = gethostbyname(host)
+    if not gw:
+        hostip = gethostbyname(host)
+    else: hostip = None
 
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
 
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
@@ -306,6 +237,7 @@ def rexec(command, host, user,
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
+            '-o', 'Batchmode=yes',
             '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
             '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
@@ -318,6 +250,13 @@ def rexec(command, host, user,
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
 
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
 
+    if gw:
+        if gwuser:
+            proxycommand = 'ProxyCommand=ssh -q %s@%s -W %%h:%%p' % (gwuser, gw)
+        else:
+            proxycommand = 'ProxyCommand=ssh -q %%r@%s -W %%h:%%p' % gw
+        args.extend(['-o', proxycommand])
+
     if agent:
         args.append('-A')
 
     if agent:
         args.append('-A')
 
@@ -325,6 +264,7 @@ def rexec(command, host, user,
         args.append('-p%d' % port)
 
     if identity:
         args.append('-p%d' % port)
 
     if identity:
+        identity = os.path.expanduser(identity)
         args.extend(('-i', identity))
 
     if tty:
         args.extend(('-i', identity))
 
     if tty:
@@ -343,8 +283,8 @@ def rexec(command, host, user,
         command = "sudo " + command
 
     args.append(command)
         command = "sudo " + command
 
     args.append(command)
-    
-    log_msg = " rexec - host %s - command %s " % (host, " ".join(args))
+
+    log_msg = " rexec - host %s - command %s " % (str(host), " ".join(map(str, args))) 
 
     stdout = stderr = stdin = subprocess.PIPE
     if forward_x11:
 
     stdout = stderr = stdin = subprocess.PIPE
     if forward_x11:
@@ -360,8 +300,9 @@ def rexec(command, host, user,
             blocking = blocking)
 
 def rcopy(source, dest,
             blocking = blocking)
 
 def rcopy(source, dest,
-        port = None, 
-        agent = True, 
+        port = None,
+        gwuser = None,
+        gw = None,
         recursive = False,
         identity = None,
         server_key = None,
         recursive = False,
         identity = None,
         server_key = None,
@@ -373,14 +314,15 @@ def rcopy(source, dest,
     Source and destination should have the user and host encoded
     as per scp specs.
     
     Source and destination should have the user and host encoded
     as per scp specs.
     
-    Source can be a list of files to copy to a single destination,
-    in which case it is advised that the destination be a folder.
+    Source can be a list of files to copy to a single destination, 
+    (in which case it is advised that the destination be a folder),
+    or a single file in a string.
     """
     """
-    
+
     # Parse destination as <user>@<server>:<path>
     # Parse destination as <user>@<server>:<path>
-    if isinstance(dest, basestring) and ':' in dest:
+    if isinstance(dest, str) and ':' in dest:
         remspec, path = dest.split(':',1)
         remspec, path = dest.split(':',1)
-    elif isinstance(source, basestring) and ':' in source:
+    elif isinstance(source, str) and ':' in source:
         remspec, path = source.split(':',1)
     else:
         raise ValueError, "Both endpoints cannot be local"
         remspec, path = source.split(':',1)
     else:
         raise ValueError, "Both endpoints cannot be local"
@@ -403,10 +345,18 @@ def rcopy(source, dest,
     if port:
         args.append('-P%d' % port)
 
     if port:
         args.append('-P%d' % port)
 
+    if gw:
+        if gwuser:
+            proxycommand = 'ProxyCommand=ssh -q %s@%s -W %%h:%%p' % (gwuser, gw)
+        else:
+            proxycommand = 'ProxyCommand=ssh -q %%r@%s -W %%h:%%p' % gw
+        args.extend(['-o', proxycommand])
+
     if recursive:
         args.append('-r')
 
     if identity:
     if recursive:
         args.append('-r')
 
     if identity:
+        identity = os.path.expanduser(identity)
         args.extend(('-i', identity))
 
     if server_key:
         args.extend(('-i', identity))
 
     if server_key:
@@ -417,20 +367,23 @@ def rcopy(source, dest,
     if not strict_host_checking:
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
     if not strict_host_checking:
         # Do not check for Host key. Unsafe.
         args.extend(['-o', 'StrictHostKeyChecking=no'])
-
+    
     if isinstance(source, list):
         args.extend(source)
     else:
         if openssh_has_persist():
             args.extend([
                 '-o', 'ControlMaster=auto',
     if isinstance(source, list):
         args.extend(source)
     else:
         if openssh_has_persist():
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                '-o', 'ControlPath=%s' % (make_control_path(False, False),)
                 ])
         args.append(source)
 
                 ])
         args.append(source)
 
-    args.append(dest)
+    if isinstance(dest, list):
+        args.extend(dest)
+    else:
+        args.append(dest)
 
 
-    log_msg = " rcopy - host %s - command %s " % (host, " ".join(args))
+    log_msg = " rcopy - host %s - command %s " % (str(host), " ".join(map(str, args)))
     
     return _retry_rexec(args, log_msg, env = None, retry = retry, 
             tmp_known_hosts = tmp_known_hosts,
     
     return _retry_rexec(args, log_msg, env = None, retry = retry, 
             tmp_known_hosts = tmp_known_hosts,
@@ -446,10 +399,13 @@ def rspawn(command, pidfile,
         host = None, 
         port = None, 
         user = None, 
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         identity = None, 
         server_key = None,
         agent = None, 
         identity = None, 
         server_key = None,
-        tty = False):
+        tty = False,
+        strict_host_checking = True):
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
     """
     Spawn a remote command such that it will continue working asynchronously in 
     background. 
@@ -519,10 +475,13 @@ def rspawn(command, pidfile,
         host = host,
         port = port,
         user = user,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
         server_key = server_key,
         agent = agent,
         identity = identity,
         server_key = server_key,
-        tty = tty ,
+        tty = tty,
+        strict_host_checking = strict_host_checking ,
         )
     
     if proc.wait():
         )
     
     if proc.wait():
@@ -535,9 +494,12 @@ def rgetpid(pidfile,
         host = None, 
         port = None, 
         user = None, 
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         identity = None,
         agent = None, 
         identity = None,
-        server_key = None):
+        server_key = None,
+        strict_host_checking = True):
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
     """
     Returns the pid and ppid of a process from a remote file where the 
     information was stored.
@@ -561,9 +523,12 @@ def rgetpid(pidfile,
         host = host,
         port = port,
         user = user,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
         
     if proc.wait():
         )
         
     if proc.wait():
@@ -581,9 +546,12 @@ def rstatus(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         identity = None,
         agent = None, 
         identity = None,
-        server_key = None):
+        server_key = None,
+        strict_host_checking = True):
     """
     Returns a code representing the the status of a remote process
 
     """
     Returns a code representing the the status of a remote process
 
@@ -605,9 +573,12 @@ def rstatus(pid, ppid,
         host = host,
         port = port,
         user = user,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     if proc.wait():
         )
     
     if proc.wait():
@@ -628,11 +599,14 @@ def rkill(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
         host = None, 
         port = None, 
         user = None, 
+        gwuser = None,
+        gw = None,
         agent = None, 
         sudo = False,
         identity = None, 
         server_key = None, 
         agent = None, 
         sudo = False,
         identity = None, 
         server_key = None, 
-        nowait = False):
+        nowait = False,
+        strict_host_checking = True):
     """
     Sends a kill signal to a remote process.
 
     """
     Sends a kill signal to a remote process.
 
@@ -682,9 +656,12 @@ fi
         host = host,
         port = port,
         user = user,
         host = host,
         port = port,
         user = user,
+        gwuser = gwuser,
+        gw = gw,
         agent = agent,
         identity = identity,
         agent = agent,
         identity = identity,
-        server_key = server_key
+        server_key = server_key,
+        strict_host_checking = strict_host_checking
         )
     
     # wait, don't leave zombies around
         )
     
     # wait, don't leave zombies around
@@ -719,7 +696,12 @@ def _retry_rexec(args,
         try:
             err = out = " "
             if blocking:
         try:
             err = out = " "
             if blocking:
-                (out, err) = proc.communicate()
+                #(out, err) = proc.communicate()
+                # The method communicate was re implemented for performance issues
+                # when using python subprocess communicate method the ssh commands 
+                # last one minute each
+                out, err = _communicate(proc, input=None)
+
             elif stdout:
                 out = proc.stdout.read()
                 if proc.poll() and stderr:
             elif stdout:
                 out = proc.stdout.read()
                 if proc.poll() and stderr:
@@ -753,7 +735,125 @@ def _retry_rexec(args,
             if retry <= 0:
                 raise
             retry -= 1
             if retry <= 0:
                 raise
             retry -= 1
-        
+
     return ((out, err), proc)
 
     return ((out, err), proc)
 
+# POSIX
+# Don't remove. The method communicate was re implemented for performance issues
+def _communicate(proc, input, timeout=None, err_on_timeout=True):
+    read_set = []
+    write_set = []
+    stdout = None # Return
+    stderr = None # Return
+
+    killed = False
+
+    if timeout is not None:
+        timelimit = time.time() + timeout
+        killtime = timelimit + 4
+        bailtime = timelimit + 4
+
+    if proc.stdin:
+        # Flush stdio buffer.  This might block, if the user has
+        # been writing to .stdin in an uncontrolled fashion.
+        proc.stdin.flush()
+        if input:
+            write_set.append(proc.stdin)
+        else:
+            proc.stdin.close()
+
+    if proc.stdout:
+        read_set.append(proc.stdout)
+        stdout = []
+
+    if proc.stderr:
+        read_set.append(proc.stderr)
+        stderr = []
+
+    input_offset = 0
+    while read_set or write_set:
+        if timeout is not None:
+            curtime = time.time()
+            if timeout is None or curtime > timelimit:
+                if curtime > bailtime:
+                    break
+                elif curtime > killtime:
+                    signum = signal.SIGKILL
+                else:
+                    signum = signal.SIGTERM
+                # Lets kill it
+                os.kill(proc.pid, signum)
+                select_timeout = 0.5
+            else:
+                select_timeout = timelimit - curtime + 0.1
+        else:
+            select_timeout = 1.0
+
+        if select_timeout > 1.0:
+            select_timeout = 1.0
+
+        try:
+            rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
+        except select.error,e:
+            if e[0] != 4:
+                raise
+            else:
+                continue
+
+        if not rlist and not wlist and not xlist and proc.poll() is not None:
+            # timeout and process exited, say bye
+            break
+
+        if proc.stdin in wlist:
+            # When select has indicated that the file is writable,
+            # we can write up to PIPE_BUF bytes without risk
+            # blocking.  POSIX defines PIPE_BUF >= 512
+            bytes_written = os.write(proc.stdin.fileno(),
+                    buffer(input, input_offset, 512))
+            input_offset += bytes_written
+
+            if input_offset >= len(input):
+                proc.stdin.close()
+                write_set.remove(proc.stdin)
+
+        if proc.stdout in rlist:
+            data = os.read(proc.stdout.fileno(), 1024)
+            if data == "":
+                proc.stdout.close()
+                read_set.remove(proc.stdout)
+            stdout.append(data)
+
+        if proc.stderr in rlist:
+            data = os.read(proc.stderr.fileno(), 1024)
+            if data == "":
+                proc.stderr.close()
+                read_set.remove(proc.stderr)
+            stderr.append(data)
+
+    # All data exchanged.  Translate lists into strings.
+    if stdout is not None:
+        stdout = ''.join(stdout)
+    if stderr is not None:
+        stderr = ''.join(stderr)
+
+    # Translate newlines, if requested.  We cannot let the file
+    # object do the translation: It is based on stdio, which is
+    # impossible to combine with select (unless forcing no
+    # buffering).
+    if proc.universal_newlines and hasattr(file, 'newlines'):
+        if stdout:
+            stdout = proc._translate_newlines(stdout)
+        if stderr:
+            stderr = proc._translate_newlines(stderr)
+
+    if killed and err_on_timeout:
+        errcode = proc.poll()
+        raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
+    else:
+        if killed:
+            proc.poll()
+        else:
+            proc.wait()
+        return (stdout, stderr)
+