Adding remote command execution for linux boxes
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyyyy_ssh_control_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port):
109     connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
110     if len(connkey) > 60:
111         connkey = hashlib.sha1(connkey).hexdigest()
112     return connkey
113
114 def rexec(command, host, user, 
115         port = None, 
116         agent = True,
117         sudo = False,
118         stdin = "", 
119         identity_file = None,
120         tty = False,
121         tty2 = False,
122         timeout = None,
123         retry = 0,
124         err_on_timeout = True,
125         connect_timeout = 30,
126         persistent = True):
127     """
128     Executes a remote command, returns ((stdout,stderr),process)
129     """
130     connkey = make_connkey(user, host, port)
131     args = ['ssh', '-C',
132             # Don't bother with localhost. Makes test easier
133             '-o', 'NoHostAuthenticationForLocalhost=yes',
134             # XXX: Possible security issue
135             # Avoid interactive requests to accept new host keys
136             '-o', 'StrictHostKeyChecking=no',
137             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
138             '-o', 'ConnectionAttempts=3',
139             '-o', 'ServerAliveInterval=30',
140             '-o', 'TCPKeepAlive=yes',
141             '-l', user, host]
142
143     if persistent and openssh_has_persist():
144         args.extend([
145             '-o', 'ControlMaster=auto',
146             '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
147             '-o', 'ControlPersist=60' ])
148     if agent:
149         args.append('-A')
150     if port:
151         args.append('-p%d' % port)
152     if identity_file:
153         args.extend(('-i', identity_file))
154     if tty:
155         args.append('-t')
156     elif tty2:
157         args.append('-t')
158         args.append('-t')
159     if sudo:
160         command = "sudo " + command
161     args.append(command)
162
163     for x in xrange(retry or 3):
164         # connects to the remote host and starts a remote connection
165         proc = subprocess.Popen(args, 
166                 stdout = subprocess.PIPE,
167                 stdin = subprocess.PIPE, 
168                 stderr = subprocess.PIPE)
169         
170         try:
171             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
172             if proc.poll():
173                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
174                     # SSH error, can safely retry
175                     continue
176                 elif retry:
177                     # Probably timed out or plain failed but can retry
178                     continue
179             break
180         except RuntimeError,e:
181             if retry <= 0:
182                 raise
183             retry -= 1
184         
185     return ((out, err), proc)
186
187 def rcopy(source, dest, host, user,
188         port = None, 
189         agent = True, 
190         recursive = False,
191         identity_file = None):
192     """
193     Copies file from/to remote sites.
194     
195     Source and destination should have the user and host encoded
196     as per scp specs.
197     
198     If source is a file object, a special mode will be used to
199     create the remote file with the same contents.
200     
201     If dest is a file object, the remote file (source) will be
202     read and written into dest.
203     
204     In these modes, recursive cannot be True.
205     
206     Source can be a list of files to copy to a single destination,
207     in which case it is advised that the destination be a folder.
208     """
209     
210     if isinstance(source, file) and source.tell() == 0:
211         source = source.name
212
213     elif hasattr(source, 'read'):
214         tmp = tempfile.NamedTemporaryFile()
215         while True:
216             buf = source.read(65536)
217             if buf:
218                 tmp.write(buf)
219             else:
220                 break
221         tmp.seek(0)
222         source = tmp.name
223     
224     if isinstance(source, file) or isinstance(dest, file) \
225             or hasattr(source, 'read')  or hasattr(dest, 'write'):
226         assert not recursive
227         
228         connkey = make_connkey(user,host,port)
229         args = ['ssh', '-l', user, '-C',
230                 # Don't bother with localhost. Makes test easier
231                 '-o', 'NoHostAuthenticationForLocalhost=yes',
232                 # XXX: Possible security issue
233                 # Avoid interactive requests to accept new host keys
234                 '-o', 'StrictHostKeyChecking=no',
235                 '-o', 'ConnectTimeout=30',
236                 '-o', 'ConnectionAttempts=3',
237                 '-o', 'ServerAliveInterval=30',
238                 '-o', 'TCPKeepAlive=yes',
239                 host ]
240         if openssh_has_persist():
241             args.extend([
242                 '-o', 'ControlMaster=auto',
243                 '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
244                 '-o', 'ControlPersist=60' ])
245         if port:
246             args.append('-P%d' % port)
247         if identity_file:
248             args.extend(('-i', identity_file))
249         
250         if isinstance(source, file) or hasattr(source, 'read'):
251             args.append('cat > %s' % dest)
252         elif isinstance(dest, file) or hasattr(dest, 'write'):
253             args.append('cat %s' % dest)
254         else:
255             raise AssertionError, "Unreachable code reached! :-Q"
256         
257         # connects to the remote host and starts a remote connection
258         if isinstance(source, file):
259             proc = subprocess.Popen(args, 
260                     stdout = open('/dev/null','w'),
261                     stderr = subprocess.PIPE,
262                     stdin = source)
263             err = proc.stderr.read()
264             eintr_retry(proc.wait)()
265             return ((None,err), proc)
266         elif isinstance(dest, file):
267             proc = subprocess.Popen(args, 
268                     stdout = open('/dev/null','w'),
269                     stderr = subprocess.PIPE,
270                     stdin = source)
271             err = proc.stderr.read()
272             eintr_retry(proc.wait)()
273             return ((None,err), proc)
274         elif hasattr(source, 'read'):
275             # file-like (but not file) source
276             proc = subprocess.Popen(args, 
277                     stdout = open('/dev/null','w'),
278                     stderr = subprocess.PIPE,
279                     stdin = subprocess.PIPE)
280             
281             buf = None
282             err = []
283             while True:
284                 if not buf:
285                     buf = source.read(4096)
286                 if not buf:
287                     #EOF
288                     break
289                 
290                 rdrdy, wrdy, broken = select.select(
291                     [proc.stderr],
292                     [proc.stdin],
293                     [proc.stderr,proc.stdin])
294                 
295                 if proc.stderr in rdrdy:
296                     # use os.read for fully unbuffered behavior
297                     err.append(os.read(proc.stderr.fileno(), 4096))
298                 
299                 if proc.stdin in wrdy:
300                     proc.stdin.write(buf)
301                     buf = None
302                 
303                 if broken:
304                     break
305             proc.stdin.close()
306             err.append(proc.stderr.read())
307                 
308             eintr_retry(proc.wait)()
309             return ((None,''.join(err)), proc)
310         elif hasattr(dest, 'write'):
311             # file-like (but not file) dest
312             proc = subprocess.Popen(args, 
313                     stdout = subprocess.PIPE,
314                     stderr = subprocess.PIPE,
315                     stdin = open('/dev/null','w'))
316             
317             buf = None
318             err = []
319             while True:
320                 rdrdy, wrdy, broken = select.select(
321                     [proc.stderr, proc.stdout],
322                     [],
323                     [proc.stderr, proc.stdout])
324                 
325                 if proc.stderr in rdrdy:
326                     # use os.read for fully unbuffered behavior
327                     err.append(os.read(proc.stderr.fileno(), 4096))
328                 
329                 if proc.stdout in rdrdy:
330                     # use os.read for fully unbuffered behavior
331                     buf = os.read(proc.stdout.fileno(), 4096)
332                     dest.write(buf)
333                     
334                     if not buf:
335                         #EOF
336                         break
337                 
338                 if broken:
339                     break
340             err.append(proc.stderr.read())
341                 
342             eintr_retry(proc.wait)()
343             return ((None,''.join(err)), proc)
344         else:
345             raise AssertionError, "Unreachable code reached! :-Q"
346     else:
347         # plain scp
348         args = ['scp', '-q', '-p', '-C',
349                 # Don't bother with localhost. Makes test easier
350                 '-o', 'NoHostAuthenticationForLocalhost=yes',
351                 # XXX: Possible security issue
352                 # Avoid interactive requests to accept new host keys
353                 '-o', 'StrictHostKeyChecking=no',
354                 '-o', 'ConnectTimeout=30',
355                 '-o', 'ConnectionAttempts=3',
356                 '-o', 'ServerAliveInterval=30',
357                 '-o', 'TCPKeepAlive=yes' ]
358                 
359         if port:
360             args.append('-P%d' % port)
361         if recursive:
362             args.append('-r')
363         if identity_file:
364             args.extend(('-i', identity_file))
365
366         if isinstance(source,list):
367             args.extend(source)
368         else:
369             if openssh_has_persist():
370                 connkey = make_connkey(user,host,port)
371                 args.extend([
372                     '-o', 'ControlMaster=no',
373                     '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
374             args.append(source)
375         args.append("%s@%s:%s" %(user, host, dest))
376
377         # connects to the remote host and starts a remote connection
378         proc = subprocess.Popen(args, 
379                 stdout = subprocess.PIPE,
380                 stdin = subprocess.PIPE, 
381                 stderr = subprocess.PIPE)
382         
383         comm = proc.communicate()
384         eintr_retry(proc.wait)()
385         return (comm, proc)
386
387 def rspawn(command, pidfile, 
388         stdout = '/dev/null', 
389         stderr = STDOUT, 
390         stdin = '/dev/null', 
391         home = None, 
392         create_home = False, 
393         host = None, 
394         port = None, 
395         user = None, 
396         agent = None, 
397         sudo = False,
398         identity_file = None, 
399         tty = False):
400     """
401     Spawn a remote command such that it will continue working asynchronously.
402     
403     Parameters:
404         command: the command to run - it should be a single line.
405         
406         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
407         
408         stdout: path of a file to redirect standard output to - must be a string.
409             Defaults to /dev/null
410         stderr: path of a file to redirect standard error to - string or the special STDOUT value
411             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
412         stdin: path of a file with input to be piped into the command's standard input
413         
414         home: path of a folder to use as working directory - should exist, unless you specify create_home
415         
416         create_home: if True, the home folder will be created first with mkdir -p
417         
418         sudo: whether the command needs to be executed as root
419         
420         host/port/user/agent/identity_file: see rexec
421     
422     Returns:
423         (stdout, stderr), process
424         
425         Of the spawning process, which only captures errors at spawning time.
426         Usually only useful for diagnostics.
427     """
428     # Start process in a "daemonized" way, using nohup and heavy
429     # stdin/out redirection to avoid connection issues
430     if stderr is STDOUT:
431         stderr = '&1'
432     else:
433         stderr = ' ' + stderr
434     
435     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
436         'command' : command,
437         'pidfile' : pidfile,
438         
439         'stdout' : stdout,
440         'stderr' : stderr,
441         'stdin' : stdin,
442     }
443     
444     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
445             'command' : daemon_command,
446             
447             'sudo' : 'sudo -S' if sudo else '',
448             
449             'pidfile' : pidfile,
450             'gohome' : 'cd %s ; ' % home if home else '',
451             'create' : 'mkdir -p %s ; ' % home if create_home else '',
452         }
453
454     (out,err), proc = rexec(
455         cmd,
456         host = host,
457         port = port,
458         user = user,
459         agent = agent,
460         identity_file = identity_file,
461         tty = tty
462         )
463     
464     if proc.wait():
465         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
466
467     return (out,err),proc
468
469 @eintr_retry
470 def rcheck_pid(pidfile,
471         host = None, 
472         port = None, 
473         user = None, 
474         agent = None, 
475         identity_file = None):
476     """
477     Check the pidfile of a process spawned with remote_spawn.
478     
479     Parameters:
480         pidfile: the pidfile passed to remote_span
481         
482         host/port/user/agent/identity_file: see rexec
483     
484     Returns:
485         
486         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
487         or None if the pidfile isn't valid yet (maybe the process is still starting).
488     """
489
490     (out,err),proc = rexec(
491         "cat %(pidfile)s" % {
492             'pidfile' : pidfile,
493         },
494         host = host,
495         port = port,
496         user = user,
497         agent = agent,
498         identity_file = identity_file
499         )
500         
501     if proc.wait():
502         return None
503     
504     if out:
505         try:
506             return map(int,out.strip().split(' ',1))
507         except:
508             # Ignore, many ways to fail that don't matter that much
509             return None
510
511 @eintr_retry
512 def rstatus(pid, ppid, 
513         host = None, 
514         port = None, 
515         user = None, 
516         agent = None, 
517         identity_file = None):
518     """
519     Check the status of a process spawned with remote_spawn.
520     
521     Parameters:
522         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
523         
524         host/port/user/agent/identity_file: see rexec
525     
526     Returns:
527         
528         One of NOT_STARTED, RUNNING, FINISHED
529     """
530
531     (out,err),proc = rexec(
532         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
533             'ppid' : ppid,
534             'pid' : pid,
535         },
536         host = host,
537         port = port,
538         user = user,
539         agent = agent,
540         identity_file = identity_file
541         )
542     
543     if proc.wait():
544         return NOT_STARTED
545     
546     status = False
547     if out:
548         try:
549             status = bool(int(out.strip()))
550         except:
551             if out or err:
552                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
553             # Ignore, many ways to fail that don't matter that much
554             return NOT_STARTED
555     return RUNNING if status else FINISHED
556     
557
558 @eintr_retry
559 def rkill(pid, ppid, 
560         host = None, 
561         port = None, 
562         user = None, 
563         agent = None, 
564         sudo = False,
565         identity_file = None, 
566         nowait = False):
567     """
568     Kill a process spawned with remote_spawn.
569     
570     First tries a SIGTERM, and if the process does not end in 10 seconds,
571     it sends a SIGKILL.
572     
573     Parameters:
574         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
575         
576         sudo: whether the command was run with sudo - careful killing like this.
577         
578         host/port/user/agent/identity_file: see rexec
579     
580     Returns:
581         
582         Nothing, should have killed the process
583     """
584     
585     if sudo:
586         subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
587     else:
588         subkill = ""
589     cmd = """
590 SUBKILL="%(subkill)s" ;
591 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
592 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
593 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
594     sleep 0.2 
595     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
596         break
597     else
598         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
599         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
600     fi
601     sleep 1.8
602 done
603 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
604     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
605     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
606 fi
607 """
608     if nowait:
609         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
610
611     (out,err),proc = rexec(
612         cmd % {
613             'ppid' : ppid,
614             'pid' : pid,
615             'sudo' : 'sudo -S' if sudo else '',
616             'subkill' : subkill,
617         },
618         host = host,
619         port = port,
620         user = user,
621         agent = agent,
622         identity_file = identity_file
623         )
624     
625     # wait, don't leave zombies around
626     proc.wait()
627
628 # POSIX
629 def _communicate(self, input, timeout=None, err_on_timeout=True):
630     read_set = []
631     write_set = []
632     stdout = None # Return
633     stderr = None # Return
634     
635     killed = False
636     
637     if timeout is not None:
638         timelimit = time.time() + timeout
639         killtime = timelimit + 4
640         bailtime = timelimit + 4
641
642     if self.stdin:
643         # Flush stdio buffer.  This might block, if the user has
644         # been writing to .stdin in an uncontrolled fashion.
645         self.stdin.flush()
646         if input:
647             write_set.append(self.stdin)
648         else:
649             self.stdin.close()
650     if self.stdout:
651         read_set.append(self.stdout)
652         stdout = []
653     if self.stderr:
654         read_set.append(self.stderr)
655         stderr = []
656
657     input_offset = 0
658     while read_set or write_set:
659         if timeout is not None:
660             curtime = time.time()
661             if timeout is None or curtime > timelimit:
662                 if curtime > bailtime:
663                     break
664                 elif curtime > killtime:
665                     signum = signal.SIGKILL
666                 else:
667                     signum = signal.SIGTERM
668                 # Lets kill it
669                 os.kill(self.pid, signum)
670                 select_timeout = 0.5
671             else:
672                 select_timeout = timelimit - curtime + 0.1
673         else:
674             select_timeout = 1.0
675         
676         if select_timeout > 1.0:
677             select_timeout = 1.0
678             
679         try:
680             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
681         except select.error,e:
682             if e[0] != 4:
683                 raise
684             else:
685                 continue
686         
687         if not rlist and not wlist and not xlist and self.poll() is not None:
688             # timeout and process exited, say bye
689             break
690
691         if self.stdin in wlist:
692             # When select has indicated that the file is writable,
693             # we can write up to PIPE_BUF bytes without risk
694             # blocking.  POSIX defines PIPE_BUF >= 512
695             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
696             input_offset += bytes_written
697             if input_offset >= len(input):
698                 self.stdin.close()
699                 write_set.remove(self.stdin)
700
701         if self.stdout in rlist:
702             data = os.read(self.stdout.fileno(), 1024)
703             if data == "":
704                 self.stdout.close()
705                 read_set.remove(self.stdout)
706             stdout.append(data)
707
708         if self.stderr in rlist:
709             data = os.read(self.stderr.fileno(), 1024)
710             if data == "":
711                 self.stderr.close()
712                 read_set.remove(self.stderr)
713             stderr.append(data)
714     
715     # All data exchanged.  Translate lists into strings.
716     if stdout is not None:
717         stdout = ''.join(stdout)
718     if stderr is not None:
719         stderr = ''.join(stderr)
720
721     # Translate newlines, if requested.  We cannot let the file
722     # object do the translation: It is based on stdio, which is
723     # impossible to combine with select (unless forcing no
724     # buffering).
725     if self.universal_newlines and hasattr(file, 'newlines'):
726         if stdout:
727             stdout = self._translate_newlines(stdout)
728         if stderr:
729             stderr = self._translate_newlines(stderr)
730
731     if killed and err_on_timeout:
732         errcode = self.poll()
733         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
734     else:
735         if killed:
736             self.poll()
737         else:
738             self.wait()
739         return (stdout, stderr)
740