Add NEPI_STRICT_AUTH_MODE, when not enabled, it takes user-configured host SSH keys.
[nepi.git] / src / nepi / util / server.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import errno
6 import os
7 import resource
8 import select
9 import socket
10 import sys
11 import subprocess
12 import threading
13 import time
14 import traceback
15 import signal
16 import re
17 import tempfile
18
19 CTRL_SOCK = "ctrl.sock"
20 STD_ERR = "stderr.log"
21 MAX_FD = 1024
22
23 STOP_MSG = "STOP"
24
25 ERROR_LEVEL = 0
26 DEBUG_LEVEL = 1
27 TRACE = False
28
29 if hasattr(os, "devnull"):
30     DEV_NULL = os.devnull
31 else:
32     DEV_NULL = "/dev/null"
33
34
35
36 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
37
38 def shell_escape(s):
39     """ Escapes strings so that they are safe to use as command-line arguments """
40     if SHELL_SAFE.match(s):
41         # safe string - no escaping needed
42         return s
43     else:
44         # unsafe string - escape
45         s = s.replace("'","\\'")
46         return "'%s'" % (s,)
47
48 class Server(object):
49     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
50         self._root_dir = root_dir
51         self._stop = False
52         self._ctrl_sock = None
53         self._log_level = log_level
54
55     def run(self):
56         try:
57             if self.daemonize():
58                 self.post_daemonize()
59                 self.loop()
60                 self.cleanup()
61                 # ref: "os._exit(0)"
62                 # can not return normally after fork beacuse no exec was done.
63                 # This means that if we don't do a os._exit(0) here the code that 
64                 # follows the call to "Server.run()" in the "caller code" will be 
65                 # executed... but by now it has already been executed after the 
66                 # first process (the one that did the first fork) returned.
67                 os._exit(0)
68         except:
69             self.log_error()
70             self.cleanup()
71             os._exit(0)
72
73     def daemonize(self):
74         # pipes for process synchronization
75         (r, w) = os.pipe()
76
77         pid1 = os.fork()
78         if pid1 > 0:
79             os.close(w)
80             os.read(r, 1)
81             os.close(r)
82             # os.waitpid avoids leaving a <defunc> (zombie) process
83             st = os.waitpid(pid1, 0)[1]
84             if st:
85                 raise RuntimeError("Daemonization failed")
86             # return 0 to inform the caller method that this is not the 
87             # daemonized process
88             return 0
89         os.close(r)
90
91         # Decouple from parent environment.
92         os.chdir(self._root_dir)
93         os.umask(0)
94         os.setsid()
95
96         # fork 2
97         pid2 = os.fork()
98         if pid2 > 0:
99             # see ref: "os._exit(0)"
100             os._exit(0)
101
102         # close all open file descriptors.
103         max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
104         if (max_fd == resource.RLIM_INFINITY):
105             max_fd = MAX_FD
106         for fd in range(3, max_fd):
107             if fd != w:
108                 try:
109                     os.close(fd)
110                 except OSError:
111                     pass
112
113         # Redirect standard file descriptors.
114         stdin = open(DEV_NULL, "r")
115         stderr = stdout = open(STD_ERR, "a", 0)
116         os.dup2(stdin.fileno(), sys.stdin.fileno())
117         # NOTE: sys.stdout.write will still be buffered, even if the file
118         # was opened with 0 buffer
119         os.dup2(stdout.fileno(), sys.stdout.fileno())
120         os.dup2(stderr.fileno(), sys.stderr.fileno())
121
122         # create control socket
123         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
124         self._ctrl_sock.bind(CTRL_SOCK)
125         self._ctrl_sock.listen(0)
126
127         # let the parent process know that the daemonization is finished
128         os.write(w, "\n")
129         os.close(w)
130         return 1
131
132     def post_daemonize(self):
133         pass
134
135     def loop(self):
136         while not self._stop:
137             conn, addr = self._ctrl_sock.accept()
138             conn.settimeout(5)
139             while not self._stop:
140                 try:
141                     msg = self.recv_msg(conn)
142                 except socket.timeout, e:
143                     break
144                     
145                 if msg == STOP_MSG:
146                     self._stop = True
147                     reply = self.stop_action()
148                 else:
149                     reply = self.reply_action(msg)
150                 
151                 try:
152                     self.send_reply(conn, reply)
153                 except socket.error:
154                     self.log_error()
155                     self.log_error("NOTICE: Awaiting for reconnection")
156                     break
157             try:
158                 conn.close()
159             except:
160                 # Doesn't matter
161                 self.log_error()
162
163     def recv_msg(self, conn):
164         data = ""
165         while True:
166             try:
167                 chunk = conn.recv(1024)
168             except OSError, e:
169                 if e.errno != errno.EINTR:
170                     raise
171                 if chunk == '':
172                     continue
173             if chunk:
174                 data += chunk
175                 if chunk[-1] == "\n":
176                     break
177             else:
178                 # empty chunk = EOF
179                 break
180         decoded = base64.b64decode(data)
181         return decoded.rstrip()
182
183     def send_reply(self, conn, reply):
184         encoded = base64.b64encode(reply)
185         conn.send("%s\n" % encoded)
186        
187     def cleanup(self):
188         try:
189             self._ctrl_sock.close()
190             os.remove(CTRL_SOCK)
191         except:
192             self.log_error()
193
194     def stop_action(self):
195         return "Stopping server"
196
197     def reply_action(self, msg):
198         return "Reply to: %s" % msg
199
200     def log_error(self, text = None, context = ''):
201         if text == None:
202             text = traceback.format_exc()
203         date = time.strftime("%Y-%m-%d %H:%M:%S")
204         if context:
205             context = " (%s)" % (context,)
206         sys.stderr.write("ERROR%s: %s\n%s\n" % (context, date, text))
207         return text
208
209     def log_debug(self, text):
210         if self._log_level == DEBUG_LEVEL:
211             date = time.strftime("%Y-%m-%d %H:%M:%S")
212             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
213
214 class Forwarder(object):
215     def __init__(self, root_dir = "."):
216         self._ctrl_sock = None
217         self._root_dir = root_dir
218         self._stop = False
219
220     def forward(self):
221         self.connect()
222         print >>sys.stderr, "READY."
223         while not self._stop:
224             data = self.read_data()
225             self.send_to_server(data)
226             data = self.recv_from_server()
227             self.write_data(data)
228         self.disconnect()
229
230     def read_data(self):
231         return sys.stdin.readline()
232
233     def write_data(self, data):
234         sys.stdout.write(data)
235         # sys.stdout.write is buffered, this is why we need to do a flush()
236         sys.stdout.flush()
237
238     def send_to_server(self, data):
239         try:
240             self._ctrl_sock.send(data)
241         except IOError, e:
242             if e.errno == errno.EPIPE:
243                 self.connect()
244                 self._ctrl_sock.send(data)
245             else:
246                 raise e
247         encoded = data.rstrip() 
248         msg = base64.b64decode(encoded)
249         if msg == STOP_MSG:
250             self._stop = True
251
252     def recv_from_server(self):
253         data = ""
254         while True:
255             try:
256                 chunk = self._ctrl_sock.recv(1024)
257             except OSError, e:
258                 if e.errno != errno.EINTR:
259                     raise
260                 if chunk == '':
261                     continue
262             data += chunk
263             if chunk[-1] == "\n":
264                 break
265         return data
266  
267     def connect(self):
268         self.disconnect()
269         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
270         sock_addr = os.path.join(self._root_dir, CTRL_SOCK)
271         self._ctrl_sock.connect(sock_addr)
272
273     def disconnect(self):
274         try:
275             self._ctrl_sock.close()
276         except:
277             pass
278
279 class Client(object):
280     def __init__(self, root_dir = ".", host = None, port = None, user = None, 
281             agent = None):
282         self.root_dir = root_dir
283         self.addr = (host, port)
284         self.user = user
285         self.agent = agent
286         self._stopped = False
287         self.connect()
288     
289     def __del__(self):
290         if self._process.poll() is None:
291             os.kill(self._process.pid, signal.SIGTERM)
292         self._process.wait()
293         
294     def connect(self):
295         root_dir = self.root_dir
296         (host, port) = self.addr
297         user = self.user
298         agent = self.agent
299         
300         python_code = "from nepi.util import server;c=server.Forwarder(%r);\
301                 c.forward()" % (root_dir,)
302         if host != None:
303             self._process = popen_ssh_subprocess(python_code, host, port, 
304                     user, agent)
305             # popen_ssh_subprocess already waits for readiness
306         else:
307             self._process = subprocess.Popen(
308                     ["python", "-c", python_code],
309                     stdin = subprocess.PIPE, 
310                     stdout = subprocess.PIPE,
311                     stderr = subprocess.PIPE
312                 )
313                 
314         # Wait for the forwarder to be ready, otherwise nobody
315         # will be able to connect to it
316         helo = self._process.stderr.readline()
317         if helo != 'READY.\n':
318             raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
319                     helo + self._process.stderr.read())
320         
321     def send_msg(self, msg):
322         encoded = base64.b64encode(msg)
323         data = "%s\n" % encoded
324         
325         try:
326             self._process.stdin.write(data)
327         except (IOError, ValueError):
328             # dead process, poll it to un-zombify
329             self._process.poll()
330             
331             # try again after reconnect
332             # If it fails again, though, give up
333             self.connect()
334             self._process.stdin.write(data)
335
336     def send_stop(self):
337         self.send_msg(STOP_MSG)
338         self._stopped = True
339
340     def read_reply(self):
341         data = self._process.stdout.readline()
342         encoded = data.rstrip() 
343         return base64.b64decode(encoded)
344
345 def _make_server_key_args(server_key, host, port, args):
346     """ 
347     Returns a reference to the created temporary file, and adds the
348     corresponding arguments to the given argument list.
349     
350     Make sure to hold onto it until the process is done with the file
351     """
352     if port is not None:
353         host = '%s:%s' % (host,port)
354     # Create a temporary server key file
355     tmp_known_hosts = tempfile.NamedTemporaryFile()
356     
357     # Add the intended host key
358     tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
359     
360     # If we're not in strict mode, add user-configured keys
361     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
362         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
363         if os.access(user_hosts_path, os.R_OK):
364             f = open(user_hosts_path, "r")
365             tmp_known_hosts.write(f.read())
366             f.close()
367         
368     tmp_known_hosts.flush()
369     
370     args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
371     return tmp_known_hosts
372
373 def popen_ssh_command(command, host, port, user, agent, 
374             stdin="", 
375             ident_key = None,
376             server_key = None,
377             tty = False):
378         """
379         Executes a remote commands, returns ((stdout,stderr),process)
380         """
381         if TRACE:
382             print "ssh", host, command
383         
384         tmp_known_hosts = None
385         args = ['ssh',
386                 # Don't bother with localhost. Makes test easier
387                 '-o', 'NoHostAuthenticationForLocalhost=yes',
388                 '-l', user, host]
389         if agent:
390             args.append('-A')
391         if port:
392             args.append('-p%d' % port)
393         if ident_key:
394             args.extend(('-i', ident_key))
395         if tty:
396             args.append('-t')
397         if server_key:
398             # Create a temporary server key file
399             tmp_known_hosts = _make_server_key_args(
400                 server_key, host, port, args)
401         args.append(command)
402
403         # connects to the remote host and starts a remote connection
404         proc = subprocess.Popen(args, 
405                 stdout = subprocess.PIPE,
406                 stdin = subprocess.PIPE, 
407                 stderr = subprocess.PIPE)
408         
409         # attach tempfile object to the process, to make sure the file stays
410         # alive until the process is finished with it
411         proc._known_hosts = tmp_known_hosts
412         
413         return (proc.communicate(stdin), proc)
414  
415 def popen_scp(source, dest, 
416             port = None, 
417             agent = None, 
418             recursive = False,
419             ident_key = None,
420             server_key = None):
421         """
422         Copies from/to remote sites.
423         
424         Source and destination should have the user and host encoded
425         as per scp specs.
426         
427         If source is a file object, a special mode will be used to
428         create the remote file with the same contents.
429         
430         If dest is a file object, the remote file (source) will be
431         read and written into dest.
432         
433         In these modes, recursive cannot be True.
434         
435         Source can be a list of files to copy to a single destination,
436         in which case it is advised that the destination be a folder.
437         """
438         
439         if TRACE:
440             print "scp", source, dest
441         
442         if isinstance(source, file) or isinstance(dest, file) \
443                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
444             assert not recursive
445             
446             # Parse source/destination as <user>@<server>:<path>
447             if isinstance(dest, basestring) and ':' in dest:
448                 remspec, path = dest.split(':',1)
449             elif isinstance(source, basestring) and ':' in source:
450                 remspec, path = source.split(':',1)
451             else:
452                 raise ValueError, "Both endpoints cannot be local"
453             user,host = remspec.rsplit('@',1)
454             tmp_known_hosts = None
455             
456             args = ['ssh', '-l', user, '-C',
457                     # Don't bother with localhost. Makes test easier
458                     '-o', 'NoHostAuthenticationForLocalhost=yes',
459                     host ]
460             if port:
461                 args.append('-P%d' % port)
462             if ident_key:
463                 args.extend(('-i', ident_key))
464             if server_key:
465                 # Create a temporary server key file
466                 tmp_known_hosts = _make_server_key_args(
467                     server_key, host, port, args)
468             
469             if isinstance(source, file) or hasattr(source, 'read'):
470                 args.append('cat > %s' % (shell_escape(path),))
471             elif isinstance(dest, file) or hasattr(dest, 'write'):
472                 args.append('cat %s' % (shell_escape(path),))
473             else:
474                 raise AssertionError, "Unreachable code reached! :-Q"
475             
476             # connects to the remote host and starts a remote connection
477             if isinstance(source, file):
478                 proc = subprocess.Popen(args, 
479                         stdout = open('/dev/null','w'),
480                         stderr = subprocess.PIPE,
481                         stdin = source)
482                 err = proc.stderr.read()
483                 proc._known_hosts = tmp_known_hosts
484                 proc.wait()
485                 return ((None,err), proc)
486             elif isinstance(dest, file):
487                 proc = subprocess.Popen(args, 
488                         stdout = open('/dev/null','w'),
489                         stderr = subprocess.PIPE,
490                         stdin = source)
491                 err = proc.stderr.read()
492                 proc._known_hosts = tmp_known_hosts
493                 proc.wait()
494                 return ((None,err), proc)
495             elif hasattr(source, 'read'):
496                 # file-like (but not file) source
497                 proc = subprocess.Popen(args, 
498                         stdout = open('/dev/null','w'),
499                         stderr = subprocess.PIPE,
500                         stdin = subprocess.PIPE)
501                 
502                 buf = None
503                 err = []
504                 while True:
505                     if not buf:
506                         buf = source.read(4096)
507                     if not buf:
508                         #EOF
509                         break
510                     
511                     rdrdy, wrdy, broken = select.select(
512                         [proc.stderr],
513                         [proc.stdin],
514                         [proc.stderr,proc.stdin])
515                     
516                     if proc.stderr in rdrdy:
517                         # use os.read for fully unbuffered behavior
518                         err.append(os.read(proc.stderr.fileno(), 4096))
519                     
520                     if proc.stdin in wrdy:
521                         proc.stdin.write(buf)
522                         buf = None
523                     
524                     if broken:
525                         break
526                 proc.stdin.close()
527                 err.append(proc.stderr.read())
528                     
529                 proc._known_hosts = tmp_known_hosts
530                 proc.wait()
531                 return ((None,''.join(err)), proc)
532             elif hasattr(dest, 'write'):
533                 # file-like (but not file) dest
534                 proc = subprocess.Popen(args, 
535                         stdout = subprocess.PIPE,
536                         stderr = subprocess.PIPE,
537                         stdin = open('/dev/null','w'))
538                 
539                 buf = None
540                 err = []
541                 while True:
542                     rdrdy, wrdy, broken = select.select(
543                         [proc.stderr, proc.stdout],
544                         [],
545                         [proc.stderr, proc.stdout])
546                     
547                     if proc.stderr in rdrdy:
548                         # use os.read for fully unbuffered behavior
549                         err.append(os.read(proc.stderr.fileno(), 4096))
550                     
551                     if proc.stdout in rdrdy:
552                         # use os.read for fully unbuffered behavior
553                         buf = os.read(proc.stdout.fileno(), 4096)
554                         dest.write(buf)
555                         
556                         if not buf:
557                             #EOF
558                             break
559                     
560                     if broken:
561                         break
562                 err.append(proc.stderr.read())
563                     
564                 proc._known_hosts = tmp_known_hosts
565                 proc.wait()
566                 return ((None,''.join(err)), proc)
567             else:
568                 raise AssertionError, "Unreachable code reached! :-Q"
569         else:
570             # Parse destination as <user>@<server>:<path>
571             if isinstance(dest, basestring) and ':' in dest:
572                 remspec, path = dest.split(':',1)
573             elif isinstance(source, basestring) and ':' in source:
574                 remspec, path = source.split(':',1)
575             else:
576                 raise ValueError, "Both endpoints cannot be local"
577             user,host = remspec.rsplit('@',1)
578             
579             # plain scp
580             tmp_known_hosts = None
581             args = ['scp', '-q', '-p', '-C',
582                     # Don't bother with localhost. Makes test easier
583                     '-o', 'NoHostAuthenticationForLocalhost=yes' ]
584             if port:
585                 args.append('-P%d' % port)
586             if recursive:
587                 args.append('-r')
588             if ident_key:
589                 args.extend(('-i', ident_key))
590             if server_key:
591                 # Create a temporary server key file
592                 tmp_known_hosts = _make_server_key_args(
593                     server_key, host, port, args)
594             if isinstance(source,list):
595                 args.extend(source)
596             else:
597                 args.append(source)
598             args.append(dest)
599
600             # connects to the remote host and starts a remote connection
601             proc = subprocess.Popen(args, 
602                     stdout = subprocess.PIPE,
603                     stdin = subprocess.PIPE, 
604                     stderr = subprocess.PIPE)
605             proc._known_hosts = tmp_known_hosts
606             
607             comm = proc.communicate()
608             proc.wait()
609             return (comm, proc)
610  
611 def popen_ssh_subprocess(python_code, host, port, user, agent, 
612         python_path = None,
613         ident_key = None,
614         server_key = None,
615         tty = False):
616         if python_path:
617             python_path.replace("'", r"'\''")
618             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
619         else:
620             cmd = ""
621         # Uncomment for debug (to run everything under strace)
622         # We had to verify if strace works (cannot nest them)
623         #cmd += "if strace echo >/dev/null 2>&1; then CMD='strace -ff -tt -s 200 -o strace.out'; else CMD=''; fi\n"
624         #cmd += "$CMD "
625         #if self.mode == MODE_SSH:
626         #    cmd += "strace -f -tt -s 200 -o strace$$.out "
627         cmd += "python -c '"
628         cmd += "import base64, os\n"
629         cmd += "cmd = \"\"\n"
630         cmd += "while True:\n"
631         cmd += " cmd += os.read(0, 1)\n" # one byte from stdin
632         cmd += " if cmd[-1] == \"\\n\": break\n"
633         cmd += "cmd = base64.b64decode(cmd)\n"
634         # Uncomment for debug
635         #cmd += "os.write(2, \"Executing python code: %s\\n\" % cmd)\n"
636         cmd += "os.write(1, \"OK\\n\")\n" # send a sync message
637         cmd += "exec(cmd)\n'"
638
639         tmp_known_hosts = None
640         args = ['ssh',
641                 # Don't bother with localhost. Makes test easier
642                 '-o', 'NoHostAuthenticationForLocalhost=yes',
643                 '-l', user, host]
644         if agent:
645             args.append('-A')
646         if port:
647             args.append('-p%d' % port)
648         if ident_key:
649             args.extend(('-i', ident_key))
650         if tty:
651             args.append('-t')
652         if server_key:
653             # Create a temporary server key file
654             tmp_known_hosts = _make_server_key_args(
655                 server_key, host, port, args)
656         args.append(cmd)
657
658         # connects to the remote host and starts a remote rpyc connection
659         proc = subprocess.Popen(args, 
660                 stdout = subprocess.PIPE,
661                 stdin = subprocess.PIPE, 
662                 stderr = subprocess.PIPE)
663         proc._known_hosts = tmp_known_hosts
664         
665         # send the command to execute
666         os.write(proc.stdin.fileno(),
667                 base64.b64encode(python_code) + "\n")
668         msg = os.read(proc.stdout.fileno(), 3)
669         if msg != "OK\n":
670             raise RuntimeError("Failed to start remote python interpreter")
671         return proc
672