dd04ba6229acb5e9bcb0a2273920c2de6c991873
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyy_ssh_ctrl_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port, x11, agent):
109     # It is important to consider the x11 and agent forwarding
110     # parameters when creating the connection key since the parameters
111     # used for the first ssh connection will determine the
112     # parameters of all subsequent connections using the same key
113     x11 = 1 if x11 else 0
114     agent = 1 if agent else 0
115
116     connkey = repr((user, host, port, x11, agent)
117             ).encode("base64").strip().replace('/','.')
118
119     if len(connkey) > 60:
120         connkey = hashlib.sha1(connkey).hexdigest()
121     return connkey
122
123 def make_control_path(user, host, port, x11, agent):
124     connkey = make_connkey(user, host, port, x11, agent)
125     return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
126
127 def rexec(command, host, user, 
128         port = None, 
129         agent = True,
130         sudo = False,
131         stdin = None,
132         identity = None,
133         env = None,
134         tty = False,
135         x11 = False,
136         timeout = None,
137         retry = 0,
138         err_on_timeout = True,
139         connect_timeout = 30,
140         persistent = True):
141     """
142     Executes a remote command, returns ((stdout,stderr),process)
143     """
144     args = ['ssh', '-C',
145             # Don't bother with localhost. Makes test easier
146             '-o', 'NoHostAuthenticationForLocalhost=yes',
147             # XXX: Possible security issue
148             # Avoid interactive requests to accept new host keys
149             '-o', 'StrictHostKeyChecking=no',
150             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
151             '-o', 'ConnectionAttempts=3',
152             '-o', 'ServerAliveInterval=30',
153             '-o', 'TCPKeepAlive=yes',
154             '-l', user, host]
155
156     if persistent and openssh_has_persist():
157         control_path = make_control_path(user, host, port, x11, agent)
158         args.extend([
159             '-o', 'ControlMaster=auto',
160             '-o', 'ControlPath=%s' % control_path,
161             '-o', 'ControlPersist=60' ])
162     if agent:
163         args.append('-A')
164     if port:
165         args.append('-p%d' % port)
166     if identity:
167         args.extend(('-i', identity))
168     if tty:
169         args.append('-t')
170         if sudo:
171             args.append('-t')
172     if x11:
173         args.append('-X')
174
175     if env:
176         export = ''
177         for envkey, envval in env.iteritems():
178             export += '%s=%s ' % (envkey, envval)
179         command = export + command
180
181     if sudo:
182         command = "sudo " + command
183
184     args.append(command)
185
186     for x in xrange(retry or 3):
187         # connects to the remote host and starts a remote connection
188         proc = subprocess.Popen(args, 
189                 stdout = subprocess.PIPE,
190                 stdin = subprocess.PIPE, 
191                 stderr = subprocess.PIPE)
192         
193         try:
194             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
195             if proc.poll():
196                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
197                     # SSH error, can safely retry
198                     continue
199                 elif retry:
200                     # Probably timed out or plain failed but can retry
201                     continue
202             break
203         except RuntimeError,e:
204             if retry <= 0:
205                 raise
206             retry -= 1
207         
208     return ((out, err), proc)
209
210 def rcopy(source, dest,
211         port = None, 
212         agent = True, 
213         identity = None):
214     """
215     Copies file from/to remote sites.
216     
217     Source and destination should have the user and host encoded
218     as per scp specs.
219     
220     Source can be a list of files to copy to a single destination,
221     in which case it is advised that the destination be a folder.
222     """
223     
224     # Parse destination as <user>@<server>:<path>
225     if isinstance(dest, basestring) and ':' in dest:
226         remspec, path = dest.split(':',1)
227     elif isinstance(source, basestring) and ':' in source:
228         remspec, path = source.split(':',1)
229     else:
230         raise ValueError, "Both endpoints cannot be local"
231     user, host = remspec.rsplit('@',1)
232
233     raw_string = r'''rsync -rlpcSz --timeout=900 '''
234     raw_string += r''' -e 'ssh -o BatchMode=yes '''
235     raw_string += r''' -o NoHostAuthenticationForLocalhost=yes '''
236     raw_string += r''' -o StrictHostKeyChecking=no '''
237     raw_string += r''' -o ConnectionAttempts=3 '''
238  
239     if openssh_has_persist():
240         control_path = make_control_path(user, host, port, False, agent)
241         raw_string += r''' -o ControlMaster=auto '''
242         raw_string += r''' -o ControlPath=%s ''' % control_path
243  
244     if agent:
245         raw_string += r''' -A '''
246
247     if port:
248         raw_string += r''' -p %d ''' % port
249     
250     if identity:
251         raw_string += r''' -i "%s" ''' % identity
252     
253     # closing -e 'ssh...'
254     raw_string += r''' ' '''
255
256     if isinstance(source,list):
257         source = ' '.join(source)
258     else:
259         source = '"%s"' % source
260
261     raw_string += r''' %s ''' % source
262     raw_string += r''' %s ''' % dest
263
264     # connects to the remote host and starts a remote connection
265     proc = subprocess.Popen(raw_string,
266             shell=True,
267             stdout = subprocess.PIPE,
268             stdin = subprocess.PIPE, 
269             stderr = subprocess.PIPE)
270   
271     comm = proc.communicate()
272     eintr_retry(proc.wait)()
273     return (comm, proc)
274
275 def rspawn(command, pidfile, 
276         stdout = '/dev/null', 
277         stderr = STDOUT, 
278         stdin = '/dev/null', 
279         home = None, 
280         create_home = False, 
281         host = None, 
282         port = None, 
283         user = None, 
284         agent = None, 
285         sudo = False,
286         identity = None, 
287         tty = False):
288     """
289     Spawn a remote command such that it will continue working asynchronously.
290     
291     Parameters:
292         command: the command to run - it should be a single line.
293         
294         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
295         
296         stdout: path of a file to redirect standard output to - must be a string.
297             Defaults to /dev/null
298         stderr: path of a file to redirect standard error to - string or the special STDOUT value
299             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
300         stdin: path of a file with input to be piped into the command's standard input
301         
302         home: path of a folder to use as working directory - should exist, unless you specify create_home
303         
304         create_home: if True, the home folder will be created first with mkdir -p
305         
306         sudo: whether the command needs to be executed as root
307         
308         host/port/user/agent/identity: see rexec
309     
310     Returns:
311         (stdout, stderr), process
312         
313         Of the spawning process, which only captures errors at spawning time.
314         Usually only useful for diagnostics.
315     """
316     # Start process in a "daemonized" way, using nohup and heavy
317     # stdin/out redirection to avoid connection issues
318     if stderr is STDOUT:
319         stderr = '&1'
320     else:
321         stderr = ' ' + stderr
322    
323     #XXX: ppid is always 1!!!
324     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
325         'command' : command,
326         'pidfile' : pidfile,
327         
328         'stdout' : stdout,
329         'stderr' : stderr,
330         'stdin' : stdin,
331     }
332     
333     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
334             'command' : daemon_command,
335             
336             'sudo' : 'sudo -S' if sudo else '',
337             
338             'pidfile' : pidfile,
339             'gohome' : 'cd %s ; ' % home if home else '',
340             'create' : 'mkdir -p %s ; ' % home if create_home else '',
341         }
342
343     (out,err), proc = rexec(
344         cmd,
345         host = host,
346         port = port,
347         user = user,
348         agent = agent,
349         identity = identity,
350         tty = tty
351         )
352     
353     if proc.wait():
354         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
355
356     return (out,err),proc
357
358 @eintr_retry
359 def rcheckpid(pidfile,
360         host = None, 
361         port = None, 
362         user = None, 
363         agent = None, 
364         identity = None):
365     """
366     Check the pidfile of a process spawned with remote_spawn.
367     
368     Parameters:
369         pidfile: the pidfile passed to remote_span
370         
371         host/port/user/agent/identity: see rexec
372     
373     Returns:
374         
375         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
376         or None if the pidfile isn't valid yet (maybe the process is still starting).
377     """
378
379     (out,err),proc = rexec(
380         "cat %(pidfile)s" % {
381             'pidfile' : pidfile,
382         },
383         host = host,
384         port = port,
385         user = user,
386         agent = agent,
387         identity = identity
388         )
389         
390     if proc.wait():
391         return None
392     
393     if out:
394         try:
395             return map(int,out.strip().split(' ',1))
396         except:
397             # Ignore, many ways to fail that don't matter that much
398             return None
399
400 @eintr_retry
401 def rstatus(pid, ppid, 
402         host = None, 
403         port = None, 
404         user = None, 
405         agent = None, 
406         identity = None):
407     """
408     Check the status of a process spawned with remote_spawn.
409     
410     Parameters:
411         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
412         
413         host/port/user/agent/identity: see rexec
414     
415     Returns:
416         
417         One of NOT_STARTED, RUNNING, FINISHED
418     """
419
420     # XXX: ppid unused
421     (out,err),proc = rexec(
422         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
423             'ppid' : ppid,
424             'pid' : pid,
425         },
426         host = host,
427         port = port,
428         user = user,
429         agent = agent,
430         identity = identity
431         )
432     
433     if proc.wait():
434         return NOT_STARTED
435     
436     status = False
437     if out:
438         try:
439             status = bool(int(out.strip()))
440         except:
441             if out or err:
442                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
443             # Ignore, many ways to fail that don't matter that much
444             return NOT_STARTED
445     return RUNNING if status else FINISHED
446     
447
448 @eintr_retry
449 def rkill(pid, ppid,
450         host = None, 
451         port = None, 
452         user = None, 
453         agent = None, 
454         sudo = False,
455         identity = None, 
456         nowait = False):
457     """
458     Kill a process spawned with remote_spawn.
459     
460     First tries a SIGTERM, and if the process does not end in 10 seconds,
461     it sends a SIGKILL.
462     
463     Parameters:
464         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
465         
466         sudo: whether the command was run with sudo - careful killing like this.
467         
468         host/port/user/agent/identity: see rexec
469     
470     Returns:
471         
472         Nothing, should have killed the process
473     """
474     
475     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
476     cmd = """
477 SUBKILL="%(subkill)s" ;
478 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
479 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
480 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
481     sleep 0.2 
482     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
483         break
484     else
485         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
486         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
487     fi
488     sleep 1.8
489 done
490 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
491     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
492     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
493 fi
494 """
495     if nowait:
496         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
497
498     (out,err),proc = rexec(
499         cmd % {
500             'ppid' : ppid,
501             'pid' : pid,
502             'sudo' : 'sudo -S' if sudo else '',
503             'subkill' : subkill,
504         },
505         host = host,
506         port = port,
507         user = user,
508         agent = agent,
509         identity = identity
510         )
511     
512     # wait, don't leave zombies around
513     proc.wait()
514
515 # POSIX
516 def _communicate(self, input, timeout=None, err_on_timeout=True):
517     read_set = []
518     write_set = []
519     stdout = None # Return
520     stderr = None # Return
521     
522     killed = False
523     
524     if timeout is not None:
525         timelimit = time.time() + timeout
526         killtime = timelimit + 4
527         bailtime = timelimit + 4
528
529     if self.stdin:
530         # Flush stdio buffer.  This might block, if the user has
531         # been writing to .stdin in an uncontrolled fashion.
532         self.stdin.flush()
533         if input:
534             write_set.append(self.stdin)
535         else:
536             self.stdin.close()
537     if self.stdout:
538         read_set.append(self.stdout)
539         stdout = []
540     if self.stderr:
541         read_set.append(self.stderr)
542         stderr = []
543
544     input_offset = 0
545     while read_set or write_set:
546         if timeout is not None:
547             curtime = time.time()
548             if timeout is None or curtime > timelimit:
549                 if curtime > bailtime:
550                     break
551                 elif curtime > killtime:
552                     signum = signal.SIGKILL
553                 else:
554                     signum = signal.SIGTERM
555                 # Lets kill it
556                 os.kill(self.pid, signum)
557                 select_timeout = 0.5
558             else:
559                 select_timeout = timelimit - curtime + 0.1
560         else:
561             select_timeout = 1.0
562         
563         if select_timeout > 1.0:
564             select_timeout = 1.0
565             
566         try:
567             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
568         except select.error,e:
569             if e[0] != 4:
570                 raise
571             else:
572                 continue
573         
574         if not rlist and not wlist and not xlist and self.poll() is not None:
575             # timeout and process exited, say bye
576             break
577
578         if self.stdin in wlist:
579             # When select has indicated that the file is writable,
580             # we can write up to PIPE_BUF bytes without risk
581             # blocking.  POSIX defines PIPE_BUF >= 512
582             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
583             input_offset += bytes_written
584             if input_offset >= len(input):
585                 self.stdin.close()
586                 write_set.remove(self.stdin)
587
588         if self.stdout in rlist:
589             data = os.read(self.stdout.fileno(), 1024)
590             if data == "":
591                 self.stdout.close()
592                 read_set.remove(self.stdout)
593             stdout.append(data)
594
595         if self.stderr in rlist:
596             data = os.read(self.stderr.fileno(), 1024)
597             if data == "":
598                 self.stderr.close()
599                 read_set.remove(self.stderr)
600             stderr.append(data)
601     
602     # All data exchanged.  Translate lists into strings.
603     if stdout is not None:
604         stdout = ''.join(stdout)
605     if stderr is not None:
606         stderr = ''.join(stderr)
607
608     # Translate newlines, if requested.  We cannot let the file
609     # object do the translation: It is based on stdio, which is
610     # impossible to combine with select (unless forcing no
611     # buffering).
612     if self.universal_newlines and hasattr(file, 'newlines'):
613         if stdout:
614             stdout = self._translate_newlines(stdout)
615         if stderr:
616             stderr = self._translate_newlines(stderr)
617
618     if killed and err_on_timeout:
619         errcode = self.poll()
620         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
621     else:
622         if killed:
623             self.poll()
624         else:
625             self.wait()
626         return (stdout, stderr)
627