Ticket #12: proxy reconnection
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15
16 CTRL_SOCK = "ctrl.sock"
17 STD_ERR = "stderr.log"
18 MAX_FD = 1024
19
20 STOP_MSG = "STOP"
21
22 ERROR_LEVEL = 0
23 DEBUG_LEVEL = 1
24
25 if hasattr(os, "devnull"):
26     DEV_NULL = os.devnull
27 else:
28     DEV_NULL = "/dev/null"
29
30 class Server(object):
31     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
32         self._root_dir = root_dir
33         self._stop = False
34         self._ctrl_sock = None
35         self._log_level = log_level
36
37     def run(self):
38         try:
39             if self.daemonize():
40                 self.post_daemonize()
41                 self.loop()
42                 self.cleanup()
43                 # ref: "os._exit(0)"
44                 # can not return normally after fork beacuse no exec was done.
45                 # This means that if we don't do a os._exit(0) here the code that 
46                 # follows the call to "Server.run()" in the "caller code" will be 
47                 # executed... but by now it has already been executed after the 
48                 # first process (the one that did the first fork) returned.
49                 os._exit(0)
50         except:
51             self.log_error()
52             self.cleanup()
53             os._exit(0)
54
55     def daemonize(self):
56         # pipes for process synchronization
57         (r, w) = os.pipe()
58
59         pid1 = os.fork()
60         if pid1 > 0:
61             os.close(w)
62             os.read(r, 1)
63             os.close(r)
64             # os.waitpid avoids leaving a <defunc> (zombie) process
65             st = os.waitpid(pid1, 0)[1]
66             if st:
67                 raise RuntimeError("Daemonization failed")
68             # return 0 to inform the caller method that this is not the 
69             # daemonized process
70             return 0
71         os.close(r)
72
73         # Decouple from parent environment.
74         os.chdir(self._root_dir)
75         os.umask(0)
76         os.setsid()
77
78         # fork 2
79         pid2 = os.fork()
80         if pid2 > 0:
81             # see ref: "os._exit(0)"
82             os._exit(0)
83
84         # close all open file descriptors.
85         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
86         if (max_fd == resource.RLIM_INFINITY):
87             max_fd = MAX_FD
88         for fd in range(3, max_fd):
89             if fd != w:
90                 try:
91                     os.close(fd)
92                 except OSError:
93                     pass
94
95         # Redirect standard file descriptors.
96         stdin = open(DEV_NULL, "r")
97         stderr = stdout = open(STD_ERR, "a", 0)
98         os.dup2(stdin.fileno(), sys.stdin.fileno())
99         # NOTE: sys.stdout.write will still be buffered, even if the file
100         # was opened with 0 buffer
101         os.dup2(stdout.fileno(), sys.stdout.fileno())
102         os.dup2(stderr.fileno(), sys.stderr.fileno())
103
104         # let the parent process know that the daemonization is finished
105         os.write(w, "\n")
106         os.close(w)
107         return 1
108
109     def post_daemonize(self):
110         pass
111
112     def loop(self):
113         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
114         self._ctrl_sock.bind(CTRL_SOCK)
115         self._ctrl_sock.listen(0)
116         while not self._stop:
117             conn, addr = self._ctrl_sock.accept()
118             conn.settimeout(5)
119             while True:
120                 try:
121                     msg = self.recv_msg(conn)
122                 except socket.timeout, e:
123                     break
124                     
125                 if msg == STOP_MSG:
126                     self._stop = True
127                     reply = self.stop_action()
128                 else:
129                     reply = self.reply_action(msg)
130                 self.send_reply(conn, reply)
131             conn.close()
132
133     def recv_msg(self, conn):
134         data = ""
135         while True:
136             try:
137                 chunk = conn.recv(1024)
138             except OSError, e:
139                 if e.errno != errno.EINTR:
140                     raise
141                 if chunk == '':
142                     continue
143             data += chunk
144             if chunk[-1] == "\n":
145                 break
146         decoded = base64.b64decode(data)
147         return decoded.rstrip()
148
149     def send_reply(self, conn, reply):
150         encoded = base64.b64encode(reply)
151         conn.send("%s\n" % encoded)
152        
153     def cleanup(self):
154         try:
155             self._ctrl_sock.close()
156             os.remove(CTRL_SOCK)
157         except:
158             self.log_error()
159
160     def stop_action(self):
161         return "Stopping server"
162
163     def reply_action(self, msg):
164         return "Reply to: %s" % msg
165
166     def log_error(self, text = None):
167         if text == None:
168             text = traceback.format_exc()
169         date = time.strftime("%Y-%m-%d %H:%M:%S")
170         sys.stderr.write("ERROR: %s\n%s\n" % (date, text))
171         return text
172
173     def log_debug(self, text):
174         if self._log_level == DEBUG_LEVEL:
175             date = time.strftime("%Y-%m-%d %H:%M:%S")
176             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
177
178 class Forwarder(object):
179     def __init__(self, root_dir = "."):
180         self._ctrl_sock = None
181         self._root_dir = root_dir
182         self._stop = False
183
184     def forward(self):
185         self.connect()
186         while not self._stop:
187             data = self.read_data()
188             self.send_to_server(data)
189             data = self.recv_from_server()
190             self.write_data(data)
191         self.disconnect()
192
193     def read_data(self):
194         return sys.stdin.readline()
195
196     def write_data(self, data):
197         sys.stdout.write(data)
198         # sys.stdout.write is buffered, this is why we need to do a flush()
199         sys.stdout.flush()
200
201     def send_to_server(self, data):
202         try:
203             self._ctrl_sock.send(data)
204         except IOError, e:
205             if e.errno == errno.EPIPE:
206                 self.connect()
207                 self._ctrl_sock.send(data)
208             else:
209                 raise e
210         encoded = data.rstrip() 
211         msg = base64.b64decode(encoded)
212         if msg == STOP_MSG:
213             self._stop = True
214
215     def recv_from_server(self):
216         data = ""
217         while True:
218             try:
219                 chunk = self._ctrl_sock.recv(1024)
220             except OSError, e:
221                 if e.errno != errno.EINTR:
222                     raise
223                 if chunk == '':
224                     continue
225             data += chunk
226             if chunk[-1] == "\n":
227                 break
228         return data
229  
230     def connect(self):
231         self.disconnect()
232         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
233         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
234         self._ctrl_sock.connect(sock_addr)
235
236     def disconnect(self):
237         try:
238             self._ctrl_sock.close()
239         except:
240             pass
241
242 class Client(object):
243     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
244             agent = None):
245         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
246                 c.forward()" % (root_dir,)
247         if host != None:
248             self._process = popen_ssh_subprocess(python_code, host, port, 
249                     user, agent)
250         else:
251             self._process = subprocess.Popen(
252                     ["python", "-c", python_code],
253                     stdin = subprocess.PIPE, 
254                     stdout = subprocess.PIPE,
255                     stderr = subprocess.PIPE
256                 )
257
258     def send_msg(self, msg):
259         encoded = base64.b64encode(msg)
260         data = "%s\n" % encoded
261         self._process.stdin.write(data)
262
263     def send_stop(self):
264         self.send_msg(STOP_MSG)
265
266     def read_reply(self):
267         data = self._process.stdout.readline()
268         encoded = data.rstrip() 
269         return base64.b64decode(encoded)
270
271 def popen_ssh_subprocess(python_code, host, port, user, agent, 
272         python_path = None):
273         if python_path:
274             python_path.replace("'", r"'\''")
275             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
276         else:
277             cmd = ""
278         # Uncomment for debug (to run everything under strace)
279         # We had to verify if strace works (cannot nest them)
280         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
281         #cmd += "$CMD "
282         #if self.mode == MODE_SSH:
283         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
284         cmd += "python -c '"
285         cmd += "import base64, os\n"
286         cmd += "cmd = \"\"\n"
287         cmd += "while True:\n"
288         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
289         cmd += " if cmd[-1] == \"\\n\": break\n"
290         cmd += "cmd = base64.b64decode(cmd)\n"
291         # Uncomment for debug
292         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
293         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
294         cmd += "exec(cmd)\n'"
295
296         args = ['ssh',
297                 # Don't bother with localhost. Makes test easier
298                 '-o', 'NoHostAuthenticationForLocalhost=yes',
299                 '-l', user, host]
300         if agent:
301             args.append('-A')
302         if port:
303             args.append('-p%d' % port)
304         args.append(cmd)
305
306         # connects to the remote host and starts a remote rpyc connection
307         proc = subprocess.Popen(args, 
308                 stdout = subprocess.PIPE,
309                 stdin = subprocess.PIPE, 
310                 stderr = subprocess.PIPE)
311         # send the command to execute
312         os.write(proc.stdin.fileno(),
313                 base64.b64encode(python_code) + "\n")
314         msg = os.read(proc.stdout.fileno(), 3)
315         if msg != "OK\n":
316             raise RuntimeError("Failed to start remote python interpreter")
317         return proc
318