Merge
[nepi.git] / src / nepi / util / server.py
index 44c6408..73fa84c 100644 (file)
@@ -4,6 +4,7 @@
 import base64
 import errno
 import os
+import os.path
 import resource
 import select
 import socket
@@ -15,6 +16,9 @@ import traceback
 import signal
 import re
 import tempfile
+import defer
+import functools
+import collections
 
 CTRL_SOCK = "ctrl.sock"
 STD_ERR = "stderr.log"
@@ -31,8 +35,6 @@ if hasattr(os, "devnull"):
 else:
     DEV_NULL = "/dev/null"
 
-
-
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
 def shell_escape(s):
@@ -42,15 +44,44 @@ def shell_escape(s):
         return s
     else:
         # unsafe string - escape
-        s = s.replace("'","\\'")
+        def escp(c):
+            if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",):
+                return c
+            else:
+                return "'$'\\x%02x''" % (ord(c),)
+        s = ''.join(map(escp,s))
         return "'%s'" % (s,)
 
+def eintr_retry(func):
+    import functools
+    @functools.wraps(func)
+    def rv(*p, **kw):
+        retry = kw.pop("_retry", False)
+        for i in xrange(0 if retry else 4):
+            try:
+                return func(*p, **kw)
+            except (select.error, socket.error), args:
+                if args[0] == errno.EINTR:
+                    continue
+                else:
+                    raise 
+            except OSError, e:
+                if e.errno == errno.EINTR:
+                    continue
+                else:
+                    raise
+        else:
+            return func(*p, **kw)
+    return rv
+
 class Server(object):
-    def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
+    def __init__(self, root_dir = ".", log_level = ERROR_LEVEL, environment_setup = ""):
         self._root_dir = root_dir
         self._stop = False
         self._ctrl_sock = None
         self._log_level = log_level
+        self._rdbuf = ""
+        self._environment_setup = environment_setup
 
     def run(self):
         try:
@@ -73,11 +104,24 @@ class Server(object):
     def daemonize(self):
         # pipes for process synchronization
         (r, w) = os.pipe()
+        
+        # build root folder
+        root = os.path.normpath(self._root_dir)
+        if not os.path.exists(root):
+            os.makedirs(root, 0755)
 
         pid1 = os.fork()
         if pid1 > 0:
             os.close(w)
-            os.read(r, 1)
+            while True:
+                try:
+                    os.read(r, 1)
+                except OSError, e: # pragma: no cover
+                    if e.errno == errno.EINTR:
+                        continue
+                    else:
+                        raise
+                break
             os.close(r)
             # os.waitpid avoids leaving a <defunc> (zombie) process
             st = os.waitpid(pid1, 0)[1]
@@ -118,6 +162,32 @@ class Server(object):
         # was opened with 0 buffer
         os.dup2(stdout.fileno(), sys.stdout.fileno())
         os.dup2(stderr.fileno(), sys.stderr.fileno())
+        
+        # setup environment
+        if self._environment_setup:
+            # parse environment variables and pass to child process
+            # do it by executing shell commands, in case there's some heavy setup involved
+            envproc = subprocess.Popen(
+                [ "bash", "-c", 
+                    "( %s python -c 'import os,sys ; print \"\\x01\".join(\"\\x02\".join(map(str,x)) for x in os.environ.iteritems())' ) | tail -1" %
+                        ( self._environment_setup, ) ],
+                stdin = subprocess.PIPE, 
+                stdout = subprocess.PIPE,
+                stderr = subprocess.PIPE
+            )
+            out,err = envproc.communicate()
+
+            # parse new environment
+            if out:
+                environment = dict(map(lambda x:x.split("\x02"), out.split("\x01")))
+            
+                # apply to current environment
+                for name, value in environment.iteritems():
+                    os.environ[name] = value
+                
+                # apply pythonpath
+                if 'PYTHONPATH' in environment:
+                    sys.path = environment['PYTHONPATH'].split(':') + sys.path
 
         # create control socket
         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -140,6 +210,7 @@ class Server(object):
                 try:
                     msg = self.recv_msg(conn)
                 except socket.timeout, e:
+                    self.log_error()
                     break
                     
                 if msg == STOP_MSG:
@@ -161,22 +232,26 @@ class Server(object):
                 self.log_error()
 
     def recv_msg(self, conn):
-        data = ""
-        while True:
+        data = [self._rdbuf]
+        chunk = data[0]
+        while '\n' not in chunk:
             try:
                 chunk = conn.recv(1024)
-            except OSError, e:
-                if e.errno != errno.EINTR:
+            except (OSError, socket.error), e:
+                if e[0] != errno.EINTR:
                     raise
-                if chunk == '':
+                else:
                     continue
             if chunk:
-                data += chunk
-                if chunk[-1] == "\n":
-                    break
+                data.append(chunk)
             else:
                 # empty chunk = EOF
                 break
+        data = ''.join(data).split('\n',1)
+        while len(data) < 2:
+            data.append('')
+        data, self._rdbuf = data
+        
         decoded = base64.b64decode(data)
         return decoded.rstrip()
 
@@ -216,6 +291,7 @@ class Forwarder(object):
         self._ctrl_sock = None
         self._root_dir = root_dir
         self._stop = False
+        self._rdbuf = ""
 
     def forward(self):
         self.connect()
@@ -238,8 +314,8 @@ class Forwarder(object):
     def send_to_server(self, data):
         try:
             self._ctrl_sock.send(data)
-        except IOError, e:
-            if e.errno == errno.EPIPE:
+        except (IOError, socket.error), e:
+            if e[0] == errno.EPIPE:
                 self.connect()
                 self._ctrl_sock.send(data)
             else:
@@ -250,19 +326,26 @@ class Forwarder(object):
             self._stop = True
 
     def recv_from_server(self):
-        data = ""
-        while True:
+        data = [self._rdbuf]
+        chunk = data[0]
+        while '\n' not in chunk:
             try:
                 chunk = self._ctrl_sock.recv(1024)
-            except OSError, e:
-                if e.errno != errno.EINTR:
+            except (OSError, socket.error), e:
+                if e[0] != errno.EINTR:
                     raise
-                if chunk == '':
-                    continue
-            data += chunk
-            if chunk[-1] == "\n":
+                continue
+            if chunk:
+                data.append(chunk)
+            else:
+                # empty chunk = EOF
                 break
-        return data
+        data = ''.join(data).split('\n',1)
+        while len(data) < 2:
+            data.append('')
+        data, self._rdbuf = data
+        
+        return data+'\n'
  
     def connect(self):
         self.disconnect()
@@ -278,12 +361,14 @@ class Forwarder(object):
 
 class Client(object):
     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
-            agent = None):
+            agent = None, environment_setup = ""):
         self.root_dir = root_dir
         self.addr = (host, port)
         self.user = user
         self.agent = agent
+        self.environment_setup = environment_setup
         self._stopped = False
+        self._deferreds = collections.deque()
         self.connect()
     
     def __del__(self):
@@ -301,8 +386,13 @@ class Client(object):
                 c.forward()" % (root_dir,)
         if host != None:
             self._process = popen_ssh_subprocess(python_code, host, port, 
-                    user, agent)
+                    user, agent,
+                    environment_setup = self.environment_setup)
             # popen_ssh_subprocess already waits for readiness
+            if self._process.poll():
+                err = proc.stderr.read()
+                raise RuntimeError("Client could not be reached: %s" % \
+                        err)
         else:
             self._process = subprocess.Popen(
                     ["python", "-c", python_code],
@@ -337,10 +427,56 @@ class Client(object):
         self.send_msg(STOP_MSG)
         self._stopped = True
 
-    def read_reply(self):
+    def defer_reply(self, transform=None):
+        defer_entry = []
+        self._deferreds.append(defer_entry)
+        return defer.Defer(
+            functools.partial(self.read_reply, defer_entry, transform)
+        )
+        
+    def _read_reply(self):
         data = self._process.stdout.readline()
         encoded = data.rstrip() 
+        if not encoded:
+            # empty == eof == dead process, poll it to un-zombify
+            self._process.poll()
+            
+            raise RuntimeError, "Forwarder died while awaiting reply: %s" % (self._process.stderr.read(),)
         return base64.b64decode(encoded)
+    
+    def read_reply(self, which=None, transform=None):
+        # Test to see if someone did it already
+        if which is not None and len(which):
+            # Ok, they did it...
+            # ...just return the deferred value
+            if transform:
+                return transform(which[0])
+            else:
+                return which[0]
+        
+        # Process all deferreds until the one we're looking for
+        # or until the queue is empty
+        while self._deferreds:
+            try:
+                deferred = self._deferreds.popleft()
+            except IndexError:
+                # emptied
+                break
+            
+            deferred.append(self._read_reply())
+            if deferred is which:
+                # We reached the one we were looking for
+                if transform:
+                    return transform(deferred[0])
+                else:
+                    return deferred[0]
+        
+        if which is None:
+            # They've requested a synchronous read
+            if transform:
+                return transform(self._read_reply())
+            else:
+                return self._read_reply()
 
 def _make_server_key_args(server_key, host, port, args):
     """ 
@@ -368,6 +504,7 @@ def _make_server_key_args(server_key, host, port, args):
     tmp_known_hosts.flush()
     
     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+    
     return tmp_known_hosts
 
 def popen_ssh_command(command, host, port, user, agent, 
@@ -485,7 +622,7 @@ def popen_scp(source, dest,
                         stdin = source)
                 err = proc.stderr.read()
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,err), proc)
             elif isinstance(dest, file):
                 proc = subprocess.Popen(args, 
@@ -494,7 +631,7 @@ def popen_scp(source, dest,
                         stdin = source)
                 err = proc.stderr.read()
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,err), proc)
             elif hasattr(source, 'read'):
                 # file-like (but not file) source
@@ -531,7 +668,7 @@ def popen_scp(source, dest,
                 err.append(proc.stderr.read())
                     
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,''.join(err)), proc)
             elif hasattr(dest, 'write'):
                 # file-like (but not file) dest
@@ -566,7 +703,7 @@ def popen_scp(source, dest,
                 err.append(proc.stderr.read())
                     
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,''.join(err)), proc)
             else:
                 raise AssertionError, "Unreachable code reached! :-Q"
@@ -609,7 +746,7 @@ def popen_scp(source, dest,
             proc._known_hosts = tmp_known_hosts
             
             comm = proc.communicate()
-            proc.wait()
+            eintr_retry(proc.wait)()
             return (comm, proc)
  
 def popen_ssh_subprocess(python_code, host, port, user, agent, 
@@ -617,21 +754,21 @@ def popen_ssh_subprocess(python_code, host, port, user, agent,
         ident_key = None,
         server_key = None,
         tty = False,
-        environment_setup = ""):
+        environment_setup = "",
+        waitcommand = False):
+        cmd = ""
         if python_path:
             python_path.replace("'", r"'\''")
             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
-        else:
-            cmd = ""
+            cmd += " ; "
         if environment_setup:
             cmd += environment_setup
-            cmd += " "
+            cmd += " "
         # Uncomment for debug (to run everything under strace)
         # We had to verify if strace works (cannot nest them)
         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
         #cmd += "$CMD "
-        #if self.mode == MODE_SSH:
-        #    cmd += "strace -f -tt -s 200 -o strace$$.out "
+        #cmd += "strace -f -tt -s 200 -o strace$$.out "
         cmd += "python -c '"
         cmd += "import base64, os\n"
         cmd += "cmd = \"\"\n"
@@ -641,9 +778,13 @@ def popen_ssh_subprocess(python_code, host, port, user, agent,
         cmd += "cmd = base64.b64decode(cmd)\n"
         # Uncomment for debug
         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
-        cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
-        cmd += "exec(cmd)\n'"
-
+        if not waitcommand:
+            cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
+        cmd += "exec(cmd)\n"
+        if waitcommand:
+            cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
+        cmd += "'"
+        
         tmp_known_hosts = None
         args = ['ssh',
                 # Don't bother with localhost. Makes test easier
@@ -675,6 +816,7 @@ def popen_ssh_subprocess(python_code, host, port, user, agent,
                 base64.b64encode(python_code) + "\n")
         msg = os.read(proc.stdout.fileno(), 3)
         if msg != "OK\n":
-            raise RuntimeError("Failed to start remote python interpreter")
+            raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % (
+                msg, proc.stdout.read(), proc.stderr.read())
         return proc