Added example for Linux Application using CCNx
[nepi.git] / src / neco / util / sshfuncs.py
index aad6039..982e23a 100644 (file)
@@ -1,19 +1,29 @@
 import base64
 import errno
+import hashlib
+import logging
 import os
 import os.path
+import re
 import select
 import signal
 import socket
 import subprocess
 import time
-import traceback
-import re
 import tempfile
-import hashlib
 
-OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyy_ssh_ctrl_path"
+
+logger = logging.getLogger("sshfuncs")
+
+def log(msg, level, out = None, err = None):
+    if out:
+        msg += " - OUT: %s " % out
+
+    if err:
+        msg += " - ERROR: %s " % err
+
+    logger.log(level, msg)
+
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -22,8 +32,6 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
-hostbyname_cache = dict()
-
 class STDOUT: 
     """
     Special value that when given to rspawn in stderr causes stderr to 
@@ -45,6 +53,17 @@ class NOT_STARTED:
     Process hasn't started running yet (this should be very rare)
     """
 
+hostbyname_cache = dict()
+
+def gethostbyname(host):
+    hostbyname = hostbyname_cache.get(host)
+    if not hostbyname:
+        hostbyname = socket.gethostbyname(host)
+        hostbyname_cache[host] = hostbyname
+    return hostbyname
+
+OPENSSH_HAS_PERSIST = None
+
 def openssh_has_persist():
     """ The ssh_config options ControlMaster and ControlPersist allow to
     reuse a same network connection for multiple ssh sessions. In this 
@@ -66,6 +85,59 @@ def openssh_has_persist():
         OPENSSH_HAS_PERSIST = bool(vre.match(out))
     return OPENSSH_HAS_PERSIST
 
+def make_server_key_args(server_key, host, port):
+    """ Returns a reference to a temporary known_hosts file, to which 
+    the server key has been added. 
+    
+    Make sure to hold onto the temp file reference until the process is 
+    done with it
+
+    :param server_key: the server public key
+    :type server_key: str
+
+    :param host: the hostname
+    :type host: str
+
+    :param port: the ssh port
+    :type port: str
+
+    """
+    if port is not None:
+        host = '%s:%s' % (host, str(port))
+
+    # Create a temporary server key file
+    tmp_known_hosts = tempfile.NamedTemporaryFile()
+   
+    hostbyname = gethostbyname(host) 
+
+    # Add the intended host key
+    tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
+    
+    # If we're not in strict mode, add user-configured keys
+    if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
+        user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
+        if os.access(user_hosts_path, os.R_OK):
+            f = open(user_hosts_path, "r")
+            tmp_known_hosts.write(f.read())
+            f.close()
+        
+    tmp_known_hosts.flush()
+    
+    return tmp_known_hosts
+
+def make_control_path(agent, forward_x11):
+    ctrl_path = "/tmp/nepi_ssh"
+
+    if agent:
+        ctrl_path +="_a"
+
+    if forward_x11:
+        ctrl_path +="_x"
+
+    ctrl_path += "-%r@%h:%p"
+
+    return ctrl_path
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line 
     arguments """
@@ -105,93 +177,106 @@ def eintr_retry(func):
             return func(*p, **kw)
     return rv
 
-def make_connkey(user, host, port):
-    connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
-    if len(connkey) > 60:
-        connkey = hashlib.sha1(connkey).hexdigest()
-    return connkey
-
-def make_control_path(user, host, port):
-    connkey = make_connkey(user, host, port)
-    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
-
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
         stdin = None,
-        identity_file = None,
+        identity = None,
+        server_key = None,
         env = None,
         tty = False,
-        x11 = False,
         timeout = None,
-        retry = 0,
+        retry = 3,
         err_on_timeout = True,
         connect_timeout = 30,
-        persistent = True):
+        persistent = True,
+        forward_x11 = False):
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
+    
+    tmp_known_hosts = None
+    hostip = gethostbyname(host)
+
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
-            # XXX: Possible security issue
-            # Avoid interactive requests to accept new host keys
-            '-o', 'StrictHostKeyChecking=no',
             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
             '-o', 'ConnectionAttempts=3',
             '-o', 'ServerAliveInterval=30',
             '-o', 'TCPKeepAlive=yes',
-            '-l', user, host]
+            '-l', user, hostip or host]
 
     if persistent and openssh_has_persist():
-        control_path = make_control_path(user, host, port)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=%s' % control_path,
+            '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
             '-o', 'ControlPersist=60' ])
+
     if agent:
         args.append('-A')
+
     if port:
         args.append('-p%d' % port)
-    if identity_file:
-        args.extend(('-i', identity_file))
+
+    if identity:
+        args.extend(('-i', identity))
+
     if tty:
         args.append('-t')
-        if sudo:
-            args.append('-t')
-    if x11:
-        args.append('-X')
+        args.append('-t')
 
-    if env:
-        export = ''
-        for envkey, envval in env.iteritems():
-            export += '%s=%s ' % (envkey, envval)
-        command = export + command
+    if forward_x11:
+        args.append('-X')
 
-    if sudo:
-        command = "sudo " + command
+    if server_key:
+        # Create a temporary server key file
+        tmp_known_hosts = make_server_key_args(server_key, host, port)
+        args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
     args.append(command)
 
-    for x in xrange(retry or 3):
+    for x in xrange(retry):
         # connects to the remote host and starts a remote connection
-        proc = subprocess.Popen(args, 
+        proc = subprocess.Popen(args,
+                env = env,
                 stdout = subprocess.PIPE,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
         
+        # attach tempfile object to the process, to make sure the file stays
+        # alive until the process is finished with it
+        proc._known_hosts = tmp_known_hosts
+    
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
+            msg = " rexec - host %s - command %s " % (host, " ".join(args))
+            log(msg, logging.DEBUG, out, err)
+
             if proc.poll():
+                skip = False
+
                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
                     # SSH error, can safely retry
-                    continue
+                    skip = True 
                 elif retry:
                     # Probably timed out or plain failed but can retry
+                    skip = True 
+                
+                if skip:
+                    t = x*2
+                    msg = "SLEEPING %d ... ATEMPT %d - host %s - command %s " % ( 
+                            t, x, host, " ".join(args))
+                    log(msg, logging.DEBUG)
+
+                    time.sleep(t)
                     continue
             break
-        except RuntimeError,e:
+        except RuntimeError, e:
+            msg = " rexec EXCEPTION - host %s - command %s - TIMEOUT -> %s" % (host, " ".join(args), e.args)
+            log(msg, logging.DEBUG, out, err)
+
             if retry <= 0:
                 raise
             retry -= 1
@@ -202,9 +287,10 @@ def rcopy(source, dest,
         port = None, 
         agent = True, 
         recursive = False,
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
-    Copies file from/to remote sites.
+    Copies from/to remote sites.
     
     Source and destination should have the user and host encoded
     as per scp specs.
@@ -221,6 +307,9 @@ def rcopy(source, dest,
     in which case it is advised that the destination be a folder.
     """
     
+    msg = " rcopy - scp %s %s " % (source, dest)
+    log(msg, logging.DEBUG)
+    
     if isinstance(source, file) and source.tell() == 0:
         source = source.name
     elif hasattr(source, 'read'):
@@ -237,7 +326,7 @@ def rcopy(source, dest,
     if isinstance(source, file) or isinstance(dest, file) \
             or hasattr(source, 'read')  or hasattr(dest, 'write'):
         assert not recursive
-    
+        
         # Parse source/destination as <user>@<server>:<path>
         if isinstance(dest, basestring) and ':' in dest:
             remspec, path = dest.split(':',1)
@@ -246,30 +335,35 @@ def rcopy(source, dest,
         else:
             raise ValueError, "Both endpoints cannot be local"
         user,host = remspec.rsplit('@',1)
+        
         tmp_known_hosts = None
+        hostip = gethostbyname(host)
         
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectTimeout=60',
                 '-o', 'ConnectionAttempts=3',
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes',
-                host ]
+                hostip or host ]
 
         if openssh_has_persist():
-            control_path = make_control_path(user, host, port)
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=%s' % control_path,
+                '-o', 'ControlPath=%s' % (make_control_path(agent, False),),
                 '-o', 'ControlPersist=60' ])
+
         if port:
             args.append('-P%d' % port)
-        if identity_file:
-            args.extend(('-i', identity_file))
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
         
         if isinstance(source, file) or hasattr(source, 'read'):
             args.append('cat > %s' % (shell_escape(path),))
@@ -277,7 +371,7 @@ def rcopy(source, dest,
             args.append('cat %s' % (shell_escape(path),))
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
-
+        
         # connects to the remote host and starts a remote connection
         if isinstance(source, file):
             proc = subprocess.Popen(args, 
@@ -285,6 +379,7 @@ def rcopy(source, dest,
                     stderr = subprocess.PIPE,
                     stdin = source)
             err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,err), proc)
         elif isinstance(dest, file):
@@ -293,6 +388,7 @@ def rcopy(source, dest,
                     stderr = subprocess.PIPE,
                     stdin = source)
             err = proc.stderr.read()
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,err), proc)
         elif hasattr(source, 'read'):
@@ -329,6 +425,7 @@ def rcopy(source, dest,
             proc.stdin.close()
             err.append(proc.stderr.read())
                 
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,''.join(err)), proc)
         elif hasattr(dest, 'write'):
@@ -363,6 +460,7 @@ def rcopy(source, dest,
                     break
             err.append(proc.stderr.read())
                 
+            proc._known_hosts = tmp_known_hosts
             eintr_retry(proc.wait)()
             return ((None,''.join(err)), proc)
         else:
@@ -376,34 +474,40 @@ def rcopy(source, dest,
         else:
             raise ValueError, "Both endpoints cannot be local"
         user,host = remspec.rsplit('@',1)
-
+        
         # plain scp
+        tmp_known_hosts = None
+
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
-                # XXX: Possible security issue
-                # Avoid interactive requests to accept new host keys
-                '-o', 'StrictHostKeyChecking=no',
-                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectTimeout=60',
                 '-o', 'ConnectionAttempts=3',
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes' ]
                 
         if port:
             args.append('-P%d' % port)
+
         if recursive:
             args.append('-r')
-        if identity_file:
-            args.extend(('-i', identity_file))
+
+        if identity:
+            args.extend(('-i', identity))
+
+        if server_key:
+            # Create a temporary server key file
+            tmp_known_hosts = make_server_key_args(server_key, host, port)
+            args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
 
         if isinstance(source,list):
             args.extend(source)
         else:
             if openssh_has_persist():
-                control_path = make_control_path(user, host, port)
                 args.extend([
-                    '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=%s' % control_path ])
+                    '-o', 'ControlMaster=auto',
+                    '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
+                    ])
             args.append(source)
 
         args.append(dest)
@@ -413,10 +517,11 @@ def rcopy(source, dest,
                 stdout = subprocess.PIPE,
                 stdin = subprocess.PIPE, 
                 stderr = subprocess.PIPE)
+        proc._known_hosts = tmp_known_hosts
         
-        comm = proc.communicate()
+        (out, err) = proc.communicate()
         eintr_retry(proc.wait)()
-        return (comm, proc)
+        return ((out, err), proc)
 
 def rspawn(command, pidfile, 
         stdout = '/dev/null', 
@@ -424,12 +529,13 @@ def rspawn(command, pidfile,
         stdin = '/dev/null', 
         home = None, 
         create_home = False, 
+        sudo = False,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        sudo = False,
-        identity_file = None, 
+        identity = None, 
+        server_key = None,
         tty = False):
     """
     Spawn a remote command such that it will continue working asynchronously.
@@ -451,7 +557,7 @@ def rspawn(command, pidfile,
         
         sudo: whether the command needs to be executed as root
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         (stdout, stderr), process
@@ -468,52 +574,51 @@ def rspawn(command, pidfile,
     
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
-        'pidfile' : pidfile,
-        
+        'pidfile' : shell_escape(pidfile),
         'stdout' : stdout,
         'stderr' : stderr,
         'stdin' : stdin,
     }
     
-    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
-            'command' : daemon_command,
-            
+    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
+            'command' : shell_escape(daemon_command),
             'sudo' : 'sudo -S' if sudo else '',
-            
-            'pidfile' : pidfile,
-            'gohome' : 'cd %s ; ' % home if home else '',
-            'create' : 'mkdir -p %s ; ' % home if create_home else '',
+            'pidfile' : shell_escape(pidfile),
+            'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
+            'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
         }
 
-    (out,err), proc = rexec(
+    (out,err),proc = rexec(
         cmd,
         host = host,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file,
-        tty = tty
+        identity = identity,
+        server_key = server_key,
+        tty = tty ,
         )
     
     if proc.wait():
-        raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
+        raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
 
-    return (out,err),proc
+    return ((out, err), proc)
 
 @eintr_retry
-def rcheck_pid(pidfile,
+def rcheckpid(pidfile,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
     Check the pidfile of a process spawned with remote_spawn.
     
     Parameters:
         pidfile: the pidfile passed to remote_span
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -529,7 +634,8 @@ def rcheck_pid(pidfile,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
         
     if proc.wait():
@@ -548,14 +654,15 @@ def rstatus(pid, ppid,
         port = None, 
         user = None, 
         agent = None, 
-        identity_file = None):
+        identity = None,
+        server_key = None):
     """
     Check the status of a process spawned with remote_spawn.
     
     Parameters:
         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -563,7 +670,8 @@ def rstatus(pid, ppid,
     """
 
     (out,err),proc = rexec(
-        "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
+        # Check only by pid. pid+ppid does not always work (especially with sudo) 
+        " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
             'ppid' : ppid,
             'pid' : pid,
         },
@@ -571,32 +679,32 @@ def rstatus(pid, ppid,
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
     
     if proc.wait():
         return NOT_STARTED
     
     status = False
-    if out:
-        try:
-            status = bool(int(out.strip()))
-        except:
-            if out or err:
-                logging.warn("Error checking remote status:\n%s%s\n", out, err)
-            # Ignore, many ways to fail that don't matter that much
-            return NOT_STARTED
+    if err:
+        if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
+            status = True
+    elif out:
+        status = (out.strip() == 'wait')
+    else:
+        return NOT_STARTED
     return RUNNING if status else FINISHED
-    
 
 @eintr_retry
-def rkill(pid, ppid, 
+def rkill(pid, ppid,
         host = None, 
         port = None, 
         user = None, 
         agent = None, 
         sudo = False,
-        identity_file = None, 
+        identity = None, 
+        server_key = None, 
         nowait = False):
     """
     Kill a process spawned with remote_spawn.
@@ -609,7 +717,7 @@ def rkill(pid, ppid,
         
         sudo: whether the command was run with sudo - careful killing like this.
         
-        host/port/user/agent/identity_file: see rexec
+        host/port/user/agent/identity: see rexec
     
     Returns:
         
@@ -650,12 +758,15 @@ fi
         port = port,
         user = user,
         agent = agent,
-        identity_file = identity_file
+        identity = identity,
+        server_key = server_key
         )
     
     # wait, don't leave zombies around
     proc.wait()
 
+    return (out, err), proc
+
 # POSIX
 def _communicate(self, input, timeout=None, err_on_timeout=True):
     read_set = []