f8d1cfc9d093c1852f2632c3046cd2058f1aea05
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyy_ssh_ctrl_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port):
109     connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
110     if len(connkey) > 60:
111         connkey = hashlib.sha1(connkey).hexdigest()
112     return connkey
113
114 def make_control_path(user, host, port):
115     connkey = make_connkey(user, host, port)
116     return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
117
118 def rexec(command, host, user, 
119         port = None, 
120         agent = True,
121         sudo = False,
122         stdin = None,
123         identity_file = None,
124         env = None,
125         tty = False,
126         x11 = False,
127         timeout = None,
128         retry = 0,
129         err_on_timeout = True,
130         connect_timeout = 30,
131         persistent = True):
132     """
133     Executes a remote command, returns ((stdout,stderr),process)
134     """
135     args = ['ssh', '-C',
136             # Don't bother with localhost. Makes test easier
137             '-o', 'NoHostAuthenticationForLocalhost=yes',
138             # XXX: Possible security issue
139             # Avoid interactive requests to accept new host keys
140             '-o', 'StrictHostKeyChecking=no',
141             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
142             '-o', 'ConnectionAttempts=3',
143             '-o', 'ServerAliveInterval=30',
144             '-o', 'TCPKeepAlive=yes',
145             '-l', user, host]
146
147     if persistent and openssh_has_persist():
148         control_path = make_control_path(user, host, port)
149         args.extend([
150             '-o', 'ControlMaster=auto',
151             '-o', 'ControlPath=%s' % control_path,
152             '-o', 'ControlPersist=60' ])
153     if agent:
154         args.append('-A')
155     if port:
156         args.append('-p%d' % port)
157     if identity_file:
158         args.extend(('-i', identity_file))
159     if tty:
160         args.append('-t')
161         if sudo:
162             args.append('-t')
163     if x11:
164         args.append('-X')
165
166     if env:
167         export = ''
168         for envkey, envval in env.iteritems():
169             export += '%s=%s ' % (envkey, envval)
170         command = export + command
171
172     if sudo:
173         command = "sudo " + command
174
175     args.append(command)
176
177     for x in xrange(retry or 3):
178         # connects to the remote host and starts a remote connection
179         proc = subprocess.Popen(args, 
180                 stdout = subprocess.PIPE,
181                 stdin = subprocess.PIPE, 
182                 stderr = subprocess.PIPE)
183         
184         try:
185             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
186             if proc.poll():
187                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
188                     # SSH error, can safely retry
189                     continue
190                 elif retry:
191                     # Probably timed out or plain failed but can retry
192                     continue
193             break
194         except RuntimeError,e:
195             if retry <= 0:
196                 raise
197             retry -= 1
198         
199     return ((out, err), proc)
200
201 def rcopy(source, dest,
202         port = None, 
203         agent = True, 
204         identity_file = None):
205     """
206     Copies file from/to remote sites.
207     
208     Source and destination should have the user and host encoded
209     as per scp specs.
210     
211     Source can be a list of files to copy to a single destination,
212     in which case it is advised that the destination be a folder.
213     """
214     
215     # Parse destination as <user>@<server>:<path>
216     if isinstance(dest, basestring) and ':' in dest:
217         remspec, path = dest.split(':',1)
218     elif isinstance(source, basestring) and ':' in source:
219         remspec, path = source.split(':',1)
220     else:
221         raise ValueError, "Both endpoints cannot be local"
222     user, host = remspec.rsplit('@',1)
223
224     raw_string = r'''rsync -rlpcSz --timeout=900 '''
225     raw_string += r''' -e 'ssh -o BatchMode=yes '''
226     raw_string += r''' -o NoHostAuthenticationForLocalhost=yes '''
227     raw_string += r''' -o StrictHostKeyChecking=no '''
228     raw_string += r''' -o ConnectionAttempts=3 '''
229  
230     if openssh_has_persist():
231         control_path = make_control_path(user, host, port)
232         raw_string += r''' -o ControlMaster=auto '''
233         raw_string += r''' -o ControlPath=%s ''' % control_path
234   
235     if port:
236         raw_string += r''' -p %d ''' % port
237     
238     if identity_file:
239         raw_string += r''' -i "%s" ''' % identity_file
240     
241     # closing -e 'ssh...'
242     raw_string += r''' ' '''
243
244     if isinstance(source,list):
245         source = ' '.join(source)
246     else:
247         source = '"%s"' % source
248
249     raw_string += r''' %s ''' % source
250     raw_string += r''' %s ''' % dest
251
252     # connects to the remote host and starts a remote connection
253     proc = subprocess.Popen(raw_string,
254             shell=True,
255             stdout = subprocess.PIPE,
256             stdin = subprocess.PIPE, 
257             stderr = subprocess.PIPE)
258   
259     comm = proc.communicate()
260     eintr_retry(proc.wait)()
261     return (comm, proc)
262
263 def rspawn(command, pidfile, 
264         stdout = '/dev/null', 
265         stderr = STDOUT, 
266         stdin = '/dev/null', 
267         home = None, 
268         create_home = False, 
269         host = None, 
270         port = None, 
271         user = None, 
272         agent = None, 
273         sudo = False,
274         identity_file = None, 
275         tty = False):
276     """
277     Spawn a remote command such that it will continue working asynchronously.
278     
279     Parameters:
280         command: the command to run - it should be a single line.
281         
282         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
283         
284         stdout: path of a file to redirect standard output to - must be a string.
285             Defaults to /dev/null
286         stderr: path of a file to redirect standard error to - string or the special STDOUT value
287             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
288         stdin: path of a file with input to be piped into the command's standard input
289         
290         home: path of a folder to use as working directory - should exist, unless you specify create_home
291         
292         create_home: if True, the home folder will be created first with mkdir -p
293         
294         sudo: whether the command needs to be executed as root
295         
296         host/port/user/agent/identity_file: see rexec
297     
298     Returns:
299         (stdout, stderr), process
300         
301         Of the spawning process, which only captures errors at spawning time.
302         Usually only useful for diagnostics.
303     """
304     # Start process in a "daemonized" way, using nohup and heavy
305     # stdin/out redirection to avoid connection issues
306     if stderr is STDOUT:
307         stderr = '&1'
308     else:
309         stderr = ' ' + stderr
310    
311     #XXX: ppid is always 1!!!
312     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
313         'command' : command,
314         'pidfile' : pidfile,
315         
316         'stdout' : stdout,
317         'stderr' : stderr,
318         'stdin' : stdin,
319     }
320     
321     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
322             'command' : daemon_command,
323             
324             'sudo' : 'sudo -S' if sudo else '',
325             
326             'pidfile' : pidfile,
327             'gohome' : 'cd %s ; ' % home if home else '',
328             'create' : 'mkdir -p %s ; ' % home if create_home else '',
329         }
330
331     (out,err), proc = rexec(
332         cmd,
333         host = host,
334         port = port,
335         user = user,
336         agent = agent,
337         identity_file = identity_file,
338         tty = tty
339         )
340     
341     if proc.wait():
342         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
343
344     return (out,err),proc
345
346 @eintr_retry
347 def rcheck_pid(pidfile,
348         host = None, 
349         port = None, 
350         user = None, 
351         agent = None, 
352         identity_file = None):
353     """
354     Check the pidfile of a process spawned with remote_spawn.
355     
356     Parameters:
357         pidfile: the pidfile passed to remote_span
358         
359         host/port/user/agent/identity_file: see rexec
360     
361     Returns:
362         
363         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
364         or None if the pidfile isn't valid yet (maybe the process is still starting).
365     """
366
367     (out,err),proc = rexec(
368         "cat %(pidfile)s" % {
369             'pidfile' : pidfile,
370         },
371         host = host,
372         port = port,
373         user = user,
374         agent = agent,
375         identity_file = identity_file
376         )
377         
378     if proc.wait():
379         return None
380     
381     if out:
382         try:
383             return map(int,out.strip().split(' ',1))
384         except:
385             # Ignore, many ways to fail that don't matter that much
386             return None
387
388 @eintr_retry
389 def rstatus(pid, ppid, 
390         host = None, 
391         port = None, 
392         user = None, 
393         agent = None, 
394         identity_file = None):
395     """
396     Check the status of a process spawned with remote_spawn.
397     
398     Parameters:
399         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
400         
401         host/port/user/agent/identity_file: see rexec
402     
403     Returns:
404         
405         One of NOT_STARTED, RUNNING, FINISHED
406     """
407
408     # XXX: ppid unused
409     (out,err),proc = rexec(
410         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
411             'ppid' : ppid,
412             'pid' : pid,
413         },
414         host = host,
415         port = port,
416         user = user,
417         agent = agent,
418         identity_file = identity_file
419         )
420     
421     if proc.wait():
422         return NOT_STARTED
423     
424     status = False
425     if out:
426         try:
427             status = bool(int(out.strip()))
428         except:
429             if out or err:
430                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
431             # Ignore, many ways to fail that don't matter that much
432             return NOT_STARTED
433     return RUNNING if status else FINISHED
434     
435
436 @eintr_retry
437 def rkill(pid, ppid,
438         host = None, 
439         port = None, 
440         user = None, 
441         agent = None, 
442         sudo = False,
443         identity_file = None, 
444         nowait = False):
445     """
446     Kill a process spawned with remote_spawn.
447     
448     First tries a SIGTERM, and if the process does not end in 10 seconds,
449     it sends a SIGKILL.
450     
451     Parameters:
452         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
453         
454         sudo: whether the command was run with sudo - careful killing like this.
455         
456         host/port/user/agent/identity_file: see rexec
457     
458     Returns:
459         
460         Nothing, should have killed the process
461     """
462     
463     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
464     cmd = """
465 SUBKILL="%(subkill)s" ;
466 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
467 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
468 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
469     sleep 0.2 
470     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
471         break
472     else
473         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
474         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
475     fi
476     sleep 1.8
477 done
478 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
479     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
480     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
481 fi
482 """
483     if nowait:
484         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
485
486     (out,err),proc = rexec(
487         cmd % {
488             'ppid' : ppid,
489             'pid' : pid,
490             'sudo' : 'sudo -S' if sudo else '',
491             'subkill' : subkill,
492         },
493         host = host,
494         port = port,
495         user = user,
496         agent = agent,
497         identity_file = identity_file
498         )
499     
500     # wait, don't leave zombies around
501     proc.wait()
502
503 # POSIX
504 def _communicate(self, input, timeout=None, err_on_timeout=True):
505     read_set = []
506     write_set = []
507     stdout = None # Return
508     stderr = None # Return
509     
510     killed = False
511     
512     if timeout is not None:
513         timelimit = time.time() + timeout
514         killtime = timelimit + 4
515         bailtime = timelimit + 4
516
517     if self.stdin:
518         # Flush stdio buffer.  This might block, if the user has
519         # been writing to .stdin in an uncontrolled fashion.
520         self.stdin.flush()
521         if input:
522             write_set.append(self.stdin)
523         else:
524             self.stdin.close()
525     if self.stdout:
526         read_set.append(self.stdout)
527         stdout = []
528     if self.stderr:
529         read_set.append(self.stderr)
530         stderr = []
531
532     input_offset = 0
533     while read_set or write_set:
534         if timeout is not None:
535             curtime = time.time()
536             if timeout is None or curtime > timelimit:
537                 if curtime > bailtime:
538                     break
539                 elif curtime > killtime:
540                     signum = signal.SIGKILL
541                 else:
542                     signum = signal.SIGTERM
543                 # Lets kill it
544                 os.kill(self.pid, signum)
545                 select_timeout = 0.5
546             else:
547                 select_timeout = timelimit - curtime + 0.1
548         else:
549             select_timeout = 1.0
550         
551         if select_timeout > 1.0:
552             select_timeout = 1.0
553             
554         try:
555             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
556         except select.error,e:
557             if e[0] != 4:
558                 raise
559             else:
560                 continue
561         
562         if not rlist and not wlist and not xlist and self.poll() is not None:
563             # timeout and process exited, say bye
564             break
565
566         if self.stdin in wlist:
567             # When select has indicated that the file is writable,
568             # we can write up to PIPE_BUF bytes without risk
569             # blocking.  POSIX defines PIPE_BUF >= 512
570             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
571             input_offset += bytes_written
572             if input_offset >= len(input):
573                 self.stdin.close()
574                 write_set.remove(self.stdin)
575
576         if self.stdout in rlist:
577             data = os.read(self.stdout.fileno(), 1024)
578             if data == "":
579                 self.stdout.close()
580                 read_set.remove(self.stdout)
581             stdout.append(data)
582
583         if self.stderr in rlist:
584             data = os.read(self.stderr.fileno(), 1024)
585             if data == "":
586                 self.stderr.close()
587                 read_set.remove(self.stderr)
588             stderr.append(data)
589     
590     # All data exchanged.  Translate lists into strings.
591     if stdout is not None:
592         stdout = ''.join(stdout)
593     if stderr is not None:
594         stderr = ''.join(stderr)
595
596     # Translate newlines, if requested.  We cannot let the file
597     # object do the translation: It is based on stdio, which is
598     # impossible to combine with select (unless forcing no
599     # buffering).
600     if self.universal_newlines and hasattr(file, 'newlines'):
601         if stdout:
602             stdout = self._translate_newlines(stdout)
603         if stderr:
604             stderr = self._translate_newlines(stderr)
605
606     if killed and err_on_timeout:
607         errcode = self.poll()
608         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
609     else:
610         if killed:
611             self.poll()
612         else:
613             self.wait()
614         return (stdout, stderr)
615