bugfixing
[nepi.git] / src / nepi / util / server.py
index 1a432ef..8a3dbed 100644 (file)
@@ -9,6 +9,8 @@ import socket
 import sys
 import subprocess
 import threading
+from time import strftime
+import traceback
 
 CTRL_SOCK = "ctrl.sock"
 STD_ERR = "stderr.log"
@@ -16,12 +18,16 @@ MAX_FD = 1024
 
 STOP_MSG = "STOP"
 
+ERROR_LEVEL = 0
+DEBUG_LEVEL = 1
+
 class Server(object):
     def __init__(self, root_dir = "."):
         self._root_dir = root_dir
         self._stop = False
         self._ctrl_sock = None
-        self._stderr = None 
+        self._stderr = None
+        self._log_level = ERROR_LEVEL
 
     def run(self):
         try:
@@ -38,16 +44,26 @@ class Server(object):
                 os._exit(0)
         except:
             self.log_error()
-            raise
+            self.cleanup()
+            os._exit(0)
 
     def daemonize(self):
+        # pipes for process synchronization
+        (r, w) = os.pipe()
+
         pid1 = os.fork()
         if pid1 > 0:
-            # we do os.waitpid to avoid leaving a <defunc> (zombie) process
-            os.waitpid(pid1, 0)
+            os.close(w)
+            os.read(r, 1)
+            os.close(r)
+            # os.waitpid avoids leaving a <defunc> (zombie) process
+            st = os.waitpid(pid1, 0)[1]
+            if st:
+                raise RuntimeError("Daemonization failed")
             # return 0 to inform the caller method that this is not the 
             # daemonized process
             return 0
+        os.close(r)
 
         # Decouple from parent environment.
         os.chdir(self._root_dir)
@@ -61,11 +77,12 @@ class Server(object):
             os._exit(0)
 
         # close all open file descriptors.
-        for fd in range(2, MAX_FD):
-            try:
-                os.close(fd)
-            except OSError:
-                pass
+        for fd in range(3, MAX_FD):
+            if fd != w:
+                try:
+                    os.close(fd)
+                except OSError:
+                    pass
 
         # Redirect standard file descriptors.
         self._stderr = stdout = file(STD_ERR, "a", 0)
@@ -73,6 +90,9 @@ class Server(object):
         os.dup2(stdin.fileno(), sys.stdin.fileno())
         os.dup2(stdout.fileno(), sys.stdout.fileno())
         os.dup2(self._stderr.fileno(), sys.stderr.fileno())
+        # let the parent process know that the daemonization is finished
+        os.write(w, "\n")
+        os.close(w)
         return 1
 
     def post_daemonize(self):
@@ -93,34 +113,37 @@ class Server(object):
                     
                 if msg == STOP_MSG:
                     self._stop = True
-                    try:
-                        reply = self.stop_action()
-                    except:
-                        self.log_error()
-                    self.send_reply(conn, reply)
-                    break
+                    reply = self.stop_action()
                 else:
-                    try:
-                        reply = self.reply_action(msg)
-                    except:
-                        self.log_error()
-                    self.send_reply(conn, reply)
+                    reply = self.reply_action(msg)
+                self.send_reply(conn, reply)
             conn.close()
 
     def recv_msg(self, conn):
-       data = conn.recv(1024)
-       decoded = base64.b64decode(data)
-       return decoded.rstrip()
+        data = ""
+        while True:
+            try:
+                chunk = conn.recv(1024)
+            except OSError, e:
+                if e.errno != errno.EINTR:
+                    raise
+                if chunk == '':
+                    continue
+            data += chunk
+            if chunk[-1] == "\n":
+                break
+        decoded = base64.b64decode(data)
+        return decoded.rstrip()
 
     def send_reply(self, conn, reply):
-       encoded = base64.b64encode(reply)
-       conn.send("%s\n" % encoded)
+        encoded = base64.b64encode(reply)
+        conn.send("%s\n" % encoded)
        
     def cleanup(self):
         try:
             self._ctrl_sock.close()
             os.remove(CTRL_SOCK)
-        except e:
+        except:
             self.log_error()
 
     def stop_action(self):
@@ -129,12 +152,23 @@ class Server(object):
     def reply_action(self, msg):
         return "Reply to: %s" % msg
 
-    def log_error(self, error = None):
-        if error == None:
-            import traceback
-            error = "%s\n" %  traceback.format_exc()
-        sys.stderr.write(error)
-        return error
+    def set_error_log_level(self):
+        self._log_level = ERROR_LEVEL
+
+    def set_debug_log_level(self):
+        self._log_level = DEBUG_LEVEL
+
+    def log_error(self, text = None):
+        if text == None:
+            text = traceback.format_exc()
+        date = strftime("%Y-%m-%d %H:%M:%S")
+        sys.stderr.write("ERROR: %s\n%s\n" % (date, text))
+        return text
+
+    def log_debug(self, text):
+        if self._log_level == DEBUG_LEVEL:
+            date = strftime("%Y-%m-%d %H:%M:%S")
+            sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
 
 class Forwarder(object):
     def __init__(self, root_dir = "."):
@@ -173,7 +207,19 @@ class Forwarder(object):
             self._stop = True
 
     def recv_from_server(self):
-        return self._ctrl_sock.recv(1024)
+        data = ""
+        while True:
+            try:
+                chunk = self._ctrl_sock.recv(1024)
+            except OSError, e:
+                if e.errno != errno.EINTR:
+                    raise
+                if chunk == '':
+                    continue
+            data += chunk
+            if chunk[-1] == "\n":
+                break
+        return data
  
     def connect(self):
         self.disconnect()
@@ -195,8 +241,7 @@ class Client(object):
                         c.forward()" % root_dir
                 ],
                 stdin = subprocess.PIPE, 
-                stdout = subprocess.PIPE, 
-                env = os.environ)
+                stdout = subprocess.PIPE)
 
     def send_msg(self, msg):
         encoded = base64.b64encode(msg)