NECo: A tool to design and run experiments on arbitrary platforms.
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyyyy_ssh_control_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port):
109     connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
110     if len(connkey) > 60:
111         connkey = hashlib.sha1(connkey).hexdigest()
112     return connkey
113
114 def rexec(command, host, user, 
115         port = None, 
116         agent = True,
117         sudo = False,
118         stdin = "", 
119         identity_file = None,
120         tty = False,
121         tty2 = False,
122         timeout = None,
123         retry = 0,
124         err_on_timeout = True,
125         connect_timeout = 30,
126         persistent = True):
127     """
128     Executes a remote command, returns ((stdout,stderr),process)
129     """
130     connkey = make_connkey(user, host, port)
131     args = ['ssh', '-C',
132             # Don't bother with localhost. Makes test easier
133             '-o', 'NoHostAuthenticationForLocalhost=yes',
134             # XXX: Possible security issue
135             # Avoid interactive requests to accept new host keys
136             '-o', 'StrictHostKeyChecking=no',
137             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
138             '-o', 'ConnectionAttempts=3',
139             '-o', 'ServerAliveInterval=30',
140             '-o', 'TCPKeepAlive=yes',
141             '-l', user, host]
142
143     if persistent and openssh_has_persist():
144         args.extend([
145             '-o', 'ControlMaster=auto',
146             '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
147             '-o', 'ControlPersist=60' ])
148     if agent:
149         args.append('-A')
150     if port:
151         args.append('-p%d' % port)
152     if identity_file:
153         args.extend(('-i', identity_file))
154     if tty:
155         args.append('-t')
156     elif tty2:
157         args.append('-t')
158         args.append('-t')
159     if sudo:
160         command = "sudo " + command
161     args.append(command)
162
163     print " ".join(args)
164
165     for x in xrange(retry or 3):
166         # connects to the remote host and starts a remote connection
167         proc = subprocess.Popen(args, 
168                 stdout = subprocess.PIPE,
169                 stdin = subprocess.PIPE, 
170                 stderr = subprocess.PIPE)
171         
172         try:
173             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
174             if proc.poll():
175                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
176                     # SSH error, can safely retry
177                     continue
178                 elif retry:
179                     # Probably timed out or plain failed but can retry
180                     continue
181             break
182         except RuntimeError,e:
183             if retry <= 0:
184                 raise
185             retry -= 1
186         
187     return ((out, err), proc)
188
189 def rcopy(source, dest, host, user,
190         port = None, 
191         agent = True, 
192         recursive = False,
193         identity_file = None):
194     """
195     Copies file from/to remote sites.
196     
197     Source and destination should have the user and host encoded
198     as per scp specs.
199     
200     If source is a file object, a special mode will be used to
201     create the remote file with the same contents.
202     
203     If dest is a file object, the remote file (source) will be
204     read and written into dest.
205     
206     In these modes, recursive cannot be True.
207     
208     Source can be a list of files to copy to a single destination,
209     in which case it is advised that the destination be a folder.
210     """
211     
212     if isinstance(source, file) and source.tell() == 0:
213         source = source.name
214
215     elif hasattr(source, 'read'):
216         tmp = tempfile.NamedTemporaryFile()
217         while True:
218             buf = source.read(65536)
219             if buf:
220                 tmp.write(buf)
221             else:
222                 break
223         tmp.seek(0)
224         source = tmp.name
225     
226     if isinstance(source, file) or isinstance(dest, file) \
227             or hasattr(source, 'read')  or hasattr(dest, 'write'):
228         assert not recursive
229         
230         connkey = make_connkey(user,host,port)
231         args = ['ssh', '-l', user, '-C',
232                 # Don't bother with localhost. Makes test easier
233                 '-o', 'NoHostAuthenticationForLocalhost=yes',
234                 # XXX: Possible security issue
235                 # Avoid interactive requests to accept new host keys
236                 '-o', 'StrictHostKeyChecking=no',
237                 '-o', 'ConnectTimeout=30',
238                 '-o', 'ConnectionAttempts=3',
239                 '-o', 'ServerAliveInterval=30',
240                 '-o', 'TCPKeepAlive=yes',
241                 host ]
242         if openssh_has_persist():
243             args.extend([
244                 '-o', 'ControlMaster=auto',
245                 '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
246                 '-o', 'ControlPersist=60' ])
247         if port:
248             args.append('-P%d' % port)
249         if identity_file:
250             args.extend(('-i', identity_file))
251         
252         if isinstance(source, file) or hasattr(source, 'read'):
253             args.append('cat > %s' % (shell_escape(dest),))
254         elif isinstance(dest, file) or hasattr(dest, 'write'):
255             args.append('cat %s' % (shell_escape(dest),))
256         else:
257             raise AssertionError, "Unreachable code reached! :-Q"
258         
259         # connects to the remote host and starts a remote connection
260         if isinstance(source, file):
261             proc = subprocess.Popen(args, 
262                     stdout = open('/dev/null','w'),
263                     stderr = subprocess.PIPE,
264                     stdin = source)
265             err = proc.stderr.read()
266             eintr_retry(proc.wait)()
267             return ((None,err), proc)
268         elif isinstance(dest, file):
269             proc = subprocess.Popen(args, 
270                     stdout = open('/dev/null','w'),
271                     stderr = subprocess.PIPE,
272                     stdin = source)
273             err = proc.stderr.read()
274             eintr_retry(proc.wait)()
275             return ((None,err), proc)
276         elif hasattr(source, 'read'):
277             # file-like (but not file) source
278             proc = subprocess.Popen(args, 
279                     stdout = open('/dev/null','w'),
280                     stderr = subprocess.PIPE,
281                     stdin = subprocess.PIPE)
282             
283             buf = None
284             err = []
285             while True:
286                 if not buf:
287                     buf = source.read(4096)
288                 if not buf:
289                     #EOF
290                     break
291                 
292                 rdrdy, wrdy, broken = select.select(
293                     [proc.stderr],
294                     [proc.stdin],
295                     [proc.stderr,proc.stdin])
296                 
297                 if proc.stderr in rdrdy:
298                     # use os.read for fully unbuffered behavior
299                     err.append(os.read(proc.stderr.fileno(), 4096))
300                 
301                 if proc.stdin in wrdy:
302                     proc.stdin.write(buf)
303                     buf = None
304                 
305                 if broken:
306                     break
307             proc.stdin.close()
308             err.append(proc.stderr.read())
309                 
310             eintr_retry(proc.wait)()
311             return ((None,''.join(err)), proc)
312         elif hasattr(dest, 'write'):
313             # file-like (but not file) dest
314             proc = subprocess.Popen(args, 
315                     stdout = subprocess.PIPE,
316                     stderr = subprocess.PIPE,
317                     stdin = open('/dev/null','w'))
318             
319             buf = None
320             err = []
321             while True:
322                 rdrdy, wrdy, broken = select.select(
323                     [proc.stderr, proc.stdout],
324                     [],
325                     [proc.stderr, proc.stdout])
326                 
327                 if proc.stderr in rdrdy:
328                     # use os.read for fully unbuffered behavior
329                     err.append(os.read(proc.stderr.fileno(), 4096))
330                 
331                 if proc.stdout in rdrdy:
332                     # use os.read for fully unbuffered behavior
333                     buf = os.read(proc.stdout.fileno(), 4096)
334                     dest.write(buf)
335                     
336                     if not buf:
337                         #EOF
338                         break
339                 
340                 if broken:
341                     break
342             err.append(proc.stderr.read())
343                 
344             eintr_retry(proc.wait)()
345             return ((None,''.join(err)), proc)
346         else:
347             raise AssertionError, "Unreachable code reached! :-Q"
348     else:
349         # plain scp
350         args = ['scp', '-q', '-p', '-C',
351                 # Don't bother with localhost. Makes test easier
352                 '-o', 'NoHostAuthenticationForLocalhost=yes',
353                 # XXX: Possible security issue
354                 # Avoid interactive requests to accept new host keys
355                 '-o', 'StrictHostKeyChecking=no',
356                 '-o', 'ConnectTimeout=30',
357                 '-o', 'ConnectionAttempts=3',
358                 '-o', 'ServerAliveInterval=30',
359                 '-o', 'TCPKeepAlive=yes' ]
360                 
361         if port:
362             args.append('-P%d' % port)
363         if recursive:
364             args.append('-r')
365         if identity_file:
366             args.extend(('-i', identity_file))
367
368         if isinstance(source,list):
369             args.extend(source)
370         else:
371             if openssh_has_persist():
372                 connkey = make_connkey(user,host,port)
373                 args.extend([
374                     '-o', 'ControlMaster=no',
375                     '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
376             args.append(source)
377         args.append("%s@%s:%s" %(user, host, dest))
378
379         # connects to the remote host and starts a remote connection
380         proc = subprocess.Popen(args, 
381                 stdout = subprocess.PIPE,
382                 stdin = subprocess.PIPE, 
383                 stderr = subprocess.PIPE)
384         
385         comm = proc.communicate()
386         eintr_retry(proc.wait)()
387         return (comm, proc)
388
389 def rspawn(command, pidfile, 
390         stdout = '/dev/null', 
391         stderr = STDOUT, 
392         stdin = '/dev/null', 
393         home = None, 
394         create_home = False, 
395         host = None, 
396         port = None, 
397         user = None, 
398         agent = None, 
399         sudo = False,
400         identity_file = None, 
401         tty = False):
402     """
403     Spawn a remote command such that it will continue working asynchronously.
404     
405     Parameters:
406         command: the command to run - it should be a single line.
407         
408         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
409         
410         stdout: path of a file to redirect standard output to - must be a string.
411             Defaults to /dev/null
412         stderr: path of a file to redirect standard error to - string or the special STDOUT value
413             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
414         stdin: path of a file with input to be piped into the command's standard input
415         
416         home: path of a folder to use as working directory - should exist, unless you specify create_home
417         
418         create_home: if True, the home folder will be created first with mkdir -p
419         
420         sudo: whether the command needs to be executed as root
421         
422         host/port/user/agent/identity_file: see rexec
423     
424     Returns:
425         (stdout, stderr), process
426         
427         Of the spawning process, which only captures errors at spawning time.
428         Usually only useful for diagnostics.
429     """
430     # Start process in a "daemonized" way, using nohup and heavy
431     # stdin/out redirection to avoid connection issues
432     if stderr is STDOUT:
433         stderr = '&1'
434     else:
435         stderr = ' ' + stderr
436     
437     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
438         'command' : command,
439         'pidfile' : shell_escape(pidfile),
440         
441         'stdout' : stdout,
442         'stderr' : stderr,
443         'stdin' : stdin,
444     }
445     
446     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
447             'command' : shell_escape(daemon_command),
448             
449             'sudo' : 'sudo -S' if sudo else '',
450             
451             'pidfile' : shell_escape(pidfile),
452             'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
453             'create' : 'mkdir -p %s ; ' % (shell_escape,) if create_home else '',
454         }
455
456     (out,err), proc = rexec(
457         cmd,
458         host = host,
459         port = port,
460         user = user,
461         agent = agent,
462         identity_file = identity_file,
463         tty = tty
464         )
465     
466     if proc.wait():
467         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
468
469     return (out,err),proc
470
471 @eintr_retry
472 def rcheck_pid(pidfile,
473         host = None, 
474         port = None, 
475         user = None, 
476         agent = None, 
477         identity_file = None):
478     """
479     Check the pidfile of a process spawned with remote_spawn.
480     
481     Parameters:
482         pidfile: the pidfile passed to remote_span
483         
484         host/port/user/agent/identity_file: see rexec
485     
486     Returns:
487         
488         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
489         or None if the pidfile isn't valid yet (maybe the process is still starting).
490     """
491
492     (out,err),proc = rexec(
493         "cat %(pidfile)s" % {
494             'pidfile' : pidfile,
495         },
496         host = host,
497         port = port,
498         user = user,
499         agent = agent,
500         identity_file = identity_file
501         )
502         
503     if proc.wait():
504         return None
505     
506     if out:
507         try:
508             return map(int,out.strip().split(' ',1))
509         except:
510             # Ignore, many ways to fail that don't matter that much
511             return None
512
513 @eintr_retry
514 def rstatus(pid, ppid, 
515         host = None, 
516         port = None, 
517         user = None, 
518         agent = None, 
519         identity_file = None):
520     """
521     Check the status of a process spawned with remote_spawn.
522     
523     Parameters:
524         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
525         
526         host/port/user/agent/identity_file: see rexec
527     
528     Returns:
529         
530         One of NOT_STARTED, RUNNING, FINISHED
531     """
532
533     (out,err),proc = rexec(
534         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
535             'ppid' : ppid,
536             'pid' : pid,
537         },
538         host = host,
539         port = port,
540         user = user,
541         agent = agent,
542         identity_file = identity_file
543         )
544     
545     if proc.wait():
546         return NOT_STARTED
547     
548     status = False
549     if out:
550         try:
551             status = bool(int(out.strip()))
552         except:
553             if out or err:
554                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
555             # Ignore, many ways to fail that don't matter that much
556             return NOT_STARTED
557     return RUNNING if status else FINISHED
558     
559
560 @eintr_retry
561 def rkill(pid, ppid, 
562         host = None, 
563         port = None, 
564         user = None, 
565         agent = None, 
566         sudo = False,
567         identity_file = None, 
568         nowait = False):
569     """
570     Kill a process spawned with remote_spawn.
571     
572     First tries a SIGTERM, and if the process does not end in 10 seconds,
573     it sends a SIGKILL.
574     
575     Parameters:
576         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
577         
578         sudo: whether the command was run with sudo - careful killing like this.
579         
580         host/port/user/agent/identity_file: see rexec
581     
582     Returns:
583         
584         Nothing, should have killed the process
585     """
586     
587     if sudo:
588         subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
589     else:
590         subkill = ""
591     cmd = """
592 SUBKILL="%(subkill)s" ;
593 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
594 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
595 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
596     sleep 0.2 
597     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
598         break
599     else
600         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
601         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
602     fi
603     sleep 1.8
604 done
605 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
606     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
607     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
608 fi
609 """
610     if nowait:
611         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
612
613     (out,err),proc = rexec(
614         cmd % {
615             'ppid' : ppid,
616             'pid' : pid,
617             'sudo' : 'sudo -S' if sudo else '',
618             'subkill' : subkill,
619         },
620         host = host,
621         port = port,
622         user = user,
623         agent = agent,
624         identity_file = identity_file
625         )
626     
627     # wait, don't leave zombies around
628     proc.wait()
629
630 # POSIX
631 def _communicate(self, input, timeout=None, err_on_timeout=True):
632     read_set = []
633     write_set = []
634     stdout = None # Return
635     stderr = None # Return
636     
637     killed = False
638     
639     if timeout is not None:
640         timelimit = time.time() + timeout
641         killtime = timelimit + 4
642         bailtime = timelimit + 4
643
644     if self.stdin:
645         # Flush stdio buffer.  This might block, if the user has
646         # been writing to .stdin in an uncontrolled fashion.
647         self.stdin.flush()
648         if input:
649             write_set.append(self.stdin)
650         else:
651             self.stdin.close()
652     if self.stdout:
653         read_set.append(self.stdout)
654         stdout = []
655     if self.stderr:
656         read_set.append(self.stderr)
657         stderr = []
658
659     input_offset = 0
660     while read_set or write_set:
661         if timeout is not None:
662             curtime = time.time()
663             if timeout is None or curtime > timelimit:
664                 if curtime > bailtime:
665                     break
666                 elif curtime > killtime:
667                     signum = signal.SIGKILL
668                 else:
669                     signum = signal.SIGTERM
670                 # Lets kill it
671                 os.kill(self.pid, signum)
672                 select_timeout = 0.5
673             else:
674                 select_timeout = timelimit - curtime + 0.1
675         else:
676             select_timeout = 1.0
677         
678         if select_timeout > 1.0:
679             select_timeout = 1.0
680             
681         try:
682             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
683         except select.error,e:
684             if e[0] != 4:
685                 raise
686             else:
687                 continue
688         
689         if not rlist and not wlist and not xlist and self.poll() is not None:
690             # timeout and process exited, say bye
691             break
692
693         if self.stdin in wlist:
694             # When select has indicated that the file is writable,
695             # we can write up to PIPE_BUF bytes without risk
696             # blocking.  POSIX defines PIPE_BUF >= 512
697             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
698             input_offset += bytes_written
699             if input_offset >= len(input):
700                 self.stdin.close()
701                 write_set.remove(self.stdin)
702
703         if self.stdout in rlist:
704             data = os.read(self.stdout.fileno(), 1024)
705             if data == "":
706                 self.stdout.close()
707                 read_set.remove(self.stdout)
708             stdout.append(data)
709
710         if self.stderr in rlist:
711             data = os.read(self.stderr.fileno(), 1024)
712             if data == "":
713                 self.stderr.close()
714                 read_set.remove(self.stderr)
715             stderr.append(data)
716     
717     # All data exchanged.  Translate lists into strings.
718     if stdout is not None:
719         stdout = ''.join(stdout)
720     if stderr is not None:
721         stderr = ''.join(stderr)
722
723     # Translate newlines, if requested.  We cannot let the file
724     # object do the translation: It is based on stdio, which is
725     # impossible to combine with select (unless forcing no
726     # buffering).
727     if self.universal_newlines and hasattr(file, 'newlines'):
728         if stdout:
729             stdout = self._translate_newlines(stdout)
730         if stderr:
731             stderr = self._translate_newlines(stderr)
732
733     if killed and err_on_timeout:
734         errcode = self.poll()
735         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
736     else:
737         if killed:
738             self.poll()
739         else:
740             self.wait()
741         return (stdout, stderr)
742