Merge
[nepi.git] / src / nepi / util / server.py
index 1cfcea8..73fa84c 100644 (file)
@@ -16,6 +16,9 @@ import traceback
 import signal
 import re
 import tempfile
+import defer
+import functools
+import collections
 
 CTRL_SOCK = "ctrl.sock"
 STD_ERR = "stderr.log"
@@ -32,8 +35,6 @@ if hasattr(os, "devnull"):
 else:
     DEV_NULL = "/dev/null"
 
-
-
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
 def shell_escape(s):
@@ -51,12 +52,36 @@ def shell_escape(s):
         s = ''.join(map(escp,s))
         return "'%s'" % (s,)
 
+def eintr_retry(func):
+    import functools
+    @functools.wraps(func)
+    def rv(*p, **kw):
+        retry = kw.pop("_retry", False)
+        for i in xrange(0 if retry else 4):
+            try:
+                return func(*p, **kw)
+            except (select.error, socket.error), args:
+                if args[0] == errno.EINTR:
+                    continue
+                else:
+                    raise 
+            except OSError, e:
+                if e.errno == errno.EINTR:
+                    continue
+                else:
+                    raise
+        else:
+            return func(*p, **kw)
+    return rv
+
 class Server(object):
-    def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
+    def __init__(self, root_dir = ".", log_level = ERROR_LEVEL, environment_setup = ""):
         self._root_dir = root_dir
         self._stop = False
         self._ctrl_sock = None
         self._log_level = log_level
+        self._rdbuf = ""
+        self._environment_setup = environment_setup
 
     def run(self):
         try:
@@ -88,7 +113,15 @@ class Server(object):
         pid1 = os.fork()
         if pid1 > 0:
             os.close(w)
-            os.read(r, 1)
+            while True:
+                try:
+                    os.read(r, 1)
+                except OSError, e: # pragma: no cover
+                    if e.errno == errno.EINTR:
+                        continue
+                    else:
+                        raise
+                break
             os.close(r)
             # os.waitpid avoids leaving a <defunc> (zombie) process
             st = os.waitpid(pid1, 0)[1]
@@ -129,6 +162,32 @@ class Server(object):
         # was opened with 0 buffer
         os.dup2(stdout.fileno(), sys.stdout.fileno())
         os.dup2(stderr.fileno(), sys.stderr.fileno())
+        
+        # setup environment
+        if self._environment_setup:
+            # parse environment variables and pass to child process
+            # do it by executing shell commands, in case there's some heavy setup involved
+            envproc = subprocess.Popen(
+                [ "bash", "-c", 
+                    "( %s python -c 'import os,sys ; print \"\\x01\".join(\"\\x02\".join(map(str,x)) for x in os.environ.iteritems())' ) | tail -1" %
+                        ( self._environment_setup, ) ],
+                stdin = subprocess.PIPE, 
+                stdout = subprocess.PIPE,
+                stderr = subprocess.PIPE
+            )
+            out,err = envproc.communicate()
+
+            # parse new environment
+            if out:
+                environment = dict(map(lambda x:x.split("\x02"), out.split("\x01")))
+            
+                # apply to current environment
+                for name, value in environment.iteritems():
+                    os.environ[name] = value
+                
+                # apply pythonpath
+                if 'PYTHONPATH' in environment:
+                    sys.path = environment['PYTHONPATH'].split(':') + sys.path
 
         # create control socket
         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -151,6 +210,7 @@ class Server(object):
                 try:
                     msg = self.recv_msg(conn)
                 except socket.timeout, e:
+                    self.log_error()
                     break
                     
                 if msg == STOP_MSG:
@@ -172,22 +232,26 @@ class Server(object):
                 self.log_error()
 
     def recv_msg(self, conn):
-        data = ""
-        while True:
+        data = [self._rdbuf]
+        chunk = data[0]
+        while '\n' not in chunk:
             try:
                 chunk = conn.recv(1024)
-            except OSError, e:
-                if e.errno != errno.EINTR:
+            except (OSError, socket.error), e:
+                if e[0] != errno.EINTR:
                     raise
-                if chunk == '':
+                else:
                     continue
             if chunk:
-                data += chunk
-                if chunk[-1] == "\n":
-                    break
+                data.append(chunk)
             else:
                 # empty chunk = EOF
                 break
+        data = ''.join(data).split('\n',1)
+        while len(data) < 2:
+            data.append('')
+        data, self._rdbuf = data
+        
         decoded = base64.b64decode(data)
         return decoded.rstrip()
 
@@ -227,6 +291,7 @@ class Forwarder(object):
         self._ctrl_sock = None
         self._root_dir = root_dir
         self._stop = False
+        self._rdbuf = ""
 
     def forward(self):
         self.connect()
@@ -249,8 +314,8 @@ class Forwarder(object):
     def send_to_server(self, data):
         try:
             self._ctrl_sock.send(data)
-        except IOError, e:
-            if e.errno == errno.EPIPE:
+        except (IOError, socket.error), e:
+            if e[0] == errno.EPIPE:
                 self.connect()
                 self._ctrl_sock.send(data)
             else:
@@ -261,19 +326,26 @@ class Forwarder(object):
             self._stop = True
 
     def recv_from_server(self):
-        data = ""
-        while True:
+        data = [self._rdbuf]
+        chunk = data[0]
+        while '\n' not in chunk:
             try:
                 chunk = self._ctrl_sock.recv(1024)
-            except OSError, e:
-                if e.errno != errno.EINTR:
+            except (OSError, socket.error), e:
+                if e[0] != errno.EINTR:
                     raise
-                if chunk == '':
-                    continue
-            data += chunk
-            if chunk[-1] == "\n":
+                continue
+            if chunk:
+                data.append(chunk)
+            else:
+                # empty chunk = EOF
                 break
-        return data
+        data = ''.join(data).split('\n',1)
+        while len(data) < 2:
+            data.append('')
+        data, self._rdbuf = data
+        
+        return data+'\n'
  
     def connect(self):
         self.disconnect()
@@ -296,6 +368,7 @@ class Client(object):
         self.agent = agent
         self.environment_setup = environment_setup
         self._stopped = False
+        self._deferreds = collections.deque()
         self.connect()
     
     def __del__(self):
@@ -354,10 +427,56 @@ class Client(object):
         self.send_msg(STOP_MSG)
         self._stopped = True
 
-    def read_reply(self):
+    def defer_reply(self, transform=None):
+        defer_entry = []
+        self._deferreds.append(defer_entry)
+        return defer.Defer(
+            functools.partial(self.read_reply, defer_entry, transform)
+        )
+        
+    def _read_reply(self):
         data = self._process.stdout.readline()
         encoded = data.rstrip() 
+        if not encoded:
+            # empty == eof == dead process, poll it to un-zombify
+            self._process.poll()
+            
+            raise RuntimeError, "Forwarder died while awaiting reply: %s" % (self._process.stderr.read(),)
         return base64.b64decode(encoded)
+    
+    def read_reply(self, which=None, transform=None):
+        # Test to see if someone did it already
+        if which is not None and len(which):
+            # Ok, they did it...
+            # ...just return the deferred value
+            if transform:
+                return transform(which[0])
+            else:
+                return which[0]
+        
+        # Process all deferreds until the one we're looking for
+        # or until the queue is empty
+        while self._deferreds:
+            try:
+                deferred = self._deferreds.popleft()
+            except IndexError:
+                # emptied
+                break
+            
+            deferred.append(self._read_reply())
+            if deferred is which:
+                # We reached the one we were looking for
+                if transform:
+                    return transform(deferred[0])
+                else:
+                    return deferred[0]
+        
+        if which is None:
+            # They've requested a synchronous read
+            if transform:
+                return transform(self._read_reply())
+            else:
+                return self._read_reply()
 
 def _make_server_key_args(server_key, host, port, args):
     """ 
@@ -385,6 +504,7 @@ def _make_server_key_args(server_key, host, port, args):
     tmp_known_hosts.flush()
     
     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
+    
     return tmp_known_hosts
 
 def popen_ssh_command(command, host, port, user, agent, 
@@ -502,7 +622,7 @@ def popen_scp(source, dest,
                         stdin = source)
                 err = proc.stderr.read()
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,err), proc)
             elif isinstance(dest, file):
                 proc = subprocess.Popen(args, 
@@ -511,7 +631,7 @@ def popen_scp(source, dest,
                         stdin = source)
                 err = proc.stderr.read()
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,err), proc)
             elif hasattr(source, 'read'):
                 # file-like (but not file) source
@@ -548,7 +668,7 @@ def popen_scp(source, dest,
                 err.append(proc.stderr.read())
                     
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,''.join(err)), proc)
             elif hasattr(dest, 'write'):
                 # file-like (but not file) dest
@@ -583,7 +703,7 @@ def popen_scp(source, dest,
                 err.append(proc.stderr.read())
                     
                 proc._known_hosts = tmp_known_hosts
-                proc.wait()
+                eintr_retry(proc.wait)()
                 return ((None,''.join(err)), proc)
             else:
                 raise AssertionError, "Unreachable code reached! :-Q"
@@ -626,7 +746,7 @@ def popen_scp(source, dest,
             proc._known_hosts = tmp_known_hosts
             
             comm = proc.communicate()
-            proc.wait()
+            eintr_retry(proc.wait)()
             return (comm, proc)
  
 def popen_ssh_subprocess(python_code, host, port, user, agent, 
@@ -648,8 +768,7 @@ def popen_ssh_subprocess(python_code, host, port, user, agent,
         # We had to verify if strace works (cannot nest them)
         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
         #cmd += "$CMD "
-        #if self.mode == MODE_SSH:
-        #    cmd += "strace -f -tt -s 200 -o strace$$.out "
+        #cmd += "strace -f -tt -s 200 -o strace$$.out "
         cmd += "python -c '"
         cmd += "import base64, os\n"
         cmd += "cmd = \"\"\n"