X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=src%2Fneco%2Futil%2Fsshfuncs.py;h=e81c558d1973ce78ec15dfd33c11e32613535724;hb=87f7bd08605070569d7c28a94302b3e24f996b15;hp=dd04ba6229acb5e9bcb0a2273920c2de6c991873;hpb=b12a961a7225625c8112069e6e523e06cc4a181a;p=nepi.git diff --git a/src/neco/util/sshfuncs.py b/src/neco/util/sshfuncs.py index dd04ba62..e81c558d 100644 --- a/src/neco/util/sshfuncs.py +++ b/src/neco/util/sshfuncs.py @@ -12,8 +12,7 @@ import re import tempfile import hashlib -OPENSSH_HAS_PERSIST = None -CONTROL_PATH = "yyy_ssh_ctrl_path" +TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on") if hasattr(os, "devnull"): DEV_NULL = os.devnull @@ -22,8 +21,6 @@ else: SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$') -hostbyname_cache = dict() - class STDOUT: """ Special value that when given to rspawn in stderr causes stderr to @@ -45,6 +42,17 @@ class NOT_STARTED: Process hasn't started running yet (this should be very rare) """ +hostbyname_cache = dict() + +def gethostbyname(host): + hostbyname = hostbyname_cache.get(host) + if not hostbyname: + hostbyname = socket.gethostbyname(host) + hostbyname_cache[host] = hostbyname + return hostbyname + +OPENSSH_HAS_PERSIST = None + def openssh_has_persist(): """ The ssh_config options ControlMaster and ControlPersist allow to reuse a same network connection for multiple ssh sessions. In this @@ -66,6 +74,59 @@ def openssh_has_persist(): OPENSSH_HAS_PERSIST = bool(vre.match(out)) return OPENSSH_HAS_PERSIST +def make_server_key_args(server_key, host, port): + """ Returns a reference to a temporary known_hosts file, to which + the server key has been added. + + Make sure to hold onto the temp file reference until the process is + done with it + + :param server_key: the server public key + :type server_key: str + + :param host: the hostname + :type host: str + + :param port: the ssh port + :type port: str + + """ + if port is not None: + host = '%s:%s' % (host, str(port)) + + # Create a temporary server key file + tmp_known_hosts = tempfile.NamedTemporaryFile() + + hostbyname = gethostbyname(host) + + # Add the intended host key + tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key)) + + # If we're not in strict mode, add user-configured keys + if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'): + user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),) + if os.access(user_hosts_path, os.R_OK): + f = open(user_hosts_path, "r") + tmp_known_hosts.write(f.read()) + f.close() + + tmp_known_hosts.flush() + + return tmp_known_hosts + +def make_control_path(agent, forward_x11): + ctrl_path = "/tmp/nepi_ssh" + + if agent: + ctrl_path +="_a" + + if forward_x11: + ctrl_path +="_x" + + ctrl_path += "-%r@%h:%p" + + return ctrl_path + def shell_escape(s): """ Escapes strings so that they are safe to use as command-line arguments """ @@ -105,81 +166,63 @@ def eintr_retry(func): return func(*p, **kw) return rv -def make_connkey(user, host, port, x11, agent): - # It is important to consider the x11 and agent forwarding - # parameters when creating the connection key since the parameters - # used for the first ssh connection will determine the - # parameters of all subsequent connections using the same key - x11 = 1 if x11 else 0 - agent = 1 if agent else 0 - - connkey = repr((user, host, port, x11, agent) - ).encode("base64").strip().replace('/','.') - - if len(connkey) > 60: - connkey = hashlib.sha1(connkey).hexdigest() - return connkey - -def make_control_path(user, host, port, x11, agent): - connkey = make_connkey(user, host, port, x11, agent) - return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, ) - def rexec(command, host, user, port = None, agent = True, sudo = False, stdin = None, identity = None, + server_key = None, env = None, tty = False, - x11 = False, timeout = None, retry = 0, err_on_timeout = True, connect_timeout = 30, - persistent = True): + persistent = True, + forward_x11 = False): """ Executes a remote command, returns ((stdout,stderr),process) """ + + tmp_known_hosts = None + hostip = gethostbyname(host) + args = ['ssh', '-C', # Don't bother with localhost. Makes test easier '-o', 'NoHostAuthenticationForLocalhost=yes', - # XXX: Possible security issue - # Avoid interactive requests to accept new host keys - '-o', 'StrictHostKeyChecking=no', '-o', 'ConnectTimeout=%d' % (int(connect_timeout),), '-o', 'ConnectionAttempts=3', '-o', 'ServerAliveInterval=30', '-o', 'TCPKeepAlive=yes', - '-l', user, host] + '-l', user, hostip or host] if persistent and openssh_has_persist(): - control_path = make_control_path(user, host, port, x11, agent) args.extend([ '-o', 'ControlMaster=auto', - '-o', 'ControlPath=%s' % control_path, + '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),), '-o', 'ControlPersist=60' ]) + if agent: args.append('-A') + if port: args.append('-p%d' % port) + if identity: args.extend(('-i', identity)) + if tty: args.append('-t') - if sudo: - args.append('-t') - if x11: - args.append('-X') + args.append('-t') - if env: - export = '' - for envkey, envval in env.iteritems(): - export += '%s=%s ' % (envkey, envval) - command = export + command + if forward_x11: + args.append('-X') - if sudo: - command = "sudo " + command + if server_key: + # Create a temporary server key file + tmp_known_hosts = make_server_key_args(server_key, host, port) + args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) args.append(command) @@ -190,8 +233,15 @@ def rexec(command, host, user, stdin = subprocess.PIPE, stderr = subprocess.PIPE) + # attach tempfile object to the process, to make sure the file stays + # alive until the process is finished with it + proc._known_hosts = tmp_known_hosts + try: out, err = _communicate(proc, stdin, timeout, err_on_timeout) + if TRACE: + print "COMMAND host %s, command %s, out %s, error %s" % (host, " ".join(args), out, err) + if proc.poll(): if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '): # SSH error, can safely retry @@ -200,7 +250,11 @@ def rexec(command, host, user, # Probably timed out or plain failed but can retry continue break - except RuntimeError,e: + except RuntimeError, e: + if TRACE: + print "EXCEPTION host %s, command %s, out %s, error %s, exception TIMEOUT -> %s" % ( + host, " ".join(args), out, err, e.args) + if retry <= 0: raise retry -= 1 @@ -210,67 +264,242 @@ def rexec(command, host, user, def rcopy(source, dest, port = None, agent = True, - identity = None): + recursive = False, + identity = None, + server_key = None): """ - Copies file from/to remote sites. + Copies from/to remote sites. Source and destination should have the user and host encoded as per scp specs. + If source is a file object, a special mode will be used to + create the remote file with the same contents. + + If dest is a file object, the remote file (source) will be + read and written into dest. + + In these modes, recursive cannot be True. + Source can be a list of files to copy to a single destination, in which case it is advised that the destination be a folder. """ - # Parse destination as @: - if isinstance(dest, basestring) and ':' in dest: - remspec, path = dest.split(':',1) - elif isinstance(source, basestring) and ':' in source: - remspec, path = source.split(':',1) - else: - raise ValueError, "Both endpoints cannot be local" - user, host = remspec.rsplit('@',1) - - raw_string = r'''rsync -rlpcSz --timeout=900 ''' - raw_string += r''' -e 'ssh -o BatchMode=yes ''' - raw_string += r''' -o NoHostAuthenticationForLocalhost=yes ''' - raw_string += r''' -o StrictHostKeyChecking=no ''' - raw_string += r''' -o ConnectionAttempts=3 ''' - - if openssh_has_persist(): - control_path = make_control_path(user, host, port, False, agent) - raw_string += r''' -o ControlMaster=auto ''' - raw_string += r''' -o ControlPath=%s ''' % control_path - - if agent: - raw_string += r''' -A ''' - - if port: - raw_string += r''' -p %d ''' % port + if TRACE: + print "scp", source, dest - if identity: - raw_string += r''' -i "%s" ''' % identity + if isinstance(source, file) and source.tell() == 0: + source = source.name + elif hasattr(source, 'read'): + tmp = tempfile.NamedTemporaryFile() + while True: + buf = source.read(65536) + if buf: + tmp.write(buf) + else: + break + tmp.seek(0) + source = tmp.name - # closing -e 'ssh...' - raw_string += r''' ' ''' - - if isinstance(source,list): - source = ' '.join(source) + if isinstance(source, file) or isinstance(dest, file) \ + or hasattr(source, 'read') or hasattr(dest, 'write'): + assert not recursive + + # Parse source/destination as @: + if isinstance(dest, basestring) and ':' in dest: + remspec, path = dest.split(':',1) + elif isinstance(source, basestring) and ':' in source: + remspec, path = source.split(':',1) + else: + raise ValueError, "Both endpoints cannot be local" + user,host = remspec.rsplit('@',1) + + tmp_known_hosts = None + hostip = gethostbyname(host) + + args = ['ssh', '-l', user, '-C', + # Don't bother with localhost. Makes test easier + '-o', 'NoHostAuthenticationForLocalhost=yes', + '-o', 'ConnectTimeout=60', + '-o', 'ConnectionAttempts=3', + '-o', 'ServerAliveInterval=30', + '-o', 'TCPKeepAlive=yes', + hostip or host ] + + if openssh_has_persist(): + args.extend([ + '-o', 'ControlMaster=auto', + '-o', 'ControlPath=%s' % (make_control_path(agent, False),), + '-o', 'ControlPersist=60' ]) + + if port: + args.append('-P%d' % port) + + if identity: + args.extend(('-i', identity)) + + if server_key: + # Create a temporary server key file + tmp_known_hosts = make_server_key_args(server_key, host, port) + args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) + + if isinstance(source, file) or hasattr(source, 'read'): + args.append('cat > %s' % (shell_escape(path),)) + elif isinstance(dest, file) or hasattr(dest, 'write'): + args.append('cat %s' % (shell_escape(path),)) + else: + raise AssertionError, "Unreachable code reached! :-Q" + + # connects to the remote host and starts a remote connection + if isinstance(source, file): + proc = subprocess.Popen(args, + stdout = open('/dev/null','w'), + stderr = subprocess.PIPE, + stdin = source) + err = proc.stderr.read() + proc._known_hosts = tmp_known_hosts + eintr_retry(proc.wait)() + return ((None,err), proc) + elif isinstance(dest, file): + proc = subprocess.Popen(args, + stdout = open('/dev/null','w'), + stderr = subprocess.PIPE, + stdin = source) + err = proc.stderr.read() + proc._known_hosts = tmp_known_hosts + eintr_retry(proc.wait)() + return ((None,err), proc) + elif hasattr(source, 'read'): + # file-like (but not file) source + proc = subprocess.Popen(args, + stdout = open('/dev/null','w'), + stderr = subprocess.PIPE, + stdin = subprocess.PIPE) + + buf = None + err = [] + while True: + if not buf: + buf = source.read(4096) + if not buf: + #EOF + break + + rdrdy, wrdy, broken = select.select( + [proc.stderr], + [proc.stdin], + [proc.stderr,proc.stdin]) + + if proc.stderr in rdrdy: + # use os.read for fully unbuffered behavior + err.append(os.read(proc.stderr.fileno(), 4096)) + + if proc.stdin in wrdy: + proc.stdin.write(buf) + buf = None + + if broken: + break + proc.stdin.close() + err.append(proc.stderr.read()) + + proc._known_hosts = tmp_known_hosts + eintr_retry(proc.wait)() + return ((None,''.join(err)), proc) + elif hasattr(dest, 'write'): + # file-like (but not file) dest + proc = subprocess.Popen(args, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + stdin = open('/dev/null','w')) + + buf = None + err = [] + while True: + rdrdy, wrdy, broken = select.select( + [proc.stderr, proc.stdout], + [], + [proc.stderr, proc.stdout]) + + if proc.stderr in rdrdy: + # use os.read for fully unbuffered behavior + err.append(os.read(proc.stderr.fileno(), 4096)) + + if proc.stdout in rdrdy: + # use os.read for fully unbuffered behavior + buf = os.read(proc.stdout.fileno(), 4096) + dest.write(buf) + + if not buf: + #EOF + break + + if broken: + break + err.append(proc.stderr.read()) + + proc._known_hosts = tmp_known_hosts + eintr_retry(proc.wait)() + return ((None,''.join(err)), proc) + else: + raise AssertionError, "Unreachable code reached! :-Q" else: - source = '"%s"' % source + # Parse destination as @: + if isinstance(dest, basestring) and ':' in dest: + remspec, path = dest.split(':',1) + elif isinstance(source, basestring) and ':' in source: + remspec, path = source.split(':',1) + else: + raise ValueError, "Both endpoints cannot be local" + user,host = remspec.rsplit('@',1) + + # plain scp + tmp_known_hosts = None + + args = ['scp', '-q', '-p', '-C', + # Don't bother with localhost. Makes test easier + '-o', 'NoHostAuthenticationForLocalhost=yes', + '-o', 'ConnectTimeout=60', + '-o', 'ConnectionAttempts=3', + '-o', 'ServerAliveInterval=30', + '-o', 'TCPKeepAlive=yes' ] + + if port: + args.append('-P%d' % port) + + if recursive: + args.append('-r') + + if identity: + args.extend(('-i', identity)) + + if server_key: + # Create a temporary server key file + tmp_known_hosts = make_server_key_args(server_key, host, port) + args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) + + if isinstance(source,list): + args.extend(source) + else: + if openssh_has_persist(): + args.extend([ + '-o', 'ControlMaster=auto', + '-o', 'ControlPath=%s' % (make_control_path(agent, False),) + ]) + args.append(source) - raw_string += r''' %s ''' % source - raw_string += r''' %s ''' % dest + args.append(dest) - # connects to the remote host and starts a remote connection - proc = subprocess.Popen(raw_string, - shell=True, - stdout = subprocess.PIPE, - stdin = subprocess.PIPE, - stderr = subprocess.PIPE) - - comm = proc.communicate() - eintr_retry(proc.wait)() - return (comm, proc) + # connects to the remote host and starts a remote connection + proc = subprocess.Popen(args, + stdout = subprocess.PIPE, + stdin = subprocess.PIPE, + stderr = subprocess.PIPE) + proc._known_hosts = tmp_known_hosts + + (out, err) = proc.communicate() + eintr_retry(proc.wait)() + return ((out, err), proc) def rspawn(command, pidfile, stdout = '/dev/null', @@ -278,12 +507,13 @@ def rspawn(command, pidfile, stdin = '/dev/null', home = None, create_home = False, + sudo = False, host = None, port = None, user = None, agent = None, - sudo = False, identity = None, + server_key = None, tty = False): """ Spawn a remote command such that it will continue working asynchronously. @@ -319,41 +549,38 @@ def rspawn(command, pidfile, stderr = '&1' else: stderr = ' ' + stderr - - #XXX: ppid is always 1!!! + daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % { 'command' : command, - 'pidfile' : pidfile, - + 'pidfile' : shell_escape(pidfile), 'stdout' : stdout, 'stderr' : stderr, 'stdin' : stdin, } - cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % { - 'command' : daemon_command, - + cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % { + 'command' : shell_escape(daemon_command), 'sudo' : 'sudo -S' if sudo else '', - - 'pidfile' : pidfile, - 'gohome' : 'cd %s ; ' % home if home else '', - 'create' : 'mkdir -p %s ; ' % home if create_home else '', + 'pidfile' : shell_escape(pidfile), + 'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '', + 'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '', } - (out,err), proc = rexec( + (out,err),proc = rexec( cmd, host = host, port = port, user = user, agent = agent, identity = identity, - tty = tty + server_key = server_key, + tty = tty , ) if proc.wait(): - raise RuntimeError, "Failed to set up application: %s %s" % (out,err,) + raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,) - return (out,err),proc + return ((out, err), proc) @eintr_retry def rcheckpid(pidfile, @@ -361,7 +588,8 @@ def rcheckpid(pidfile, port = None, user = None, agent = None, - identity = None): + identity = None, + server_key = None): """ Check the pidfile of a process spawned with remote_spawn. @@ -384,7 +612,8 @@ def rcheckpid(pidfile, port = port, user = user, agent = agent, - identity = identity + identity = identity, + server_key = server_key ) if proc.wait(): @@ -403,7 +632,8 @@ def rstatus(pid, ppid, port = None, user = None, agent = None, - identity = None): + identity = None, + server_key = None): """ Check the status of a process spawned with remote_spawn. @@ -417,9 +647,9 @@ def rstatus(pid, ppid, One of NOT_STARTED, RUNNING, FINISHED """ - # XXX: ppid unused (out,err),proc = rexec( - "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % { + # Check only by pid. pid+ppid does not always work (especially with sudo) + " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait') || echo 'done' ) | tail -n 1" % { 'ppid' : ppid, 'pid' : pid, }, @@ -427,23 +657,22 @@ def rstatus(pid, ppid, port = port, user = user, agent = agent, - identity = identity + identity = identity, + server_key = server_key ) if proc.wait(): return NOT_STARTED status = False - if out: - try: - status = bool(int(out.strip())) - except: - if out or err: - logging.warn("Error checking remote status:\n%s%s\n", out, err) - # Ignore, many ways to fail that don't matter that much - return NOT_STARTED + if err: + if err.strip().find("Error, do this: mount -t proc none /proc") >= 0: + status = True + elif out: + status = (out.strip() == 'wait') + else: + return NOT_STARTED return RUNNING if status else FINISHED - @eintr_retry def rkill(pid, ppid, @@ -453,6 +682,7 @@ def rkill(pid, ppid, agent = None, sudo = False, identity = None, + server_key = None, nowait = False): """ Kill a process spawned with remote_spawn. @@ -506,12 +736,15 @@ fi port = port, user = user, agent = agent, - identity = identity + identity = identity, + server_key = server_key ) # wait, don't leave zombies around proc.wait() + return (out, err), proc + # POSIX def _communicate(self, input, timeout=None, err_on_timeout=True): read_set = []