X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=src%2Fnepi%2Futil%2Fserver.py;h=73fa84c07d308d94da188b31e24bee44872d0510;hb=d8557fdefcd90b34f5126e65b7b82bea72ad5eae;hp=f52acf59fd15fdbb3c92b9a138d7cdc2fbd3c0b8;hpb=fe42f33a7a49165888f9af1776001629bbe09b94;p=nepi.git diff --git a/src/nepi/util/server.py b/src/nepi/util/server.py index f52acf59..73fa84c0 100644 --- a/src/nepi/util/server.py +++ b/src/nepi/util/server.py @@ -4,6 +4,7 @@ import base64 import errno import os +import os.path import resource import select import socket @@ -15,6 +16,9 @@ import traceback import signal import re import tempfile +import defer +import functools +import collections CTRL_SOCK = "ctrl.sock" STD_ERR = "stderr.log" @@ -24,15 +28,14 @@ STOP_MSG = "STOP" ERROR_LEVEL = 0 DEBUG_LEVEL = 1 +TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on") if hasattr(os, "devnull"): DEV_NULL = os.devnull else: DEV_NULL = "/dev/null" - - -SHELL_SAFE = re.compile('[-a-zA-Z0-9_=+:.,/]*') +SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$') def shell_escape(s): """ Escapes strings so that they are safe to use as command-line arguments """ @@ -41,15 +44,44 @@ def shell_escape(s): return s else: # unsafe string - escape - s = s.replace("'","\\'") + def escp(c): + if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",): + return c + else: + return "'$'\\x%02x''" % (ord(c),) + s = ''.join(map(escp,s)) return "'%s'" % (s,) +def eintr_retry(func): + import functools + @functools.wraps(func) + def rv(*p, **kw): + retry = kw.pop("_retry", False) + for i in xrange(0 if retry else 4): + try: + return func(*p, **kw) + except (select.error, socket.error), args: + if args[0] == errno.EINTR: + continue + else: + raise + except OSError, e: + if e.errno == errno.EINTR: + continue + else: + raise + else: + return func(*p, **kw) + return rv + class Server(object): - def __init__(self, root_dir = ".", log_level = ERROR_LEVEL): + def __init__(self, root_dir = ".", log_level = ERROR_LEVEL, environment_setup = ""): self._root_dir = root_dir self._stop = False self._ctrl_sock = None self._log_level = log_level + self._rdbuf = "" + self._environment_setup = environment_setup def run(self): try: @@ -72,11 +104,24 @@ class Server(object): def daemonize(self): # pipes for process synchronization (r, w) = os.pipe() + + # build root folder + root = os.path.normpath(self._root_dir) + if not os.path.exists(root): + os.makedirs(root, 0755) pid1 = os.fork() if pid1 > 0: os.close(w) - os.read(r, 1) + while True: + try: + os.read(r, 1) + except OSError, e: # pragma: no cover + if e.errno == errno.EINTR: + continue + else: + raise + break os.close(r) # os.waitpid avoids leaving a (zombie) process st = os.waitpid(pid1, 0)[1] @@ -117,6 +162,32 @@ class Server(object): # was opened with 0 buffer os.dup2(stdout.fileno(), sys.stdout.fileno()) os.dup2(stderr.fileno(), sys.stderr.fileno()) + + # setup environment + if self._environment_setup: + # parse environment variables and pass to child process + # do it by executing shell commands, in case there's some heavy setup involved + envproc = subprocess.Popen( + [ "bash", "-c", + "( %s python -c 'import os,sys ; print \"\\x01\".join(\"\\x02\".join(map(str,x)) for x in os.environ.iteritems())' ) | tail -1" % + ( self._environment_setup, ) ], + stdin = subprocess.PIPE, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE + ) + out,err = envproc.communicate() + + # parse new environment + if out: + environment = dict(map(lambda x:x.split("\x02"), out.split("\x01"))) + + # apply to current environment + for name, value in environment.iteritems(): + os.environ[name] = value + + # apply pythonpath + if 'PYTHONPATH' in environment: + sys.path = environment['PYTHONPATH'].split(':') + sys.path # create control socket self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -139,6 +210,7 @@ class Server(object): try: msg = self.recv_msg(conn) except socket.timeout, e: + self.log_error() break if msg == STOP_MSG: @@ -160,22 +232,26 @@ class Server(object): self.log_error() def recv_msg(self, conn): - data = "" - while True: + data = [self._rdbuf] + chunk = data[0] + while '\n' not in chunk: try: chunk = conn.recv(1024) - except OSError, e: - if e.errno != errno.EINTR: + except (OSError, socket.error), e: + if e[0] != errno.EINTR: raise - if chunk == '': + else: continue if chunk: - data += chunk - if chunk[-1] == "\n": - break + data.append(chunk) else: # empty chunk = EOF break + data = ''.join(data).split('\n',1) + while len(data) < 2: + data.append('') + data, self._rdbuf = data + decoded = base64.b64decode(data) return decoded.rstrip() @@ -215,6 +291,7 @@ class Forwarder(object): self._ctrl_sock = None self._root_dir = root_dir self._stop = False + self._rdbuf = "" def forward(self): self.connect() @@ -237,8 +314,8 @@ class Forwarder(object): def send_to_server(self, data): try: self._ctrl_sock.send(data) - except IOError, e: - if e.errno == errno.EPIPE: + except (IOError, socket.error), e: + if e[0] == errno.EPIPE: self.connect() self._ctrl_sock.send(data) else: @@ -249,19 +326,26 @@ class Forwarder(object): self._stop = True def recv_from_server(self): - data = "" - while True: + data = [self._rdbuf] + chunk = data[0] + while '\n' not in chunk: try: chunk = self._ctrl_sock.recv(1024) - except OSError, e: - if e.errno != errno.EINTR: + except (OSError, socket.error), e: + if e[0] != errno.EINTR: raise - if chunk == '': - continue - data += chunk - if chunk[-1] == "\n": + continue + if chunk: + data.append(chunk) + else: + # empty chunk = EOF break - return data + data = ''.join(data).split('\n',1) + while len(data) < 2: + data.append('') + data, self._rdbuf = data + + return data+'\n' def connect(self): self.disconnect() @@ -277,12 +361,14 @@ class Forwarder(object): class Client(object): def __init__(self, root_dir = ".", host = None, port = None, user = None, - agent = None): + agent = None, environment_setup = ""): self.root_dir = root_dir self.addr = (host, port) self.user = user self.agent = agent + self.environment_setup = environment_setup self._stopped = False + self._deferreds = collections.deque() self.connect() def __del__(self): @@ -300,8 +386,13 @@ class Client(object): c.forward()" % (root_dir,) if host != None: self._process = popen_ssh_subprocess(python_code, host, port, - user, agent) + user, agent, + environment_setup = self.environment_setup) # popen_ssh_subprocess already waits for readiness + if self._process.poll(): + err = proc.stderr.read() + raise RuntimeError("Client could not be reached: %s" % \ + err) else: self._process = subprocess.Popen( ["python", "-c", python_code], @@ -336,10 +427,56 @@ class Client(object): self.send_msg(STOP_MSG) self._stopped = True - def read_reply(self): + def defer_reply(self, transform=None): + defer_entry = [] + self._deferreds.append(defer_entry) + return defer.Defer( + functools.partial(self.read_reply, defer_entry, transform) + ) + + def _read_reply(self): data = self._process.stdout.readline() encoded = data.rstrip() + if not encoded: + # empty == eof == dead process, poll it to un-zombify + self._process.poll() + + raise RuntimeError, "Forwarder died while awaiting reply: %s" % (self._process.stderr.read(),) return base64.b64decode(encoded) + + def read_reply(self, which=None, transform=None): + # Test to see if someone did it already + if which is not None and len(which): + # Ok, they did it... + # ...just return the deferred value + if transform: + return transform(which[0]) + else: + return which[0] + + # Process all deferreds until the one we're looking for + # or until the queue is empty + while self._deferreds: + try: + deferred = self._deferreds.popleft() + except IndexError: + # emptied + break + + deferred.append(self._read_reply()) + if deferred is which: + # We reached the one we were looking for + if transform: + return transform(deferred[0]) + else: + return deferred[0] + + if which is None: + # They've requested a synchronous read + if transform: + return transform(self._read_reply()) + else: + return self._read_reply() def _make_server_key_args(server_key, host, port, args): """ @@ -352,9 +489,22 @@ def _make_server_key_args(server_key, host, port, args): host = '%s:%s' % (host,port) # Create a temporary server key file tmp_known_hosts = tempfile.NamedTemporaryFile() + + # Add the intended host key tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key)) + + # If we're not in strict mode, add user-configured keys + if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'): + user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),) + if os.access(user_hosts_path, os.R_OK): + f = open(user_hosts_path, "r") + tmp_known_hosts.write(f.read()) + f.close() + tmp_known_hosts.flush() + args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)]) + return tmp_known_hosts def popen_ssh_command(command, host, port, user, agent, @@ -365,6 +515,9 @@ def popen_ssh_command(command, host, port, user, agent, """ Executes a remote commands, returns ((stdout,stderr),process) """ + if TRACE: + print "ssh", host, command + tmp_known_hosts = None args = ['ssh', # Don't bother with localhost. Makes test easier @@ -394,7 +547,11 @@ def popen_ssh_command(command, host, port, user, agent, # alive until the process is finished with it proc._known_hosts = tmp_known_hosts - return (proc.communicate(stdin), proc) + out, err = proc.communicate(stdin) + if TRACE: + print " -> ", out, err + + return ((out, err), proc) def popen_scp(source, dest, port = None, @@ -415,11 +572,17 @@ def popen_scp(source, dest, read and written into dest. In these modes, recursive cannot be True. + + Source can be a list of files to copy to a single destination, + in which case it is advised that the destination be a folder. """ + if TRACE: + print "scp", source, dest + if isinstance(source, file) or isinstance(dest, file) \ or hasattr(source, 'read') or hasattr(dest, 'write'): - assert not resursive + assert not recursive # Parse source/destination as @: if isinstance(dest, basestring) and ':' in dest: @@ -433,7 +596,8 @@ def popen_scp(source, dest, args = ['ssh', '-l', user, '-C', # Don't bother with localhost. Makes test easier - '-o', 'NoHostAuthenticationForLocalhost=yes' ] + '-o', 'NoHostAuthenticationForLocalhost=yes', + host ] if port: args.append('-P%d' % port) if ident_key: @@ -458,7 +622,7 @@ def popen_scp(source, dest, stdin = source) err = proc.stderr.read() proc._known_hosts = tmp_known_hosts - proc.wait() + eintr_retry(proc.wait)() return ((None,err), proc) elif isinstance(dest, file): proc = subprocess.Popen(args, @@ -467,22 +631,25 @@ def popen_scp(source, dest, stdin = source) err = proc.stderr.read() proc._known_hosts = tmp_known_hosts - proc.wait() + eintr_retry(proc.wait)() return ((None,err), proc) elif hasattr(source, 'read'): # file-like (but not file) source proc = subprocess.Popen(args, stdout = open('/dev/null','w'), stderr = subprocess.PIPE, - stdin = source) + stdin = subprocess.PIPE) buf = None err = [] while True: if not buf: buf = source.read(4096) + if not buf: + #EOF + break - rdrdy, wrdy, broken = os.select( + rdrdy, wrdy, broken = select.select( [proc.stderr], [proc.stdin], [proc.stderr,proc.stdin]) @@ -497,22 +664,23 @@ def popen_scp(source, dest, if broken: break + proc.stdin.close() err.append(proc.stderr.read()) proc._known_hosts = tmp_known_hosts - proc.wait() + eintr_retry(proc.wait)() return ((None,''.join(err)), proc) elif hasattr(dest, 'write'): # file-like (but not file) dest proc = subprocess.Popen(args, - stdout = open('/dev/null','w'), + stdout = subprocess.PIPE, stderr = subprocess.PIPE, - stdin = source) + stdin = open('/dev/null','w')) buf = None err = [] while True: - rdrdy, wrdy, broken = os.select( + rdrdy, wrdy, broken = select.select( [proc.stderr, proc.stdout], [], [proc.stderr, proc.stdout]) @@ -523,14 +691,19 @@ def popen_scp(source, dest, if proc.stdout in rdrdy: # use os.read for fully unbuffered behavior - dest.write(os.read(proc.stdout.fileno(), 4096)) + buf = os.read(proc.stdout.fileno(), 4096) + dest.write(buf) + + if not buf: + #EOF + break if broken: break err.append(proc.stderr.read()) proc._known_hosts = tmp_known_hosts - proc.wait() + eintr_retry(proc.wait)() return ((None,''.join(err)), proc) else: raise AssertionError, "Unreachable code reached! :-Q" @@ -559,7 +732,10 @@ def popen_scp(source, dest, # Create a temporary server key file tmp_known_hosts = _make_server_key_args( server_key, host, port, args) - args.append(source) + if isinstance(source,list): + args.extend(source) + else: + args.append(source) args.append(dest) # connects to the remote host and starts a remote connection @@ -570,25 +746,29 @@ def popen_scp(source, dest, proc._known_hosts = tmp_known_hosts comm = proc.communicate() - proc.wait() + eintr_retry(proc.wait)() return (comm, proc) def popen_ssh_subprocess(python_code, host, port, user, agent, python_path = None, ident_key = None, server_key = None, - tty = False): + tty = False, + environment_setup = "", + waitcommand = False): + cmd = "" if python_path: python_path.replace("'", r"'\''") cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path - else: - cmd = "" + cmd += " ; " + if environment_setup: + cmd += environment_setup + cmd += " ; " # Uncomment for debug (to run everything under strace) # We had to verify if strace works (cannot nest them) #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n" #cmd += "$CMD " - #if self.mode == MODE_SSH: - # cmd += "strace -f -tt -s 200 -o strace$$.out " + #cmd += "strace -f -tt -s 200 -o strace$$.out " cmd += "python -c '" cmd += "import base64, os\n" cmd += "cmd = \"\"\n" @@ -598,9 +778,13 @@ def popen_ssh_subprocess(python_code, host, port, user, agent, cmd += "cmd = base64.b64decode(cmd)\n" # Uncomment for debug #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n" - cmd += "os.write(1, \"OK\\n\")\n" # send a sync message - cmd += "exec(cmd)\n'" - + if not waitcommand: + cmd += "os.write(1, \"OK\\n\")\n" # send a sync message + cmd += "exec(cmd)\n" + if waitcommand: + cmd += "os.write(1, \"OK\\n\")\n" # send a sync message + cmd += "'" + tmp_known_hosts = None args = ['ssh', # Don't bother with localhost. Makes test easier @@ -632,6 +816,7 @@ def popen_ssh_subprocess(python_code, host, port, user, agent, base64.b64encode(python_code) + "\n") msg = os.read(proc.stdout.fileno(), 3) if msg != "OK\n": - raise RuntimeError("Failed to start remote python interpreter") + raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % ( + msg, proc.stdout.read(), proc.stderr.read()) return proc