SSHApi functionality migrated to LinuxNode
[nepi.git] / src / neco / util / sshfuncs.py
index 872143d..e81c558 100644 (file)
@@ -12,8 +12,7 @@ import re
 import tempfile
 import hashlib
 
-OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyyyy_ssh_control_path"
+TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -22,8 +21,6 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
-hostbyname_cache = dict()
-
 class STDOUT: 
     """
     Special value that when given to rspawn in stderr causes stderr to 
@@ -45,6 +42,17 @@ class NOT_STARTED:
     Process hasn't started running yet (this should be very rare)
     """
 
+hostbyname_cache = dict()
+
+def gethostbyname(host):
+    hostbyname = hostbyname_cache.get(host)
+    if not hostbyname:
+        hostbyname = socket.gethostbyname(host)
+        hostbyname_cache[host] = hostbyname
+    return hostbyname
+
+OPENSSH_HAS_PERSIST = None
+
 def openssh_has_persist():
     """ The ssh_config options ControlMaster and ControlPersist allow to
     reuse a same network connection for multiple ssh sessions. In this 
@@ -66,6 +74,59 @@ def openssh_has_persist():
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
     return OPENSSH_HAS_PERSIST
 
+def make_server_key_args(server_key, host, port):
+    """ Returns a reference to a temporary known_hosts file, to which 
+    the server key has been added. 
+    
+    Make sure to hold onto the temp file reference until the process is 
+    done with it
+
+    :param server_key: the server public key
+    :type server_key: str
+
+    :param host: the hostname
+    :type host: str
+
+    :param port: the ssh port
+    :type port: str
+
+    """
+    if port is not None:
+        host = '%s:%s' % (host, str(port))
+
+    # Create a temporary server key file
+    tmp_known_hosts = tempfile.NamedTemporaryFile()
+   
+    hostbyname = gethostbyname(host) 
+
+    # Add the intended host key
+    tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
+    
+    # If we're not in strict mode, add user-configured keys
+    if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
+        user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
+        if os.access(user_hosts_path, os.R_OK):
+            f = open(user_hosts_path, "r")
+            tmp_known_hosts.write(f.read())
+            f.close()
+        
+    tmp_known_hosts.flush()
+    
+    return tmp_known_hosts
+
+def make_control_path(agent, forward_x11):
+    ctrl_path = "/tmp/nepi_ssh"
+
+    if agent:
+        ctrl_path +="_a"
+
+    if forward_x11:
+        ctrl_path +="_x"
+
+    ctrl_path += "-%r@%h:%p"
+
+    return ctrl_path
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line 
     arguments """
@@ -105,58 +166,64 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
-def make_connkey(user, host, port):
-    connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
-    if len(connkey) > 60:
-        connkey = hashlib.sha1(connkey).hexdigest()
-    return connkey
-
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
-        stdin = "", 
-        identity_file = None,
+        stdin = None,
+        identity = None,
+        server_key = None,
+        env = None,
         tty = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
         connect_timeout = 30,
-        persistent = True):
+        persistent = True,
+        forward_x11 = False):
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    connkey = make_connkey(user, host, port)
+    
+    tmp_known_hosts = None
+    hostip = gethostbyname(host)
+
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
-            # XXX: Possible security issue
-            # Avoid interactive requests to accept new host keys
-            '-o', 'StrictHostKeyChecking=no',
             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
-            '-l', user, host]
+            '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+            '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
             '-o', 'ControlPersist=60' ])
+
     if agent:
         args.append('-A')
+
     if port:
         args.append('-p%d' % port)
-    if identity_file:
-        args.extend(('-i', identity_file))
+
+    if identity:
+        args.extend(('-i', identity))
+
     if tty:
         args.append('-t')
-        if sudo:
-            args.append('-t')
+        args.append('-t')
+
+    if forward_x11:
+        args.append('-X')
+
+    if server_key:
+        # Create a temporary server key file
+        tmp_known_hosts = make_server_key_args(server_key, host, port)
+        args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
-    if sudo:
-        command = "sudo " + command
     args.append(command)
 
     for x in xrange(retry or 3):
@@ -166,8 +233,15 @@ def rexec(command, host, user,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
         
+        # attach tempfile object to the process, to make sure the file stays
+        # alive until the process is finished with it
+        proc._known_hosts = tmp_known_hosts
+    
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
+            if TRACE:
+                print "COMMAND host %s, command %s, out %s, error %s" % (host, " ".join(args), out, err)
+
             if proc.poll():
                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
                     # SSH error, can safely retry
@@ -176,20 +250,25 @@ def rexec(command, host, user,
                     # Probably timed out or plain failed but can retry
                     continue
             break
-        except RuntimeError,e:
+        except RuntimeError, e:
+            if TRACE:
+                print "EXCEPTION host %s, command %s, out %s, error %s, exception TIMEOUT ->  %s" % (
+                        host, " ".join(args), out, err, e.args)
+
             if retry <= 0:
                 raise
             retry -= 1
         
     return ((out, err), proc)
 
-def rcopy(source, dest, host, user,
+def rcopy(source, dest,
         port = None, 
         agent = True, 
         recursive = False,
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
-    Copies file from/to remote sites.
+    Copies from/to remote sites.
     
     Source and destination should have the user and host encoded
     as per scp specs.
@@ -206,9 +285,11 @@ def rcopy(source, dest, host, user,
     in which case it is advised that the destination be a folder.
     """
     
+    if TRACE:
+        print "scp", source, dest
+    
     if isinstance(source, file) and source.tell() == 0:
         source = source.name
-
     elif hasattr(source, 'read'):
         tmp = tempfile.NamedTemporaryFile()
         while True:
@@ -224,32 +305,48 @@ def rcopy(source, dest, host, user,
             or hasattr(source, 'read')  or hasattr(dest, 'write'):
         assert not recursive
         
-        connkey = make_connkey(user,host,port)
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
+        tmp_known_hosts = None
+        hostip = gethostbyname(host)
+        
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectTimeout=60',
                 '-o', 'ConnectionAttempts=3',
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes',
-                host ]
+                hostip or host ]
+
         if openssh_has_persist():
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+                '-o', 'ControlPath=%s' % (make_control_path(agent, False),),
                 '-o', 'ControlPersist=60' ])
+
         if port:
             args.append('-P%d' % port)
-        if identity_file:
-            args.extend(('-i', identity_file))
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
         
         if isinstance(source, file) or hasattr(source, 'read'):
-            args.append('cat > %s' % dest)
+            args.append('cat > %s' % (shell_escape(path),))
         elif isinstance(dest, file) or hasattr(dest, 'write'):
-            args.append('cat %s' % dest)
+            args.append('cat %s' % (shell_escape(path),))
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
         
@@ -260,6 +357,7 @@ def rcopy(source, dest, host, user,
                     stderr = subprocess.PIPE,
                     stdin = source)
             err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,err), proc)
         elif isinstance(dest, file):
@@ -268,6 +366,7 @@ def rcopy(source, dest, host, user,
                     stderr = subprocess.PIPE,
                     stdin = source)
             err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,err), proc)
         elif hasattr(source, 'read'):
@@ -304,6 +403,7 @@ def rcopy(source, dest, host, user,
             proc.stdin.close()
             err.append(proc.stderr.read())
                 
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,''.join(err)), proc)
         elif hasattr(dest, 'write'):
@@ -338,50 +438,68 @@ def rcopy(source, dest, host, user,
                     break
             err.append(proc.stderr.read())
                 
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,''.join(err)), proc)
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
     else:
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
         # plain scp
+        tmp_known_hosts = None
+
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectTimeout=60',
                 '-o', 'ConnectionAttempts=3',
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes' ]
                 
         if port:
             args.append('-P%d' % port)
+
         if recursive:
             args.append('-r')
-        if identity_file:
-            args.extend(('-i', identity_file))
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
         if isinstance(source,list):
             args.extend(source)
         else:
             if openssh_has_persist():
-                connkey = make_connkey(user,host,port)
                 args.extend([
-                    '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
+                    '-o', 'ControlMaster=auto',
+                    '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                    ])
             args.append(source)
-        args.append("%s@%s:%s" %(user, host, dest))
+
+        args.append(dest)
 
         # connects to the remote host and starts a remote connection
         proc = subprocess.Popen(args, 
                 stdout = subprocess.PIPE,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
+        proc._known_hosts = tmp_known_hosts
         
-        comm = proc.communicate()
+        (out, err) = proc.communicate()
         eintr_retry(proc.wait)()
-        return (comm, proc)
+        return ((out, err), proc)
 
 def rspawn(command, pidfile, 
         stdout = '/dev/null', 
@@ -389,12 +507,13 @@ def rspawn(command, pidfile,
         stdin = '/dev/null', 
         home = None, 
         create_home = False, 
+        sudo = False,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        sudo = False,
-        identity_file = None, 
+        identity = None, 
+        server_key = None,
         tty = False):
     """
     Spawn a remote command such that it will continue working asynchronously.
@@ -416,7 +535,7 @@ def rspawn(command, pidfile,
         
         sudo: whether the command needs to be executed as root
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         (stdout, stderr), process
@@ -433,52 +552,51 @@ def rspawn(command, pidfile,
     
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
-        'pidfile' : pidfile,
-        
+        'pidfile' : shell_escape(pidfile),
         'stdout' : stdout,
         'stderr' : stderr,
         'stdin' : stdin,
     }
     
-    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-            'command' : daemon_command,
-            
+    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
+            'command' : shell_escape(daemon_command),
             'sudo' : 'sudo -S' if sudo else '',
-            
-            'pidfile' : pidfile,
-            'gohome' : 'cd %s ; ' % home if home else '',
-            'create' : 'mkdir -p %s ; ' % home if create_home else '',
+            'pidfile' : shell_escape(pidfile),
+            'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
+            'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
         }
 
-    (out,err), proc = rexec(
+    (out,err),proc = rexec(
         cmd,
         host = host,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file,
-        tty = tty
+        identity = identity,
+        server_key = server_key,
+        tty = tty ,
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
 
-    return (out,err),proc
+    return ((out, err), proc)
 
 @eintr_retry
-def rcheck_pid(pidfile,
+def rcheckpid(pidfile,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
     Check the pidfile of a process spawned with remote_spawn.
     
     Parameters:
         pidfile: the pidfile passed to remote_span
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -494,7 +612,8 @@ def rcheck_pid(pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
         
     if proc.wait():
@@ -513,14 +632,15 @@ def rstatus(pid, ppid,
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
     Check the status of a process spawned with remote_spawn.
     
     Parameters:
         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -528,7 +648,8 @@ def rstatus(pid, ppid,
     """
 
     (out,err),proc = rexec(
-        "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
+        # Check only by pid. pid+ppid does not always work (especially with sudo) 
+        " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
             'ppid' : ppid,
             'pid' : pid,
         },
@@ -536,32 +657,32 @@ def rstatus(pid, ppid,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
     
     if proc.wait():
         return NOT_STARTED
     
     status = False
-    if out:
-        try:
-            status = bool(int(out.strip()))
-        except:
-            if out or err:
-                logging.warn("Error checking remote status:\n%s%s\n", out, err)
-            # Ignore, many ways to fail that don't matter that much
-            return NOT_STARTED
+    if err:
+        if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
+            status = True
+    elif out:
+        status = (out.strip() == 'wait')
+    else:
+        return NOT_STARTED
     return RUNNING if status else FINISHED
-    
 
 @eintr_retry
-def rkill(pid, ppid, 
+def rkill(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
         sudo = False,
-        identity_file = None, 
+        identity = None, 
+        server_key = None, 
         nowait = False):
     """
     Kill a process spawned with remote_spawn.
@@ -574,7 +695,7 @@ def rkill(pid, ppid,
         
         sudo: whether the command was run with sudo - careful killing like this.
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -615,12 +736,15 @@ fi
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
     
     # wait, don't leave zombies around
     proc.wait()
 
+    return (out, err), proc
+
 # POSIX
 def _communicate(self, input, timeout=None, err_on_timeout=True):
     read_set = []