bugfix: EINTRs caught
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import os.path
8 import resource
9 import select
10 import socket
11 import sys
12 import subprocess
13 import threading
14 import time
15 import traceback
16 import signal
17 import re
18 import tempfile
19
20 CTRL_SOCK = "ctrl.sock"
21 STD_ERR = "stderr.log"
22 MAX_FD = 1024
23
24 STOP_MSG = "STOP"
25
26 ERROR_LEVEL = 0
27 DEBUG_LEVEL = 1
28 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
29
30 if hasattr(os, "devnull"):
31     DEV_NULL = os.devnull
32 else:
33     DEV_NULL = "/dev/null"
34
35 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
36
37 def shell_escape(s):
38     """ Escapes strings so that they are safe to use as command-line arguments """
39     if SHELL_SAFE.match(s):
40         # safe string - no escaping needed
41         return s
42     else:
43         # unsafe string - escape
44         def escp(c):
45             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",):
46                 return c
47             else:
48                 return "'$'\\x%02x''" % (ord(c),)
49         s = ''.join(map(escp,s))
50         return "'%s'" % (s,)
51
52 class Server(object):
53     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
54         self._root_dir = root_dir
55         self._stop = False
56         self._ctrl_sock = None
57         self._log_level = log_level
58
59     def run(self):
60         try:
61             if self.daemonize():
62                 self.post_daemonize()
63                 self.loop()
64                 self.cleanup()
65                 # ref: "os._exit(0)"
66                 # can not return normally after fork beacuse no exec was done.
67                 # This means that if we don't do a os._exit(0) here the code that 
68                 # follows the call to "Server.run()" in the "caller code" will be 
69                 # executed... but by now it has already been executed after the 
70                 # first process (the one that did the first fork) returned.
71                 os._exit(0)
72         except:
73             self.log_error()
74             self.cleanup()
75             os._exit(0)
76
77     def daemonize(self):
78         # pipes for process synchronization
79         (r, w) = os.pipe()
80         
81         # build root folder
82         root = os.path.normpath(self._root_dir)
83         if not os.path.exists(root):
84             os.makedirs(root, 0755)
85
86         pid1 = os.fork()
87         if pid1 > 0:
88             os.close(w)
89             while True:
90                 try:
91                     os.read(r, 1)
92                 except OSError, e: # pragma: no cover
93                     if e.errno == errno.EINTR:
94                         continue
95                     else:
96                         raise
97                 break
98             os.close(r)
99             # os.waitpid avoids leaving a <defunc> (zombie) process
100             st = os.waitpid(pid1, 0)[1]
101             if st:
102                 raise RuntimeError("Daemonization failed")
103             # return 0 to inform the caller method that this is not the 
104             # daemonized process
105             return 0
106         os.close(r)
107
108         # Decouple from parent environment.
109         os.chdir(self._root_dir)
110         os.umask(0)
111         os.setsid()
112
113         # fork 2
114         pid2 = os.fork()
115         if pid2 > 0:
116             # see ref: "os._exit(0)"
117             os._exit(0)
118
119         # close all open file descriptors.
120         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
121         if (max_fd == resource.RLIM_INFINITY):
122             max_fd = MAX_FD
123         for fd in range(3, max_fd):
124             if fd != w:
125                 try:
126                     os.close(fd)
127                 except OSError:
128                     pass
129
130         # Redirect standard file descriptors.
131         stdin = open(DEV_NULL, "r")
132         stderr = stdout = open(STD_ERR, "a", 0)
133         os.dup2(stdin.fileno(), sys.stdin.fileno())
134         # NOTE: sys.stdout.write will still be buffered, even if the file
135         # was opened with 0 buffer
136         os.dup2(stdout.fileno(), sys.stdout.fileno())
137         os.dup2(stderr.fileno(), sys.stderr.fileno())
138
139         # create control socket
140         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
141         self._ctrl_sock.bind(CTRL_SOCK)
142         self._ctrl_sock.listen(0)
143
144         # let the parent process know that the daemonization is finished
145         os.write(w, "\n")
146         os.close(w)
147         return 1
148
149     def post_daemonize(self):
150         pass
151
152     def loop(self):
153         while not self._stop:
154             conn, addr = self._ctrl_sock.accept()
155             conn.settimeout(5)
156             while not self._stop:
157                 try:
158                     msg = self.recv_msg(conn)
159                 except socket.timeout, e:
160                     break
161                     
162                 if msg == STOP_MSG:
163                     self._stop = True
164                     reply = self.stop_action()
165                 else:
166                     reply = self.reply_action(msg)
167                 
168                 try:
169                     self.send_reply(conn, reply)
170                 except socket.error:
171                     self.log_error()
172                     self.log_error("NOTICE: Awaiting for reconnection")
173                     break
174             try:
175                 conn.close()
176             except:
177                 # Doesn't matter
178                 self.log_error()
179
180     def recv_msg(self, conn):
181         data = ""
182         while True:
183             try:
184                 chunk = conn.recv(1024)
185             except OSError, e:
186                 if e.errno != errno.EINTR:
187                     raise
188                 if chunk == '':
189                     continue
190             if chunk:
191                 data += chunk
192                 if chunk[-1] == "\n":
193                     break
194             else:
195                 # empty chunk = EOF
196                 break
197         decoded = base64.b64decode(data)
198         return decoded.rstrip()
199
200     def send_reply(self, conn, reply):
201         encoded = base64.b64encode(reply)
202         conn.send("%s\n" % encoded)
203        
204     def cleanup(self):
205         try:
206             self._ctrl_sock.close()
207             os.remove(CTRL_SOCK)
208         except:
209             self.log_error()
210
211     def stop_action(self):
212         return "Stopping server"
213
214     def reply_action(self, msg):
215         return "Reply to: %s" % msg
216
217     def log_error(self, text = None, context = ''):
218         if text == None:
219             text = traceback.format_exc()
220         date = time.strftime("%Y-%m-%d %H:%M:%S")
221         if context:
222             context = " (%s)" % (context,)
223         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
224         return text
225
226     def log_debug(self, text):
227         if self._log_level == DEBUG_LEVEL:
228             date = time.strftime("%Y-%m-%d %H:%M:%S")
229             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
230
231 class Forwarder(object):
232     def __init__(self, root_dir = "."):
233         self._ctrl_sock = None
234         self._root_dir = root_dir
235         self._stop = False
236
237     def forward(self):
238         self.connect()
239         print >>sys.stderr, "READY."
240         while not self._stop:
241             data = self.read_data()
242             self.send_to_server(data)
243             data = self.recv_from_server()
244             self.write_data(data)
245         self.disconnect()
246
247     def read_data(self):
248         return sys.stdin.readline()
249
250     def write_data(self, data):
251         sys.stdout.write(data)
252         # sys.stdout.write is buffered, this is why we need to do a flush()
253         sys.stdout.flush()
254
255     def send_to_server(self, data):
256         try:
257             self._ctrl_sock.send(data)
258         except IOError, e:
259             if e.errno == errno.EPIPE:
260                 self.connect()
261                 self._ctrl_sock.send(data)
262             else:
263                 raise e
264         encoded = data.rstrip() 
265         msg = base64.b64decode(encoded)
266         if msg == STOP_MSG:
267             self._stop = True
268
269     def recv_from_server(self):
270         data = ""
271         while True:
272             try:
273                 chunk = self._ctrl_sock.recv(1024)
274             except OSError, e:
275                 if e.errno != errno.EINTR:
276                     raise
277                 if chunk == '':
278                     continue
279             data += chunk
280             if chunk[-1] == "\n":
281                 break
282         return data
283  
284     def connect(self):
285         self.disconnect()
286         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
287         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
288         self._ctrl_sock.connect(sock_addr)
289
290     def disconnect(self):
291         try:
292             self._ctrl_sock.close()
293         except:
294             pass
295
296 class Client(object):
297     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
298             agent = None, environment_setup = ""):
299         self.root_dir = root_dir
300         self.addr = (host, port)
301         self.user = user
302         self.agent = agent
303         self.environment_setup = environment_setup
304         self._stopped = False
305         self.connect()
306     
307     def __del__(self):
308         if self._process.poll() is None:
309             os.kill(self._process.pid, signal.SIGTERM)
310         self._process.wait()
311         
312     def connect(self):
313         root_dir = self.root_dir
314         (host, port) = self.addr
315         user = self.user
316         agent = self.agent
317         
318         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
319                 c.forward()" % (root_dir,)
320         if host != None:
321             self._process = popen_ssh_subprocess(python_code, host, port, 
322                     user, agent,
323                     environment_setup = self.environment_setup)
324             # popen_ssh_subprocess already waits for readiness
325             if self._process.poll():
326                 err = proc.stderr.read()
327                 raise RuntimeError("Client could not be reached: %s" % \
328                         err)
329         else:
330             self._process = subprocess.Popen(
331                     ["python", "-c", python_code],
332                     stdin = subprocess.PIPE, 
333                     stdout = subprocess.PIPE,
334                     stderr = subprocess.PIPE
335                 )
336                 
337         # Wait for the forwarder to be ready, otherwise nobody
338         # will be able to connect to it
339         helo = self._process.stderr.readline()
340         if helo != 'READY.\n':
341             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
342                     helo + self._process.stderr.read())
343         
344     def send_msg(self, msg):
345         encoded = base64.b64encode(msg)
346         data = "%s\n" % encoded
347         
348         try:
349             self._process.stdin.write(data)
350         except (IOError, ValueError):
351             # dead process, poll it to un-zombify
352             self._process.poll()
353             
354             # try again after reconnect
355             # If it fails again, though, give up
356             self.connect()
357             self._process.stdin.write(data)
358
359     def send_stop(self):
360         self.send_msg(STOP_MSG)
361         self._stopped = True
362
363     def read_reply(self):
364         data = self._process.stdout.readline()
365         encoded = data.rstrip() 
366         return base64.b64decode(encoded)
367
368 def _make_server_key_args(server_key, host, port, args):
369     """ 
370     Returns a reference to the created temporary file, and adds the
371     corresponding arguments to the given argument list.
372     
373     Make sure to hold onto it until the process is done with the file
374     """
375     if port is not None:
376         host = '%s:%s' % (host,port)
377     # Create a temporary server key file
378     tmp_known_hosts = tempfile.NamedTemporaryFile()
379     
380     # Add the intended host key
381     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
382     
383     # If we're not in strict mode, add user-configured keys
384     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
385         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
386         if os.access(user_hosts_path, os.R_OK):
387             f = open(user_hosts_path, "r")
388             tmp_known_hosts.write(f.read())
389             f.close()
390         
391     tmp_known_hosts.flush()
392     
393     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
394     return tmp_known_hosts
395
396 def popen_ssh_command(command, host, port, user, agent, 
397             stdin="", 
398             ident_key = None,
399             server_key = None,
400             tty = False):
401         """
402         Executes a remote commands, returns ((stdout,stderr),process)
403         """
404         if TRACE:
405             print "ssh", host, command
406         
407         tmp_known_hosts = None
408         args = ['ssh',
409                 # Don't bother with localhost. Makes test easier
410                 '-o', 'NoHostAuthenticationForLocalhost=yes',
411                 '-l', user, host]
412         if agent:
413             args.append('-A')
414         if port:
415             args.append('-p%d' % port)
416         if ident_key:
417             args.extend(('-i', ident_key))
418         if tty:
419             args.append('-t')
420         if server_key:
421             # Create a temporary server key file
422             tmp_known_hosts = _make_server_key_args(
423                 server_key, host, port, args)
424         args.append(command)
425
426         # connects to the remote host and starts a remote connection
427         proc = subprocess.Popen(args, 
428                 stdout = subprocess.PIPE,
429                 stdin = subprocess.PIPE, 
430                 stderr = subprocess.PIPE)
431         
432         # attach tempfile object to the process, to make sure the file stays
433         # alive until the process is finished with it
434         proc._known_hosts = tmp_known_hosts
435         
436         out, err = proc.communicate(stdin)
437         if TRACE:
438             print " -> ", out, err
439
440         return ((out, err), proc)
441  
442 def popen_scp(source, dest, 
443             port = None, 
444             agent = None, 
445             recursive = False,
446             ident_key = None,
447             server_key = None):
448         """
449         Copies from/to remote sites.
450         
451         Source and destination should have the user and host encoded
452         as per scp specs.
453         
454         If source is a file object, a special mode will be used to
455         create the remote file with the same contents.
456         
457         If dest is a file object, the remote file (source) will be
458         read and written into dest.
459         
460         In these modes, recursive cannot be True.
461         
462         Source can be a list of files to copy to a single destination,
463         in which case it is advised that the destination be a folder.
464         """
465         
466         if TRACE:
467             print "scp", source, dest
468         
469         if isinstance(source, file) or isinstance(dest, file) \
470                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
471             assert not recursive
472             
473             # Parse source/destination as <user>@<server>:<path>
474             if isinstance(dest, basestring) and ':' in dest:
475                 remspec, path = dest.split(':',1)
476             elif isinstance(source, basestring) and ':' in source:
477                 remspec, path = source.split(':',1)
478             else:
479                 raise ValueError, "Both endpoints cannot be local"
480             user,host = remspec.rsplit('@',1)
481             tmp_known_hosts = None
482             
483             args = ['ssh', '-l', user, '-C',
484                     # Don't bother with localhost. Makes test easier
485                     '-o', 'NoHostAuthenticationForLocalhost=yes',
486                     host ]
487             if port:
488                 args.append('-P%d' % port)
489             if ident_key:
490                 args.extend(('-i', ident_key))
491             if server_key:
492                 # Create a temporary server key file
493                 tmp_known_hosts = _make_server_key_args(
494                     server_key, host, port, args)
495             
496             if isinstance(source, file) or hasattr(source, 'read'):
497                 args.append('cat > %s' % (shell_escape(path),))
498             elif isinstance(dest, file) or hasattr(dest, 'write'):
499                 args.append('cat %s' % (shell_escape(path),))
500             else:
501                 raise AssertionError, "Unreachable code reached! :-Q"
502             
503             # connects to the remote host and starts a remote connection
504             if isinstance(source, file):
505                 proc = subprocess.Popen(args, 
506                         stdout = open('/dev/null','w'),
507                         stderr = subprocess.PIPE,
508                         stdin = source)
509                 err = proc.stderr.read()
510                 proc._known_hosts = tmp_known_hosts
511                 proc.wait()
512                 return ((None,err), proc)
513             elif isinstance(dest, file):
514                 proc = subprocess.Popen(args, 
515                         stdout = open('/dev/null','w'),
516                         stderr = subprocess.PIPE,
517                         stdin = source)
518                 err = proc.stderr.read()
519                 proc._known_hosts = tmp_known_hosts
520                 proc.wait()
521                 return ((None,err), proc)
522             elif hasattr(source, 'read'):
523                 # file-like (but not file) source
524                 proc = subprocess.Popen(args, 
525                         stdout = open('/dev/null','w'),
526                         stderr = subprocess.PIPE,
527                         stdin = subprocess.PIPE)
528                 
529                 buf = None
530                 err = []
531                 while True:
532                     if not buf:
533                         buf = source.read(4096)
534                     if not buf:
535                         #EOF
536                         break
537                     
538                     rdrdy, wrdy, broken = select.select(
539                         [proc.stderr],
540                         [proc.stdin],
541                         [proc.stderr,proc.stdin])
542                     
543                     if proc.stderr in rdrdy:
544                         # use os.read for fully unbuffered behavior
545                         err.append(os.read(proc.stderr.fileno(), 4096))
546                     
547                     if proc.stdin in wrdy:
548                         proc.stdin.write(buf)
549                         buf = None
550                     
551                     if broken:
552                         break
553                 proc.stdin.close()
554                 err.append(proc.stderr.read())
555                     
556                 proc._known_hosts = tmp_known_hosts
557                 proc.wait()
558                 return ((None,''.join(err)), proc)
559             elif hasattr(dest, 'write'):
560                 # file-like (but not file) dest
561                 proc = subprocess.Popen(args, 
562                         stdout = subprocess.PIPE,
563                         stderr = subprocess.PIPE,
564                         stdin = open('/dev/null','w'))
565                 
566                 buf = None
567                 err = []
568                 while True:
569                     rdrdy, wrdy, broken = select.select(
570                         [proc.stderr, proc.stdout],
571                         [],
572                         [proc.stderr, proc.stdout])
573                     
574                     if proc.stderr in rdrdy:
575                         # use os.read for fully unbuffered behavior
576                         err.append(os.read(proc.stderr.fileno(), 4096))
577                     
578                     if proc.stdout in rdrdy:
579                         # use os.read for fully unbuffered behavior
580                         buf = os.read(proc.stdout.fileno(), 4096)
581                         dest.write(buf)
582                         
583                         if not buf:
584                             #EOF
585                             break
586                     
587                     if broken:
588                         break
589                 err.append(proc.stderr.read())
590                     
591                 proc._known_hosts = tmp_known_hosts
592                 proc.wait()
593                 return ((None,''.join(err)), proc)
594             else:
595                 raise AssertionError, "Unreachable code reached! :-Q"
596         else:
597             # Parse destination as <user>@<server>:<path>
598             if isinstance(dest, basestring) and ':' in dest:
599                 remspec, path = dest.split(':',1)
600             elif isinstance(source, basestring) and ':' in source:
601                 remspec, path = source.split(':',1)
602             else:
603                 raise ValueError, "Both endpoints cannot be local"
604             user,host = remspec.rsplit('@',1)
605             
606             # plain scp
607             tmp_known_hosts = None
608             args = ['scp', '-q', '-p', '-C',
609                     # Don't bother with localhost. Makes test easier
610                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
611             if port:
612                 args.append('-P%d' % port)
613             if recursive:
614                 args.append('-r')
615             if ident_key:
616                 args.extend(('-i', ident_key))
617             if server_key:
618                 # Create a temporary server key file
619                 tmp_known_hosts = _make_server_key_args(
620                     server_key, host, port, args)
621             if isinstance(source,list):
622                 args.extend(source)
623             else:
624                 args.append(source)
625             args.append(dest)
626
627             # connects to the remote host and starts a remote connection
628             proc = subprocess.Popen(args, 
629                     stdout = subprocess.PIPE,
630                     stdin = subprocess.PIPE, 
631                     stderr = subprocess.PIPE)
632             proc._known_hosts = tmp_known_hosts
633             
634             comm = proc.communicate()
635             proc.wait()
636             return (comm, proc)
637  
638 def popen_ssh_subprocess(python_code, host, port, user, agent, 
639         python_path = None,
640         ident_key = None,
641         server_key = None,
642         tty = False,
643         environment_setup = "",
644         waitcommand = False):
645         cmd = ""
646         if python_path:
647             python_path.replace("'", r"'\''")
648             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
649             cmd += " ; "
650         if environment_setup:
651             cmd += environment_setup
652             cmd += " ; "
653         # Uncomment for debug (to run everything under strace)
654         # We had to verify if strace works (cannot nest them)
655         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
656         #cmd += "$CMD "
657         #cmd += "strace -f -tt -s 200 -o strace$$.out "
658         cmd += "python -c '"
659         cmd += "import base64, os\n"
660         cmd += "cmd = \"\"\n"
661         cmd += "while True:\n"
662         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
663         cmd += " if cmd[-1] == \"\\n\": break\n"
664         cmd += "cmd = base64.b64decode(cmd)\n"
665         # Uncomment for debug
666         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
667         if not waitcommand:
668             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
669         cmd += "exec(cmd)\n"
670         if waitcommand:
671             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
672         cmd += "'"
673         
674         tmp_known_hosts = None
675         args = ['ssh',
676                 # Don't bother with localhost. Makes test easier
677                 '-o', 'NoHostAuthenticationForLocalhost=yes',
678                 '-l', user, host]
679         if agent:
680             args.append('-A')
681         if port:
682             args.append('-p%d' % port)
683         if ident_key:
684             args.extend(('-i', ident_key))
685         if tty:
686             args.append('-t')
687         if server_key:
688             # Create a temporary server key file
689             tmp_known_hosts = _make_server_key_args(
690                 server_key, host, port, args)
691         args.append(cmd)
692
693         # connects to the remote host and starts a remote rpyc connection
694         proc = subprocess.Popen(args, 
695                 stdout = subprocess.PIPE,
696                 stdin = subprocess.PIPE, 
697                 stderr = subprocess.PIPE)
698         proc._known_hosts = tmp_known_hosts
699         
700         # send the command to execute
701         os.write(proc.stdin.fileno(),
702                 base64.b64encode(python_code) + "\n")
703         msg = os.read(proc.stdout.fileno(), 3)
704         if msg != "OK\n":
705             raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % (
706                 msg, proc.stdout.read(), proc.stderr.read())
707         return proc
708