Ticket #29: introduce some "standard" box attributes to support testbed-in-testbed...
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
28
29 if hasattr(os, "devnull"):
30     DEV_NULL = os.devnull
31 else:
32     DEV_NULL = "/dev/null"
33
34
35
36 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
37
38 def shell_escape(s):
39     """ Escapes strings so that they are safe to use as command-line arguments """
40     if SHELL_SAFE.match(s):
41         # safe string - no escaping needed
42         return s
43     else:
44         # unsafe string - escape
45         s = s.replace("'","\\'")
46         return "'%s'" % (s,)
47
48 class Server(object):
49     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
50         self._root_dir = root_dir
51         self._stop = False
52         self._ctrl_sock = None
53         self._log_level = log_level
54
55     def run(self):
56         try:
57             if self.daemonize():
58                 self.post_daemonize()
59                 self.loop()
60                 self.cleanup()
61                 # ref: "os._exit(0)"
62                 # can not return normally after fork beacuse no exec was done.
63                 # This means that if we don't do a os._exit(0) here the code that 
64                 # follows the call to "Server.run()" in the "caller code" will be 
65                 # executed... but by now it has already been executed after the 
66                 # first process (the one that did the first fork) returned.
67                 os._exit(0)
68         except:
69             self.log_error()
70             self.cleanup()
71             os._exit(0)
72
73     def daemonize(self):
74         # pipes for process synchronization
75         (r, w) = os.pipe()
76
77         pid1 = os.fork()
78         if pid1 > 0:
79             os.close(w)
80             os.read(r, 1)
81             os.close(r)
82             # os.waitpid avoids leaving a <defunc> (zombie) process
83             st = os.waitpid(pid1, 0)[1]
84             if st:
85                 raise RuntimeError("Daemonization failed")
86             # return 0 to inform the caller method that this is not the 
87             # daemonized process
88             return 0
89         os.close(r)
90
91         # Decouple from parent environment.
92         os.chdir(self._root_dir)
93         os.umask(0)
94         os.setsid()
95
96         # fork 2
97         pid2 = os.fork()
98         if pid2 > 0:
99             # see ref: "os._exit(0)"
100             os._exit(0)
101
102         # close all open file descriptors.
103         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
104         if (max_fd == resource.RLIM_INFINITY):
105             max_fd = MAX_FD
106         for fd in range(3, max_fd):
107             if fd != w:
108                 try:
109                     os.close(fd)
110                 except OSError:
111                     pass
112
113         # Redirect standard file descriptors.
114         stdin = open(DEV_NULL, "r")
115         stderr = stdout = open(STD_ERR, "a", 0)
116         os.dup2(stdin.fileno(), sys.stdin.fileno())
117         # NOTE: sys.stdout.write will still be buffered, even if the file
118         # was opened with 0 buffer
119         os.dup2(stdout.fileno(), sys.stdout.fileno())
120         os.dup2(stderr.fileno(), sys.stderr.fileno())
121
122         # create control socket
123         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
124         self._ctrl_sock.bind(CTRL_SOCK)
125         self._ctrl_sock.listen(0)
126
127         # let the parent process know that the daemonization is finished
128         os.write(w, "\n")
129         os.close(w)
130         return 1
131
132     def post_daemonize(self):
133         pass
134
135     def loop(self):
136         while not self._stop:
137             conn, addr = self._ctrl_sock.accept()
138             conn.settimeout(5)
139             while not self._stop:
140                 try:
141                     msg = self.recv_msg(conn)
142                 except socket.timeout, e:
143                     break
144                     
145                 if msg == STOP_MSG:
146                     self._stop = True
147                     reply = self.stop_action()
148                 else:
149                     reply = self.reply_action(msg)
150                 
151                 try:
152                     self.send_reply(conn, reply)
153                 except socket.error:
154                     self.log_error()
155                     self.log_error("NOTICE: Awaiting for reconnection")
156                     break
157             try:
158                 conn.close()
159             except:
160                 # Doesn't matter
161                 self.log_error()
162
163     def recv_msg(self, conn):
164         data = ""
165         while True:
166             try:
167                 chunk = conn.recv(1024)
168             except OSError, e:
169                 if e.errno != errno.EINTR:
170                     raise
171                 if chunk == '':
172                     continue
173             if chunk:
174                 data += chunk
175                 if chunk[-1] == "\n":
176                     break
177             else:
178                 # empty chunk = EOF
179                 break
180         decoded = base64.b64decode(data)
181         return decoded.rstrip()
182
183     def send_reply(self, conn, reply):
184         encoded = base64.b64encode(reply)
185         conn.send("%s\n" % encoded)
186        
187     def cleanup(self):
188         try:
189             self._ctrl_sock.close()
190             os.remove(CTRL_SOCK)
191         except:
192             self.log_error()
193
194     def stop_action(self):
195         return "Stopping server"
196
197     def reply_action(self, msg):
198         return "Reply to: %s" % msg
199
200     def log_error(self, text = None, context = ''):
201         if text == None:
202             text = traceback.format_exc()
203         date = time.strftime("%Y-%m-%d %H:%M:%S")
204         if context:
205             context = " (%s)" % (context,)
206         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
207         return text
208
209     def log_debug(self, text):
210         if self._log_level == DEBUG_LEVEL:
211             date = time.strftime("%Y-%m-%d %H:%M:%S")
212             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
213
214 class Forwarder(object):
215     def __init__(self, root_dir = "."):
216         self._ctrl_sock = None
217         self._root_dir = root_dir
218         self._stop = False
219
220     def forward(self):
221         self.connect()
222         print >>sys.stderr, "READY."
223         while not self._stop:
224             data = self.read_data()
225             self.send_to_server(data)
226             data = self.recv_from_server()
227             self.write_data(data)
228         self.disconnect()
229
230     def read_data(self):
231         return sys.stdin.readline()
232
233     def write_data(self, data):
234         sys.stdout.write(data)
235         # sys.stdout.write is buffered, this is why we need to do a flush()
236         sys.stdout.flush()
237
238     def send_to_server(self, data):
239         try:
240             self._ctrl_sock.send(data)
241         except IOError, e:
242             if e.errno == errno.EPIPE:
243                 self.connect()
244                 self._ctrl_sock.send(data)
245             else:
246                 raise e
247         encoded = data.rstrip() 
248         msg = base64.b64decode(encoded)
249         if msg == STOP_MSG:
250             self._stop = True
251
252     def recv_from_server(self):
253         data = ""
254         while True:
255             try:
256                 chunk = self._ctrl_sock.recv(1024)
257             except OSError, e:
258                 if e.errno != errno.EINTR:
259                     raise
260                 if chunk == '':
261                     continue
262             data += chunk
263             if chunk[-1] == "\n":
264                 break
265         return data
266  
267     def connect(self):
268         self.disconnect()
269         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
270         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
271         self._ctrl_sock.connect(sock_addr)
272
273     def disconnect(self):
274         try:
275             self._ctrl_sock.close()
276         except:
277             pass
278
279 class Client(object):
280     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
281             agent = None):
282         self.root_dir = root_dir
283         self.addr = (host, port)
284         self.user = user
285         self.agent = agent
286         self._stopped = False
287         self.connect()
288     
289     def __del__(self):
290         if self._process.poll() is None:
291             os.kill(self._process.pid, signal.SIGTERM)
292         self._process.wait()
293         
294     def connect(self):
295         root_dir = self.root_dir
296         (host, port) = self.addr
297         user = self.user
298         agent = self.agent
299         
300         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
301                 c.forward()" % (root_dir,)
302         if host != None:
303             self._process = popen_ssh_subprocess(python_code, host, port, 
304                     user, agent)
305             # popen_ssh_subprocess already waits for readiness
306         else:
307             self._process = subprocess.Popen(
308                     ["python", "-c", python_code],
309                     stdin = subprocess.PIPE, 
310                     stdout = subprocess.PIPE,
311                     stderr = subprocess.PIPE
312                 )
313                 
314         # Wait for the forwarder to be ready, otherwise nobody
315         # will be able to connect to it
316         helo = self._process.stderr.readline()
317         if helo != 'READY.\n':
318             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
319                     helo + self._process.stderr.read())
320         
321     def send_msg(self, msg):
322         encoded = base64.b64encode(msg)
323         data = "%s\n" % encoded
324         
325         try:
326             self._process.stdin.write(data)
327         except (IOError, ValueError):
328             # dead process, poll it to un-zombify
329             self._process.poll()
330             
331             # try again after reconnect
332             # If it fails again, though, give up
333             self.connect()
334             self._process.stdin.write(data)
335
336     def send_stop(self):
337         self.send_msg(STOP_MSG)
338         self._stopped = True
339
340     def read_reply(self):
341         data = self._process.stdout.readline()
342         encoded = data.rstrip() 
343         return base64.b64decode(encoded)
344
345 def _make_server_key_args(server_key, host, port, args):
346     """ 
347     Returns a reference to the created temporary file, and adds the
348     corresponding arguments to the given argument list.
349     
350     Make sure to hold onto it until the process is done with the file
351     """
352     if port is not None:
353         host = '%s:%s' % (host,port)
354     # Create a temporary server key file
355     tmp_known_hosts = tempfile.NamedTemporaryFile()
356     
357     # Add the intended host key
358     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
359     
360     # If we're not in strict mode, add user-configured keys
361     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
362         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
363         if os.access(user_hosts_path, os.R_OK):
364             f = open(user_hosts_path, "r")
365             tmp_known_hosts.write(f.read())
366             f.close()
367         
368     tmp_known_hosts.flush()
369     
370     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
371     return tmp_known_hosts
372
373 def popen_ssh_command(command, host, port, user, agent, 
374             stdin="", 
375             ident_key = None,
376             server_key = None,
377             tty = False):
378         """
379         Executes a remote commands, returns ((stdout,stderr),process)
380         """
381         if TRACE:
382             print "ssh", host, command
383         
384         tmp_known_hosts = None
385         args = ['ssh',
386                 # Don't bother with localhost. Makes test easier
387                 '-o', 'NoHostAuthenticationForLocalhost=yes',
388                 '-l', user, host]
389         if agent:
390             args.append('-A')
391         if port:
392             args.append('-p%d' % port)
393         if ident_key:
394             args.extend(('-i', ident_key))
395         if tty:
396             args.append('-t')
397         if server_key:
398             # Create a temporary server key file
399             tmp_known_hosts = _make_server_key_args(
400                 server_key, host, port, args)
401         args.append(command)
402
403         # connects to the remote host and starts a remote connection
404         proc = subprocess.Popen(args, 
405                 stdout = subprocess.PIPE,
406                 stdin = subprocess.PIPE, 
407                 stderr = subprocess.PIPE)
408         
409         # attach tempfile object to the process, to make sure the file stays
410         # alive until the process is finished with it
411         proc._known_hosts = tmp_known_hosts
412         
413         out, err = proc.communicate(stdin)
414         if TRACE:
415             print " -> ", out, err
416
417         return ((out, err), proc)
418  
419 def popen_scp(source, dest, 
420             port = None, 
421             agent = None, 
422             recursive = False,
423             ident_key = None,
424             server_key = None):
425         """
426         Copies from/to remote sites.
427         
428         Source and destination should have the user and host encoded
429         as per scp specs.
430         
431         If source is a file object, a special mode will be used to
432         create the remote file with the same contents.
433         
434         If dest is a file object, the remote file (source) will be
435         read and written into dest.
436         
437         In these modes, recursive cannot be True.
438         
439         Source can be a list of files to copy to a single destination,
440         in which case it is advised that the destination be a folder.
441         """
442         
443         if TRACE:
444             print "scp", source, dest
445         
446         if isinstance(source, file) or isinstance(dest, file) \
447                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
448             assert not recursive
449             
450             # Parse source/destination as <user>@<server>:<path>
451             if isinstance(dest, basestring) and ':' in dest:
452                 remspec, path = dest.split(':',1)
453             elif isinstance(source, basestring) and ':' in source:
454                 remspec, path = source.split(':',1)
455             else:
456                 raise ValueError, "Both endpoints cannot be local"
457             user,host = remspec.rsplit('@',1)
458             tmp_known_hosts = None
459             
460             args = ['ssh', '-l', user, '-C',
461                     # Don't bother with localhost. Makes test easier
462                     '-o', 'NoHostAuthenticationForLocalhost=yes',
463                     host ]
464             if port:
465                 args.append('-P%d' % port)
466             if ident_key:
467                 args.extend(('-i', ident_key))
468             if server_key:
469                 # Create a temporary server key file
470                 tmp_known_hosts = _make_server_key_args(
471                     server_key, host, port, args)
472             
473             if isinstance(source, file) or hasattr(source, 'read'):
474                 args.append('cat > %s' % (shell_escape(path),))
475             elif isinstance(dest, file) or hasattr(dest, 'write'):
476                 args.append('cat %s' % (shell_escape(path),))
477             else:
478                 raise AssertionError, "Unreachable code reached! :-Q"
479             
480             # connects to the remote host and starts a remote connection
481             if isinstance(source, file):
482                 proc = subprocess.Popen(args, 
483                         stdout = open('/dev/null','w'),
484                         stderr = subprocess.PIPE,
485                         stdin = source)
486                 err = proc.stderr.read()
487                 proc._known_hosts = tmp_known_hosts
488                 proc.wait()
489                 return ((None,err), proc)
490             elif isinstance(dest, file):
491                 proc = subprocess.Popen(args, 
492                         stdout = open('/dev/null','w'),
493                         stderr = subprocess.PIPE,
494                         stdin = source)
495                 err = proc.stderr.read()
496                 proc._known_hosts = tmp_known_hosts
497                 proc.wait()
498                 return ((None,err), proc)
499             elif hasattr(source, 'read'):
500                 # file-like (but not file) source
501                 proc = subprocess.Popen(args, 
502                         stdout = open('/dev/null','w'),
503                         stderr = subprocess.PIPE,
504                         stdin = subprocess.PIPE)
505                 
506                 buf = None
507                 err = []
508                 while True:
509                     if not buf:
510                         buf = source.read(4096)
511                     if not buf:
512                         #EOF
513                         break
514                     
515                     rdrdy, wrdy, broken = select.select(
516                         [proc.stderr],
517                         [proc.stdin],
518                         [proc.stderr,proc.stdin])
519                     
520                     if proc.stderr in rdrdy:
521                         # use os.read for fully unbuffered behavior
522                         err.append(os.read(proc.stderr.fileno(), 4096))
523                     
524                     if proc.stdin in wrdy:
525                         proc.stdin.write(buf)
526                         buf = None
527                     
528                     if broken:
529                         break
530                 proc.stdin.close()
531                 err.append(proc.stderr.read())
532                     
533                 proc._known_hosts = tmp_known_hosts
534                 proc.wait()
535                 return ((None,''.join(err)), proc)
536             elif hasattr(dest, 'write'):
537                 # file-like (but not file) dest
538                 proc = subprocess.Popen(args, 
539                         stdout = subprocess.PIPE,
540                         stderr = subprocess.PIPE,
541                         stdin = open('/dev/null','w'))
542                 
543                 buf = None
544                 err = []
545                 while True:
546                     rdrdy, wrdy, broken = select.select(
547                         [proc.stderr, proc.stdout],
548                         [],
549                         [proc.stderr, proc.stdout])
550                     
551                     if proc.stderr in rdrdy:
552                         # use os.read for fully unbuffered behavior
553                         err.append(os.read(proc.stderr.fileno(), 4096))
554                     
555                     if proc.stdout in rdrdy:
556                         # use os.read for fully unbuffered behavior
557                         buf = os.read(proc.stdout.fileno(), 4096)
558                         dest.write(buf)
559                         
560                         if not buf:
561                             #EOF
562                             break
563                     
564                     if broken:
565                         break
566                 err.append(proc.stderr.read())
567                     
568                 proc._known_hosts = tmp_known_hosts
569                 proc.wait()
570                 return ((None,''.join(err)), proc)
571             else:
572                 raise AssertionError, "Unreachable code reached! :-Q"
573         else:
574             # Parse destination as <user>@<server>:<path>
575             if isinstance(dest, basestring) and ':' in dest:
576                 remspec, path = dest.split(':',1)
577             elif isinstance(source, basestring) and ':' in source:
578                 remspec, path = source.split(':',1)
579             else:
580                 raise ValueError, "Both endpoints cannot be local"
581             user,host = remspec.rsplit('@',1)
582             
583             # plain scp
584             tmp_known_hosts = None
585             args = ['scp', '-q', '-p', '-C',
586                     # Don't bother with localhost. Makes test easier
587                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
588             if port:
589                 args.append('-P%d' % port)
590             if recursive:
591                 args.append('-r')
592             if ident_key:
593                 args.extend(('-i', ident_key))
594             if server_key:
595                 # Create a temporary server key file
596                 tmp_known_hosts = _make_server_key_args(
597                     server_key, host, port, args)
598             if isinstance(source,list):
599                 args.extend(source)
600             else:
601                 args.append(source)
602             args.append(dest)
603
604             # connects to the remote host and starts a remote connection
605             proc = subprocess.Popen(args, 
606                     stdout = subprocess.PIPE,
607                     stdin = subprocess.PIPE, 
608                     stderr = subprocess.PIPE)
609             proc._known_hosts = tmp_known_hosts
610             
611             comm = proc.communicate()
612             proc.wait()
613             return (comm, proc)
614  
615 def popen_ssh_subprocess(python_code, host, port, user, agent, 
616         python_path = None,
617         ident_key = None,
618         server_key = None,
619         tty = False,
620         environment_setup = ""):
621         if python_path:
622             python_path.replace("'", r"'\''")
623             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
624         else:
625             cmd = ""
626         if environment_setup:
627             cmd += environment_setup
628             cmd += " "
629         # Uncomment for debug (to run everything under strace)
630         # We had to verify if strace works (cannot nest them)
631         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
632         #cmd += "$CMD "
633         #if self.mode == MODE_SSH:
634         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
635         cmd += "python -c '"
636         cmd += "import base64, os\n"
637         cmd += "cmd = \"\"\n"
638         cmd += "while True:\n"
639         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
640         cmd += " if cmd[-1] == \"\\n\": break\n"
641         cmd += "cmd = base64.b64decode(cmd)\n"
642         # Uncomment for debug
643         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
644         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
645         cmd += "exec(cmd)\n'"
646
647         tmp_known_hosts = None
648         args = ['ssh',
649                 # Don't bother with localhost. Makes test easier
650                 '-o', 'NoHostAuthenticationForLocalhost=yes',
651                 '-l', user, host]
652         if agent:
653             args.append('-A')
654         if port:
655             args.append('-p%d' % port)
656         if ident_key:
657             args.extend(('-i', ident_key))
658         if tty:
659             args.append('-t')
660         if server_key:
661             # Create a temporary server key file
662             tmp_known_hosts = _make_server_key_args(
663                 server_key, host, port, args)
664         args.append(cmd)
665
666         # connects to the remote host and starts a remote rpyc connection
667         proc = subprocess.Popen(args, 
668                 stdout = subprocess.PIPE,
669                 stdin = subprocess.PIPE, 
670                 stderr = subprocess.PIPE)
671         proc._known_hosts = tmp_known_hosts
672         
673         # send the command to execute
674         os.write(proc.stdin.fileno(),
675                 base64.b64encode(python_code) + "\n")
676         msg = os.read(proc.stdout.fileno(), 3)
677         if msg != "OK\n":
678             raise RuntimeError("Failed to start remote python interpreter")
679         return proc
680