popen_scp fixes: StringIO-based (and similar) transfers were utterly broken
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27
28 if hasattr(os, "devnull"):
29     DEV_NULL = os.devnull
30 else:
31     DEV_NULL = "/dev/null"
32
33
34
35 SHELL_SAFE = re.compile('[-a-zA-Z0-9_=+:.,/]*')
36
37 def shell_escape(s):
38     """ Escapes strings so that they are safe to use as command-line arguments """
39     if SHELL_SAFE.match(s):
40         # safe string - no escaping needed
41         return s
42     else:
43         # unsafe string - escape
44         s = s.replace("'","\\'")
45         return "'%s'" % (s,)
46
47 class Server(object):
48     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
49         self._root_dir = root_dir
50         self._stop = False
51         self._ctrl_sock = None
52         self._log_level = log_level
53
54     def run(self):
55         try:
56             if self.daemonize():
57                 self.post_daemonize()
58                 self.loop()
59                 self.cleanup()
60                 # ref: "os._exit(0)"
61                 # can not return normally after fork beacuse no exec was done.
62                 # This means that if we don't do a os._exit(0) here the code that 
63                 # follows the call to "Server.run()" in the "caller code" will be 
64                 # executed... but by now it has already been executed after the 
65                 # first process (the one that did the first fork) returned.
66                 os._exit(0)
67         except:
68             self.log_error()
69             self.cleanup()
70             os._exit(0)
71
72     def daemonize(self):
73         # pipes for process synchronization
74         (r, w) = os.pipe()
75
76         pid1 = os.fork()
77         if pid1 > 0:
78             os.close(w)
79             os.read(r, 1)
80             os.close(r)
81             # os.waitpid avoids leaving a <defunc> (zombie) process
82             st = os.waitpid(pid1, 0)[1]
83             if st:
84                 raise RuntimeError("Daemonization failed")
85             # return 0 to inform the caller method that this is not the 
86             # daemonized process
87             return 0
88         os.close(r)
89
90         # Decouple from parent environment.
91         os.chdir(self._root_dir)
92         os.umask(0)
93         os.setsid()
94
95         # fork 2
96         pid2 = os.fork()
97         if pid2 > 0:
98             # see ref: "os._exit(0)"
99             os._exit(0)
100
101         # close all open file descriptors.
102         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
103         if (max_fd == resource.RLIM_INFINITY):
104             max_fd = MAX_FD
105         for fd in range(3, max_fd):
106             if fd != w:
107                 try:
108                     os.close(fd)
109                 except OSError:
110                     pass
111
112         # Redirect standard file descriptors.
113         stdin = open(DEV_NULL, "r")
114         stderr = stdout = open(STD_ERR, "a", 0)
115         os.dup2(stdin.fileno(), sys.stdin.fileno())
116         # NOTE: sys.stdout.write will still be buffered, even if the file
117         # was opened with 0 buffer
118         os.dup2(stdout.fileno(), sys.stdout.fileno())
119         os.dup2(stderr.fileno(), sys.stderr.fileno())
120
121         # create control socket
122         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
123         self._ctrl_sock.bind(CTRL_SOCK)
124         self._ctrl_sock.listen(0)
125
126         # let the parent process know that the daemonization is finished
127         os.write(w, "\n")
128         os.close(w)
129         return 1
130
131     def post_daemonize(self):
132         pass
133
134     def loop(self):
135         while not self._stop:
136             conn, addr = self._ctrl_sock.accept()
137             conn.settimeout(5)
138             while not self._stop:
139                 try:
140                     msg = self.recv_msg(conn)
141                 except socket.timeout, e:
142                     break
143                     
144                 if msg == STOP_MSG:
145                     self._stop = True
146                     reply = self.stop_action()
147                 else:
148                     reply = self.reply_action(msg)
149                 
150                 try:
151                     self.send_reply(conn, reply)
152                 except socket.error:
153                     self.log_error()
154                     self.log_error("NOTICE: Awaiting for reconnection")
155                     break
156             try:
157                 conn.close()
158             except:
159                 # Doesn't matter
160                 self.log_error()
161
162     def recv_msg(self, conn):
163         data = ""
164         while True:
165             try:
166                 chunk = conn.recv(1024)
167             except OSError, e:
168                 if e.errno != errno.EINTR:
169                     raise
170                 if chunk == '':
171                     continue
172             if chunk:
173                 data += chunk
174                 if chunk[-1] == "\n":
175                     break
176             else:
177                 # empty chunk = EOF
178                 break
179         decoded = base64.b64decode(data)
180         return decoded.rstrip()
181
182     def send_reply(self, conn, reply):
183         encoded = base64.b64encode(reply)
184         conn.send("%s\n" % encoded)
185        
186     def cleanup(self):
187         try:
188             self._ctrl_sock.close()
189             os.remove(CTRL_SOCK)
190         except:
191             self.log_error()
192
193     def stop_action(self):
194         return "Stopping server"
195
196     def reply_action(self, msg):
197         return "Reply to: %s" % msg
198
199     def log_error(self, text = None, context = ''):
200         if text == None:
201             text = traceback.format_exc()
202         date = time.strftime("%Y-%m-%d %H:%M:%S")
203         if context:
204             context = " (%s)" % (context,)
205         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
206         return text
207
208     def log_debug(self, text):
209         if self._log_level == DEBUG_LEVEL:
210             date = time.strftime("%Y-%m-%d %H:%M:%S")
211             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
212
213 class Forwarder(object):
214     def __init__(self, root_dir = "."):
215         self._ctrl_sock = None
216         self._root_dir = root_dir
217         self._stop = False
218
219     def forward(self):
220         self.connect()
221         print >>sys.stderr, "READY."
222         while not self._stop:
223             data = self.read_data()
224             self.send_to_server(data)
225             data = self.recv_from_server()
226             self.write_data(data)
227         self.disconnect()
228
229     def read_data(self):
230         return sys.stdin.readline()
231
232     def write_data(self, data):
233         sys.stdout.write(data)
234         # sys.stdout.write is buffered, this is why we need to do a flush()
235         sys.stdout.flush()
236
237     def send_to_server(self, data):
238         try:
239             self._ctrl_sock.send(data)
240         except IOError, e:
241             if e.errno == errno.EPIPE:
242                 self.connect()
243                 self._ctrl_sock.send(data)
244             else:
245                 raise e
246         encoded = data.rstrip() 
247         msg = base64.b64decode(encoded)
248         if msg == STOP_MSG:
249             self._stop = True
250
251     def recv_from_server(self):
252         data = ""
253         while True:
254             try:
255                 chunk = self._ctrl_sock.recv(1024)
256             except OSError, e:
257                 if e.errno != errno.EINTR:
258                     raise
259                 if chunk == '':
260                     continue
261             data += chunk
262             if chunk[-1] == "\n":
263                 break
264         return data
265  
266     def connect(self):
267         self.disconnect()
268         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
269         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
270         self._ctrl_sock.connect(sock_addr)
271
272     def disconnect(self):
273         try:
274             self._ctrl_sock.close()
275         except:
276             pass
277
278 class Client(object):
279     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
280             agent = None):
281         self.root_dir = root_dir
282         self.addr = (host, port)
283         self.user = user
284         self.agent = agent
285         self._stopped = False
286         self.connect()
287     
288     def __del__(self):
289         if self._process.poll() is None:
290             os.kill(self._process.pid, signal.SIGTERM)
291         self._process.wait()
292         
293     def connect(self):
294         root_dir = self.root_dir
295         (host, port) = self.addr
296         user = self.user
297         agent = self.agent
298         
299         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
300                 c.forward()" % (root_dir,)
301         if host != None:
302             self._process = popen_ssh_subprocess(python_code, host, port, 
303                     user, agent)
304             # popen_ssh_subprocess already waits for readiness
305         else:
306             self._process = subprocess.Popen(
307                     ["python", "-c", python_code],
308                     stdin = subprocess.PIPE, 
309                     stdout = subprocess.PIPE,
310                     stderr = subprocess.PIPE
311                 )
312                 
313         # Wait for the forwarder to be ready, otherwise nobody
314         # will be able to connect to it
315         helo = self._process.stderr.readline()
316         if helo != 'READY.\n':
317             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
318                     helo + self._process.stderr.read())
319         
320     def send_msg(self, msg):
321         encoded = base64.b64encode(msg)
322         data = "%s\n" % encoded
323         
324         try:
325             self._process.stdin.write(data)
326         except (IOError, ValueError):
327             # dead process, poll it to un-zombify
328             self._process.poll()
329             
330             # try again after reconnect
331             # If it fails again, though, give up
332             self.connect()
333             self._process.stdin.write(data)
334
335     def send_stop(self):
336         self.send_msg(STOP_MSG)
337         self._stopped = True
338
339     def read_reply(self):
340         data = self._process.stdout.readline()
341         encoded = data.rstrip() 
342         return base64.b64decode(encoded)
343
344 def _make_server_key_args(server_key, host, port, args):
345     """ 
346     Returns a reference to the created temporary file, and adds the
347     corresponding arguments to the given argument list.
348     
349     Make sure to hold onto it until the process is done with the file
350     """
351     if port is not None:
352         host = '%s:%s' % (host,port)
353     # Create a temporary server key file
354     tmp_known_hosts = tempfile.NamedTemporaryFile()
355     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
356     tmp_known_hosts.flush()
357     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
358     return tmp_known_hosts
359
360 def popen_ssh_command(command, host, port, user, agent, 
361             stdin="", 
362             ident_key = None,
363             server_key = None,
364             tty = False):
365         """
366         Executes a remote commands, returns ((stdout,stderr),process)
367         """
368         tmp_known_hosts = None
369         args = ['ssh',
370                 # Don't bother with localhost. Makes test easier
371                 '-o', 'NoHostAuthenticationForLocalhost=yes',
372                 '-l', user, host]
373         if agent:
374             args.append('-A')
375         if port:
376             args.append('-p%d' % port)
377         if ident_key:
378             args.extend(('-i', ident_key))
379         if tty:
380             args.append('-t')
381         if server_key:
382             # Create a temporary server key file
383             tmp_known_hosts = _make_server_key_args(
384                 server_key, host, port, args)
385         args.append(command)
386
387         # connects to the remote host and starts a remote connection
388         proc = subprocess.Popen(args, 
389                 stdout = subprocess.PIPE,
390                 stdin = subprocess.PIPE, 
391                 stderr = subprocess.PIPE)
392         
393         # attach tempfile object to the process, to make sure the file stays
394         # alive until the process is finished with it
395         proc._known_hosts = tmp_known_hosts
396         
397         return (proc.communicate(stdin), proc)
398  
399 def popen_scp(source, dest, 
400             port = None, 
401             agent = None, 
402             recursive = False,
403             ident_key = None,
404             server_key = None):
405         """
406         Copies from/to remote sites.
407         
408         Source and destination should have the user and host encoded
409         as per scp specs.
410         
411         If source is a file object, a special mode will be used to
412         create the remote file with the same contents.
413         
414         If dest is a file object, the remote file (source) will be
415         read and written into dest.
416         
417         In these modes, recursive cannot be True.
418         """
419         
420         if isinstance(source, file) or isinstance(dest, file) \
421                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
422             assert not recursive
423             
424             # Parse source/destination as <user>@<server>:<path>
425             if isinstance(dest, basestring) and ':' in dest:
426                 remspec, path = dest.split(':',1)
427             elif isinstance(source, basestring) and ':' in source:
428                 remspec, path = source.split(':',1)
429             else:
430                 raise ValueError, "Both endpoints cannot be local"
431             user,host = remspec.rsplit('@',1)
432             tmp_known_hosts = None
433             
434             args = ['ssh', '-l', user, '-C',
435                     # Don't bother with localhost. Makes test easier
436                     '-o', 'NoHostAuthenticationForLocalhost=yes',
437                     host ]
438             if port:
439                 args.append('-P%d' % port)
440             if ident_key:
441                 args.extend(('-i', ident_key))
442             if server_key:
443                 # Create a temporary server key file
444                 tmp_known_hosts = _make_server_key_args(
445                     server_key, host, port, args)
446             
447             if isinstance(source, file) or hasattr(source, 'read'):
448                 args.append('cat > %s' % (shell_escape(path),))
449             elif isinstance(dest, file) or hasattr(dest, 'write'):
450                 args.append('cat %s' % (shell_escape(path),))
451             else:
452                 raise AssertionError, "Unreachable code reached! :-Q"
453             
454             # connects to the remote host and starts a remote connection
455             if isinstance(source, file):
456                 proc = subprocess.Popen(args, 
457                         stdout = open('/dev/null','w'),
458                         stderr = subprocess.PIPE,
459                         stdin = source)
460                 err = proc.stderr.read()
461                 proc._known_hosts = tmp_known_hosts
462                 proc.wait()
463                 return ((None,err), proc)
464             elif isinstance(dest, file):
465                 proc = subprocess.Popen(args, 
466                         stdout = open('/dev/null','w'),
467                         stderr = subprocess.PIPE,
468                         stdin = source)
469                 err = proc.stderr.read()
470                 proc._known_hosts = tmp_known_hosts
471                 proc.wait()
472                 return ((None,err), proc)
473             elif hasattr(source, 'read'):
474                 # file-like (but not file) source
475                 proc = subprocess.Popen(args, 
476                         stdout = open('/dev/null','w'),
477                         stderr = subprocess.PIPE,
478                         stdin = subprocess.PIPE)
479                 
480                 buf = None
481                 err = []
482                 while True:
483                     if not buf:
484                         buf = source.read(4096)
485                     if not buf:
486                         #EOF
487                         break
488                     
489                     rdrdy, wrdy, broken = select.select(
490                         [proc.stderr],
491                         [proc.stdin],
492                         [proc.stderr,proc.stdin])
493                     
494                     if proc.stderr in rdrdy:
495                         # use os.read for fully unbuffered behavior
496                         err.append(os.read(proc.stderr.fileno(), 4096))
497                     
498                     if proc.stdin in wrdy:
499                         proc.stdin.write(buf)
500                         buf = None
501                     
502                     if broken:
503                         break
504                 proc.stdin.close()
505                 err.append(proc.stderr.read())
506                     
507                 proc._known_hosts = tmp_known_hosts
508                 proc.wait()
509                 return ((None,''.join(err)), proc)
510             elif hasattr(dest, 'write'):
511                 # file-like (but not file) dest
512                 proc = subprocess.Popen(args, 
513                         stdout = subprocess.PIPE,
514                         stderr = subprocess.PIPE,
515                         stdin = open('/dev/null','w'))
516                 
517                 buf = None
518                 err = []
519                 while True:
520                     rdrdy, wrdy, broken = select.select(
521                         [proc.stderr, proc.stdout],
522                         [],
523                         [proc.stderr, proc.stdout])
524                     
525                     if proc.stderr in rdrdy:
526                         # use os.read for fully unbuffered behavior
527                         err.append(os.read(proc.stderr.fileno(), 4096))
528                     
529                     if proc.stdout in rdrdy:
530                         # use os.read for fully unbuffered behavior
531                         buf = os.read(proc.stdout.fileno(), 4096)
532                         dest.write(buf)
533                         
534                         if not buf:
535                             #EOF
536                             break
537                     
538                     if broken:
539                         break
540                 err.append(proc.stderr.read())
541                     
542                 proc._known_hosts = tmp_known_hosts
543                 proc.wait()
544                 return ((None,''.join(err)), proc)
545             else:
546                 raise AssertionError, "Unreachable code reached! :-Q"
547         else:
548             # Parse destination as <user>@<server>:<path>
549             if isinstance(dest, basestring) and ':' in dest:
550                 remspec, path = dest.split(':',1)
551             elif isinstance(source, basestring) and ':' in source:
552                 remspec, path = source.split(':',1)
553             else:
554                 raise ValueError, "Both endpoints cannot be local"
555             user,host = remspec.rsplit('@',1)
556             
557             # plain scp
558             tmp_known_hosts = None
559             args = ['scp', '-q', '-p', '-C',
560                     # Don't bother with localhost. Makes test easier
561                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
562             if port:
563                 args.append('-P%d' % port)
564             if recursive:
565                 args.append('-r')
566             if ident_key:
567                 args.extend(('-i', ident_key))
568             if server_key:
569                 # Create a temporary server key file
570                 tmp_known_hosts = _make_server_key_args(
571                     server_key, host, port, args)
572             args.append(source)
573             args.append(dest)
574
575             # connects to the remote host and starts a remote connection
576             proc = subprocess.Popen(args, 
577                     stdout = subprocess.PIPE,
578                     stdin = subprocess.PIPE, 
579                     stderr = subprocess.PIPE)
580             proc._known_hosts = tmp_known_hosts
581             
582             comm = proc.communicate()
583             proc.wait()
584             return (comm, proc)
585  
586 def popen_ssh_subprocess(python_code, host, port, user, agent, 
587         python_path = None,
588         ident_key = None,
589         server_key = None,
590         tty = False):
591         if python_path:
592             python_path.replace("'", r"'\''")
593             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
594         else:
595             cmd = ""
596         # Uncomment for debug (to run everything under strace)
597         # We had to verify if strace works (cannot nest them)
598         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
599         #cmd += "$CMD "
600         #if self.mode == MODE_SSH:
601         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
602         cmd += "python -c '"
603         cmd += "import base64, os\n"
604         cmd += "cmd = \"\"\n"
605         cmd += "while True:\n"
606         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
607         cmd += " if cmd[-1] == \"\\n\": break\n"
608         cmd += "cmd = base64.b64decode(cmd)\n"
609         # Uncomment for debug
610         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
611         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
612         cmd += "exec(cmd)\n'"
613
614         tmp_known_hosts = None
615         args = ['ssh',
616                 # Don't bother with localhost. Makes test easier
617                 '-o', 'NoHostAuthenticationForLocalhost=yes',
618                 '-l', user, host]
619         if agent:
620             args.append('-A')
621         if port:
622             args.append('-p%d' % port)
623         if ident_key:
624             args.extend(('-i', ident_key))
625         if tty:
626             args.append('-t')
627         if server_key:
628             # Create a temporary server key file
629             tmp_known_hosts = _make_server_key_args(
630                 server_key, host, port, args)
631         args.append(cmd)
632
633         # connects to the remote host and starts a remote rpyc connection
634         proc = subprocess.Popen(args, 
635                 stdout = subprocess.PIPE,
636                 stdin = subprocess.PIPE, 
637                 stderr = subprocess.PIPE)
638         proc._known_hosts = tmp_known_hosts
639         
640         # send the command to execute
641         os.write(proc.stdin.fileno(),
642                 base64.b64encode(python_code) + "\n")
643         msg = os.read(proc.stdout.fileno(), 3)
644         if msg != "OK\n":
645             raise RuntimeError("Failed to start remote python interpreter")
646         return proc
647