282c074a562cec2ce86eadf61da4b4f9193968e9
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16
17 CTRL_SOCK = "ctrl.sock"
18 STD_ERR = "stderr.log"
19 MAX_FD = 1024
20
21 STOP_MSG = "STOP"
22
23 ERROR_LEVEL = 0
24 DEBUG_LEVEL = 1
25
26 if hasattr(os, "devnull"):
27     DEV_NULL = os.devnull
28 else:
29     DEV_NULL = "/dev/null"
30
31 class Server(object):
32     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
33         self._root_dir = root_dir
34         self._stop = False
35         self._ctrl_sock = None
36         self._log_level = log_level
37
38     def run(self):
39         try:
40             if self.daemonize():
41                 self.post_daemonize()
42                 self.loop()
43                 self.cleanup()
44                 # ref: "os._exit(0)"
45                 # can not return normally after fork beacuse no exec was done.
46                 # This means that if we don't do a os._exit(0) here the code that 
47                 # follows the call to "Server.run()" in the "caller code" will be 
48                 # executed... but by now it has already been executed after the 
49                 # first process (the one that did the first fork) returned.
50                 os._exit(0)
51         except:
52             self.log_error()
53             self.cleanup()
54             os._exit(0)
55
56     def daemonize(self):
57         # pipes for process synchronization
58         (r, w) = os.pipe()
59
60         pid1 = os.fork()
61         if pid1 > 0:
62             os.close(w)
63             os.read(r, 1)
64             os.close(r)
65             # os.waitpid avoids leaving a <defunc> (zombie) process
66             st = os.waitpid(pid1, 0)[1]
67             if st:
68                 raise RuntimeError("Daemonization failed")
69             # return 0 to inform the caller method that this is not the 
70             # daemonized process
71             return 0
72         os.close(r)
73
74         # Decouple from parent environment.
75         os.chdir(self._root_dir)
76         os.umask(0)
77         os.setsid()
78
79         # fork 2
80         pid2 = os.fork()
81         if pid2 > 0:
82             # see ref: "os._exit(0)"
83             os._exit(0)
84
85         # close all open file descriptors.
86         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
87         if (max_fd == resource.RLIM_INFINITY):
88             max_fd = MAX_FD
89         for fd in range(3, max_fd):
90             if fd != w:
91                 try:
92                     os.close(fd)
93                 except OSError:
94                     pass
95
96         # Redirect standard file descriptors.
97         stdin = open(DEV_NULL, "r")
98         stderr = stdout = open(STD_ERR, "a", 0)
99         os.dup2(stdin.fileno(), sys.stdin.fileno())
100         # NOTE: sys.stdout.write will still be buffered, even if the file
101         # was opened with 0 buffer
102         os.dup2(stdout.fileno(), sys.stdout.fileno())
103         os.dup2(stderr.fileno(), sys.stderr.fileno())
104
105         # let the parent process know that the daemonization is finished
106         os.write(w, "\n")
107         os.close(w)
108         return 1
109
110     def post_daemonize(self):
111         pass
112
113     def loop(self):
114         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
115         self._ctrl_sock.bind(CTRL_SOCK)
116         self._ctrl_sock.listen(0)
117         while not self._stop:
118             conn, addr = self._ctrl_sock.accept()
119             conn.settimeout(5)
120             while not self._stop:
121                 try:
122                     msg = self.recv_msg(conn)
123                 except socket.timeout, e:
124                     break
125                     
126                 if msg == STOP_MSG:
127                     self._stop = True
128                     reply = self.stop_action()
129                 else:
130                     reply = self.reply_action(msg)
131                 
132                 try:
133                     self.send_reply(conn, reply)
134                 except socket.error:
135                     self.log_error()
136                     print >>sys.stderr, "NOTICE: Awaiting for reconnection"
137                     break
138             try:
139                 conn.close()
140             except:
141                 # Doesn't matter
142                 self.log_error()
143
144     def recv_msg(self, conn):
145         data = ""
146         while True:
147             try:
148                 chunk = conn.recv(1024)
149             except OSError, e:
150                 if e.errno != errno.EINTR:
151                     raise
152                 if chunk == '':
153                     continue
154             if chunk:
155                 data += chunk
156                 if chunk[-1] == "\n":
157                     break
158             else:
159                 # empty chunk = EOF
160                 break
161         decoded = base64.b64decode(data)
162         return decoded.rstrip()
163
164     def send_reply(self, conn, reply):
165         encoded = base64.b64encode(reply)
166         conn.send("%s\n" % encoded)
167        
168     def cleanup(self):
169         try:
170             self._ctrl_sock.close()
171             os.remove(CTRL_SOCK)
172         except:
173             self.log_error()
174
175     def stop_action(self):
176         return "Stopping server"
177
178     def reply_action(self, msg):
179         return "Reply to: %s" % msg
180
181     def log_error(self, text = None, context = ''):
182         if text == None:
183             text = traceback.format_exc()
184         date = time.strftime("%Y-%m-%d %H:%M:%S")
185         if context:
186             context = " (%s)" % (context,)
187         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
188         return text
189
190     def log_debug(self, text):
191         if self._log_level == DEBUG_LEVEL:
192             date = time.strftime("%Y-%m-%d %H:%M:%S")
193             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
194
195 class Forwarder(object):
196     def __init__(self, root_dir = "."):
197         self._ctrl_sock = None
198         self._root_dir = root_dir
199         self._stop = False
200
201     def forward(self):
202         self.connect()
203         while not self._stop:
204             data = self.read_data()
205             self.send_to_server(data)
206             data = self.recv_from_server()
207             self.write_data(data)
208         self.disconnect()
209
210     def read_data(self):
211         return sys.stdin.readline()
212
213     def write_data(self, data):
214         sys.stdout.write(data)
215         # sys.stdout.write is buffered, this is why we need to do a flush()
216         sys.stdout.flush()
217
218     def send_to_server(self, data):
219         try:
220             self._ctrl_sock.send(data)
221         except IOError, e:
222             if e.errno == errno.EPIPE:
223                 self.connect()
224                 self._ctrl_sock.send(data)
225             else:
226                 raise e
227         encoded = data.rstrip() 
228         msg = base64.b64decode(encoded)
229         if msg == STOP_MSG:
230             self._stop = True
231
232     def recv_from_server(self):
233         data = ""
234         while True:
235             try:
236                 chunk = self._ctrl_sock.recv(1024)
237             except OSError, e:
238                 if e.errno != errno.EINTR:
239                     raise
240                 if chunk == '':
241                     continue
242             data += chunk
243             if chunk[-1] == "\n":
244                 break
245         return data
246  
247     def connect(self):
248         self.disconnect()
249         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
250         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
251         self._ctrl_sock.connect(sock_addr)
252
253     def disconnect(self):
254         try:
255             self._ctrl_sock.close()
256         except:
257             pass
258
259 class Client(object):
260     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
261             agent = None):
262         self.root_dir = root_dir
263         self.addr = (host, port)
264         self.user = user
265         self.agent = agent
266         self._stopped = False
267         self.connect()
268     
269     def __del__(self):
270         if self._process.poll() is None:
271             os.kill(self._process.pid, signal.SIGTERM)
272         self._process.wait()
273         
274     def connect(self):
275         root_dir = self.root_dir
276         (host, port) = self.addr
277         user = self.user
278         agent = self.agent
279         
280         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
281                 c.forward()" % (root_dir,)
282         if host != None:
283             self._process = popen_ssh_subprocess(python_code, host, port, 
284                     user, agent)
285         else:
286             self._process = subprocess.Popen(
287                     ["python", "-c", python_code],
288                     stdin = subprocess.PIPE, 
289                     stdout = subprocess.PIPE,
290                     stderr = subprocess.PIPE
291                 )
292
293     def send_msg(self, msg):
294         encoded = base64.b64encode(msg)
295         data = "%s\n" % encoded
296         
297         try:
298             self._process.stdin.write(data)
299         except (IOError,ValueError):
300             # dead process, poll it to un-zombify
301             self._process.poll()
302             
303             # try again after reconnect
304             # If it fails again, though, give up
305             self.connect()
306             self._process.stdin.write(data)
307
308     def send_stop(self):
309         self.send_msg(STOP_MSG)
310         self._stopped = True
311
312     def read_reply(self):
313         data = self._process.stdout.readline()
314         encoded = data.rstrip() 
315         return base64.b64decode(encoded)
316
317 def popen_ssh_subprocess(python_code, host, port, user, agent, 
318         python_path = None):
319         if python_path:
320             python_path.replace("'", r"'\''")
321             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
322         else:
323             cmd = ""
324         # Uncomment for debug (to run everything under strace)
325         # We had to verify if strace works (cannot nest them)
326         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
327         #cmd += "$CMD "
328         #if self.mode == MODE_SSH:
329         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
330         cmd += "python -c '"
331         cmd += "import base64, os\n"
332         cmd += "cmd = \"\"\n"
333         cmd += "while True:\n"
334         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
335         cmd += " if cmd[-1] == \"\\n\": break\n"
336         cmd += "cmd = base64.b64decode(cmd)\n"
337         # Uncomment for debug
338         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
339         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
340         cmd += "exec(cmd)\n'"
341
342         args = ['ssh',
343                 # Don't bother with localhost. Makes test easier
344                 '-o', 'NoHostAuthenticationForLocalhost=yes',
345                 '-l', user, host]
346         if agent:
347             args.append('-A')
348         if port:
349             args.append('-p%d' % port)
350         args.append(cmd)
351
352         # connects to the remote host and starts a remote rpyc connection
353         proc = subprocess.Popen(args, 
354                 stdout = subprocess.PIPE,
355                 stdin = subprocess.PIPE, 
356                 stderr = subprocess.PIPE)
357         # send the command to execute
358         os.write(proc.stdin.fileno(),
359                 base64.b64encode(python_code) + "\n")
360         msg = os.read(proc.stdout.fileno(), 3)
361         if msg != "OK\n":
362             raise RuntimeError("Failed to start remote python interpreter")
363         return proc
364