Retry retriable operations when we get an EINTR
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import os.path
8 import resource
9 import select
10 import socket
11 import sys
12 import subprocess
13 import threading
14 import time
15 import traceback
16 import signal
17 import re
18 import tempfile
19
20 CTRL_SOCK = "ctrl.sock"
21 STD_ERR = "stderr.log"
22 MAX_FD = 1024
23
24 STOP_MSG = "STOP"
25
26 ERROR_LEVEL = 0
27 DEBUG_LEVEL = 1
28 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
29
30 if hasattr(os, "devnull"):
31     DEV_NULL = os.devnull
32 else:
33     DEV_NULL = "/dev/null"
34
35 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
36
37 def shell_escape(s):
38     """ Escapes strings so that they are safe to use as command-line arguments """
39     if SHELL_SAFE.match(s):
40         # safe string - no escaping needed
41         return s
42     else:
43         # unsafe string - escape
44         def escp(c):
45             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",):
46                 return c
47             else:
48                 return "'$'\\x%02x''" % (ord(c),)
49         s = ''.join(map(escp,s))
50         return "'%s'" % (s,)
51
52 def eintr_retry(func):
53     import functools
54     @functools.wraps(func)
55     def rv(*p, **kw):
56         retry = kw.pop("_retry", False)
57         for i in xrange(0 if retry else 4):
58             try:
59                 return func(*p, **kw)
60             except select.error, args:
61                 if args[0] == errno.EINTR:
62                     continue
63                 else:
64                     raise 
65         else:
66             return func(*p, **kw)
67     return rv
68
69 class Server(object):
70     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
71         self._root_dir = root_dir
72         self._stop = False
73         self._ctrl_sock = None
74         self._log_level = log_level
75
76     def run(self):
77         try:
78             if self.daemonize():
79                 self.post_daemonize()
80                 self.loop()
81                 self.cleanup()
82                 # ref: "os._exit(0)"
83                 # can not return normally after fork beacuse no exec was done.
84                 # This means that if we don't do a os._exit(0) here the code that 
85                 # follows the call to "Server.run()" in the "caller code" will be 
86                 # executed... but by now it has already been executed after the 
87                 # first process (the one that did the first fork) returned.
88                 os._exit(0)
89         except:
90             self.log_error()
91             self.cleanup()
92             os._exit(0)
93
94     def daemonize(self):
95         # pipes for process synchronization
96         (r, w) = os.pipe()
97         
98         # build root folder
99         root = os.path.normpath(self._root_dir)
100         if not os.path.exists(root):
101             os.makedirs(root, 0755)
102
103         pid1 = os.fork()
104         if pid1 > 0:
105             os.close(w)
106             while True:
107                 try:
108                     os.read(r, 1)
109                 except OSError, e: # pragma: no cover
110                     if e.errno == errno.EINTR:
111                         continue
112                     else:
113                         raise
114                 break
115             os.close(r)
116             # os.waitpid avoids leaving a <defunc> (zombie) process
117             st = os.waitpid(pid1, 0)[1]
118             if st:
119                 raise RuntimeError("Daemonization failed")
120             # return 0 to inform the caller method that this is not the 
121             # daemonized process
122             return 0
123         os.close(r)
124
125         # Decouple from parent environment.
126         os.chdir(self._root_dir)
127         os.umask(0)
128         os.setsid()
129
130         # fork 2
131         pid2 = os.fork()
132         if pid2 > 0:
133             # see ref: "os._exit(0)"
134             os._exit(0)
135
136         # close all open file descriptors.
137         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
138         if (max_fd == resource.RLIM_INFINITY):
139             max_fd = MAX_FD
140         for fd in range(3, max_fd):
141             if fd != w:
142                 try:
143                     os.close(fd)
144                 except OSError:
145                     pass
146
147         # Redirect standard file descriptors.
148         stdin = open(DEV_NULL, "r")
149         stderr = stdout = open(STD_ERR, "a", 0)
150         os.dup2(stdin.fileno(), sys.stdin.fileno())
151         # NOTE: sys.stdout.write will still be buffered, even if the file
152         # was opened with 0 buffer
153         os.dup2(stdout.fileno(), sys.stdout.fileno())
154         os.dup2(stderr.fileno(), sys.stderr.fileno())
155
156         # create control socket
157         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
158         self._ctrl_sock.bind(CTRL_SOCK)
159         self._ctrl_sock.listen(0)
160
161         # let the parent process know that the daemonization is finished
162         os.write(w, "\n")
163         os.close(w)
164         return 1
165
166     def post_daemonize(self):
167         pass
168
169     def loop(self):
170         while not self._stop:
171             conn, addr = self._ctrl_sock.accept()
172             conn.settimeout(5)
173             while not self._stop:
174                 try:
175                     msg = self.recv_msg(conn)
176                 except socket.timeout, e:
177                     break
178                     
179                 if msg == STOP_MSG:
180                     self._stop = True
181                     reply = self.stop_action()
182                 else:
183                     reply = self.reply_action(msg)
184                 
185                 try:
186                     self.send_reply(conn, reply)
187                 except socket.error:
188                     self.log_error()
189                     self.log_error("NOTICE: Awaiting for reconnection")
190                     break
191             try:
192                 conn.close()
193             except:
194                 # Doesn't matter
195                 self.log_error()
196
197     def recv_msg(self, conn):
198         data = ""
199         while True:
200             try:
201                 chunk = conn.recv(1024)
202             except OSError, e:
203                 if e.errno != errno.EINTR:
204                     raise
205                 if chunk == '':
206                     continue
207             if chunk:
208                 data += chunk
209                 if chunk[-1] == "\n":
210                     break
211             else:
212                 # empty chunk = EOF
213                 break
214         decoded = base64.b64decode(data)
215         return decoded.rstrip()
216
217     def send_reply(self, conn, reply):
218         encoded = base64.b64encode(reply)
219         conn.send("%s\n" % encoded)
220        
221     def cleanup(self):
222         try:
223             self._ctrl_sock.close()
224             os.remove(CTRL_SOCK)
225         except:
226             self.log_error()
227
228     def stop_action(self):
229         return "Stopping server"
230
231     def reply_action(self, msg):
232         return "Reply to: %s" % msg
233
234     def log_error(self, text = None, context = ''):
235         if text == None:
236             text = traceback.format_exc()
237         date = time.strftime("%Y-%m-%d %H:%M:%S")
238         if context:
239             context = " (%s)" % (context,)
240         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
241         return text
242
243     def log_debug(self, text):
244         if self._log_level == DEBUG_LEVEL:
245             date = time.strftime("%Y-%m-%d %H:%M:%S")
246             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
247
248 class Forwarder(object):
249     def __init__(self, root_dir = "."):
250         self._ctrl_sock = None
251         self._root_dir = root_dir
252         self._stop = False
253
254     def forward(self):
255         self.connect()
256         print >>sys.stderr, "READY."
257         while not self._stop:
258             data = self.read_data()
259             self.send_to_server(data)
260             data = self.recv_from_server()
261             self.write_data(data)
262         self.disconnect()
263
264     def read_data(self):
265         return sys.stdin.readline()
266
267     def write_data(self, data):
268         sys.stdout.write(data)
269         # sys.stdout.write is buffered, this is why we need to do a flush()
270         sys.stdout.flush()
271
272     def send_to_server(self, data):
273         try:
274             self._ctrl_sock.send(data)
275         except IOError, e:
276             if e.errno == errno.EPIPE:
277                 self.connect()
278                 self._ctrl_sock.send(data)
279             else:
280                 raise e
281         encoded = data.rstrip() 
282         msg = base64.b64decode(encoded)
283         if msg == STOP_MSG:
284             self._stop = True
285
286     def recv_from_server(self):
287         data = ""
288         while True:
289             try:
290                 chunk = self._ctrl_sock.recv(1024)
291             except OSError, e:
292                 if e.errno != errno.EINTR:
293                     raise
294                 if chunk == '':
295                     continue
296             data += chunk
297             if chunk[-1] == "\n":
298                 break
299         return data
300  
301     def connect(self):
302         self.disconnect()
303         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
304         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
305         self._ctrl_sock.connect(sock_addr)
306
307     def disconnect(self):
308         try:
309             self._ctrl_sock.close()
310         except:
311             pass
312
313 class Client(object):
314     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
315             agent = None, environment_setup = ""):
316         self.root_dir = root_dir
317         self.addr = (host, port)
318         self.user = user
319         self.agent = agent
320         self.environment_setup = environment_setup
321         self._stopped = False
322         self.connect()
323     
324     def __del__(self):
325         if self._process.poll() is None:
326             os.kill(self._process.pid, signal.SIGTERM)
327         self._process.wait()
328         
329     def connect(self):
330         root_dir = self.root_dir
331         (host, port) = self.addr
332         user = self.user
333         agent = self.agent
334         
335         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
336                 c.forward()" % (root_dir,)
337         if host != None:
338             self._process = popen_ssh_subprocess(python_code, host, port, 
339                     user, agent,
340                     environment_setup = self.environment_setup)
341             # popen_ssh_subprocess already waits for readiness
342             if self._process.poll():
343                 err = proc.stderr.read()
344                 raise RuntimeError("Client could not be reached: %s" % \
345                         err)
346         else:
347             self._process = subprocess.Popen(
348                     ["python", "-c", python_code],
349                     stdin = subprocess.PIPE, 
350                     stdout = subprocess.PIPE,
351                     stderr = subprocess.PIPE
352                 )
353                 
354         # Wait for the forwarder to be ready, otherwise nobody
355         # will be able to connect to it
356         helo = self._process.stderr.readline()
357         if helo != 'READY.\n':
358             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
359                     helo + self._process.stderr.read())
360         
361     def send_msg(self, msg):
362         encoded = base64.b64encode(msg)
363         data = "%s\n" % encoded
364         
365         try:
366             self._process.stdin.write(data)
367         except (IOError, ValueError):
368             # dead process, poll it to un-zombify
369             self._process.poll()
370             
371             # try again after reconnect
372             # If it fails again, though, give up
373             self.connect()
374             self._process.stdin.write(data)
375
376     def send_stop(self):
377         self.send_msg(STOP_MSG)
378         self._stopped = True
379
380     def read_reply(self):
381         data = self._process.stdout.readline()
382         encoded = data.rstrip() 
383         return base64.b64decode(encoded)
384
385 def _make_server_key_args(server_key, host, port, args):
386     """ 
387     Returns a reference to the created temporary file, and adds the
388     corresponding arguments to the given argument list.
389     
390     Make sure to hold onto it until the process is done with the file
391     """
392     if port is not None:
393         host = '%s:%s' % (host,port)
394     # Create a temporary server key file
395     tmp_known_hosts = tempfile.NamedTemporaryFile()
396     
397     # Add the intended host key
398     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
399     
400     # If we're not in strict mode, add user-configured keys
401     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
402         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
403         if os.access(user_hosts_path, os.R_OK):
404             f = open(user_hosts_path, "r")
405             tmp_known_hosts.write(f.read())
406             f.close()
407         
408     tmp_known_hosts.flush()
409     
410     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
411     return tmp_known_hosts
412
413 def popen_ssh_command(command, host, port, user, agent, 
414             stdin="", 
415             ident_key = None,
416             server_key = None,
417             tty = False):
418         """
419         Executes a remote commands, returns ((stdout,stderr),process)
420         """
421         if TRACE:
422             print "ssh", host, command
423         
424         tmp_known_hosts = None
425         args = ['ssh',
426                 # Don't bother with localhost. Makes test easier
427                 '-o', 'NoHostAuthenticationForLocalhost=yes',
428                 '-l', user, host]
429         if agent:
430             args.append('-A')
431         if port:
432             args.append('-p%d' % port)
433         if ident_key:
434             args.extend(('-i', ident_key))
435         if tty:
436             args.append('-t')
437         if server_key:
438             # Create a temporary server key file
439             tmp_known_hosts = _make_server_key_args(
440                 server_key, host, port, args)
441         args.append(command)
442
443         # connects to the remote host and starts a remote connection
444         proc = subprocess.Popen(args, 
445                 stdout = subprocess.PIPE,
446                 stdin = subprocess.PIPE, 
447                 stderr = subprocess.PIPE)
448         
449         # attach tempfile object to the process, to make sure the file stays
450         # alive until the process is finished with it
451         proc._known_hosts = tmp_known_hosts
452         
453         out, err = proc.communicate(stdin)
454         if TRACE:
455             print " -> ", out, err
456
457         return ((out, err), proc)
458  
459 def popen_scp(source, dest, 
460             port = None, 
461             agent = None, 
462             recursive = False,
463             ident_key = None,
464             server_key = None):
465         """
466         Copies from/to remote sites.
467         
468         Source and destination should have the user and host encoded
469         as per scp specs.
470         
471         If source is a file object, a special mode will be used to
472         create the remote file with the same contents.
473         
474         If dest is a file object, the remote file (source) will be
475         read and written into dest.
476         
477         In these modes, recursive cannot be True.
478         
479         Source can be a list of files to copy to a single destination,
480         in which case it is advised that the destination be a folder.
481         """
482         
483         if TRACE:
484             print "scp", source, dest
485         
486         if isinstance(source, file) or isinstance(dest, file) \
487                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
488             assert not recursive
489             
490             # Parse source/destination as <user>@<server>:<path>
491             if isinstance(dest, basestring) and ':' in dest:
492                 remspec, path = dest.split(':',1)
493             elif isinstance(source, basestring) and ':' in source:
494                 remspec, path = source.split(':',1)
495             else:
496                 raise ValueError, "Both endpoints cannot be local"
497             user,host = remspec.rsplit('@',1)
498             tmp_known_hosts = None
499             
500             args = ['ssh', '-l', user, '-C',
501                     # Don't bother with localhost. Makes test easier
502                     '-o', 'NoHostAuthenticationForLocalhost=yes',
503                     host ]
504             if port:
505                 args.append('-P%d' % port)
506             if ident_key:
507                 args.extend(('-i', ident_key))
508             if server_key:
509                 # Create a temporary server key file
510                 tmp_known_hosts = _make_server_key_args(
511                     server_key, host, port, args)
512             
513             if isinstance(source, file) or hasattr(source, 'read'):
514                 args.append('cat > %s' % (shell_escape(path),))
515             elif isinstance(dest, file) or hasattr(dest, 'write'):
516                 args.append('cat %s' % (shell_escape(path),))
517             else:
518                 raise AssertionError, "Unreachable code reached! :-Q"
519             
520             # connects to the remote host and starts a remote connection
521             if isinstance(source, file):
522                 proc = subprocess.Popen(args, 
523                         stdout = open('/dev/null','w'),
524                         stderr = subprocess.PIPE,
525                         stdin = source)
526                 err = proc.stderr.read()
527                 proc._known_hosts = tmp_known_hosts
528                 proc.wait()
529                 return ((None,err), proc)
530             elif isinstance(dest, file):
531                 proc = subprocess.Popen(args, 
532                         stdout = open('/dev/null','w'),
533                         stderr = subprocess.PIPE,
534                         stdin = source)
535                 err = proc.stderr.read()
536                 proc._known_hosts = tmp_known_hosts
537                 proc.wait()
538                 return ((None,err), proc)
539             elif hasattr(source, 'read'):
540                 # file-like (but not file) source
541                 proc = subprocess.Popen(args, 
542                         stdout = open('/dev/null','w'),
543                         stderr = subprocess.PIPE,
544                         stdin = subprocess.PIPE)
545                 
546                 buf = None
547                 err = []
548                 while True:
549                     if not buf:
550                         buf = source.read(4096)
551                     if not buf:
552                         #EOF
553                         break
554                     
555                     rdrdy, wrdy, broken = select.select(
556                         [proc.stderr],
557                         [proc.stdin],
558                         [proc.stderr,proc.stdin])
559                     
560                     if proc.stderr in rdrdy:
561                         # use os.read for fully unbuffered behavior
562                         err.append(os.read(proc.stderr.fileno(), 4096))
563                     
564                     if proc.stdin in wrdy:
565                         proc.stdin.write(buf)
566                         buf = None
567                     
568                     if broken:
569                         break
570                 proc.stdin.close()
571                 err.append(proc.stderr.read())
572                     
573                 proc._known_hosts = tmp_known_hosts
574                 proc.wait()
575                 return ((None,''.join(err)), proc)
576             elif hasattr(dest, 'write'):
577                 # file-like (but not file) dest
578                 proc = subprocess.Popen(args, 
579                         stdout = subprocess.PIPE,
580                         stderr = subprocess.PIPE,
581                         stdin = open('/dev/null','w'))
582                 
583                 buf = None
584                 err = []
585                 while True:
586                     rdrdy, wrdy, broken = select.select(
587                         [proc.stderr, proc.stdout],
588                         [],
589                         [proc.stderr, proc.stdout])
590                     
591                     if proc.stderr in rdrdy:
592                         # use os.read for fully unbuffered behavior
593                         err.append(os.read(proc.stderr.fileno(), 4096))
594                     
595                     if proc.stdout in rdrdy:
596                         # use os.read for fully unbuffered behavior
597                         buf = os.read(proc.stdout.fileno(), 4096)
598                         dest.write(buf)
599                         
600                         if not buf:
601                             #EOF
602                             break
603                     
604                     if broken:
605                         break
606                 err.append(proc.stderr.read())
607                     
608                 proc._known_hosts = tmp_known_hosts
609                 proc.wait()
610                 return ((None,''.join(err)), proc)
611             else:
612                 raise AssertionError, "Unreachable code reached! :-Q"
613         else:
614             # Parse destination as <user>@<server>:<path>
615             if isinstance(dest, basestring) and ':' in dest:
616                 remspec, path = dest.split(':',1)
617             elif isinstance(source, basestring) and ':' in source:
618                 remspec, path = source.split(':',1)
619             else:
620                 raise ValueError, "Both endpoints cannot be local"
621             user,host = remspec.rsplit('@',1)
622             
623             # plain scp
624             tmp_known_hosts = None
625             args = ['scp', '-q', '-p', '-C',
626                     # Don't bother with localhost. Makes test easier
627                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
628             if port:
629                 args.append('-P%d' % port)
630             if recursive:
631                 args.append('-r')
632             if ident_key:
633                 args.extend(('-i', ident_key))
634             if server_key:
635                 # Create a temporary server key file
636                 tmp_known_hosts = _make_server_key_args(
637                     server_key, host, port, args)
638             if isinstance(source,list):
639                 args.extend(source)
640             else:
641                 args.append(source)
642             args.append(dest)
643
644             # connects to the remote host and starts a remote connection
645             proc = subprocess.Popen(args, 
646                     stdout = subprocess.PIPE,
647                     stdin = subprocess.PIPE, 
648                     stderr = subprocess.PIPE)
649             proc._known_hosts = tmp_known_hosts
650             
651             comm = proc.communicate()
652             proc.wait()
653             return (comm, proc)
654  
655 def popen_ssh_subprocess(python_code, host, port, user, agent, 
656         python_path = None,
657         ident_key = None,
658         server_key = None,
659         tty = False,
660         environment_setup = "",
661         waitcommand = False):
662         cmd = ""
663         if python_path:
664             python_path.replace("'", r"'\''")
665             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
666             cmd += " ; "
667         if environment_setup:
668             cmd += environment_setup
669             cmd += " ; "
670         # Uncomment for debug (to run everything under strace)
671         # We had to verify if strace works (cannot nest them)
672         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
673         #cmd += "$CMD "
674         #cmd += "strace -f -tt -s 200 -o strace$$.out "
675         cmd += "python -c '"
676         cmd += "import base64, os\n"
677         cmd += "cmd = \"\"\n"
678         cmd += "while True:\n"
679         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
680         cmd += " if cmd[-1] == \"\\n\": break\n"
681         cmd += "cmd = base64.b64decode(cmd)\n"
682         # Uncomment for debug
683         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
684         if not waitcommand:
685             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
686         cmd += "exec(cmd)\n"
687         if waitcommand:
688             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
689         cmd += "'"
690         
691         tmp_known_hosts = None
692         args = ['ssh',
693                 # Don't bother with localhost. Makes test easier
694                 '-o', 'NoHostAuthenticationForLocalhost=yes',
695                 '-l', user, host]
696         if agent:
697             args.append('-A')
698         if port:
699             args.append('-p%d' % port)
700         if ident_key:
701             args.extend(('-i', ident_key))
702         if tty:
703             args.append('-t')
704         if server_key:
705             # Create a temporary server key file
706             tmp_known_hosts = _make_server_key_args(
707                 server_key, host, port, args)
708         args.append(cmd)
709
710         # connects to the remote host and starts a remote rpyc connection
711         proc = subprocess.Popen(args, 
712                 stdout = subprocess.PIPE,
713                 stdin = subprocess.PIPE, 
714                 stderr = subprocess.PIPE)
715         proc._known_hosts = tmp_known_hosts
716         
717         # send the command to execute
718         os.write(proc.stdin.fileno(),
719                 base64.b64encode(python_code) + "\n")
720         msg = os.read(proc.stdout.fileno(), 3)
721         if msg != "OK\n":
722             raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % (
723                 msg, proc.stdout.read(), proc.stderr.read())
724         return proc
725