872143dc0919d807c49976d2daee70625612d08b
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyyyy_ssh_control_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port):
109     connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
110     if len(connkey) > 60:
111         connkey = hashlib.sha1(connkey).hexdigest()
112     return connkey
113
114 def rexec(command, host, user, 
115         port = None, 
116         agent = True,
117         sudo = False,
118         stdin = "", 
119         identity_file = None,
120         tty = False,
121         timeout = None,
122         retry = 0,
123         err_on_timeout = True,
124         connect_timeout = 30,
125         persistent = True):
126     """
127     Executes a remote command, returns ((stdout,stderr),process)
128     """
129     connkey = make_connkey(user, host, port)
130     args = ['ssh', '-C',
131             # Don't bother with localhost. Makes test easier
132             '-o', 'NoHostAuthenticationForLocalhost=yes',
133             # XXX: Possible security issue
134             # Avoid interactive requests to accept new host keys
135             '-o', 'StrictHostKeyChecking=no',
136             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
137             '-o', 'ConnectionAttempts=3',
138             '-o', 'ServerAliveInterval=30',
139             '-o', 'TCPKeepAlive=yes',
140             '-l', user, host]
141
142     if persistent and openssh_has_persist():
143         args.extend([
144             '-o', 'ControlMaster=auto',
145             '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
146             '-o', 'ControlPersist=60' ])
147     if agent:
148         args.append('-A')
149     if port:
150         args.append('-p%d' % port)
151     if identity_file:
152         args.extend(('-i', identity_file))
153     if tty:
154         args.append('-t')
155         if sudo:
156             args.append('-t')
157
158     if sudo:
159         command = "sudo " + command
160     args.append(command)
161
162     for x in xrange(retry or 3):
163         # connects to the remote host and starts a remote connection
164         proc = subprocess.Popen(args, 
165                 stdout = subprocess.PIPE,
166                 stdin = subprocess.PIPE, 
167                 stderr = subprocess.PIPE)
168         
169         try:
170             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
171             if proc.poll():
172                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
173                     # SSH error, can safely retry
174                     continue
175                 elif retry:
176                     # Probably timed out or plain failed but can retry
177                     continue
178             break
179         except RuntimeError,e:
180             if retry <= 0:
181                 raise
182             retry -= 1
183         
184     return ((out, err), proc)
185
186 def rcopy(source, dest, host, user,
187         port = None, 
188         agent = True, 
189         recursive = False,
190         identity_file = None):
191     """
192     Copies file from/to remote sites.
193     
194     Source and destination should have the user and host encoded
195     as per scp specs.
196     
197     If source is a file object, a special mode will be used to
198     create the remote file with the same contents.
199     
200     If dest is a file object, the remote file (source) will be
201     read and written into dest.
202     
203     In these modes, recursive cannot be True.
204     
205     Source can be a list of files to copy to a single destination,
206     in which case it is advised that the destination be a folder.
207     """
208     
209     if isinstance(source, file) and source.tell() == 0:
210         source = source.name
211
212     elif hasattr(source, 'read'):
213         tmp = tempfile.NamedTemporaryFile()
214         while True:
215             buf = source.read(65536)
216             if buf:
217                 tmp.write(buf)
218             else:
219                 break
220         tmp.seek(0)
221         source = tmp.name
222     
223     if isinstance(source, file) or isinstance(dest, file) \
224             or hasattr(source, 'read')  or hasattr(dest, 'write'):
225         assert not recursive
226         
227         connkey = make_connkey(user,host,port)
228         args = ['ssh', '-l', user, '-C',
229                 # Don't bother with localhost. Makes test easier
230                 '-o', 'NoHostAuthenticationForLocalhost=yes',
231                 # XXX: Possible security issue
232                 # Avoid interactive requests to accept new host keys
233                 '-o', 'StrictHostKeyChecking=no',
234                 '-o', 'ConnectTimeout=30',
235                 '-o', 'ConnectionAttempts=3',
236                 '-o', 'ServerAliveInterval=30',
237                 '-o', 'TCPKeepAlive=yes',
238                 host ]
239         if openssh_has_persist():
240             args.extend([
241                 '-o', 'ControlMaster=auto',
242                 '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
243                 '-o', 'ControlPersist=60' ])
244         if port:
245             args.append('-P%d' % port)
246         if identity_file:
247             args.extend(('-i', identity_file))
248         
249         if isinstance(source, file) or hasattr(source, 'read'):
250             args.append('cat > %s' % dest)
251         elif isinstance(dest, file) or hasattr(dest, 'write'):
252             args.append('cat %s' % dest)
253         else:
254             raise AssertionError, "Unreachable code reached! :-Q"
255         
256         # connects to the remote host and starts a remote connection
257         if isinstance(source, file):
258             proc = subprocess.Popen(args, 
259                     stdout = open('/dev/null','w'),
260                     stderr = subprocess.PIPE,
261                     stdin = source)
262             err = proc.stderr.read()
263             eintr_retry(proc.wait)()
264             return ((None,err), proc)
265         elif isinstance(dest, file):
266             proc = subprocess.Popen(args, 
267                     stdout = open('/dev/null','w'),
268                     stderr = subprocess.PIPE,
269                     stdin = source)
270             err = proc.stderr.read()
271             eintr_retry(proc.wait)()
272             return ((None,err), proc)
273         elif hasattr(source, 'read'):
274             # file-like (but not file) source
275             proc = subprocess.Popen(args, 
276                     stdout = open('/dev/null','w'),
277                     stderr = subprocess.PIPE,
278                     stdin = subprocess.PIPE)
279             
280             buf = None
281             err = []
282             while True:
283                 if not buf:
284                     buf = source.read(4096)
285                 if not buf:
286                     #EOF
287                     break
288                 
289                 rdrdy, wrdy, broken = select.select(
290                     [proc.stderr],
291                     [proc.stdin],
292                     [proc.stderr,proc.stdin])
293                 
294                 if proc.stderr in rdrdy:
295                     # use os.read for fully unbuffered behavior
296                     err.append(os.read(proc.stderr.fileno(), 4096))
297                 
298                 if proc.stdin in wrdy:
299                     proc.stdin.write(buf)
300                     buf = None
301                 
302                 if broken:
303                     break
304             proc.stdin.close()
305             err.append(proc.stderr.read())
306                 
307             eintr_retry(proc.wait)()
308             return ((None,''.join(err)), proc)
309         elif hasattr(dest, 'write'):
310             # file-like (but not file) dest
311             proc = subprocess.Popen(args, 
312                     stdout = subprocess.PIPE,
313                     stderr = subprocess.PIPE,
314                     stdin = open('/dev/null','w'))
315             
316             buf = None
317             err = []
318             while True:
319                 rdrdy, wrdy, broken = select.select(
320                     [proc.stderr, proc.stdout],
321                     [],
322                     [proc.stderr, proc.stdout])
323                 
324                 if proc.stderr in rdrdy:
325                     # use os.read for fully unbuffered behavior
326                     err.append(os.read(proc.stderr.fileno(), 4096))
327                 
328                 if proc.stdout in rdrdy:
329                     # use os.read for fully unbuffered behavior
330                     buf = os.read(proc.stdout.fileno(), 4096)
331                     dest.write(buf)
332                     
333                     if not buf:
334                         #EOF
335                         break
336                 
337                 if broken:
338                     break
339             err.append(proc.stderr.read())
340                 
341             eintr_retry(proc.wait)()
342             return ((None,''.join(err)), proc)
343         else:
344             raise AssertionError, "Unreachable code reached! :-Q"
345     else:
346         # plain scp
347         args = ['scp', '-q', '-p', '-C',
348                 # Don't bother with localhost. Makes test easier
349                 '-o', 'NoHostAuthenticationForLocalhost=yes',
350                 # XXX: Possible security issue
351                 # Avoid interactive requests to accept new host keys
352                 '-o', 'StrictHostKeyChecking=no',
353                 '-o', 'ConnectTimeout=30',
354                 '-o', 'ConnectionAttempts=3',
355                 '-o', 'ServerAliveInterval=30',
356                 '-o', 'TCPKeepAlive=yes' ]
357                 
358         if port:
359             args.append('-P%d' % port)
360         if recursive:
361             args.append('-r')
362         if identity_file:
363             args.extend(('-i', identity_file))
364
365         if isinstance(source,list):
366             args.extend(source)
367         else:
368             if openssh_has_persist():
369                 connkey = make_connkey(user,host,port)
370                 args.extend([
371                     '-o', 'ControlMaster=no',
372                     '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
373             args.append(source)
374         args.append("%s@%s:%s" %(user, host, dest))
375
376         # connects to the remote host and starts a remote connection
377         proc = subprocess.Popen(args, 
378                 stdout = subprocess.PIPE,
379                 stdin = subprocess.PIPE, 
380                 stderr = subprocess.PIPE)
381         
382         comm = proc.communicate()
383         eintr_retry(proc.wait)()
384         return (comm, proc)
385
386 def rspawn(command, pidfile, 
387         stdout = '/dev/null', 
388         stderr = STDOUT, 
389         stdin = '/dev/null', 
390         home = None, 
391         create_home = False, 
392         host = None, 
393         port = None, 
394         user = None, 
395         agent = None, 
396         sudo = False,
397         identity_file = None, 
398         tty = False):
399     """
400     Spawn a remote command such that it will continue working asynchronously.
401     
402     Parameters:
403         command: the command to run - it should be a single line.
404         
405         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
406         
407         stdout: path of a file to redirect standard output to - must be a string.
408             Defaults to /dev/null
409         stderr: path of a file to redirect standard error to - string or the special STDOUT value
410             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
411         stdin: path of a file with input to be piped into the command's standard input
412         
413         home: path of a folder to use as working directory - should exist, unless you specify create_home
414         
415         create_home: if True, the home folder will be created first with mkdir -p
416         
417         sudo: whether the command needs to be executed as root
418         
419         host/port/user/agent/identity_file: see rexec
420     
421     Returns:
422         (stdout, stderr), process
423         
424         Of the spawning process, which only captures errors at spawning time.
425         Usually only useful for diagnostics.
426     """
427     # Start process in a "daemonized" way, using nohup and heavy
428     # stdin/out redirection to avoid connection issues
429     if stderr is STDOUT:
430         stderr = '&1'
431     else:
432         stderr = ' ' + stderr
433     
434     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
435         'command' : command,
436         'pidfile' : pidfile,
437         
438         'stdout' : stdout,
439         'stderr' : stderr,
440         'stdin' : stdin,
441     }
442     
443     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
444             'command' : daemon_command,
445             
446             'sudo' : 'sudo -S' if sudo else '',
447             
448             'pidfile' : pidfile,
449             'gohome' : 'cd %s ; ' % home if home else '',
450             'create' : 'mkdir -p %s ; ' % home if create_home else '',
451         }
452
453     (out,err), proc = rexec(
454         cmd,
455         host = host,
456         port = port,
457         user = user,
458         agent = agent,
459         identity_file = identity_file,
460         tty = tty
461         )
462     
463     if proc.wait():
464         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
465
466     return (out,err),proc
467
468 @eintr_retry
469 def rcheck_pid(pidfile,
470         host = None, 
471         port = None, 
472         user = None, 
473         agent = None, 
474         identity_file = None):
475     """
476     Check the pidfile of a process spawned with remote_spawn.
477     
478     Parameters:
479         pidfile: the pidfile passed to remote_span
480         
481         host/port/user/agent/identity_file: see rexec
482     
483     Returns:
484         
485         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
486         or None if the pidfile isn't valid yet (maybe the process is still starting).
487     """
488
489     (out,err),proc = rexec(
490         "cat %(pidfile)s" % {
491             'pidfile' : pidfile,
492         },
493         host = host,
494         port = port,
495         user = user,
496         agent = agent,
497         identity_file = identity_file
498         )
499         
500     if proc.wait():
501         return None
502     
503     if out:
504         try:
505             return map(int,out.strip().split(' ',1))
506         except:
507             # Ignore, many ways to fail that don't matter that much
508             return None
509
510 @eintr_retry
511 def rstatus(pid, ppid, 
512         host = None, 
513         port = None, 
514         user = None, 
515         agent = None, 
516         identity_file = None):
517     """
518     Check the status of a process spawned with remote_spawn.
519     
520     Parameters:
521         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
522         
523         host/port/user/agent/identity_file: see rexec
524     
525     Returns:
526         
527         One of NOT_STARTED, RUNNING, FINISHED
528     """
529
530     (out,err),proc = rexec(
531         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
532             'ppid' : ppid,
533             'pid' : pid,
534         },
535         host = host,
536         port = port,
537         user = user,
538         agent = agent,
539         identity_file = identity_file
540         )
541     
542     if proc.wait():
543         return NOT_STARTED
544     
545     status = False
546     if out:
547         try:
548             status = bool(int(out.strip()))
549         except:
550             if out or err:
551                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
552             # Ignore, many ways to fail that don't matter that much
553             return NOT_STARTED
554     return RUNNING if status else FINISHED
555     
556
557 @eintr_retry
558 def rkill(pid, ppid, 
559         host = None, 
560         port = None, 
561         user = None, 
562         agent = None, 
563         sudo = False,
564         identity_file = None, 
565         nowait = False):
566     """
567     Kill a process spawned with remote_spawn.
568     
569     First tries a SIGTERM, and if the process does not end in 10 seconds,
570     it sends a SIGKILL.
571     
572     Parameters:
573         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
574         
575         sudo: whether the command was run with sudo - careful killing like this.
576         
577         host/port/user/agent/identity_file: see rexec
578     
579     Returns:
580         
581         Nothing, should have killed the process
582     """
583     
584     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
585     cmd = """
586 SUBKILL="%(subkill)s" ;
587 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
588 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
589 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
590     sleep 0.2 
591     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
592         break
593     else
594         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
595         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
596     fi
597     sleep 1.8
598 done
599 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
600     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
601     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
602 fi
603 """
604     if nowait:
605         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
606
607     (out,err),proc = rexec(
608         cmd % {
609             'ppid' : ppid,
610             'pid' : pid,
611             'sudo' : 'sudo -S' if sudo else '',
612             'subkill' : subkill,
613         },
614         host = host,
615         port = port,
616         user = user,
617         agent = agent,
618         identity_file = identity_file
619         )
620     
621     # wait, don't leave zombies around
622     proc.wait()
623
624 # POSIX
625 def _communicate(self, input, timeout=None, err_on_timeout=True):
626     read_set = []
627     write_set = []
628     stdout = None # Return
629     stderr = None # Return
630     
631     killed = False
632     
633     if timeout is not None:
634         timelimit = time.time() + timeout
635         killtime = timelimit + 4
636         bailtime = timelimit + 4
637
638     if self.stdin:
639         # Flush stdio buffer.  This might block, if the user has
640         # been writing to .stdin in an uncontrolled fashion.
641         self.stdin.flush()
642         if input:
643             write_set.append(self.stdin)
644         else:
645             self.stdin.close()
646     if self.stdout:
647         read_set.append(self.stdout)
648         stdout = []
649     if self.stderr:
650         read_set.append(self.stderr)
651         stderr = []
652
653     input_offset = 0
654     while read_set or write_set:
655         if timeout is not None:
656             curtime = time.time()
657             if timeout is None or curtime > timelimit:
658                 if curtime > bailtime:
659                     break
660                 elif curtime > killtime:
661                     signum = signal.SIGKILL
662                 else:
663                     signum = signal.SIGTERM
664                 # Lets kill it
665                 os.kill(self.pid, signum)
666                 select_timeout = 0.5
667             else:
668                 select_timeout = timelimit - curtime + 0.1
669         else:
670             select_timeout = 1.0
671         
672         if select_timeout > 1.0:
673             select_timeout = 1.0
674             
675         try:
676             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
677         except select.error,e:
678             if e[0] != 4:
679                 raise
680             else:
681                 continue
682         
683         if not rlist and not wlist and not xlist and self.poll() is not None:
684             # timeout and process exited, say bye
685             break
686
687         if self.stdin in wlist:
688             # When select has indicated that the file is writable,
689             # we can write up to PIPE_BUF bytes without risk
690             # blocking.  POSIX defines PIPE_BUF >= 512
691             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
692             input_offset += bytes_written
693             if input_offset >= len(input):
694                 self.stdin.close()
695                 write_set.remove(self.stdin)
696
697         if self.stdout in rlist:
698             data = os.read(self.stdout.fileno(), 1024)
699             if data == "":
700                 self.stdout.close()
701                 read_set.remove(self.stdout)
702             stdout.append(data)
703
704         if self.stderr in rlist:
705             data = os.read(self.stderr.fileno(), 1024)
706             if data == "":
707                 self.stderr.close()
708                 read_set.remove(self.stderr)
709             stderr.append(data)
710     
711     # All data exchanged.  Translate lists into strings.
712     if stdout is not None:
713         stdout = ''.join(stdout)
714     if stderr is not None:
715         stderr = ''.join(stderr)
716
717     # Translate newlines, if requested.  We cannot let the file
718     # object do the translation: It is based on stdio, which is
719     # impossible to combine with select (unless forcing no
720     # buffering).
721     if self.universal_newlines and hasattr(file, 'newlines'):
722         if stdout:
723             stdout = self._translate_newlines(stdout)
724         if stderr:
725             stderr = self._translate_newlines(stderr)
726
727     if killed and err_on_timeout:
728         errcode = self.poll()
729         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
730     else:
731         if killed:
732             self.poll()
733         else:
734             self.wait()
735         return (stdout, stderr)
736