Bug fixing in integration tests
[nepi.git] / src / nepi / util / server.py
index e1f68c1..9136b90 100644 (file)
@@ -12,6 +12,7 @@ import subprocess
 import threading
 import time
 import traceback
+import signal
 
 CTRL_SOCK = "ctrl.sock"
 STD_ERR = "stderr.log"
@@ -116,7 +117,7 @@ class Server(object):
         while not self._stop:
             conn, addr = self._ctrl_sock.accept()
             conn.settimeout(5)
-            while True:
+            while not self._stop:
                 try:
                     msg = self.recv_msg(conn)
                 except socket.timeout, e:
@@ -127,8 +128,18 @@ class Server(object):
                     reply = self.stop_action()
                 else:
                     reply = self.reply_action(msg)
-                self.send_reply(conn, reply)
-            conn.close()
+                
+                try:
+                    self.send_reply(conn, reply)
+                except socket.error:
+                    self.log_error()
+                    print >>sys.stderr, "NOTICE: Awaiting for reconnection"
+                    break
+            try:
+                conn.close()
+            except:
+                # Doesn't matter
+                self.log_error()
 
     def recv_msg(self, conn):
         data = ""
@@ -140,8 +151,12 @@ class Server(object):
                     raise
                 if chunk == '':
                     continue
-            data += chunk
-            if chunk[-1] == "\n":
+            if chunk:
+                data += chunk
+                if chunk[-1] == "\n":
+                    break
+            else:
+                # empty chunk = EOF
                 break
         decoded = base64.b64decode(data)
         return decoded.rstrip()
@@ -163,11 +178,13 @@ class Server(object):
     def reply_action(self, msg):
         return "Reply to: %s" % msg
 
-    def log_error(self, text = None):
+    def log_error(self, text = None, context = ''):
         if text == None:
             text = traceback.format_exc()
         date = time.strftime("%Y-%m-%d %H:%M:%S")
-        sys.stderr.write("ERROR: %s\n%s\n" % (date, text))
+        if context:
+            context = " (%s)" % (context,)
+        sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
         return text
 
     def log_debug(self, text):
@@ -242,8 +259,26 @@ class Forwarder(object):
 class Client(object):
     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
             agent = None):
-        python_code = "from nepi.util import server;c=server.Forwarder('%s');\
-                c.forward()" % root_dir
+        self.root_dir = root_dir
+        self.addr = (host, port)
+        self.user = user
+        self.agent = agent
+        self._stopped = False
+        self.connect()
+    
+    def __del__(self):
+        if self._process.poll() is None:
+            os.kill(self._process.pid, signal.SIGTERM)
+        self._process.wait()
+        
+    def connect(self):
+        root_dir = self.root_dir
+        (host, port) = self.addr
+        user = self.user
+        agent = self.agent
+        
+        python_code = "from nepi.util import server;c=server.Forwarder(%r);\
+                c.forward()" % (root_dir,)
         if host != None:
             self._process = popen_ssh_subprocess(python_code, host, port, 
                     user, agent)
@@ -254,14 +289,29 @@ class Client(object):
                     stdout = subprocess.PIPE,
                     stderr = subprocess.PIPE
                 )
+            if self._process.poll():
+                err = self._process.stderr.read()
+                raise RuntimeError("Client could not be executed: %s" % \
+                        err)
 
     def send_msg(self, msg):
         encoded = base64.b64encode(msg)
         data = "%s\n" % encoded
-        self._process.stdin.write(data)
+        
+        try:
+            self._process.stdin.write(data)
+        except (IOError, ValueError):
+            # dead process, poll it to un-zombify
+            self._process.poll()
+            
+            # try again after reconnect
+            # If it fails again, though, give up
+            self.connect()
+            self._process.stdin.write(data)
 
     def send_stop(self):
         self.send_msg(STOP_MSG)
+        self._stopped = True
 
     def read_reply(self):
         data = self._process.stdout.readline()