199a00955d84407385a2f091ada5b803310ebe30
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27 TRACE = False
28
29 if hasattr(os, "devnull"):
30     DEV_NULL = os.devnull
31 else:
32     DEV_NULL = "/dev/null"
33
34
35
36 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
37
38 def shell_escape(s):
39     """ Escapes strings so that they are safe to use as command-line arguments """
40     if SHELL_SAFE.match(s):
41         # safe string - no escaping needed
42         return s
43     else:
44         # unsafe string - escape
45         s = s.replace("'","\\'")
46         return "'%s'" % (s,)
47
48 class Server(object):
49     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
50         self._root_dir = root_dir
51         self._stop = False
52         self._ctrl_sock = None
53         self._log_level = log_level
54
55     def run(self):
56         try:
57             if self.daemonize():
58                 self.post_daemonize()
59                 self.loop()
60                 self.cleanup()
61                 # ref: "os._exit(0)"
62                 # can not return normally after fork beacuse no exec was done.
63                 # This means that if we don't do a os._exit(0) here the code that 
64                 # follows the call to "Server.run()" in the "caller code" will be 
65                 # executed... but by now it has already been executed after the 
66                 # first process (the one that did the first fork) returned.
67                 os._exit(0)
68         except:
69             self.log_error()
70             self.cleanup()
71             os._exit(0)
72
73     def daemonize(self):
74         # pipes for process synchronization
75         (r, w) = os.pipe()
76
77         pid1 = os.fork()
78         if pid1 > 0:
79             os.close(w)
80             os.read(r, 1)
81             os.close(r)
82             # os.waitpid avoids leaving a <defunc> (zombie) process
83             st = os.waitpid(pid1, 0)[1]
84             if st:
85                 raise RuntimeError("Daemonization failed")
86             # return 0 to inform the caller method that this is not the 
87             # daemonized process
88             return 0
89         os.close(r)
90
91         # Decouple from parent environment.
92         os.chdir(self._root_dir)
93         os.umask(0)
94         os.setsid()
95
96         # fork 2
97         pid2 = os.fork()
98         if pid2 > 0:
99             # see ref: "os._exit(0)"
100             os._exit(0)
101
102         # close all open file descriptors.
103         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
104         if (max_fd == resource.RLIM_INFINITY):
105             max_fd = MAX_FD
106         for fd in range(3, max_fd):
107             if fd != w:
108                 try:
109                     os.close(fd)
110                 except OSError:
111                     pass
112
113         # Redirect standard file descriptors.
114         stdin = open(DEV_NULL, "r")
115         stderr = stdout = open(STD_ERR, "a", 0)
116         os.dup2(stdin.fileno(), sys.stdin.fileno())
117         # NOTE: sys.stdout.write will still be buffered, even if the file
118         # was opened with 0 buffer
119         os.dup2(stdout.fileno(), sys.stdout.fileno())
120         os.dup2(stderr.fileno(), sys.stderr.fileno())
121
122         # create control socket
123         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
124         self._ctrl_sock.bind(CTRL_SOCK)
125         self._ctrl_sock.listen(0)
126
127         # let the parent process know that the daemonization is finished
128         os.write(w, "\n")
129         os.close(w)
130         return 1
131
132     def post_daemonize(self):
133         pass
134
135     def loop(self):
136         while not self._stop:
137             conn, addr = self._ctrl_sock.accept()
138             conn.settimeout(5)
139             while not self._stop:
140                 try:
141                     msg = self.recv_msg(conn)
142                 except socket.timeout, e:
143                     break
144                     
145                 if msg == STOP_MSG:
146                     self._stop = True
147                     reply = self.stop_action()
148                 else:
149                     reply = self.reply_action(msg)
150                 
151                 try:
152                     self.send_reply(conn, reply)
153                 except socket.error:
154                     self.log_error()
155                     self.log_error("NOTICE: Awaiting for reconnection")
156                     break
157             try:
158                 conn.close()
159             except:
160                 # Doesn't matter
161                 self.log_error()
162
163     def recv_msg(self, conn):
164         data = ""
165         while True:
166             try:
167                 chunk = conn.recv(1024)
168             except OSError, e:
169                 if e.errno != errno.EINTR:
170                     raise
171                 if chunk == '':
172                     continue
173             if chunk:
174                 data += chunk
175                 if chunk[-1] == "\n":
176                     break
177             else:
178                 # empty chunk = EOF
179                 break
180         decoded = base64.b64decode(data)
181         return decoded.rstrip()
182
183     def send_reply(self, conn, reply):
184         encoded = base64.b64encode(reply)
185         conn.send("%s\n" % encoded)
186        
187     def cleanup(self):
188         try:
189             self._ctrl_sock.close()
190             os.remove(CTRL_SOCK)
191         except:
192             self.log_error()
193
194     def stop_action(self):
195         return "Stopping server"
196
197     def reply_action(self, msg):
198         return "Reply to: %s" % msg
199
200     def log_error(self, text = None, context = ''):
201         if text == None:
202             text = traceback.format_exc()
203         date = time.strftime("%Y-%m-%d %H:%M:%S")
204         if context:
205             context = " (%s)" % (context,)
206         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
207         return text
208
209     def log_debug(self, text):
210         if self._log_level == DEBUG_LEVEL:
211             date = time.strftime("%Y-%m-%d %H:%M:%S")
212             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
213
214 class Forwarder(object):
215     def __init__(self, root_dir = "."):
216         self._ctrl_sock = None
217         self._root_dir = root_dir
218         self._stop = False
219
220     def forward(self):
221         self.connect()
222         print >>sys.stderr, "READY."
223         while not self._stop:
224             data = self.read_data()
225             self.send_to_server(data)
226             data = self.recv_from_server()
227             self.write_data(data)
228         self.disconnect()
229
230     def read_data(self):
231         return sys.stdin.readline()
232
233     def write_data(self, data):
234         sys.stdout.write(data)
235         # sys.stdout.write is buffered, this is why we need to do a flush()
236         sys.stdout.flush()
237
238     def send_to_server(self, data):
239         try:
240             self._ctrl_sock.send(data)
241         except IOError, e:
242             if e.errno == errno.EPIPE:
243                 self.connect()
244                 self._ctrl_sock.send(data)
245             else:
246                 raise e
247         encoded = data.rstrip() 
248         msg = base64.b64decode(encoded)
249         if msg == STOP_MSG:
250             self._stop = True
251
252     def recv_from_server(self):
253         data = ""
254         while True:
255             try:
256                 chunk = self._ctrl_sock.recv(1024)
257             except OSError, e:
258                 if e.errno != errno.EINTR:
259                     raise
260                 if chunk == '':
261                     continue
262             data += chunk
263             if chunk[-1] == "\n":
264                 break
265         return data
266  
267     def connect(self):
268         self.disconnect()
269         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
270         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
271         self._ctrl_sock.connect(sock_addr)
272
273     def disconnect(self):
274         try:
275             self._ctrl_sock.close()
276         except:
277             pass
278
279 class Client(object):
280     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
281             agent = None):
282         self.root_dir = root_dir
283         self.addr = (host, port)
284         self.user = user
285         self.agent = agent
286         self._stopped = False
287         self.connect()
288     
289     def __del__(self):
290         if self._process.poll() is None:
291             os.kill(self._process.pid, signal.SIGTERM)
292         self._process.wait()
293         
294     def connect(self):
295         root_dir = self.root_dir
296         (host, port) = self.addr
297         user = self.user
298         agent = self.agent
299         
300         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
301                 c.forward()" % (root_dir,)
302         if host != None:
303             self._process = popen_ssh_subprocess(python_code, host, port, 
304                     user, agent)
305             # popen_ssh_subprocess already waits for readiness
306         else:
307             self._process = subprocess.Popen(
308                     ["python", "-c", python_code],
309                     stdin = subprocess.PIPE, 
310                     stdout = subprocess.PIPE,
311                     stderr = subprocess.PIPE
312                 )
313                 
314         # Wait for the forwarder to be ready, otherwise nobody
315         # will be able to connect to it
316         helo = self._process.stderr.readline()
317         if helo != 'READY.\n':
318             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
319                     helo + self._process.stderr.read())
320         
321     def send_msg(self, msg):
322         encoded = base64.b64encode(msg)
323         data = "%s\n" % encoded
324         
325         try:
326             self._process.stdin.write(data)
327         except (IOError, ValueError):
328             # dead process, poll it to un-zombify
329             self._process.poll()
330             
331             # try again after reconnect
332             # If it fails again, though, give up
333             self.connect()
334             self._process.stdin.write(data)
335
336     def send_stop(self):
337         self.send_msg(STOP_MSG)
338         self._stopped = True
339
340     def read_reply(self):
341         data = self._process.stdout.readline()
342         encoded = data.rstrip() 
343         return base64.b64decode(encoded)
344
345 def _make_server_key_args(server_key, host, port, args):
346     """ 
347     Returns a reference to the created temporary file, and adds the
348     corresponding arguments to the given argument list.
349     
350     Make sure to hold onto it until the process is done with the file
351     """
352     if port is not None:
353         host = '%s:%s' % (host,port)
354     # Create a temporary server key file
355     tmp_known_hosts = tempfile.NamedTemporaryFile()
356     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
357     tmp_known_hosts.flush()
358     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
359     return tmp_known_hosts
360
361 def popen_ssh_command(command, host, port, user, agent, 
362             stdin="", 
363             ident_key = None,
364             server_key = None,
365             tty = False):
366         """
367         Executes a remote commands, returns ((stdout,stderr),process)
368         """
369         if TRACE:
370             print "ssh", host, command
371         
372         tmp_known_hosts = None
373         args = ['ssh',
374                 # Don't bother with localhost. Makes test easier
375                 '-o', 'NoHostAuthenticationForLocalhost=yes',
376                 '-l', user, host]
377         if agent:
378             args.append('-A')
379         if port:
380             args.append('-p%d' % port)
381         if ident_key:
382             args.extend(('-i', ident_key))
383         if tty:
384             args.append('-t')
385         if server_key:
386             # Create a temporary server key file
387             tmp_known_hosts = _make_server_key_args(
388                 server_key, host, port, args)
389         args.append(command)
390
391         # connects to the remote host and starts a remote connection
392         proc = subprocess.Popen(args, 
393                 stdout = subprocess.PIPE,
394                 stdin = subprocess.PIPE, 
395                 stderr = subprocess.PIPE)
396         
397         # attach tempfile object to the process, to make sure the file stays
398         # alive until the process is finished with it
399         proc._known_hosts = tmp_known_hosts
400         
401         return (proc.communicate(stdin), proc)
402  
403 def popen_scp(source, dest, 
404             port = None, 
405             agent = None, 
406             recursive = False,
407             ident_key = None,
408             server_key = None):
409         """
410         Copies from/to remote sites.
411         
412         Source and destination should have the user and host encoded
413         as per scp specs.
414         
415         If source is a file object, a special mode will be used to
416         create the remote file with the same contents.
417         
418         If dest is a file object, the remote file (source) will be
419         read and written into dest.
420         
421         In these modes, recursive cannot be True.
422         
423         Source can be a list of files to copy to a single destination,
424         in which case it is advised that the destination be a folder.
425         """
426         
427         if TRACE:
428             print "scp", source, dest
429         
430         if isinstance(source, file) or isinstance(dest, file) \
431                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
432             assert not recursive
433             
434             # Parse source/destination as <user>@<server>:<path>
435             if isinstance(dest, basestring) and ':' in dest:
436                 remspec, path = dest.split(':',1)
437             elif isinstance(source, basestring) and ':' in source:
438                 remspec, path = source.split(':',1)
439             else:
440                 raise ValueError, "Both endpoints cannot be local"
441             user,host = remspec.rsplit('@',1)
442             tmp_known_hosts = None
443             
444             args = ['ssh', '-l', user, '-C',
445                     # Don't bother with localhost. Makes test easier
446                     '-o', 'NoHostAuthenticationForLocalhost=yes',
447                     host ]
448             if port:
449                 args.append('-P%d' % port)
450             if ident_key:
451                 args.extend(('-i', ident_key))
452             if server_key:
453                 # Create a temporary server key file
454                 tmp_known_hosts = _make_server_key_args(
455                     server_key, host, port, args)
456             
457             if isinstance(source, file) or hasattr(source, 'read'):
458                 args.append('cat > %s' % (shell_escape(path),))
459             elif isinstance(dest, file) or hasattr(dest, 'write'):
460                 args.append('cat %s' % (shell_escape(path),))
461             else:
462                 raise AssertionError, "Unreachable code reached! :-Q"
463             
464             # connects to the remote host and starts a remote connection
465             if isinstance(source, file):
466                 proc = subprocess.Popen(args, 
467                         stdout = open('/dev/null','w'),
468                         stderr = subprocess.PIPE,
469                         stdin = source)
470                 err = proc.stderr.read()
471                 proc._known_hosts = tmp_known_hosts
472                 proc.wait()
473                 return ((None,err), proc)
474             elif isinstance(dest, file):
475                 proc = subprocess.Popen(args, 
476                         stdout = open('/dev/null','w'),
477                         stderr = subprocess.PIPE,
478                         stdin = source)
479                 err = proc.stderr.read()
480                 proc._known_hosts = tmp_known_hosts
481                 proc.wait()
482                 return ((None,err), proc)
483             elif hasattr(source, 'read'):
484                 # file-like (but not file) source
485                 proc = subprocess.Popen(args, 
486                         stdout = open('/dev/null','w'),
487                         stderr = subprocess.PIPE,
488                         stdin = subprocess.PIPE)
489                 
490                 buf = None
491                 err = []
492                 while True:
493                     if not buf:
494                         buf = source.read(4096)
495                     if not buf:
496                         #EOF
497                         break
498                     
499                     rdrdy, wrdy, broken = select.select(
500                         [proc.stderr],
501                         [proc.stdin],
502                         [proc.stderr,proc.stdin])
503                     
504                     if proc.stderr in rdrdy:
505                         # use os.read for fully unbuffered behavior
506                         err.append(os.read(proc.stderr.fileno(), 4096))
507                     
508                     if proc.stdin in wrdy:
509                         proc.stdin.write(buf)
510                         buf = None
511                     
512                     if broken:
513                         break
514                 proc.stdin.close()
515                 err.append(proc.stderr.read())
516                     
517                 proc._known_hosts = tmp_known_hosts
518                 proc.wait()
519                 return ((None,''.join(err)), proc)
520             elif hasattr(dest, 'write'):
521                 # file-like (but not file) dest
522                 proc = subprocess.Popen(args, 
523                         stdout = subprocess.PIPE,
524                         stderr = subprocess.PIPE,
525                         stdin = open('/dev/null','w'))
526                 
527                 buf = None
528                 err = []
529                 while True:
530                     rdrdy, wrdy, broken = select.select(
531                         [proc.stderr, proc.stdout],
532                         [],
533                         [proc.stderr, proc.stdout])
534                     
535                     if proc.stderr in rdrdy:
536                         # use os.read for fully unbuffered behavior
537                         err.append(os.read(proc.stderr.fileno(), 4096))
538                     
539                     if proc.stdout in rdrdy:
540                         # use os.read for fully unbuffered behavior
541                         buf = os.read(proc.stdout.fileno(), 4096)
542                         dest.write(buf)
543                         
544                         if not buf:
545                             #EOF
546                             break
547                     
548                     if broken:
549                         break
550                 err.append(proc.stderr.read())
551                     
552                 proc._known_hosts = tmp_known_hosts
553                 proc.wait()
554                 return ((None,''.join(err)), proc)
555             else:
556                 raise AssertionError, "Unreachable code reached! :-Q"
557         else:
558             # Parse destination as <user>@<server>:<path>
559             if isinstance(dest, basestring) and ':' in dest:
560                 remspec, path = dest.split(':',1)
561             elif isinstance(source, basestring) and ':' in source:
562                 remspec, path = source.split(':',1)
563             else:
564                 raise ValueError, "Both endpoints cannot be local"
565             user,host = remspec.rsplit('@',1)
566             
567             # plain scp
568             tmp_known_hosts = None
569             args = ['scp', '-q', '-p', '-C',
570                     # Don't bother with localhost. Makes test easier
571                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
572             if port:
573                 args.append('-P%d' % port)
574             if recursive:
575                 args.append('-r')
576             if ident_key:
577                 args.extend(('-i', ident_key))
578             if server_key:
579                 # Create a temporary server key file
580                 tmp_known_hosts = _make_server_key_args(
581                     server_key, host, port, args)
582             if isinstance(source,list):
583                 args.extend(source)
584             else:
585                 args.append(source)
586             args.append(dest)
587
588             # connects to the remote host and starts a remote connection
589             proc = subprocess.Popen(args, 
590                     stdout = subprocess.PIPE,
591                     stdin = subprocess.PIPE, 
592                     stderr = subprocess.PIPE)
593             proc._known_hosts = tmp_known_hosts
594             
595             comm = proc.communicate()
596             proc.wait()
597             return (comm, proc)
598  
599 def popen_ssh_subprocess(python_code, host, port, user, agent, 
600         python_path = None,
601         ident_key = None,
602         server_key = None,
603         tty = False):
604         if python_path:
605             python_path.replace("'", r"'\''")
606             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
607         else:
608             cmd = ""
609         # Uncomment for debug (to run everything under strace)
610         # We had to verify if strace works (cannot nest them)
611         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
612         #cmd += "$CMD "
613         #if self.mode == MODE_SSH:
614         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
615         cmd += "python -c '"
616         cmd += "import base64, os\n"
617         cmd += "cmd = \"\"\n"
618         cmd += "while True:\n"
619         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
620         cmd += " if cmd[-1] == \"\\n\": break\n"
621         cmd += "cmd = base64.b64decode(cmd)\n"
622         # Uncomment for debug
623         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
624         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
625         cmd += "exec(cmd)\n'"
626
627         tmp_known_hosts = None
628         args = ['ssh',
629                 # Don't bother with localhost. Makes test easier
630                 '-o', 'NoHostAuthenticationForLocalhost=yes',
631                 '-l', user, host]
632         if agent:
633             args.append('-A')
634         if port:
635             args.append('-p%d' % port)
636         if ident_key:
637             args.extend(('-i', ident_key))
638         if tty:
639             args.append('-t')
640         if server_key:
641             # Create a temporary server key file
642             tmp_known_hosts = _make_server_key_args(
643                 server_key, host, port, args)
644         args.append(cmd)
645
646         # connects to the remote host and starts a remote rpyc connection
647         proc = subprocess.Popen(args, 
648                 stdout = subprocess.PIPE,
649                 stdin = subprocess.PIPE, 
650                 stderr = subprocess.PIPE)
651         proc._known_hosts = tmp_known_hosts
652         
653         # send the command to execute
654         os.write(proc.stdin.fileno(),
655                 base64.b64encode(python_code) + "\n")
656         msg = os.read(proc.stdout.fileno(), 3)
657         if msg != "OK\n":
658             raise RuntimeError("Failed to start remote python interpreter")
659         return proc
660