Typos, type, environment, synchronization and other small fixes to Proxies
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import os.path
8 import resource
9 import select
10 import socket
11 import sys
12 import subprocess
13 import threading
14 import time
15 import traceback
16 import signal
17 import re
18 import tempfile
19
20 CTRL_SOCK = "ctrl.sock"
21 STD_ERR = "stderr.log"
22 MAX_FD = 1024
23
24 STOP_MSG = "STOP"
25
26 ERROR_LEVEL = 0
27 DEBUG_LEVEL = 1
28 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
29
30 if hasattr(os, "devnull"):
31     DEV_NULL = os.devnull
32 else:
33     DEV_NULL = "/dev/null"
34
35
36
37 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
38
39 def shell_escape(s):
40     """ Escapes strings so that they are safe to use as command-line arguments """
41     if SHELL_SAFE.match(s):
42         # safe string - no escaping needed
43         return s
44     else:
45         # unsafe string - escape
46         def escp(c):
47             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",):
48                 return c
49             else:
50                 return "'$'\\x%02x''" % (ord(c),)
51         s = ''.join(map(escp,s))
52         return "'%s'" % (s,)
53
54 class Server(object):
55     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
56         self._root_dir = root_dir
57         self._stop = False
58         self._ctrl_sock = None
59         self._log_level = log_level
60
61     def run(self):
62         try:
63             if self.daemonize():
64                 self.post_daemonize()
65                 self.loop()
66                 self.cleanup()
67                 # ref: "os._exit(0)"
68                 # can not return normally after fork beacuse no exec was done.
69                 # This means that if we don't do a os._exit(0) here the code that 
70                 # follows the call to "Server.run()" in the "caller code" will be 
71                 # executed... but by now it has already been executed after the 
72                 # first process (the one that did the first fork) returned.
73                 os._exit(0)
74         except:
75             self.log_error()
76             self.cleanup()
77             os._exit(0)
78
79     def daemonize(self):
80         # pipes for process synchronization
81         (r, w) = os.pipe()
82         
83         # build root folder
84         root = os.path.normpath(self._root_dir)
85         if not os.path.exists(root):
86             os.makedirs(root, 0755)
87
88         pid1 = os.fork()
89         if pid1 > 0:
90             os.close(w)
91             os.read(r, 1)
92             os.close(r)
93             # os.waitpid avoids leaving a <defunc> (zombie) process
94             st = os.waitpid(pid1, 0)[1]
95             if st:
96                 raise RuntimeError("Daemonization failed")
97             # return 0 to inform the caller method that this is not the 
98             # daemonized process
99             return 0
100         os.close(r)
101
102         # Decouple from parent environment.
103         os.chdir(self._root_dir)
104         os.umask(0)
105         os.setsid()
106
107         # fork 2
108         pid2 = os.fork()
109         if pid2 > 0:
110             # see ref: "os._exit(0)"
111             os._exit(0)
112
113         # close all open file descriptors.
114         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
115         if (max_fd == resource.RLIM_INFINITY):
116             max_fd = MAX_FD
117         for fd in range(3, max_fd):
118             if fd != w:
119                 try:
120                     os.close(fd)
121                 except OSError:
122                     pass
123
124         # Redirect standard file descriptors.
125         stdin = open(DEV_NULL, "r")
126         stderr = stdout = open(STD_ERR, "a", 0)
127         os.dup2(stdin.fileno(), sys.stdin.fileno())
128         # NOTE: sys.stdout.write will still be buffered, even if the file
129         # was opened with 0 buffer
130         os.dup2(stdout.fileno(), sys.stdout.fileno())
131         os.dup2(stderr.fileno(), sys.stderr.fileno())
132
133         # create control socket
134         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
135         self._ctrl_sock.bind(CTRL_SOCK)
136         self._ctrl_sock.listen(0)
137
138         # let the parent process know that the daemonization is finished
139         os.write(w, "\n")
140         os.close(w)
141         return 1
142
143     def post_daemonize(self):
144         pass
145
146     def loop(self):
147         while not self._stop:
148             conn, addr = self._ctrl_sock.accept()
149             conn.settimeout(5)
150             while not self._stop:
151                 try:
152                     msg = self.recv_msg(conn)
153                 except socket.timeout, e:
154                     break
155                     
156                 if msg == STOP_MSG:
157                     self._stop = True
158                     reply = self.stop_action()
159                 else:
160                     reply = self.reply_action(msg)
161                 
162                 try:
163                     self.send_reply(conn, reply)
164                 except socket.error:
165                     self.log_error()
166                     self.log_error("NOTICE: Awaiting for reconnection")
167                     break
168             try:
169                 conn.close()
170             except:
171                 # Doesn't matter
172                 self.log_error()
173
174     def recv_msg(self, conn):
175         data = ""
176         while True:
177             try:
178                 chunk = conn.recv(1024)
179             except OSError, e:
180                 if e.errno != errno.EINTR:
181                     raise
182                 if chunk == '':
183                     continue
184             if chunk:
185                 data += chunk
186                 if chunk[-1] == "\n":
187                     break
188             else:
189                 # empty chunk = EOF
190                 break
191         decoded = base64.b64decode(data)
192         return decoded.rstrip()
193
194     def send_reply(self, conn, reply):
195         encoded = base64.b64encode(reply)
196         conn.send("%s\n" % encoded)
197        
198     def cleanup(self):
199         try:
200             self._ctrl_sock.close()
201             os.remove(CTRL_SOCK)
202         except:
203             self.log_error()
204
205     def stop_action(self):
206         return "Stopping server"
207
208     def reply_action(self, msg):
209         return "Reply to: %s" % msg
210
211     def log_error(self, text = None, context = ''):
212         if text == None:
213             text = traceback.format_exc()
214         date = time.strftime("%Y-%m-%d %H:%M:%S")
215         if context:
216             context = " (%s)" % (context,)
217         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
218         return text
219
220     def log_debug(self, text):
221         if self._log_level == DEBUG_LEVEL:
222             date = time.strftime("%Y-%m-%d %H:%M:%S")
223             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
224
225 class Forwarder(object):
226     def __init__(self, root_dir = "."):
227         self._ctrl_sock = None
228         self._root_dir = root_dir
229         self._stop = False
230
231     def forward(self):
232         self.connect()
233         print >>sys.stderr, "READY."
234         while not self._stop:
235             data = self.read_data()
236             self.send_to_server(data)
237             data = self.recv_from_server()
238             self.write_data(data)
239         self.disconnect()
240
241     def read_data(self):
242         return sys.stdin.readline()
243
244     def write_data(self, data):
245         sys.stdout.write(data)
246         # sys.stdout.write is buffered, this is why we need to do a flush()
247         sys.stdout.flush()
248
249     def send_to_server(self, data):
250         try:
251             self._ctrl_sock.send(data)
252         except IOError, e:
253             if e.errno == errno.EPIPE:
254                 self.connect()
255                 self._ctrl_sock.send(data)
256             else:
257                 raise e
258         encoded = data.rstrip() 
259         msg = base64.b64decode(encoded)
260         if msg == STOP_MSG:
261             self._stop = True
262
263     def recv_from_server(self):
264         data = ""
265         while True:
266             try:
267                 chunk = self._ctrl_sock.recv(1024)
268             except OSError, e:
269                 if e.errno != errno.EINTR:
270                     raise
271                 if chunk == '':
272                     continue
273             data += chunk
274             if chunk[-1] == "\n":
275                 break
276         return data
277  
278     def connect(self):
279         self.disconnect()
280         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
281         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
282         self._ctrl_sock.connect(sock_addr)
283
284     def disconnect(self):
285         try:
286             self._ctrl_sock.close()
287         except:
288             pass
289
290 class Client(object):
291     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
292             agent = None, environment_setup = ""):
293         self.root_dir = root_dir
294         self.addr = (host, port)
295         self.user = user
296         self.agent = agent
297         self.environment_setup = environment_setup
298         self._stopped = False
299         self.connect()
300     
301     def __del__(self):
302         if self._process.poll() is None:
303             os.kill(self._process.pid, signal.SIGTERM)
304         self._process.wait()
305         
306     def connect(self):
307         root_dir = self.root_dir
308         (host, port) = self.addr
309         user = self.user
310         agent = self.agent
311         
312         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
313                 c.forward()" % (root_dir,)
314         if host != None:
315             self._process = popen_ssh_subprocess(python_code, host, port, 
316                     user, agent,
317                     environment_setup = self.environment_setup)
318             # popen_ssh_subprocess already waits for readiness
319             if self._process.poll():
320                 err = proc.stderr.read()
321                 raise RuntimeError("Client could not be reached: %s" % \
322                         err)
323         else:
324             self._process = subprocess.Popen(
325                     ["python", "-c", python_code],
326                     stdin = subprocess.PIPE, 
327                     stdout = subprocess.PIPE,
328                     stderr = subprocess.PIPE
329                 )
330                 
331         # Wait for the forwarder to be ready, otherwise nobody
332         # will be able to connect to it
333         helo = self._process.stderr.readline()
334         if helo != 'READY.\n':
335             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
336                     helo + self._process.stderr.read())
337         
338     def send_msg(self, msg):
339         encoded = base64.b64encode(msg)
340         data = "%s\n" % encoded
341         
342         try:
343             self._process.stdin.write(data)
344         except (IOError, ValueError):
345             # dead process, poll it to un-zombify
346             self._process.poll()
347             
348             # try again after reconnect
349             # If it fails again, though, give up
350             self.connect()
351             self._process.stdin.write(data)
352
353     def send_stop(self):
354         self.send_msg(STOP_MSG)
355         self._stopped = True
356
357     def read_reply(self):
358         data = self._process.stdout.readline()
359         encoded = data.rstrip() 
360         return base64.b64decode(encoded)
361
362 def _make_server_key_args(server_key, host, port, args):
363     """ 
364     Returns a reference to the created temporary file, and adds the
365     corresponding arguments to the given argument list.
366     
367     Make sure to hold onto it until the process is done with the file
368     """
369     if port is not None:
370         host = '%s:%s' % (host,port)
371     # Create a temporary server key file
372     tmp_known_hosts = tempfile.NamedTemporaryFile()
373     
374     # Add the intended host key
375     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
376     
377     # If we're not in strict mode, add user-configured keys
378     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
379         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
380         if os.access(user_hosts_path, os.R_OK):
381             f = open(user_hosts_path, "r")
382             tmp_known_hosts.write(f.read())
383             f.close()
384         
385     tmp_known_hosts.flush()
386     
387     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
388     return tmp_known_hosts
389
390 def popen_ssh_command(command, host, port, user, agent, 
391             stdin="", 
392             ident_key = None,
393             server_key = None,
394             tty = False):
395         """
396         Executes a remote commands, returns ((stdout,stderr),process)
397         """
398         if TRACE:
399             print "ssh", host, command
400         
401         tmp_known_hosts = None
402         args = ['ssh',
403                 # Don't bother with localhost. Makes test easier
404                 '-o', 'NoHostAuthenticationForLocalhost=yes',
405                 '-l', user, host]
406         if agent:
407             args.append('-A')
408         if port:
409             args.append('-p%d' % port)
410         if ident_key:
411             args.extend(('-i', ident_key))
412         if tty:
413             args.append('-t')
414         if server_key:
415             # Create a temporary server key file
416             tmp_known_hosts = _make_server_key_args(
417                 server_key, host, port, args)
418         args.append(command)
419
420         # connects to the remote host and starts a remote connection
421         proc = subprocess.Popen(args, 
422                 stdout = subprocess.PIPE,
423                 stdin = subprocess.PIPE, 
424                 stderr = subprocess.PIPE)
425         
426         # attach tempfile object to the process, to make sure the file stays
427         # alive until the process is finished with it
428         proc._known_hosts = tmp_known_hosts
429         
430         out, err = proc.communicate(stdin)
431         if TRACE:
432             print " -> ", out, err
433
434         return ((out, err), proc)
435  
436 def popen_scp(source, dest, 
437             port = None, 
438             agent = None, 
439             recursive = False,
440             ident_key = None,
441             server_key = None):
442         """
443         Copies from/to remote sites.
444         
445         Source and destination should have the user and host encoded
446         as per scp specs.
447         
448         If source is a file object, a special mode will be used to
449         create the remote file with the same contents.
450         
451         If dest is a file object, the remote file (source) will be
452         read and written into dest.
453         
454         In these modes, recursive cannot be True.
455         
456         Source can be a list of files to copy to a single destination,
457         in which case it is advised that the destination be a folder.
458         """
459         
460         if TRACE:
461             print "scp", source, dest
462         
463         if isinstance(source, file) or isinstance(dest, file) \
464                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
465             assert not recursive
466             
467             # Parse source/destination as <user>@<server>:<path>
468             if isinstance(dest, basestring) and ':' in dest:
469                 remspec, path = dest.split(':',1)
470             elif isinstance(source, basestring) and ':' in source:
471                 remspec, path = source.split(':',1)
472             else:
473                 raise ValueError, "Both endpoints cannot be local"
474             user,host = remspec.rsplit('@',1)
475             tmp_known_hosts = None
476             
477             args = ['ssh', '-l', user, '-C',
478                     # Don't bother with localhost. Makes test easier
479                     '-o', 'NoHostAuthenticationForLocalhost=yes',
480                     host ]
481             if port:
482                 args.append('-P%d' % port)
483             if ident_key:
484                 args.extend(('-i', ident_key))
485             if server_key:
486                 # Create a temporary server key file
487                 tmp_known_hosts = _make_server_key_args(
488                     server_key, host, port, args)
489             
490             if isinstance(source, file) or hasattr(source, 'read'):
491                 args.append('cat > %s' % (shell_escape(path),))
492             elif isinstance(dest, file) or hasattr(dest, 'write'):
493                 args.append('cat %s' % (shell_escape(path),))
494             else:
495                 raise AssertionError, "Unreachable code reached! :-Q"
496             
497             # connects to the remote host and starts a remote connection
498             if isinstance(source, file):
499                 proc = subprocess.Popen(args, 
500                         stdout = open('/dev/null','w'),
501                         stderr = subprocess.PIPE,
502                         stdin = source)
503                 err = proc.stderr.read()
504                 proc._known_hosts = tmp_known_hosts
505                 proc.wait()
506                 return ((None,err), proc)
507             elif isinstance(dest, file):
508                 proc = subprocess.Popen(args, 
509                         stdout = open('/dev/null','w'),
510                         stderr = subprocess.PIPE,
511                         stdin = source)
512                 err = proc.stderr.read()
513                 proc._known_hosts = tmp_known_hosts
514                 proc.wait()
515                 return ((None,err), proc)
516             elif hasattr(source, 'read'):
517                 # file-like (but not file) source
518                 proc = subprocess.Popen(args, 
519                         stdout = open('/dev/null','w'),
520                         stderr = subprocess.PIPE,
521                         stdin = subprocess.PIPE)
522                 
523                 buf = None
524                 err = []
525                 while True:
526                     if not buf:
527                         buf = source.read(4096)
528                     if not buf:
529                         #EOF
530                         break
531                     
532                     rdrdy, wrdy, broken = select.select(
533                         [proc.stderr],
534                         [proc.stdin],
535                         [proc.stderr,proc.stdin])
536                     
537                     if proc.stderr in rdrdy:
538                         # use os.read for fully unbuffered behavior
539                         err.append(os.read(proc.stderr.fileno(), 4096))
540                     
541                     if proc.stdin in wrdy:
542                         proc.stdin.write(buf)
543                         buf = None
544                     
545                     if broken:
546                         break
547                 proc.stdin.close()
548                 err.append(proc.stderr.read())
549                     
550                 proc._known_hosts = tmp_known_hosts
551                 proc.wait()
552                 return ((None,''.join(err)), proc)
553             elif hasattr(dest, 'write'):
554                 # file-like (but not file) dest
555                 proc = subprocess.Popen(args, 
556                         stdout = subprocess.PIPE,
557                         stderr = subprocess.PIPE,
558                         stdin = open('/dev/null','w'))
559                 
560                 buf = None
561                 err = []
562                 while True:
563                     rdrdy, wrdy, broken = select.select(
564                         [proc.stderr, proc.stdout],
565                         [],
566                         [proc.stderr, proc.stdout])
567                     
568                     if proc.stderr in rdrdy:
569                         # use os.read for fully unbuffered behavior
570                         err.append(os.read(proc.stderr.fileno(), 4096))
571                     
572                     if proc.stdout in rdrdy:
573                         # use os.read for fully unbuffered behavior
574                         buf = os.read(proc.stdout.fileno(), 4096)
575                         dest.write(buf)
576                         
577                         if not buf:
578                             #EOF
579                             break
580                     
581                     if broken:
582                         break
583                 err.append(proc.stderr.read())
584                     
585                 proc._known_hosts = tmp_known_hosts
586                 proc.wait()
587                 return ((None,''.join(err)), proc)
588             else:
589                 raise AssertionError, "Unreachable code reached! :-Q"
590         else:
591             # Parse destination as <user>@<server>:<path>
592             if isinstance(dest, basestring) and ':' in dest:
593                 remspec, path = dest.split(':',1)
594             elif isinstance(source, basestring) and ':' in source:
595                 remspec, path = source.split(':',1)
596             else:
597                 raise ValueError, "Both endpoints cannot be local"
598             user,host = remspec.rsplit('@',1)
599             
600             # plain scp
601             tmp_known_hosts = None
602             args = ['scp', '-q', '-p', '-C',
603                     # Don't bother with localhost. Makes test easier
604                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
605             if port:
606                 args.append('-P%d' % port)
607             if recursive:
608                 args.append('-r')
609             if ident_key:
610                 args.extend(('-i', ident_key))
611             if server_key:
612                 # Create a temporary server key file
613                 tmp_known_hosts = _make_server_key_args(
614                     server_key, host, port, args)
615             if isinstance(source,list):
616                 args.extend(source)
617             else:
618                 args.append(source)
619             args.append(dest)
620
621             # connects to the remote host and starts a remote connection
622             proc = subprocess.Popen(args, 
623                     stdout = subprocess.PIPE,
624                     stdin = subprocess.PIPE, 
625                     stderr = subprocess.PIPE)
626             proc._known_hosts = tmp_known_hosts
627             
628             comm = proc.communicate()
629             proc.wait()
630             return (comm, proc)
631  
632 def popen_ssh_subprocess(python_code, host, port, user, agent, 
633         python_path = None,
634         ident_key = None,
635         server_key = None,
636         tty = False,
637         environment_setup = "",
638         waitcommand = False):
639         cmd = ""
640         if python_path:
641             python_path.replace("'", r"'\''")
642             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
643             cmd += " ; "
644         if environment_setup:
645             cmd += environment_setup
646             cmd += " ; "
647         # Uncomment for debug (to run everything under strace)
648         # We had to verify if strace works (cannot nest them)
649         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
650         #cmd += "$CMD "
651         #if self.mode == MODE_SSH:
652         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
653         cmd += "python -c '"
654         cmd += "import base64, os\n"
655         cmd += "cmd = \"\"\n"
656         cmd += "while True:\n"
657         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
658         cmd += " if cmd[-1] == \"\\n\": break\n"
659         cmd += "cmd = base64.b64decode(cmd)\n"
660         # Uncomment for debug
661         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
662         if not waitcommand:
663             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
664         cmd += "exec(cmd)\n"
665         if waitcommand:
666             cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
667         cmd += "'"
668         
669         tmp_known_hosts = None
670         args = ['ssh',
671                 # Don't bother with localhost. Makes test easier
672                 '-o', 'NoHostAuthenticationForLocalhost=yes',
673                 '-l', user, host]
674         if agent:
675             args.append('-A')
676         if port:
677             args.append('-p%d' % port)
678         if ident_key:
679             args.extend(('-i', ident_key))
680         if tty:
681             args.append('-t')
682         if server_key:
683             # Create a temporary server key file
684             tmp_known_hosts = _make_server_key_args(
685                 server_key, host, port, args)
686         args.append(cmd)
687
688         # connects to the remote host and starts a remote rpyc connection
689         proc = subprocess.Popen(args, 
690                 stdout = subprocess.PIPE,
691                 stdin = subprocess.PIPE, 
692                 stderr = subprocess.PIPE)
693         proc._known_hosts = tmp_known_hosts
694         
695         # send the command to execute
696         os.write(proc.stdin.fileno(),
697                 base64.b64encode(python_code) + "\n")
698         msg = os.read(proc.stdout.fileno(), 3)
699         if msg != "OK\n":
700             raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % (
701                 msg, proc.stdout.read(), proc.stderr.read())
702         return proc
703