36e54dc38244c16be7d77636e687aa53197069ee
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16
17 CTRL_SOCK = "ctrl.sock"
18 STD_ERR = "stderr.log"
19 MAX_FD = 1024
20
21 STOP_MSG = "STOP"
22
23 ERROR_LEVEL = 0
24 DEBUG_LEVEL = 1
25
26 if hasattr(os, "devnull"):
27     DEV_NULL = os.devnull
28 else:
29     DEV_NULL = "/dev/null"
30
31 class Server(object):
32     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
33         self._root_dir = root_dir
34         self._stop = False
35         self._ctrl_sock = None
36         self._log_level = log_level
37
38     def run(self):
39         try:
40             if self.daemonize():
41                 self.post_daemonize()
42                 self.loop()
43                 self.cleanup()
44                 # ref: "os._exit(0)"
45                 # can not return normally after fork beacuse no exec was done.
46                 # This means that if we don't do a os._exit(0) here the code that 
47                 # follows the call to "Server.run()" in the "caller code" will be 
48                 # executed... but by now it has already been executed after the 
49                 # first process (the one that did the first fork) returned.
50                 os._exit(0)
51         except:
52             self.log_error()
53             self.cleanup()
54             os._exit(0)
55
56     def daemonize(self):
57         # pipes for process synchronization
58         (r, w) = os.pipe()
59
60         pid1 = os.fork()
61         if pid1 > 0:
62             os.close(w)
63             os.read(r, 1)
64             os.close(r)
65             # os.waitpid avoids leaving a <defunc> (zombie) process
66             st = os.waitpid(pid1, 0)[1]
67             if st:
68                 raise RuntimeError("Daemonization failed")
69             # return 0 to inform the caller method that this is not the 
70             # daemonized process
71             return 0
72         os.close(r)
73
74         # Decouple from parent environment.
75         os.chdir(self._root_dir)
76         os.umask(0)
77         os.setsid()
78
79         # fork 2
80         pid2 = os.fork()
81         if pid2 > 0:
82             # see ref: "os._exit(0)"
83             os._exit(0)
84
85         # close all open file descriptors.
86         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
87         if (max_fd == resource.RLIM_INFINITY):
88             max_fd = MAX_FD
89         for fd in range(3, max_fd):
90             if fd != w:
91                 try:
92                     os.close(fd)
93                 except OSError:
94                     pass
95
96         # Redirect standard file descriptors.
97         stdin = open(DEV_NULL, "r")
98         stderr = stdout = open(STD_ERR, "a", 0)
99         os.dup2(stdin.fileno(), sys.stdin.fileno())
100         # NOTE: sys.stdout.write will still be buffered, even if the file
101         # was opened with 0 buffer
102         os.dup2(stdout.fileno(), sys.stdout.fileno())
103         os.dup2(stderr.fileno(), sys.stderr.fileno())
104
105         # create control socket
106         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
107         self._ctrl_sock.bind(CTRL_SOCK)
108         self._ctrl_sock.listen(0)
109
110         # let the parent process know that the daemonization is finished
111         os.write(w, "\n")
112         os.close(w)
113         return 1
114
115     def post_daemonize(self):
116         pass
117
118     def loop(self):
119         while not self._stop:
120             conn, addr = self._ctrl_sock.accept()
121             conn.settimeout(5)
122             while not self._stop:
123                 try:
124                     msg = self.recv_msg(conn)
125                 except socket.timeout, e:
126                     break
127                     
128                 if msg == STOP_MSG:
129                     self._stop = True
130                     reply = self.stop_action()
131                 else:
132                     reply = self.reply_action(msg)
133                 
134                 try:
135                     self.send_reply(conn, reply)
136                 except socket.error:
137                     self.log_error()
138                     self.log_error("NOTICE: Awaiting for reconnection")
139                     break
140             try:
141                 conn.close()
142             except:
143                 # Doesn't matter
144                 self.log_error()
145
146     def recv_msg(self, conn):
147         data = ""
148         while True:
149             try:
150                 chunk = conn.recv(1024)
151             except OSError, e:
152                 if e.errno != errno.EINTR:
153                     raise
154                 if chunk == '':
155                     continue
156             if chunk:
157                 data += chunk
158                 if chunk[-1] == "\n":
159                     break
160             else:
161                 # empty chunk = EOF
162                 break
163         decoded = base64.b64decode(data)
164         return decoded.rstrip()
165
166     def send_reply(self, conn, reply):
167         encoded = base64.b64encode(reply)
168         conn.send("%s\n" % encoded)
169        
170     def cleanup(self):
171         try:
172             self._ctrl_sock.close()
173             os.remove(CTRL_SOCK)
174         except:
175             self.log_error()
176
177     def stop_action(self):
178         return "Stopping server"
179
180     def reply_action(self, msg):
181         return "Reply to: %s" % msg
182
183     def log_error(self, text = None, context = ''):
184         if text == None:
185             text = traceback.format_exc()
186         date = time.strftime("%Y-%m-%d %H:%M:%S")
187         if context:
188             context = " (%s)" % (context,)
189         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
190         return text
191
192     def log_debug(self, text):
193         if self._log_level == DEBUG_LEVEL:
194             date = time.strftime("%Y-%m-%d %H:%M:%S")
195             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
196
197 class Forwarder(object):
198     def __init__(self, root_dir = "."):
199         self._ctrl_sock = None
200         self._root_dir = root_dir
201         self._stop = False
202
203     def forward(self):
204         self.connect()
205         print >>sys.stderr, "READY."
206         while not self._stop:
207             data = self.read_data()
208             self.send_to_server(data)
209             data = self.recv_from_server()
210             self.write_data(data)
211         self.disconnect()
212
213     def read_data(self):
214         return sys.stdin.readline()
215
216     def write_data(self, data):
217         sys.stdout.write(data)
218         # sys.stdout.write is buffered, this is why we need to do a flush()
219         sys.stdout.flush()
220
221     def send_to_server(self, data):
222         try:
223             self._ctrl_sock.send(data)
224         except IOError, e:
225             if e.errno == errno.EPIPE:
226                 self.connect()
227                 self._ctrl_sock.send(data)
228             else:
229                 raise e
230         encoded = data.rstrip() 
231         msg = base64.b64decode(encoded)
232         if msg == STOP_MSG:
233             self._stop = True
234
235     def recv_from_server(self):
236         data = ""
237         while True:
238             try:
239                 chunk = self._ctrl_sock.recv(1024)
240             except OSError, e:
241                 if e.errno != errno.EINTR:
242                     raise
243                 if chunk == '':
244                     continue
245             data += chunk
246             if chunk[-1] == "\n":
247                 break
248         return data
249  
250     def connect(self):
251         self.disconnect()
252         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
253         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
254         self._ctrl_sock.connect(sock_addr)
255
256     def disconnect(self):
257         try:
258             self._ctrl_sock.close()
259         except:
260             pass
261
262 class Client(object):
263     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
264             agent = None):
265         self.root_dir = root_dir
266         self.addr = (host, port)
267         self.user = user
268         self.agent = agent
269         self._stopped = False
270         self.connect()
271     
272     def __del__(self):
273         if self._process.poll() is None:
274             os.kill(self._process.pid, signal.SIGTERM)
275         self._process.wait()
276         
277     def connect(self):
278         root_dir = self.root_dir
279         (host, port) = self.addr
280         user = self.user
281         agent = self.agent
282         
283         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
284                 c.forward()" % (root_dir,)
285         if host != None:
286             self._process = popen_ssh_subprocess(python_code, host, port, 
287                     user, agent)
288             # popen_ssh_subprocess already waits for readiness
289         else:
290             self._process = subprocess.Popen(
291                     ["python", "-c", python_code],
292                     stdin = subprocess.PIPE, 
293                     stdout = subprocess.PIPE,
294                     stderr = subprocess.PIPE
295                 )
296                 
297         # Wait for the forwarder to be ready, otherwise nobody
298         # will be able to connect to it
299         helo = self._process.stderr.readline()
300         if helo != 'READY.\n':
301             raise AssertionError, "Expected 'Ready.', got %r" % (helo,)
302         
303         if self._process.poll():
304             err = self._process.stderr.read()
305             raise RuntimeError("Client could not be executed: %s" % \
306                     err)
307
308     def send_msg(self, msg):
309         encoded = base64.b64encode(msg)
310         data = "%s\n" % encoded
311         
312         try:
313             self._process.stdin.write(data)
314         except (IOError, ValueError):
315             # dead process, poll it to un-zombify
316             self._process.poll()
317             
318             # try again after reconnect
319             # If it fails again, though, give up
320             self.connect()
321             self._process.stdin.write(data)
322
323     def send_stop(self):
324         self.send_msg(STOP_MSG)
325         self._stopped = True
326
327     def read_reply(self):
328         data = self._process.stdout.readline()
329         encoded = data.rstrip() 
330         return base64.b64decode(encoded)
331
332 def popen_ssh_subprocess(python_code, host, port, user, agent, 
333         python_path = None):
334         if python_path:
335             python_path.replace("'", r"'\''")
336             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
337         else:
338             cmd = ""
339         # Uncomment for debug (to run everything under strace)
340         # We had to verify if strace works (cannot nest them)
341         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
342         #cmd += "$CMD "
343         #if self.mode == MODE_SSH:
344         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
345         cmd += "python -c '"
346         cmd += "import base64, os\n"
347         cmd += "cmd = \"\"\n"
348         cmd += "while True:\n"
349         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
350         cmd += " if cmd[-1] == \"\\n\": break\n"
351         cmd += "cmd = base64.b64decode(cmd)\n"
352         # Uncomment for debug
353         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
354         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
355         cmd += "exec(cmd)\n'"
356
357         args = ['ssh',
358                 # Don't bother with localhost. Makes test easier
359                 '-o', 'NoHostAuthenticationForLocalhost=yes',
360                 '-l', user, host]
361         if agent:
362             args.append('-A')
363         if port:
364             args.append('-p%d' % port)
365         args.append(cmd)
366
367         # connects to the remote host and starts a remote rpyc connection
368         proc = subprocess.Popen(args, 
369                 stdout = subprocess.PIPE,
370                 stdin = subprocess.PIPE, 
371                 stderr = subprocess.PIPE)
372         # send the command to execute
373         os.write(proc.stdin.fileno(),
374                 base64.b64encode(python_code) + "\n")
375         msg = os.read(proc.stdout.fileno(), 3)
376         if msg != "OK\n":
377             raise RuntimeError("Failed to start remote python interpreter")
378         return proc
379