import base64 import errno import hashlib import logging import os import os.path import re import select import signal import socket import subprocess import threading import time import tempfile logger = logging.getLogger("sshfuncs") def log(msg, level, out = None, err = None): if out: msg += " - OUT: %s " % out if err: msg += " - ERROR: %s " % err logger.log(level, msg) if hasattr(os, "devnull"): DEV_NULL = os.devnull else: DEV_NULL = "/dev/null" SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$') class STDOUT: """ Special value that when given to rspawn in stderr causes stderr to redirect to whatever stdout was redirected to. """ class RUNNING: """ Process is still running """ class FINISHED: """ Process is finished """ class NOT_STARTED: """ Process hasn't started running yet (this should be very rare) """ hostbyname_cache = dict() hostbyname_cache_lock = threading.Lock() def gethostbyname(host): global hostbyname_cache global hostbyname_cache_lock hostbyname = hostbyname_cache.get(host) if not hostbyname: with hostbyname_cache_lock: hostbyname = socket.gethostbyname(host) hostbyname_cache[host] = hostbyname msg = " Added hostbyname %s - %s " % (host, hostbyname) log(msg, logging.DEBUG) return hostbyname OPENSSH_HAS_PERSIST = None def openssh_has_persist(): """ The ssh_config options ControlMaster and ControlPersist allow to reuse a same network connection for multiple ssh sessions. In this way limitations on number of open ssh connections can be bypassed. However, older versions of openSSH do not support this feature. This function is used to determine if ssh connection persist features can be used. """ global OPENSSH_HAS_PERSIST if OPENSSH_HAS_PERSIST is None: proc = subprocess.Popen(["ssh","-v"], stdout = subprocess.PIPE, stderr = subprocess.STDOUT, stdin = open("/dev/null","r") ) out,err = proc.communicate() proc.wait() vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I) OPENSSH_HAS_PERSIST = bool(vre.match(out)) return OPENSSH_HAS_PERSIST def make_server_key_args(server_key, host, port): """ Returns a reference to a temporary known_hosts file, to which the server key has been added. Make sure to hold onto the temp file reference until the process is done with it :param server_key: the server public key :type server_key: str :param host: the hostname :type host: str :param port: the ssh port :type port: str """ if port is not None: host = '%s:%s' % (host, str(port)) # Create a temporary server key file tmp_known_hosts = tempfile.NamedTemporaryFile() hostbyname = gethostbyname(host) # Add the intended host key tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key)) # If we're not in strict mode, add user-configured keys if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'): user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),) if os.access(user_hosts_path, os.R_OK): f = open(user_hosts_path, "r") tmp_known_hosts.write(f.read()) f.close() tmp_known_hosts.flush() return tmp_known_hosts def make_control_path(agent, forward_x11): ctrl_path = "/tmp/nepi_ssh" if agent: ctrl_path +="_a" if forward_x11: ctrl_path +="_x" ctrl_path += "-%r@%h:%p" return ctrl_path def shell_escape(s): """ Escapes strings so that they are safe to use as command-line arguments """ if SHELL_SAFE.match(s): # safe string - no escaping needed return s else: # unsafe string - escape def escp(c): if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'): return c else: return "'$'\\x%02x''" % (ord(c),) s = ''.join(map(escp,s)) return "'%s'" % (s,) def eintr_retry(func): """Retries a function invocation when a EINTR occurs""" import functools @functools.wraps(func) def rv(*p, **kw): retry = kw.pop("_retry", False) for i in xrange(0 if retry else 4): try: return func(*p, **kw) except (select.error, socket.error), args: if args[0] == errno.EINTR: continue else: raise except OSError, e: if e.errno == errno.EINTR: continue else: raise else: return func(*p, **kw) return rv def rexec(command, host, user, port = None, agent = True, sudo = False, stdin = None, identity = None, server_key = None, env = None, tty = False, timeout = None, retry = 3, err_on_timeout = True, connect_timeout = 30, persistent = True, forward_x11 = False, strict_host_checking = True): """ Executes a remote command, returns ((stdout,stderr),process) """ tmp_known_hosts = None hostip = gethostbyname(host) args = ['ssh', '-C', # Don't bother with localhost. Makes test easier '-o', 'NoHostAuthenticationForLocalhost=yes', '-o', 'ConnectTimeout=%d' % (int(connect_timeout),), '-o', 'ConnectionAttempts=3', '-o', 'ServerAliveInterval=30', '-o', 'TCPKeepAlive=yes', '-l', user, hostip or host] if persistent and openssh_has_persist(): args.extend([ '-o', 'ControlMaster=auto', '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),), '-o', 'ControlPersist=60' ]) if not strict_host_checking: # Do not check for Host key. Unsafe. args.extend(['-o', 'StrictHostKeyChecking=no']) if agent: args.append('-A') if port: args.append('-p%d' % port) if identity: args.extend(('-i', identity)) if tty: args.append('-t') args.append('-t') if forward_x11: args.append('-X') if server_key: # Create a temporary server key file tmp_known_hosts = make_server_key_args(server_key, host, port) args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) args.append(command) for x in xrange(retry): # connects to the remote host and starts a remote connection proc = subprocess.Popen(args, env = env, stdout = subprocess.PIPE, stdin = subprocess.PIPE, stderr = subprocess.PIPE) # attach tempfile object to the process, to make sure the file stays # alive until the process is finished with it proc._known_hosts = tmp_known_hosts try: out, err = _communicate(proc, stdin, timeout, err_on_timeout) msg = " rexec - host %s - command %s " % (host, " ".join(args)) log(msg, logging.DEBUG, out, err) if proc.poll(): skip = False if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '): # SSH error, can safely retry skip = True elif retry: # Probably timed out or plain failed but can retry skip = True if skip: t = x*2 msg = "SLEEPING %d ... ATEMPT %d - host %s - command %s " % ( t, x, host, " ".join(args)) log(msg, logging.DEBUG) time.sleep(t) continue break except RuntimeError, e: msg = " rexec EXCEPTION - host %s - command %s - TIMEOUT -> %s" % (host, " ".join(args), e.args) log(msg, logging.DEBUG, out, err) if retry <= 0: raise retry -= 1 return ((out, err), proc) def rcopy(source, dest, port = None, agent = True, recursive = False, identity = None, server_key = None, retry = 3, strict_host_checking = True): """ Copies from/to remote sites. Source and destination should have the user and host encoded as per scp specs. If source is a file object, a special mode will be used to create the remote file with the same contents. If dest is a file object, the remote file (source) will be read and written into dest. In these modes, recursive cannot be True. Source can be a list of files to copy to a single destination, in which case it is advised that the destination be a folder. """ if isinstance(source, file) and source.tell() == 0: source = source.name elif hasattr(source, 'read'): tmp = tempfile.NamedTemporaryFile() while True: buf = source.read(65536) if buf: tmp.write(buf) else: break tmp.seek(0) source = tmp.name if isinstance(source, file) or isinstance(dest, file) \ or hasattr(source, 'read') or hasattr(dest, 'write'): assert not recursive # Parse source/destination as @: if isinstance(dest, basestring) and ':' in dest: remspec, path = dest.split(':',1) elif isinstance(source, basestring) and ':' in source: remspec, path = source.split(':',1) else: raise ValueError, "Both endpoints cannot be local" user,host = remspec.rsplit('@',1) tmp_known_hosts = None hostip = gethostbyname(host) args = ['ssh', '-l', user, '-C', # Don't bother with localhost. Makes test easier '-o', 'NoHostAuthenticationForLocalhost=yes', '-o', 'ConnectTimeout=60', '-o', 'ConnectionAttempts=3', '-o', 'ServerAliveInterval=30', '-o', 'TCPKeepAlive=yes', hostip or host ] if openssh_has_persist(): args.extend([ '-o', 'ControlMaster=auto', '-o', 'ControlPath=%s' % (make_control_path(agent, False),), '-o', 'ControlPersist=60' ]) if port: args.append('-P%d' % port) if identity: args.extend(('-i', identity)) if server_key: # Create a temporary server key file tmp_known_hosts = make_server_key_args(server_key, host, port) args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) if isinstance(source, file) or hasattr(source, 'read'): args.append('cat > %s' % (shell_escape(path),)) elif isinstance(dest, file) or hasattr(dest, 'write'): args.append('cat %s' % (shell_escape(path),)) else: raise AssertionError, "Unreachable code reached! :-Q" # connects to the remote host and starts a remote connection if isinstance(source, file): proc = subprocess.Popen(args, stdout = open('/dev/null','w'), stderr = subprocess.PIPE, stdin = source) err = proc.stderr.read() proc._known_hosts = tmp_known_hosts eintr_retry(proc.wait)() return ((None,err), proc) elif isinstance(dest, file): proc = subprocess.Popen(args, stdout = open('/dev/null','w'), stderr = subprocess.PIPE, stdin = source) err = proc.stderr.read() proc._known_hosts = tmp_known_hosts eintr_retry(proc.wait)() return ((None,err), proc) elif hasattr(source, 'read'): # file-like (but not file) source proc = subprocess.Popen(args, stdout = open('/dev/null','w'), stderr = subprocess.PIPE, stdin = subprocess.PIPE) buf = None err = [] while True: if not buf: buf = source.read(4096) if not buf: #EOF break rdrdy, wrdy, broken = select.select( [proc.stderr], [proc.stdin], [proc.stderr,proc.stdin]) if proc.stderr in rdrdy: # use os.read for fully unbuffered behavior err.append(os.read(proc.stderr.fileno(), 4096)) if proc.stdin in wrdy: proc.stdin.write(buf) buf = None if broken: break proc.stdin.close() err.append(proc.stderr.read()) proc._known_hosts = tmp_known_hosts eintr_retry(proc.wait)() return ((None,''.join(err)), proc) elif hasattr(dest, 'write'): # file-like (but not file) dest proc = subprocess.Popen(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE, stdin = open('/dev/null','w')) buf = None err = [] while True: rdrdy, wrdy, broken = select.select( [proc.stderr, proc.stdout], [], [proc.stderr, proc.stdout]) if proc.stderr in rdrdy: # use os.read for fully unbuffered behavior err.append(os.read(proc.stderr.fileno(), 4096)) if proc.stdout in rdrdy: # use os.read for fully unbuffered behavior buf = os.read(proc.stdout.fileno(), 4096) dest.write(buf) if not buf: #EOF break if broken: break err.append(proc.stderr.read()) proc._known_hosts = tmp_known_hosts eintr_retry(proc.wait)() return ((None,''.join(err)), proc) else: raise AssertionError, "Unreachable code reached! :-Q" else: # Parse destination as @: if isinstance(dest, basestring) and ':' in dest: remspec, path = dest.split(':',1) elif isinstance(source, basestring) and ':' in source: remspec, path = source.split(':',1) else: raise ValueError, "Both endpoints cannot be local" user,host = remspec.rsplit('@',1) # plain scp tmp_known_hosts = None args = ['scp', '-q', '-p', '-C', # Don't bother with localhost. Makes test easier '-o', 'NoHostAuthenticationForLocalhost=yes', '-o', 'ConnectTimeout=60', '-o', 'ConnectionAttempts=3', '-o', 'ServerAliveInterval=30', '-o', 'TCPKeepAlive=yes' ] if port: args.append('-P%d' % port) if recursive: args.append('-r') if identity: args.extend(('-i', identity)) if server_key: # Create a temporary server key file tmp_known_hosts = make_server_key_args(server_key, host, port) args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) if not strict_host_checking: # Do not check for Host key. Unsafe. args.extend(['-o', 'StrictHostKeyChecking=no']) if isinstance(source,list): args.extend(source) else: if openssh_has_persist(): args.extend([ '-o', 'ControlMaster=auto', '-o', 'ControlPath=%s' % (make_control_path(agent, False),) ]) args.append(source) args.append(dest) for x in xrange(retry): # connects to the remote host and starts a remote connection proc = subprocess.Popen(args, stdout = subprocess.PIPE, stdin = subprocess.PIPE, stderr = subprocess.PIPE) # attach tempfile object to the process, to make sure the file stays # alive until the process is finished with it proc._known_hosts = tmp_known_hosts try: (out, err) = proc.communicate() eintr_retry(proc.wait)() msg = " rcopy - host %s - command %s " % (host, " ".join(args)) log(msg, logging.DEBUG, out, err) if proc.poll(): t = x*2 msg = "SLEEPING %d ... ATEMPT %d - host %s - command %s " % ( t, x, host, " ".join(args)) log(msg, logging.DEBUG) time.sleep(t) continue break except RuntimeError, e: msg = " rcopy EXCEPTION - host %s - command %s - TIMEOUT -> %s" % (host, " ".join(args), e.args) log(msg, logging.DEBUG, out, err) if retry <= 0: raise retry -= 1 return ((out, err), proc) def rspawn(command, pidfile, stdout = '/dev/null', stderr = STDOUT, stdin = '/dev/null', home = None, create_home = False, sudo = False, host = None, port = None, user = None, agent = None, identity = None, server_key = None, tty = False): """ Spawn a remote command such that it will continue working asynchronously. Parameters: command: the command to run - it should be a single line. pidfile: path of a (ideally unique to this task) pidfile for tracking the process. stdout: path of a file to redirect standard output to - must be a string. Defaults to /dev/null stderr: path of a file to redirect standard error to - string or the special STDOUT value to redirect to the same file stdout was redirected to. Defaults to STDOUT. stdin: path of a file with input to be piped into the command's standard input home: path of a folder to use as working directory - should exist, unless you specify create_home create_home: if True, the home folder will be created first with mkdir -p sudo: whether the command needs to be executed as root host/port/user/agent/identity: see rexec Returns: (stdout, stderr), process Of the spawning process, which only captures errors at spawning time. Usually only useful for diagnostics. """ # Start process in a "daemonized" way, using nohup and heavy # stdin/out redirection to avoid connection issues if stderr is STDOUT: stderr = '&1' else: stderr = ' ' + stderr daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % { 'command' : command, 'pidfile' : shell_escape(pidfile), 'stdout' : stdout, 'stderr' : stderr, 'stdin' : stdin, } cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % { 'command' : shell_escape(daemon_command), 'sudo' : 'sudo -S' if sudo else '', 'pidfile' : shell_escape(pidfile), 'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '', 'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '', } (out,err),proc = rexec( cmd, host = host, port = port, user = user, agent = agent, identity = identity, server_key = server_key, tty = tty , ) if proc.wait(): raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,) return ((out, err), proc) @eintr_retry def rcheckpid(pidfile, host = None, port = None, user = None, agent = None, identity = None, server_key = None): """ Check the pidfile of a process spawned with remote_spawn. Parameters: pidfile: the pidfile passed to remote_span host/port/user/agent/identity: see rexec Returns: A (pid, ppid) tuple useful for calling remote_status and remote_kill, or None if the pidfile isn't valid yet (maybe the process is still starting). """ (out,err),proc = rexec( "cat %(pidfile)s" % { 'pidfile' : pidfile, }, host = host, port = port, user = user, agent = agent, identity = identity, server_key = server_key ) if proc.wait(): return None if out: try: return map(int,out.strip().split(' ',1)) except: # Ignore, many ways to fail that don't matter that much return None @eintr_retry def rstatus(pid, ppid, host = None, port = None, user = None, agent = None, identity = None, server_key = None): """ Check the status of a process spawned with remote_spawn. Parameters: pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid host/port/user/agent/identity: see rexec Returns: One of NOT_STARTED, RUNNING, FINISHED """ (out,err),proc = rexec( # Check only by pid. pid+ppid does not always work (especially with sudo) " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait') || echo 'done' ) | tail -n 1" % { 'ppid' : ppid, 'pid' : pid, }, host = host, port = port, user = user, agent = agent, identity = identity, server_key = server_key ) if proc.wait(): return NOT_STARTED status = False if err: if err.strip().find("Error, do this: mount -t proc none /proc") >= 0: status = True elif out: status = (out.strip() == 'wait') else: return NOT_STARTED return RUNNING if status else FINISHED @eintr_retry def rkill(pid, ppid, host = None, port = None, user = None, agent = None, sudo = False, identity = None, server_key = None, nowait = False): """ Kill a process spawned with remote_spawn. First tries a SIGTERM, and if the process does not end in 10 seconds, it sends a SIGKILL. Parameters: pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid sudo: whether the command was run with sudo - careful killing like this. host/port/user/agent/identity: see rexec Returns: Nothing, should have killed the process """ subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid } cmd = """ SUBKILL="%(subkill)s" ; %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true %(sudo)s kill %(pid)d $SUBKILL || /bin/true for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do sleep 0.2 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then break else %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true %(sudo)s kill %(pid)d $SUBKILL || /bin/true fi sleep 1.8 done if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true fi """ if nowait: cmd = "( %s ) >/dev/null 2>/dev/null timelimit: if curtime > bailtime: break elif curtime > killtime: signum = signal.SIGKILL else: signum = signal.SIGTERM # Lets kill it os.kill(self.pid, signum) select_timeout = 0.5 else: select_timeout = timelimit - curtime + 0.1 else: select_timeout = 1.0 if select_timeout > 1.0: select_timeout = 1.0 try: rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout) except select.error,e: if e[0] != 4: raise else: continue if not rlist and not wlist and not xlist and self.poll() is not None: # timeout and process exited, say bye break if self.stdin in wlist: # When select has indicated that the file is writable, # we can write up to PIPE_BUF bytes without risk # blocking. POSIX defines PIPE_BUF >= 512 bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512)) input_offset += bytes_written if input_offset >= len(input): self.stdin.close() write_set.remove(self.stdin) if self.stdout in rlist: data = os.read(self.stdout.fileno(), 1024) if data == "": self.stdout.close() read_set.remove(self.stdout) stdout.append(data) if self.stderr in rlist: data = os.read(self.stderr.fileno(), 1024) if data == "": self.stderr.close() read_set.remove(self.stderr) stderr.append(data) # All data exchanged. Translate lists into strings. if stdout is not None: stdout = ''.join(stdout) if stderr is not None: stderr = ''.join(stderr) # Translate newlines, if requested. We cannot let the file # object do the translation: It is based on stdio, which is # impossible to combine with select (unless forcing no # buffering). if self.universal_newlines and hasattr(file, 'newlines'): if stdout: stdout = self._translate_newlines(stdout) if stderr: stderr = self._translate_newlines(stderr) if killed and err_on_timeout: errcode = self.poll() raise RuntimeError, ("Operation timed out", errcode, stdout, stderr) else: if killed: self.poll() else: self.wait() return (stdout, stderr)