server_scp can copy multiple files at once.
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27
28 if hasattr(os, "devnull"):
29     DEV_NULL = os.devnull
30 else:
31     DEV_NULL = "/dev/null"
32
33
34
35 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
36
37 def shell_escape(s):
38     """ Escapes strings so that they are safe to use as command-line arguments """
39     if SHELL_SAFE.match(s):
40         # safe string - no escaping needed
41         return s
42     else:
43         # unsafe string - escape
44         s = s.replace("'","\\'")
45         return "'%s'" % (s,)
46
47 class Server(object):
48     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
49         self._root_dir = root_dir
50         self._stop = False
51         self._ctrl_sock = None
52         self._log_level = log_level
53
54     def run(self):
55         try:
56             if self.daemonize():
57                 self.post_daemonize()
58                 self.loop()
59                 self.cleanup()
60                 # ref: "os._exit(0)"
61                 # can not return normally after fork beacuse no exec was done.
62                 # This means that if we don't do a os._exit(0) here the code that 
63                 # follows the call to "Server.run()" in the "caller code" will be 
64                 # executed... but by now it has already been executed after the 
65                 # first process (the one that did the first fork) returned.
66                 os._exit(0)
67         except:
68             self.log_error()
69             self.cleanup()
70             os._exit(0)
71
72     def daemonize(self):
73         # pipes for process synchronization
74         (r, w) = os.pipe()
75
76         pid1 = os.fork()
77         if pid1 > 0:
78             os.close(w)
79             os.read(r, 1)
80             os.close(r)
81             # os.waitpid avoids leaving a <defunc> (zombie) process
82             st = os.waitpid(pid1, 0)[1]
83             if st:
84                 raise RuntimeError("Daemonization failed")
85             # return 0 to inform the caller method that this is not the 
86             # daemonized process
87             return 0
88         os.close(r)
89
90         # Decouple from parent environment.
91         os.chdir(self._root_dir)
92         os.umask(0)
93         os.setsid()
94
95         # fork 2
96         pid2 = os.fork()
97         if pid2 > 0:
98             # see ref: "os._exit(0)"
99             os._exit(0)
100
101         # close all open file descriptors.
102         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
103         if (max_fd == resource.RLIM_INFINITY):
104             max_fd = MAX_FD
105         for fd in range(3, max_fd):
106             if fd != w:
107                 try:
108                     os.close(fd)
109                 except OSError:
110                     pass
111
112         # Redirect standard file descriptors.
113         stdin = open(DEV_NULL, "r")
114         stderr = stdout = open(STD_ERR, "a", 0)
115         os.dup2(stdin.fileno(), sys.stdin.fileno())
116         # NOTE: sys.stdout.write will still be buffered, even if the file
117         # was opened with 0 buffer
118         os.dup2(stdout.fileno(), sys.stdout.fileno())
119         os.dup2(stderr.fileno(), sys.stderr.fileno())
120
121         # create control socket
122         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
123         self._ctrl_sock.bind(CTRL_SOCK)
124         self._ctrl_sock.listen(0)
125
126         # let the parent process know that the daemonization is finished
127         os.write(w, "\n")
128         os.close(w)
129         return 1
130
131     def post_daemonize(self):
132         pass
133
134     def loop(self):
135         while not self._stop:
136             conn, addr = self._ctrl_sock.accept()
137             conn.settimeout(5)
138             while not self._stop:
139                 try:
140                     msg = self.recv_msg(conn)
141                 except socket.timeout, e:
142                     break
143                     
144                 if msg == STOP_MSG:
145                     self._stop = True
146                     reply = self.stop_action()
147                 else:
148                     reply = self.reply_action(msg)
149                 
150                 try:
151                     self.send_reply(conn, reply)
152                 except socket.error:
153                     self.log_error()
154                     self.log_error("NOTICE: Awaiting for reconnection")
155                     break
156             try:
157                 conn.close()
158             except:
159                 # Doesn't matter
160                 self.log_error()
161
162     def recv_msg(self, conn):
163         data = ""
164         while True:
165             try:
166                 chunk = conn.recv(1024)
167             except OSError, e:
168                 if e.errno != errno.EINTR:
169                     raise
170                 if chunk == '':
171                     continue
172             if chunk:
173                 data += chunk
174                 if chunk[-1] == "\n":
175                     break
176             else:
177                 # empty chunk = EOF
178                 break
179         decoded = base64.b64decode(data)
180         return decoded.rstrip()
181
182     def send_reply(self, conn, reply):
183         encoded = base64.b64encode(reply)
184         conn.send("%s\n" % encoded)
185        
186     def cleanup(self):
187         try:
188             self._ctrl_sock.close()
189             os.remove(CTRL_SOCK)
190         except:
191             self.log_error()
192
193     def stop_action(self):
194         return "Stopping server"
195
196     def reply_action(self, msg):
197         return "Reply to: %s" % msg
198
199     def log_error(self, text = None, context = ''):
200         if text == None:
201             text = traceback.format_exc()
202         date = time.strftime("%Y-%m-%d %H:%M:%S")
203         if context:
204             context = " (%s)" % (context,)
205         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
206         return text
207
208     def log_debug(self, text):
209         if self._log_level == DEBUG_LEVEL:
210             date = time.strftime("%Y-%m-%d %H:%M:%S")
211             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
212
213 class Forwarder(object):
214     def __init__(self, root_dir = "."):
215         self._ctrl_sock = None
216         self._root_dir = root_dir
217         self._stop = False
218
219     def forward(self):
220         self.connect()
221         print >>sys.stderr, "READY."
222         while not self._stop:
223             data = self.read_data()
224             self.send_to_server(data)
225             data = self.recv_from_server()
226             self.write_data(data)
227         self.disconnect()
228
229     def read_data(self):
230         return sys.stdin.readline()
231
232     def write_data(self, data):
233         sys.stdout.write(data)
234         # sys.stdout.write is buffered, this is why we need to do a flush()
235         sys.stdout.flush()
236
237     def send_to_server(self, data):
238         try:
239             self._ctrl_sock.send(data)
240         except IOError, e:
241             if e.errno == errno.EPIPE:
242                 self.connect()
243                 self._ctrl_sock.send(data)
244             else:
245                 raise e
246         encoded = data.rstrip() 
247         msg = base64.b64decode(encoded)
248         if msg == STOP_MSG:
249             self._stop = True
250
251     def recv_from_server(self):
252         data = ""
253         while True:
254             try:
255                 chunk = self._ctrl_sock.recv(1024)
256             except OSError, e:
257                 if e.errno != errno.EINTR:
258                     raise
259                 if chunk == '':
260                     continue
261             data += chunk
262             if chunk[-1] == "\n":
263                 break
264         return data
265  
266     def connect(self):
267         self.disconnect()
268         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
269         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
270         self._ctrl_sock.connect(sock_addr)
271
272     def disconnect(self):
273         try:
274             self._ctrl_sock.close()
275         except:
276             pass
277
278 class Client(object):
279     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
280             agent = None):
281         self.root_dir = root_dir
282         self.addr = (host, port)
283         self.user = user
284         self.agent = agent
285         self._stopped = False
286         self.connect()
287     
288     def __del__(self):
289         if self._process.poll() is None:
290             os.kill(self._process.pid, signal.SIGTERM)
291         self._process.wait()
292         
293     def connect(self):
294         root_dir = self.root_dir
295         (host, port) = self.addr
296         user = self.user
297         agent = self.agent
298         
299         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
300                 c.forward()" % (root_dir,)
301         if host != None:
302             self._process = popen_ssh_subprocess(python_code, host, port, 
303                     user, agent)
304             # popen_ssh_subprocess already waits for readiness
305         else:
306             self._process = subprocess.Popen(
307                     ["python", "-c", python_code],
308                     stdin = subprocess.PIPE, 
309                     stdout = subprocess.PIPE,
310                     stderr = subprocess.PIPE
311                 )
312                 
313         # Wait for the forwarder to be ready, otherwise nobody
314         # will be able to connect to it
315         helo = self._process.stderr.readline()
316         if helo != 'READY.\n':
317             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
318                     helo + self._process.stderr.read())
319         
320     def send_msg(self, msg):
321         encoded = base64.b64encode(msg)
322         data = "%s\n" % encoded
323         
324         try:
325             self._process.stdin.write(data)
326         except (IOError, ValueError):
327             # dead process, poll it to un-zombify
328             self._process.poll()
329             
330             # try again after reconnect
331             # If it fails again, though, give up
332             self.connect()
333             self._process.stdin.write(data)
334
335     def send_stop(self):
336         self.send_msg(STOP_MSG)
337         self._stopped = True
338
339     def read_reply(self):
340         data = self._process.stdout.readline()
341         encoded = data.rstrip() 
342         return base64.b64decode(encoded)
343
344 def _make_server_key_args(server_key, host, port, args):
345     """ 
346     Returns a reference to the created temporary file, and adds the
347     corresponding arguments to the given argument list.
348     
349     Make sure to hold onto it until the process is done with the file
350     """
351     if port is not None:
352         host = '%s:%s' % (host,port)
353     # Create a temporary server key file
354     tmp_known_hosts = tempfile.NamedTemporaryFile()
355     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
356     tmp_known_hosts.flush()
357     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
358     return tmp_known_hosts
359
360 def popen_ssh_command(command, host, port, user, agent, 
361             stdin="", 
362             ident_key = None,
363             server_key = None,
364             tty = False):
365         """
366         Executes a remote commands, returns ((stdout,stderr),process)
367         """
368         tmp_known_hosts = None
369         args = ['ssh',
370                 # Don't bother with localhost. Makes test easier
371                 '-o', 'NoHostAuthenticationForLocalhost=yes',
372                 '-l', user, host]
373         if agent:
374             args.append('-A')
375         if port:
376             args.append('-p%d' % port)
377         if ident_key:
378             args.extend(('-i', ident_key))
379         if tty:
380             args.append('-t')
381         if server_key:
382             # Create a temporary server key file
383             tmp_known_hosts = _make_server_key_args(
384                 server_key, host, port, args)
385         args.append(command)
386
387         # connects to the remote host and starts a remote connection
388         proc = subprocess.Popen(args, 
389                 stdout = subprocess.PIPE,
390                 stdin = subprocess.PIPE, 
391                 stderr = subprocess.PIPE)
392         
393         # attach tempfile object to the process, to make sure the file stays
394         # alive until the process is finished with it
395         proc._known_hosts = tmp_known_hosts
396         
397         return (proc.communicate(stdin), proc)
398  
399 def popen_scp(source, dest, 
400             port = None, 
401             agent = None, 
402             recursive = False,
403             ident_key = None,
404             server_key = None):
405         """
406         Copies from/to remote sites.
407         
408         Source and destination should have the user and host encoded
409         as per scp specs.
410         
411         If source is a file object, a special mode will be used to
412         create the remote file with the same contents.
413         
414         If dest is a file object, the remote file (source) will be
415         read and written into dest.
416         
417         In these modes, recursive cannot be True.
418         
419         Source can be a list of files to copy to a single destination,
420         in which case it is advised that the destination be a folder.
421         """
422         
423         if isinstance(source, file) or isinstance(dest, file) \
424                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
425             assert not recursive
426             
427             # Parse source/destination as <user>@<server>:<path>
428             if isinstance(dest, basestring) and ':' in dest:
429                 remspec, path = dest.split(':',1)
430             elif isinstance(source, basestring) and ':' in source:
431                 remspec, path = source.split(':',1)
432             else:
433                 raise ValueError, "Both endpoints cannot be local"
434             user,host = remspec.rsplit('@',1)
435             tmp_known_hosts = None
436             
437             args = ['ssh', '-l', user, '-C',
438                     # Don't bother with localhost. Makes test easier
439                     '-o', 'NoHostAuthenticationForLocalhost=yes',
440                     host ]
441             if port:
442                 args.append('-P%d' % port)
443             if ident_key:
444                 args.extend(('-i', ident_key))
445             if server_key:
446                 # Create a temporary server key file
447                 tmp_known_hosts = _make_server_key_args(
448                     server_key, host, port, args)
449             
450             if isinstance(source, file) or hasattr(source, 'read'):
451                 args.append('cat > %s' % (shell_escape(path),))
452             elif isinstance(dest, file) or hasattr(dest, 'write'):
453                 args.append('cat %s' % (shell_escape(path),))
454             else:
455                 raise AssertionError, "Unreachable code reached! :-Q"
456             
457             # connects to the remote host and starts a remote connection
458             if isinstance(source, file):
459                 proc = subprocess.Popen(args, 
460                         stdout = open('/dev/null','w'),
461                         stderr = subprocess.PIPE,
462                         stdin = source)
463                 err = proc.stderr.read()
464                 proc._known_hosts = tmp_known_hosts
465                 proc.wait()
466                 return ((None,err), proc)
467             elif isinstance(dest, file):
468                 proc = subprocess.Popen(args, 
469                         stdout = open('/dev/null','w'),
470                         stderr = subprocess.PIPE,
471                         stdin = source)
472                 err = proc.stderr.read()
473                 proc._known_hosts = tmp_known_hosts
474                 proc.wait()
475                 return ((None,err), proc)
476             elif hasattr(source, 'read'):
477                 # file-like (but not file) source
478                 proc = subprocess.Popen(args, 
479                         stdout = open('/dev/null','w'),
480                         stderr = subprocess.PIPE,
481                         stdin = subprocess.PIPE)
482                 
483                 buf = None
484                 err = []
485                 while True:
486                     if not buf:
487                         buf = source.read(4096)
488                     if not buf:
489                         #EOF
490                         break
491                     
492                     rdrdy, wrdy, broken = select.select(
493                         [proc.stderr],
494                         [proc.stdin],
495                         [proc.stderr,proc.stdin])
496                     
497                     if proc.stderr in rdrdy:
498                         # use os.read for fully unbuffered behavior
499                         err.append(os.read(proc.stderr.fileno(), 4096))
500                     
501                     if proc.stdin in wrdy:
502                         proc.stdin.write(buf)
503                         buf = None
504                     
505                     if broken:
506                         break
507                 proc.stdin.close()
508                 err.append(proc.stderr.read())
509                     
510                 proc._known_hosts = tmp_known_hosts
511                 proc.wait()
512                 return ((None,''.join(err)), proc)
513             elif hasattr(dest, 'write'):
514                 # file-like (but not file) dest
515                 proc = subprocess.Popen(args, 
516                         stdout = subprocess.PIPE,
517                         stderr = subprocess.PIPE,
518                         stdin = open('/dev/null','w'))
519                 
520                 buf = None
521                 err = []
522                 while True:
523                     rdrdy, wrdy, broken = select.select(
524                         [proc.stderr, proc.stdout],
525                         [],
526                         [proc.stderr, proc.stdout])
527                     
528                     if proc.stderr in rdrdy:
529                         # use os.read for fully unbuffered behavior
530                         err.append(os.read(proc.stderr.fileno(), 4096))
531                     
532                     if proc.stdout in rdrdy:
533                         # use os.read for fully unbuffered behavior
534                         buf = os.read(proc.stdout.fileno(), 4096)
535                         dest.write(buf)
536                         
537                         if not buf:
538                             #EOF
539                             break
540                     
541                     if broken:
542                         break
543                 err.append(proc.stderr.read())
544                     
545                 proc._known_hosts = tmp_known_hosts
546                 proc.wait()
547                 return ((None,''.join(err)), proc)
548             else:
549                 raise AssertionError, "Unreachable code reached! :-Q"
550         else:
551             # Parse destination as <user>@<server>:<path>
552             if isinstance(dest, basestring) and ':' in dest:
553                 remspec, path = dest.split(':',1)
554             elif isinstance(source, basestring) and ':' in source:
555                 remspec, path = source.split(':',1)
556             else:
557                 raise ValueError, "Both endpoints cannot be local"
558             user,host = remspec.rsplit('@',1)
559             
560             # plain scp
561             tmp_known_hosts = None
562             args = ['scp', '-q', '-p', '-C',
563                     # Don't bother with localhost. Makes test easier
564                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
565             if port:
566                 args.append('-P%d' % port)
567             if recursive:
568                 args.append('-r')
569             if ident_key:
570                 args.extend(('-i', ident_key))
571             if server_key:
572                 # Create a temporary server key file
573                 tmp_known_hosts = _make_server_key_args(
574                     server_key, host, port, args)
575             if isinstance(source,list):
576                 args.extend(source)
577             else:
578                 args.append(source)
579             args.append(dest)
580
581             # connects to the remote host and starts a remote connection
582             proc = subprocess.Popen(args, 
583                     stdout = subprocess.PIPE,
584                     stdin = subprocess.PIPE, 
585                     stderr = subprocess.PIPE)
586             proc._known_hosts = tmp_known_hosts
587             
588             comm = proc.communicate()
589             proc.wait()
590             return (comm, proc)
591  
592 def popen_ssh_subprocess(python_code, host, port, user, agent, 
593         python_path = None,
594         ident_key = None,
595         server_key = None,
596         tty = False):
597         if python_path:
598             python_path.replace("'", r"'\''")
599             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
600         else:
601             cmd = ""
602         # Uncomment for debug (to run everything under strace)
603         # We had to verify if strace works (cannot nest them)
604         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
605         #cmd += "$CMD "
606         #if self.mode == MODE_SSH:
607         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
608         cmd += "python -c '"
609         cmd += "import base64, os\n"
610         cmd += "cmd = \"\"\n"
611         cmd += "while True:\n"
612         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
613         cmd += " if cmd[-1] == \"\\n\": break\n"
614         cmd += "cmd = base64.b64decode(cmd)\n"
615         # Uncomment for debug
616         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
617         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
618         cmd += "exec(cmd)\n'"
619
620         tmp_known_hosts = None
621         args = ['ssh',
622                 # Don't bother with localhost. Makes test easier
623                 '-o', 'NoHostAuthenticationForLocalhost=yes',
624                 '-l', user, host]
625         if agent:
626             args.append('-A')
627         if port:
628             args.append('-p%d' % port)
629         if ident_key:
630             args.extend(('-i', ident_key))
631         if tty:
632             args.append('-t')
633         if server_key:
634             # Create a temporary server key file
635             tmp_known_hosts = _make_server_key_args(
636                 server_key, host, port, args)
637         args.append(cmd)
638
639         # connects to the remote host and starts a remote rpyc connection
640         proc = subprocess.Popen(args, 
641                 stdout = subprocess.PIPE,
642                 stdin = subprocess.PIPE, 
643                 stderr = subprocess.PIPE)
644         proc._known_hosts = tmp_known_hosts
645         
646         # send the command to execute
647         os.write(proc.stdin.fileno(),
648                 base64.b64encode(python_code) + "\n")
649         msg = os.read(proc.stdout.fileno(), 3)
650         if msg != "OK\n":
651             raise RuntimeError("Failed to start remote python interpreter")
652         return proc
653