* Automatic provisioning
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27
28 if hasattr(os, "devnull"):
29     DEV_NULL = os.devnull
30 else:
31     DEV_NULL = "/dev/null"
32
33
34
35 SHELL_SAFE = re.compile('[-a-zA-Z0-9_=+:.,/]*')
36
37 def shell_escape(s):
38     """ Escapes strings so that they are safe to use as command-line arguments """
39     if SHELL_SAFE.match(s):
40         # safe string - no escaping needed
41         return s
42     else:
43         # unsafe string - escape
44         s = s.replace("'","\\'")
45         return "'%s'" % (s,)
46
47 class Server(object):
48     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
49         self._root_dir = root_dir
50         self._stop = False
51         self._ctrl_sock = None
52         self._log_level = log_level
53
54     def run(self):
55         try:
56             if self.daemonize():
57                 self.post_daemonize()
58                 self.loop()
59                 self.cleanup()
60                 # ref: "os._exit(0)"
61                 # can not return normally after fork beacuse no exec was done.
62                 # This means that if we don't do a os._exit(0) here the code that 
63                 # follows the call to "Server.run()" in the "caller code" will be 
64                 # executed... but by now it has already been executed after the 
65                 # first process (the one that did the first fork) returned.
66                 os._exit(0)
67         except:
68             self.log_error()
69             self.cleanup()
70             os._exit(0)
71
72     def daemonize(self):
73         # pipes for process synchronization
74         (r, w) = os.pipe()
75
76         pid1 = os.fork()
77         if pid1 > 0:
78             os.close(w)
79             os.read(r, 1)
80             os.close(r)
81             # os.waitpid avoids leaving a <defunc> (zombie) process
82             st = os.waitpid(pid1, 0)[1]
83             if st:
84                 raise RuntimeError("Daemonization failed")
85             # return 0 to inform the caller method that this is not the 
86             # daemonized process
87             return 0
88         os.close(r)
89
90         # Decouple from parent environment.
91         os.chdir(self._root_dir)
92         os.umask(0)
93         os.setsid()
94
95         # fork 2
96         pid2 = os.fork()
97         if pid2 > 0:
98             # see ref: "os._exit(0)"
99             os._exit(0)
100
101         # close all open file descriptors.
102         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
103         if (max_fd == resource.RLIM_INFINITY):
104             max_fd = MAX_FD
105         for fd in range(3, max_fd):
106             if fd != w:
107                 try:
108                     os.close(fd)
109                 except OSError:
110                     pass
111
112         # Redirect standard file descriptors.
113         stdin = open(DEV_NULL, "r")
114         stderr = stdout = open(STD_ERR, "a", 0)
115         os.dup2(stdin.fileno(), sys.stdin.fileno())
116         # NOTE: sys.stdout.write will still be buffered, even if the file
117         # was opened with 0 buffer
118         os.dup2(stdout.fileno(), sys.stdout.fileno())
119         os.dup2(stderr.fileno(), sys.stderr.fileno())
120
121         # create control socket
122         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
123         self._ctrl_sock.bind(CTRL_SOCK)
124         self._ctrl_sock.listen(0)
125
126         # let the parent process know that the daemonization is finished
127         os.write(w, "\n")
128         os.close(w)
129         return 1
130
131     def post_daemonize(self):
132         pass
133
134     def loop(self):
135         while not self._stop:
136             conn, addr = self._ctrl_sock.accept()
137             conn.settimeout(5)
138             while not self._stop:
139                 try:
140                     msg = self.recv_msg(conn)
141                 except socket.timeout, e:
142                     break
143                     
144                 if msg == STOP_MSG:
145                     self._stop = True
146                     reply = self.stop_action()
147                 else:
148                     reply = self.reply_action(msg)
149                 
150                 try:
151                     self.send_reply(conn, reply)
152                 except socket.error:
153                     self.log_error()
154                     self.log_error("NOTICE: Awaiting for reconnection")
155                     break
156             try:
157                 conn.close()
158             except:
159                 # Doesn't matter
160                 self.log_error()
161
162     def recv_msg(self, conn):
163         data = ""
164         while True:
165             try:
166                 chunk = conn.recv(1024)
167             except OSError, e:
168                 if e.errno != errno.EINTR:
169                     raise
170                 if chunk == '':
171                     continue
172             if chunk:
173                 data += chunk
174                 if chunk[-1] == "\n":
175                     break
176             else:
177                 # empty chunk = EOF
178                 break
179         decoded = base64.b64decode(data)
180         return decoded.rstrip()
181
182     def send_reply(self, conn, reply):
183         encoded = base64.b64encode(reply)
184         conn.send("%s\n" % encoded)
185        
186     def cleanup(self):
187         try:
188             self._ctrl_sock.close()
189             os.remove(CTRL_SOCK)
190         except:
191             self.log_error()
192
193     def stop_action(self):
194         return "Stopping server"
195
196     def reply_action(self, msg):
197         return "Reply to: %s" % msg
198
199     def log_error(self, text = None, context = ''):
200         if text == None:
201             text = traceback.format_exc()
202         date = time.strftime("%Y-%m-%d %H:%M:%S")
203         if context:
204             context = " (%s)" % (context,)
205         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
206         return text
207
208     def log_debug(self, text):
209         if self._log_level == DEBUG_LEVEL:
210             date = time.strftime("%Y-%m-%d %H:%M:%S")
211             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
212
213 class Forwarder(object):
214     def __init__(self, root_dir = "."):
215         self._ctrl_sock = None
216         self._root_dir = root_dir
217         self._stop = False
218
219     def forward(self):
220         self.connect()
221         print >>sys.stderr, "READY."
222         while not self._stop:
223             data = self.read_data()
224             self.send_to_server(data)
225             data = self.recv_from_server()
226             self.write_data(data)
227         self.disconnect()
228
229     def read_data(self):
230         return sys.stdin.readline()
231
232     def write_data(self, data):
233         sys.stdout.write(data)
234         # sys.stdout.write is buffered, this is why we need to do a flush()
235         sys.stdout.flush()
236
237     def send_to_server(self, data):
238         try:
239             self._ctrl_sock.send(data)
240         except IOError, e:
241             if e.errno == errno.EPIPE:
242                 self.connect()
243                 self._ctrl_sock.send(data)
244             else:
245                 raise e
246         encoded = data.rstrip() 
247         msg = base64.b64decode(encoded)
248         if msg == STOP_MSG:
249             self._stop = True
250
251     def recv_from_server(self):
252         data = ""
253         while True:
254             try:
255                 chunk = self._ctrl_sock.recv(1024)
256             except OSError, e:
257                 if e.errno != errno.EINTR:
258                     raise
259                 if chunk == '':
260                     continue
261             data += chunk
262             if chunk[-1] == "\n":
263                 break
264         return data
265  
266     def connect(self):
267         self.disconnect()
268         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
269         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
270         self._ctrl_sock.connect(sock_addr)
271
272     def disconnect(self):
273         try:
274             self._ctrl_sock.close()
275         except:
276             pass
277
278 class Client(object):
279     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
280             agent = None):
281         self.root_dir = root_dir
282         self.addr = (host, port)
283         self.user = user
284         self.agent = agent
285         self._stopped = False
286         self.connect()
287     
288     def __del__(self):
289         if self._process.poll() is None:
290             os.kill(self._process.pid, signal.SIGTERM)
291         self._process.wait()
292         
293     def connect(self):
294         root_dir = self.root_dir
295         (host, port) = self.addr
296         user = self.user
297         agent = self.agent
298         
299         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
300                 c.forward()" % (root_dir,)
301         if host != None:
302             self._process = popen_ssh_subprocess(python_code, host, port, 
303                     user, agent)
304             # popen_ssh_subprocess already waits for readiness
305         else:
306             self._process = subprocess.Popen(
307                     ["python", "-c", python_code],
308                     stdin = subprocess.PIPE, 
309                     stdout = subprocess.PIPE,
310                     stderr = subprocess.PIPE
311                 )
312                 
313         # Wait for the forwarder to be ready, otherwise nobody
314         # will be able to connect to it
315         helo = self._process.stderr.readline()
316         if helo != 'READY.\n':
317             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
318                     helo + self._process.stderr.read())
319         
320     def send_msg(self, msg):
321         encoded = base64.b64encode(msg)
322         data = "%s\n" % encoded
323         
324         try:
325             self._process.stdin.write(data)
326         except (IOError, ValueError):
327             # dead process, poll it to un-zombify
328             self._process.poll()
329             
330             # try again after reconnect
331             # If it fails again, though, give up
332             self.connect()
333             self._process.stdin.write(data)
334
335     def send_stop(self):
336         self.send_msg(STOP_MSG)
337         self._stopped = True
338
339     def read_reply(self):
340         data = self._process.stdout.readline()
341         encoded = data.rstrip() 
342         return base64.b64decode(encoded)
343
344 def _make_server_key_args(server_key, host, port, args):
345     """ 
346     Returns a reference to the created temporary file, and adds the
347     corresponding arguments to the given argument list.
348     
349     Make sure to hold onto it until the process is done with the file
350     """
351     if port is not None:
352         host = '%s:%s' % (host,port)
353     # Create a temporary server key file
354     tmp_known_hosts = tempfile.NamedTemporaryFile()
355     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
356     tmp_known_hosts.flush()
357     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
358     return tmp_known_hosts
359
360 def popen_ssh_command(command, host, port, user, agent, 
361             stdin="", 
362             ident_key = None,
363             server_key = None,
364             tty = False):
365         """
366         Executes a remote commands, returns ((stdout,stderr),process)
367         """
368         tmp_known_hosts = None
369         args = ['ssh',
370                 # Don't bother with localhost. Makes test easier
371                 '-o', 'NoHostAuthenticationForLocalhost=yes',
372                 '-l', user, host]
373         if agent:
374             args.append('-A')
375         if port:
376             args.append('-p%d' % port)
377         if ident_key:
378             args.extend(('-i', ident_key))
379         if tty:
380             args.append('-t')
381         if server_key:
382             # Create a temporary server key file
383             tmp_known_hosts = _make_server_key_args(
384                 server_key, host, port, args)
385         args.append(command)
386
387         # connects to the remote host and starts a remote connection
388         proc = subprocess.Popen(args, 
389                 stdout = subprocess.PIPE,
390                 stdin = subprocess.PIPE, 
391                 stderr = subprocess.PIPE)
392         
393         # attach tempfile object to the process, to make sure the file stays
394         # alive until the process is finished with it
395         proc._known_hosts = tmp_known_hosts
396         
397         return (proc.communicate(stdin), proc)
398  
399 def popen_scp(source, dest, 
400             port = None, 
401             agent = None, 
402             recursive = False,
403             ident_key = None,
404             server_key = None):
405         """
406         Copies from/to remote sites.
407         
408         Source and destination should have the user and host encoded
409         as per scp specs.
410         
411         If source is a file object, a special mode will be used to
412         create the remote file with the same contents.
413         
414         If dest is a file object, the remote file (source) will be
415         read and written into dest.
416         
417         In these modes, recursive cannot be True.
418         """
419         
420         if isinstance(source, file) or isinstance(dest, file) \
421                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
422             assert not resursive
423             
424             # Parse source/destination as <user>@<server>:<path>
425             if isinstance(dest, basestring) and ':' in dest:
426                 remspec, path = dest.split(':',1)
427             elif isinstance(source, basestring) and ':' in source:
428                 remspec, path = source.split(':',1)
429             else:
430                 raise ValueError, "Both endpoints cannot be local"
431             user,host = remspec.rsplit('@',1)
432             tmp_known_hosts = None
433             
434             args = ['ssh', '-l', user, '-C',
435                     # Don't bother with localhost. Makes test easier
436                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
437             if port:
438                 args.append('-P%d' % port)
439             if ident_key:
440                 args.extend(('-i', ident_key))
441             if server_key:
442                 # Create a temporary server key file
443                 tmp_known_hosts = _make_server_key_args(
444                     server_key, host, port, args)
445             
446             if isinstance(source, file) or hasattr(source, 'read'):
447                 args.append('cat > %s' % (shell_escape(path),))
448             elif isinstance(dest, file) or hasattr(dest, 'write'):
449                 args.append('cat %s' % (shell_escape(path),))
450             else:
451                 raise AssertionError, "Unreachable code reached! :-Q"
452             
453             # connects to the remote host and starts a remote connection
454             if isinstance(source, file):
455                 proc = subprocess.Popen(args, 
456                         stdout = open('/dev/null','w'),
457                         stderr = subprocess.PIPE,
458                         stdin = source)
459                 err = proc.stderr.read()
460                 proc._known_hosts = tmp_known_hosts
461                 proc.wait()
462                 return ((None,err), proc)
463             elif isinstance(dest, file):
464                 proc = subprocess.Popen(args, 
465                         stdout = open('/dev/null','w'),
466                         stderr = subprocess.PIPE,
467                         stdin = source)
468                 err = proc.stderr.read()
469                 proc._known_hosts = tmp_known_hosts
470                 proc.wait()
471                 return ((None,err), proc)
472             elif hasattr(source, 'read'):
473                 # file-like (but not file) source
474                 proc = subprocess.Popen(args, 
475                         stdout = open('/dev/null','w'),
476                         stderr = subprocess.PIPE,
477                         stdin = source)
478                 
479                 buf = None
480                 err = []
481                 while True:
482                     if not buf:
483                         buf = source.read(4096)
484                     
485                     rdrdy, wrdy, broken = os.select(
486                         [proc.stderr],
487                         [proc.stdin],
488                         [proc.stderr,proc.stdin])
489                     
490                     if proc.stderr in rdrdy:
491                         # use os.read for fully unbuffered behavior
492                         err.append(os.read(proc.stderr.fileno(), 4096))
493                     
494                     if proc.stdin in wrdy:
495                         proc.stdin.write(buf)
496                         buf = None
497                     
498                     if broken:
499                         break
500                 err.append(proc.stderr.read())
501                     
502                 proc._known_hosts = tmp_known_hosts
503                 proc.wait()
504                 return ((None,''.join(err)), proc)
505             elif hasattr(dest, 'write'):
506                 # file-like (but not file) dest
507                 proc = subprocess.Popen(args, 
508                         stdout = open('/dev/null','w'),
509                         stderr = subprocess.PIPE,
510                         stdin = source)
511                 
512                 buf = None
513                 err = []
514                 while True:
515                     rdrdy, wrdy, broken = os.select(
516                         [proc.stderr, proc.stdout],
517                         [],
518                         [proc.stderr, proc.stdout])
519                     
520                     if proc.stderr in rdrdy:
521                         # use os.read for fully unbuffered behavior
522                         err.append(os.read(proc.stderr.fileno(), 4096))
523                     
524                     if proc.stdout in rdrdy:
525                         # use os.read for fully unbuffered behavior
526                         dest.write(os.read(proc.stdout.fileno(), 4096))
527                     
528                     if broken:
529                         break
530                 err.append(proc.stderr.read())
531                     
532                 proc._known_hosts = tmp_known_hosts
533                 proc.wait()
534                 return ((None,''.join(err)), proc)
535             else:
536                 raise AssertionError, "Unreachable code reached! :-Q"
537         else:
538             # Parse destination as <user>@<server>:<path>
539             if isinstance(dest, basestring) and ':' in dest:
540                 remspec, path = dest.split(':',1)
541             elif isinstance(source, basestring) and ':' in source:
542                 remspec, path = source.split(':',1)
543             else:
544                 raise ValueError, "Both endpoints cannot be local"
545             user,host = remspec.rsplit('@',1)
546             
547             # plain scp
548             tmp_known_hosts = None
549             args = ['scp', '-q', '-p', '-C',
550                     # Don't bother with localhost. Makes test easier
551                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
552             if port:
553                 args.append('-P%d' % port)
554             if recursive:
555                 args.append('-r')
556             if ident_key:
557                 args.extend(('-i', ident_key))
558             if server_key:
559                 # Create a temporary server key file
560                 tmp_known_hosts = _make_server_key_args(
561                     server_key, host, port, args)
562             args.append(source)
563             args.append(dest)
564
565             # connects to the remote host and starts a remote connection
566             proc = subprocess.Popen(args, 
567                     stdout = subprocess.PIPE,
568                     stdin = subprocess.PIPE, 
569                     stderr = subprocess.PIPE)
570             proc._known_hosts = tmp_known_hosts
571             
572             comm = proc.communicate()
573             proc.wait()
574             return (comm, proc)
575  
576 def popen_ssh_subprocess(python_code, host, port, user, agent, 
577         python_path = None,
578         ident_key = None,
579         server_key = None,
580         tty = False):
581         if python_path:
582             python_path.replace("'", r"'\''")
583             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
584         else:
585             cmd = ""
586         # Uncomment for debug (to run everything under strace)
587         # We had to verify if strace works (cannot nest them)
588         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
589         #cmd += "$CMD "
590         #if self.mode == MODE_SSH:
591         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
592         cmd += "python -c '"
593         cmd += "import base64, os\n"
594         cmd += "cmd = \"\"\n"
595         cmd += "while True:\n"
596         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
597         cmd += " if cmd[-1] == \"\\n\": break\n"
598         cmd += "cmd = base64.b64decode(cmd)\n"
599         # Uncomment for debug
600         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
601         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
602         cmd += "exec(cmd)\n'"
603
604         tmp_known_hosts = None
605         args = ['ssh',
606                 # Don't bother with localhost. Makes test easier
607                 '-o', 'NoHostAuthenticationForLocalhost=yes',
608                 '-l', user, host]
609         if agent:
610             args.append('-A')
611         if port:
612             args.append('-p%d' % port)
613         if ident_key:
614             args.extend(('-i', ident_key))
615         if tty:
616             args.append('-t')
617         if server_key:
618             # Create a temporary server key file
619             tmp_known_hosts = _make_server_key_args(
620                 server_key, host, port, args)
621         args.append(cmd)
622
623         # connects to the remote host and starts a remote rpyc connection
624         proc = subprocess.Popen(args, 
625                 stdout = subprocess.PIPE,
626                 stdin = subprocess.PIPE, 
627                 stderr = subprocess.PIPE)
628         proc._known_hosts = tmp_known_hosts
629         
630         # send the command to execute
631         os.write(proc.stdin.fileno(),
632                 base64.b64encode(python_code) + "\n")
633         msg = os.read(proc.stdout.fileno(), 3)
634         if msg != "OK\n":
635             raise RuntimeError("Failed to start remote python interpreter")
636         return proc
637