SSHApi functionality migrated to LinuxNode
[nepi.git] / src / neco / util / sshfuncs.py
index dd04ba6..e81c558 100644 (file)
@@ -12,8 +12,7 @@ import re
 import tempfile
 import hashlib
 
-OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyy_ssh_ctrl_path"
+TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -22,8 +21,6 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
-hostbyname_cache = dict()
-
 class STDOUT: 
     """
     Special value that when given to rspawn in stderr causes stderr to 
@@ -45,6 +42,17 @@ class NOT_STARTED:
     Process hasn't started running yet (this should be very rare)
     """
 
+hostbyname_cache = dict()
+
+def gethostbyname(host):
+    hostbyname = hostbyname_cache.get(host)
+    if not hostbyname:
+        hostbyname = socket.gethostbyname(host)
+        hostbyname_cache[host] = hostbyname
+    return hostbyname
+
+OPENSSH_HAS_PERSIST = None
+
 def openssh_has_persist():
     """ The ssh_config options ControlMaster and ControlPersist allow to
     reuse a same network connection for multiple ssh sessions. In this 
@@ -66,6 +74,59 @@ def openssh_has_persist():
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
     return OPENSSH_HAS_PERSIST
 
+def make_server_key_args(server_key, host, port):
+    """ Returns a reference to a temporary known_hosts file, to which 
+    the server key has been added. 
+    
+    Make sure to hold onto the temp file reference until the process is 
+    done with it
+
+    :param server_key: the server public key
+    :type server_key: str
+
+    :param host: the hostname
+    :type host: str
+
+    :param port: the ssh port
+    :type port: str
+
+    """
+    if port is not None:
+        host = '%s:%s' % (host, str(port))
+
+    # Create a temporary server key file
+    tmp_known_hosts = tempfile.NamedTemporaryFile()
+   
+    hostbyname = gethostbyname(host) 
+
+    # Add the intended host key
+    tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
+    
+    # If we're not in strict mode, add user-configured keys
+    if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
+        user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
+        if os.access(user_hosts_path, os.R_OK):
+            f = open(user_hosts_path, "r")
+            tmp_known_hosts.write(f.read())
+            f.close()
+        
+    tmp_known_hosts.flush()
+    
+    return tmp_known_hosts
+
+def make_control_path(agent, forward_x11):
+    ctrl_path = "/tmp/nepi_ssh"
+
+    if agent:
+        ctrl_path +="_a"
+
+    if forward_x11:
+        ctrl_path +="_x"
+
+    ctrl_path += "-%r@%h:%p"
+
+    return ctrl_path
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line 
     arguments """
@@ -105,81 +166,63 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
-def make_connkey(user, host, port, x11, agent):
-    # It is important to consider the x11 and agent forwarding
-    # parameters when creating the connection key since the parameters
-    # used for the first ssh connection will determine the
-    # parameters of all subsequent connections using the same key
-    x11 = 1 if x11 else 0
-    agent = 1 if agent else 0
-
-    connkey = repr((user, host, port, x11, agent)
-            ).encode("base64").strip().replace('/','.')
-
-    if len(connkey) > 60:
-        connkey = hashlib.sha1(connkey).hexdigest()
-    return connkey
-
-def make_control_path(user, host, port, x11, agent):
-    connkey = make_connkey(user, host, port, x11, agent)
-    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
-
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
         stdin = None,
         identity = None,
+        server_key = None,
         env = None,
         tty = False,
-        x11 = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
         connect_timeout = 30,
-        persistent = True):
+        persistent = True,
+        forward_x11 = False):
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
+    
+    tmp_known_hosts = None
+    hostip = gethostbyname(host)
+
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
-            # XXX: Possible security issue
-            # Avoid interactive requests to accept new host keys
-            '-o', 'StrictHostKeyChecking=no',
             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
-            '-l', user, host]
+            '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
-        control_path = make_control_path(user, host, port, x11, agent)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=%s' % control_path,
+            '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
             '-o', 'ControlPersist=60' ])
+
     if agent:
         args.append('-A')
+
     if port:
         args.append('-p%d' % port)
+
     if identity:
         args.extend(('-i', identity))
+
     if tty:
         args.append('-t')
-        if sudo:
-            args.append('-t')
-    if x11:
-        args.append('-X')
+        args.append('-t')
 
-    if env:
-        export = ''
-        for envkey, envval in env.iteritems():
-            export += '%s=%s ' % (envkey, envval)
-        command = export + command
+    if forward_x11:
+        args.append('-X')
 
-    if sudo:
-        command = "sudo " + command
+    if server_key:
+        # Create a temporary server key file
+        tmp_known_hosts = make_server_key_args(server_key, host, port)
+        args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
     args.append(command)
 
@@ -190,8 +233,15 @@ def rexec(command, host, user,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
         
+        # attach tempfile object to the process, to make sure the file stays
+        # alive until the process is finished with it
+        proc._known_hosts = tmp_known_hosts
+    
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
+            if TRACE:
+                print "COMMAND host %s, command %s, out %s, error %s" % (host, " ".join(args), out, err)
+
             if proc.poll():
                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
                     # SSH error, can safely retry
@@ -200,7 +250,11 @@ def rexec(command, host, user,
                     # Probably timed out or plain failed but can retry
                     continue
             break
-        except RuntimeError,e:
+        except RuntimeError, e:
+            if TRACE:
+                print "EXCEPTION host %s, command %s, out %s, error %s, exception TIMEOUT ->  %s" % (
+                        host, " ".join(args), out, err, e.args)
+
             if retry <= 0:
                 raise
             retry -= 1
@@ -210,67 +264,242 @@ def rexec(command, host, user,
 def rcopy(source, dest,
         port = None, 
         agent = True, 
-        identity = None):
+        recursive = False,
+        identity = None,
+        server_key = None):
     """
-    Copies file from/to remote sites.
+    Copies from/to remote sites.
     
     Source and destination should have the user and host encoded
     as per scp specs.
     
+    If source is a file object, a special mode will be used to
+    create the remote file with the same contents.
+    
+    If dest is a file object, the remote file (source) will be
+    read and written into dest.
+    
+    In these modes, recursive cannot be True.
+    
     Source can be a list of files to copy to a single destination,
     in which case it is advised that the destination be a folder.
     """
     
-    # Parse destination as <user>@<server>:<path>
-    if isinstance(dest, basestring) and ':' in dest:
-        remspec, path = dest.split(':',1)
-    elif isinstance(source, basestring) and ':' in source:
-        remspec, path = source.split(':',1)
-    else:
-        raise ValueError, "Both endpoints cannot be local"
-    user, host = remspec.rsplit('@',1)
-
-    raw_string = r'''rsync -rlpcSz --timeout=900 '''
-    raw_string += r''' -e 'ssh -o BatchMode=yes '''
-    raw_string += r''' -o NoHostAuthenticationForLocalhost=yes '''
-    raw_string += r''' -o StrictHostKeyChecking=no '''
-    raw_string += r''' -o ConnectionAttempts=3 '''
-    if openssh_has_persist():
-        control_path = make_control_path(user, host, port, False, agent)
-        raw_string += r''' -o ControlMaster=auto '''
-        raw_string += r''' -o ControlPath=%s ''' % control_path
-    if agent:
-        raw_string += r''' -A '''
-
-    if port:
-        raw_string += r''' -p %d ''' % port
+    if TRACE:
+        print "scp", source, dest
     
-    if identity:
-        raw_string += r''' -i "%s" ''' % identity
+    if isinstance(source, file) and source.tell() == 0:
+        source = source.name
+    elif hasattr(source, 'read'):
+        tmp = tempfile.NamedTemporaryFile()
+        while True:
+            buf = source.read(65536)
+            if buf:
+                tmp.write(buf)
+            else:
+                break
+        tmp.seek(0)
+        source = tmp.name
     
-    # closing -e 'ssh...'
-    raw_string += r''' ' '''
-
-    if isinstance(source,list):
-        source = ' '.join(source)
+    if isinstance(source, file) or isinstance(dest, file) \
+            or hasattr(source, 'read')  or hasattr(dest, 'write'):
+        assert not recursive
+        
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
+        tmp_known_hosts = None
+        hostip = gethostbyname(host)
+        
+        args = ['ssh', '-l', user, '-C',
+                # Don't bother with localhost. Makes test easier
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
+                hostip or host ]
+
+        if openssh_has_persist():
+            args.extend([
+                '-o', 'ControlMaster=auto',
+                '-o', 'ControlPath=%s' % (make_control_path(agent, False),),
+                '-o', 'ControlPersist=60' ])
+
+        if port:
+            args.append('-P%d' % port)
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+        
+        if isinstance(source, file) or hasattr(source, 'read'):
+            args.append('cat > %s' % (shell_escape(path),))
+        elif isinstance(dest, file) or hasattr(dest, 'write'):
+            args.append('cat %s' % (shell_escape(path),))
+        else:
+            raise AssertionError, "Unreachable code reached! :-Q"
+        
+        # connects to the remote host and starts a remote connection
+        if isinstance(source, file):
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = source)
+            err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,err), proc)
+        elif isinstance(dest, file):
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = source)
+            err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,err), proc)
+        elif hasattr(source, 'read'):
+            # file-like (but not file) source
+            proc = subprocess.Popen(args, 
+                    stdout = open('/dev/null','w'),
+                    stderr = subprocess.PIPE,
+                    stdin = subprocess.PIPE)
+            
+            buf = None
+            err = []
+            while True:
+                if not buf:
+                    buf = source.read(4096)
+                if not buf:
+                    #EOF
+                    break
+                
+                rdrdy, wrdy, broken = select.select(
+                    [proc.stderr],
+                    [proc.stdin],
+                    [proc.stderr,proc.stdin])
+                
+                if proc.stderr in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    err.append(os.read(proc.stderr.fileno(), 4096))
+                
+                if proc.stdin in wrdy:
+                    proc.stdin.write(buf)
+                    buf = None
+                
+                if broken:
+                    break
+            proc.stdin.close()
+            err.append(proc.stderr.read())
+                
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,''.join(err)), proc)
+        elif hasattr(dest, 'write'):
+            # file-like (but not file) dest
+            proc = subprocess.Popen(args, 
+                    stdout = subprocess.PIPE,
+                    stderr = subprocess.PIPE,
+                    stdin = open('/dev/null','w'))
+            
+            buf = None
+            err = []
+            while True:
+                rdrdy, wrdy, broken = select.select(
+                    [proc.stderr, proc.stdout],
+                    [],
+                    [proc.stderr, proc.stdout])
+                
+                if proc.stderr in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    err.append(os.read(proc.stderr.fileno(), 4096))
+                
+                if proc.stdout in rdrdy:
+                    # use os.read for fully unbuffered behavior
+                    buf = os.read(proc.stdout.fileno(), 4096)
+                    dest.write(buf)
+                    
+                    if not buf:
+                        #EOF
+                        break
+                
+                if broken:
+                    break
+            err.append(proc.stderr.read())
+                
+            proc._known_hosts = tmp_known_hosts
+            eintr_retry(proc.wait)()
+            return ((None,''.join(err)), proc)
+        else:
+            raise AssertionError, "Unreachable code reached! :-Q"
     else:
-        source = '"%s"' % source
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        
+        # plain scp
+        tmp_known_hosts = None
+
+        args = ['scp', '-q', '-p', '-C',
+                # Don't bother with localhost. Makes test easier
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes' ]
+                
+        if port:
+            args.append('-P%d' % port)
+
+        if recursive:
+            args.append('-r')
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+
+        if isinstance(source,list):
+            args.extend(source)
+        else:
+            if openssh_has_persist():
+                args.extend([
+                    '-o', 'ControlMaster=auto',
+                    '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                    ])
+            args.append(source)
 
-    raw_string += r''' %s ''' % source
-    raw_string += r''' %s ''' % dest
+        args.append(dest)
 
-    # connects to the remote host and starts a remote connection
-    proc = subprocess.Popen(raw_string,
-            shell=True,
-            stdout = subprocess.PIPE,
-            stdin = subprocess.PIPE, 
-            stderr = subprocess.PIPE)
-  
-    comm = proc.communicate()
-    eintr_retry(proc.wait)()
-    return (comm, proc)
+        # connects to the remote host and starts a remote connection
+        proc = subprocess.Popen(args, 
+                stdout = subprocess.PIPE,
+                stdin = subprocess.PIPE, 
+                stderr = subprocess.PIPE)
+        proc._known_hosts = tmp_known_hosts
+        
+        (out, err) = proc.communicate()
+        eintr_retry(proc.wait)()
+        return ((out, err), proc)
 
 def rspawn(command, pidfile, 
         stdout = '/dev/null', 
@@ -278,12 +507,13 @@ def rspawn(command, pidfile,
         stdin = '/dev/null', 
         home = None, 
         create_home = False, 
+        sudo = False,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        sudo = False,
         identity = None, 
+        server_key = None,
         tty = False):
     """
     Spawn a remote command such that it will continue working asynchronously.
@@ -319,41 +549,38 @@ def rspawn(command, pidfile,
         stderr = '&1'
     else:
         stderr = ' ' + stderr
-   
-    #XXX: ppid is always 1!!!
+    
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
-        'pidfile' : pidfile,
-        
+        'pidfile' : shell_escape(pidfile),
         'stdout' : stdout,
         'stderr' : stderr,
         'stdin' : stdin,
     }
     
-    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-            'command' : daemon_command,
-            
+    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
+            'command' : shell_escape(daemon_command),
             'sudo' : 'sudo -S' if sudo else '',
-            
-            'pidfile' : pidfile,
-            'gohome' : 'cd %s ; ' % home if home else '',
-            'create' : 'mkdir -p %s ; ' % home if create_home else '',
+            'pidfile' : shell_escape(pidfile),
+            'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
+            'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
         }
 
-    (out,err), proc = rexec(
+    (out,err),proc = rexec(
         cmd,
         host = host,
         port = port,
         user = user,
         agent = agent,
         identity = identity,
-        tty = tty
+        server_key = server_key,
+        tty = tty ,
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
 
-    return (out,err),proc
+    return ((out, err), proc)
 
 @eintr_retry
 def rcheckpid(pidfile,
@@ -361,7 +588,8 @@ def rcheckpid(pidfile,
         port = None, 
         user = None, 
         agent = None, 
-        identity = None):
+        identity = None,
+        server_key = None):
     """
     Check the pidfile of a process spawned with remote_spawn.
     
@@ -384,7 +612,8 @@ def rcheckpid(pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
         
     if proc.wait():
@@ -403,7 +632,8 @@ def rstatus(pid, ppid,
         port = None, 
         user = None, 
         agent = None, 
-        identity = None):
+        identity = None,
+        server_key = None):
     """
     Check the status of a process spawned with remote_spawn.
     
@@ -417,9 +647,9 @@ def rstatus(pid, ppid,
         One of NOT_STARTED, RUNNING, FINISHED
     """
 
-    # XXX: ppid unused
     (out,err),proc = rexec(
-        "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
+        # Check only by pid. pid+ppid does not always work (especially with sudo) 
+        " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
             'ppid' : ppid,
             'pid' : pid,
         },
@@ -427,23 +657,22 @@ def rstatus(pid, ppid,
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
     
     if proc.wait():
         return NOT_STARTED
     
     status = False
-    if out:
-        try:
-            status = bool(int(out.strip()))
-        except:
-            if out or err:
-                logging.warn("Error checking remote status:\n%s%s\n", out, err)
-            # Ignore, many ways to fail that don't matter that much
-            return NOT_STARTED
+    if err:
+        if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
+            status = True
+    elif out:
+        status = (out.strip() == 'wait')
+    else:
+        return NOT_STARTED
     return RUNNING if status else FINISHED
-    
 
 @eintr_retry
 def rkill(pid, ppid,
@@ -453,6 +682,7 @@ def rkill(pid, ppid,
         agent = None, 
         sudo = False,
         identity = None, 
+        server_key = None, 
         nowait = False):
     """
     Kill a process spawned with remote_spawn.
@@ -506,12 +736,15 @@ fi
         port = port,
         user = user,
         agent = agent,
-        identity = identity
+        identity = identity,
+        server_key = server_key
         )
     
     # wait, don't leave zombies around
     proc.wait()
 
+    return (out, err), proc
+
 # POSIX
 def _communicate(self, input, timeout=None, err_on_timeout=True):
     read_set = []