Adding X11 forwarding tests for resources/linux/node.py
[nepi.git] / src / neco / util / sshfuncs.py
1 import base64
2 import errno
3 import os
4 import os.path
5 import select
6 import signal
7 import socket
8 import subprocess
9 import time
10 import traceback
11 import re
12 import tempfile
13 import hashlib
14
15 OPENSSH_HAS_PERSIST = None
16 CONTROL_PATH = "yyy_ssh_ctrl_path"
17
18 if hasattr(os, "devnull"):
19     DEV_NULL = os.devnull
20 else:
21     DEV_NULL = "/dev/null"
22
23 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
24
25 hostbyname_cache = dict()
26
27 class STDOUT: 
28     """
29     Special value that when given to rspawn in stderr causes stderr to 
30     redirect to whatever stdout was redirected to.
31     """
32
33 class RUNNING:
34     """
35     Process is still running
36     """
37
38 class FINISHED:
39     """
40     Process is finished
41     """
42
43 class NOT_STARTED:
44     """
45     Process hasn't started running yet (this should be very rare)
46     """
47
48 def openssh_has_persist():
49     """ The ssh_config options ControlMaster and ControlPersist allow to
50     reuse a same network connection for multiple ssh sessions. In this 
51     way limitations on number of open ssh connections can be bypassed.
52     However, older versions of openSSH do not support this feature.
53     This function is used to determine if ssh connection persist features
54     can be used.
55     """
56     global OPENSSH_HAS_PERSIST
57     if OPENSSH_HAS_PERSIST is None:
58         proc = subprocess.Popen(["ssh","-v"],
59             stdout = subprocess.PIPE,
60             stderr = subprocess.STDOUT,
61             stdin = open("/dev/null","r") )
62         out,err = proc.communicate()
63         proc.wait()
64         
65         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
66         OPENSSH_HAS_PERSIST = bool(vre.match(out))
67     return OPENSSH_HAS_PERSIST
68
69 def shell_escape(s):
70     """ Escapes strings so that they are safe to use as command-line 
71     arguments """
72     if SHELL_SAFE.match(s):
73         # safe string - no escaping needed
74         return s
75     else:
76         # unsafe string - escape
77         def escp(c):
78             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
79                 return c
80             else:
81                 return "'$'\\x%02x''" % (ord(c),)
82         s = ''.join(map(escp,s))
83         return "'%s'" % (s,)
84
85 def eintr_retry(func):
86     """Retries a function invocation when a EINTR occurs"""
87     import functools
88     @functools.wraps(func)
89     def rv(*p, **kw):
90         retry = kw.pop("_retry", False)
91         for i in xrange(0 if retry else 4):
92             try:
93                 return func(*p, **kw)
94             except (select.error, socket.error), args:
95                 if args[0] == errno.EINTR:
96                     continue
97                 else:
98                     raise 
99             except OSError, e:
100                 if e.errno == errno.EINTR:
101                     continue
102                 else:
103                     raise
104         else:
105             return func(*p, **kw)
106     return rv
107
108 def make_connkey(user, host, port):
109     connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
110     if len(connkey) > 60:
111         connkey = hashlib.sha1(connkey).hexdigest()
112     return connkey
113
114 def make_control_path(user, host, port):
115     connkey = make_connkey(user, host, port)
116     return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
117
118 def rexec(command, host, user, 
119         port = None, 
120         agent = True,
121         sudo = False,
122         stdin = "",
123         identity_file = None,
124         env = None,
125         tty = False,
126         x11 = False,
127         timeout = None,
128         retry = 0,
129         err_on_timeout = True,
130         connect_timeout = 30,
131         persistent = True):
132     """
133     Executes a remote command, returns ((stdout,stderr),process)
134     """
135     args = ['ssh', '-C',
136             # Don't bother with localhost. Makes test easier
137             '-o', 'NoHostAuthenticationForLocalhost=yes',
138             # XXX: Possible security issue
139             # Avoid interactive requests to accept new host keys
140             '-o', 'StrictHostKeyChecking=no',
141             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
142             '-o', 'ConnectionAttempts=3',
143             '-o', 'ServerAliveInterval=30',
144             '-o', 'TCPKeepAlive=yes',
145             '-l', user, host]
146
147     if persistent and openssh_has_persist():
148         control_path = make_control_path(user, host, port)
149         args.extend([
150             '-o', 'ControlMaster=auto',
151             '-o', 'ControlPath=%s' % control_path,
152             '-o', 'ControlPersist=60' ])
153     if agent:
154         args.append('-A')
155     if port:
156         args.append('-p%d' % port)
157     if identity_file:
158         args.extend(('-i', identity_file))
159     if tty:
160         args.append('-t')
161         if sudo:
162             args.append('-t')
163     if x11:
164         args.append('-X')
165
166     if env:
167         export = ''
168         for envkey, envval in env.iteritems():
169             export += '%s=%s ' % (envkey, envval)
170         command = export + command
171
172     if sudo:
173         command = "sudo " + command
174
175     args.append(command)
176
177     for x in xrange(retry or 3):
178         # connects to the remote host and starts a remote connection
179         proc = subprocess.Popen(args, 
180                 stdout = subprocess.PIPE,
181                 stdin = subprocess.PIPE, 
182                 stderr = subprocess.PIPE)
183         
184         try:
185             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
186             if proc.poll():
187                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
188                     # SSH error, can safely retry
189                     continue
190                 elif retry:
191                     # Probably timed out or plain failed but can retry
192                     continue
193             break
194         except RuntimeError,e:
195             if retry <= 0:
196                 raise
197             retry -= 1
198         
199     return ((out, err), proc)
200
201 def rcopy(source, dest,
202         port = None, 
203         agent = True, 
204         recursive = False,
205         identity_file = None):
206     """
207     Copies file from/to remote sites.
208     
209     Source and destination should have the user and host encoded
210     as per scp specs.
211     
212     If source is a file object, a special mode will be used to
213     create the remote file with the same contents.
214     
215     If dest is a file object, the remote file (source) will be
216     read and written into dest.
217     
218     In these modes, recursive cannot be True.
219     
220     Source can be a list of files to copy to a single destination,
221     in which case it is advised that the destination be a folder.
222     """
223     
224     if isinstance(source, file) and source.tell() == 0:
225         source = source.name
226     elif hasattr(source, 'read'):
227         tmp = tempfile.NamedTemporaryFile()
228         while True:
229             buf = source.read(65536)
230             if buf:
231                 tmp.write(buf)
232             else:
233                 break
234         tmp.seek(0)
235         source = tmp.name
236     
237     if isinstance(source, file) or isinstance(dest, file) \
238             or hasattr(source, 'read')  or hasattr(dest, 'write'):
239         assert not recursive
240     
241         # Parse source/destination as <user>@<server>:<path>
242         if isinstance(dest, basestring) and ':' in dest:
243             remspec, path = dest.split(':',1)
244         elif isinstance(source, basestring) and ':' in source:
245             remspec, path = source.split(':',1)
246         else:
247             raise ValueError, "Both endpoints cannot be local"
248         user,host = remspec.rsplit('@',1)
249         tmp_known_hosts = None
250         
251         args = ['ssh', '-l', user, '-C',
252                 # Don't bother with localhost. Makes test easier
253                 '-o', 'NoHostAuthenticationForLocalhost=yes',
254                 # XXX: Possible security issue
255                 # Avoid interactive requests to accept new host keys
256                 '-o', 'StrictHostKeyChecking=no',
257                 '-o', 'ConnectTimeout=30',
258                 '-o', 'ConnectionAttempts=3',
259                 '-o', 'ServerAliveInterval=30',
260                 '-o', 'TCPKeepAlive=yes',
261                 host ]
262
263         if openssh_has_persist():
264             control_path = make_control_path(user, host, port)
265             args.extend([
266                 '-o', 'ControlMaster=auto',
267                 '-o', 'ControlPath=%s' % control_path,
268                 '-o', 'ControlPersist=60' ])
269         if port:
270             args.append('-P%d' % port)
271         if identity_file:
272             args.extend(('-i', identity_file))
273         
274         if isinstance(source, file) or hasattr(source, 'read'):
275             args.append('cat > %s' % (shell_escape(path),))
276         elif isinstance(dest, file) or hasattr(dest, 'write'):
277             args.append('cat %s' % (shell_escape(path),))
278         else:
279             raise AssertionError, "Unreachable code reached! :-Q"
280
281         # connects to the remote host and starts a remote connection
282         if isinstance(source, file):
283             proc = subprocess.Popen(args, 
284                     stdout = open('/dev/null','w'),
285                     stderr = subprocess.PIPE,
286                     stdin = source)
287             err = proc.stderr.read()
288             eintr_retry(proc.wait)()
289             return ((None,err), proc)
290         elif isinstance(dest, file):
291             proc = subprocess.Popen(args, 
292                     stdout = open('/dev/null','w'),
293                     stderr = subprocess.PIPE,
294                     stdin = source)
295             err = proc.stderr.read()
296             eintr_retry(proc.wait)()
297             return ((None,err), proc)
298         elif hasattr(source, 'read'):
299             # file-like (but not file) source
300             proc = subprocess.Popen(args, 
301                     stdout = open('/dev/null','w'),
302                     stderr = subprocess.PIPE,
303                     stdin = subprocess.PIPE)
304             
305             buf = None
306             err = []
307             while True:
308                 if not buf:
309                     buf = source.read(4096)
310                 if not buf:
311                     #EOF
312                     break
313                 
314                 rdrdy, wrdy, broken = select.select(
315                     [proc.stderr],
316                     [proc.stdin],
317                     [proc.stderr,proc.stdin])
318                 
319                 if proc.stderr in rdrdy:
320                     # use os.read for fully unbuffered behavior
321                     err.append(os.read(proc.stderr.fileno(), 4096))
322                 
323                 if proc.stdin in wrdy:
324                     proc.stdin.write(buf)
325                     buf = None
326                 
327                 if broken:
328                     break
329             proc.stdin.close()
330             err.append(proc.stderr.read())
331                 
332             eintr_retry(proc.wait)()
333             return ((None,''.join(err)), proc)
334         elif hasattr(dest, 'write'):
335             # file-like (but not file) dest
336             proc = subprocess.Popen(args, 
337                     stdout = subprocess.PIPE,
338                     stderr = subprocess.PIPE,
339                     stdin = open('/dev/null','w'))
340             
341             buf = None
342             err = []
343             while True:
344                 rdrdy, wrdy, broken = select.select(
345                     [proc.stderr, proc.stdout],
346                     [],
347                     [proc.stderr, proc.stdout])
348                 
349                 if proc.stderr in rdrdy:
350                     # use os.read for fully unbuffered behavior
351                     err.append(os.read(proc.stderr.fileno(), 4096))
352                 
353                 if proc.stdout in rdrdy:
354                     # use os.read for fully unbuffered behavior
355                     buf = os.read(proc.stdout.fileno(), 4096)
356                     dest.write(buf)
357                     
358                     if not buf:
359                         #EOF
360                         break
361                 
362                 if broken:
363                     break
364             err.append(proc.stderr.read())
365                 
366             eintr_retry(proc.wait)()
367             return ((None,''.join(err)), proc)
368         else:
369             raise AssertionError, "Unreachable code reached! :-Q"
370     else:
371         # Parse destination as <user>@<server>:<path>
372         if isinstance(dest, basestring) and ':' in dest:
373             remspec, path = dest.split(':',1)
374         elif isinstance(source, basestring) and ':' in source:
375             remspec, path = source.split(':',1)
376         else:
377             raise ValueError, "Both endpoints cannot be local"
378         user,host = remspec.rsplit('@',1)
379
380         # plain scp
381         args = ['scp', '-q', '-p', '-C',
382                 # Don't bother with localhost. Makes test easier
383                 '-o', 'NoHostAuthenticationForLocalhost=yes',
384                 # XXX: Possible security issue
385                 # Avoid interactive requests to accept new host keys
386                 '-o', 'StrictHostKeyChecking=no',
387                 '-o', 'ConnectTimeout=30',
388                 '-o', 'ConnectionAttempts=3',
389                 '-o', 'ServerAliveInterval=30',
390                 '-o', 'TCPKeepAlive=yes' ]
391                 
392         if port:
393             args.append('-P%d' % port)
394         if recursive:
395             args.append('-r')
396         if identity_file:
397             args.extend(('-i', identity_file))
398
399         if isinstance(source,list):
400             args.extend(source)
401         else:
402             if openssh_has_persist():
403                 control_path = make_control_path(user, host, port)
404                 args.extend([
405                     '-o', 'ControlMaster=no',
406                     '-o', 'ControlPath=%s' % control_path ])
407             args.append(source)
408
409         args.append(dest)
410
411         # connects to the remote host and starts a remote connection
412         proc = subprocess.Popen(args, 
413                 stdout = subprocess.PIPE,
414                 stdin = subprocess.PIPE, 
415                 stderr = subprocess.PIPE)
416         
417         comm = proc.communicate()
418         eintr_retry(proc.wait)()
419         return (comm, proc)
420
421 def rspawn(command, pidfile, 
422         stdout = '/dev/null', 
423         stderr = STDOUT, 
424         stdin = '/dev/null', 
425         home = None, 
426         create_home = False, 
427         host = None, 
428         port = None, 
429         user = None, 
430         agent = None, 
431         sudo = False,
432         identity_file = None, 
433         tty = False):
434     """
435     Spawn a remote command such that it will continue working asynchronously.
436     
437     Parameters:
438         command: the command to run - it should be a single line.
439         
440         pidfile: path of a (ideally unique to this task) pidfile for tracking the process.
441         
442         stdout: path of a file to redirect standard output to - must be a string.
443             Defaults to /dev/null
444         stderr: path of a file to redirect standard error to - string or the special STDOUT value
445             to redirect to the same file stdout was redirected to. Defaults to STDOUT.
446         stdin: path of a file with input to be piped into the command's standard input
447         
448         home: path of a folder to use as working directory - should exist, unless you specify create_home
449         
450         create_home: if True, the home folder will be created first with mkdir -p
451         
452         sudo: whether the command needs to be executed as root
453         
454         host/port/user/agent/identity_file: see rexec
455     
456     Returns:
457         (stdout, stderr), process
458         
459         Of the spawning process, which only captures errors at spawning time.
460         Usually only useful for diagnostics.
461     """
462     # Start process in a "daemonized" way, using nohup and heavy
463     # stdin/out redirection to avoid connection issues
464     if stderr is STDOUT:
465         stderr = '&1'
466     else:
467         stderr = ' ' + stderr
468     
469     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
470         'command' : command,
471         'pidfile' : pidfile,
472         
473         'stdout' : stdout,
474         'stderr' : stderr,
475         'stdin' : stdin,
476     }
477     
478     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
479             'command' : daemon_command,
480             
481             'sudo' : 'sudo -S' if sudo else '',
482             
483             'pidfile' : pidfile,
484             'gohome' : 'cd %s ; ' % home if home else '',
485             'create' : 'mkdir -p %s ; ' % home if create_home else '',
486         }
487
488     (out,err), proc = rexec(
489         cmd,
490         host = host,
491         port = port,
492         user = user,
493         agent = agent,
494         identity_file = identity_file,
495         tty = tty
496         )
497     
498     if proc.wait():
499         raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
500
501     return (out,err),proc
502
503 @eintr_retry
504 def rcheck_pid(pidfile,
505         host = None, 
506         port = None, 
507         user = None, 
508         agent = None, 
509         identity_file = None):
510     """
511     Check the pidfile of a process spawned with remote_spawn.
512     
513     Parameters:
514         pidfile: the pidfile passed to remote_span
515         
516         host/port/user/agent/identity_file: see rexec
517     
518     Returns:
519         
520         A (pid, ppid) tuple useful for calling remote_status and remote_kill,
521         or None if the pidfile isn't valid yet (maybe the process is still starting).
522     """
523
524     (out,err),proc = rexec(
525         "cat %(pidfile)s" % {
526             'pidfile' : pidfile,
527         },
528         host = host,
529         port = port,
530         user = user,
531         agent = agent,
532         identity_file = identity_file
533         )
534         
535     if proc.wait():
536         return None
537     
538     if out:
539         try:
540             return map(int,out.strip().split(' ',1))
541         except:
542             # Ignore, many ways to fail that don't matter that much
543             return None
544
545 @eintr_retry
546 def rstatus(pid, ppid, 
547         host = None, 
548         port = None, 
549         user = None, 
550         agent = None, 
551         identity_file = None):
552     """
553     Check the status of a process spawned with remote_spawn.
554     
555     Parameters:
556         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
557         
558         host/port/user/agent/identity_file: see rexec
559     
560     Returns:
561         
562         One of NOT_STARTED, RUNNING, FINISHED
563     """
564
565     (out,err),proc = rexec(
566         "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % {
567             'ppid' : ppid,
568             'pid' : pid,
569         },
570         host = host,
571         port = port,
572         user = user,
573         agent = agent,
574         identity_file = identity_file
575         )
576     
577     if proc.wait():
578         return NOT_STARTED
579     
580     status = False
581     if out:
582         try:
583             status = bool(int(out.strip()))
584         except:
585             if out or err:
586                 logging.warn("Error checking remote status:\n%s%s\n", out, err)
587             # Ignore, many ways to fail that don't matter that much
588             return NOT_STARTED
589     return RUNNING if status else FINISHED
590     
591
592 @eintr_retry
593 def rkill(pid, ppid, 
594         host = None, 
595         port = None, 
596         user = None, 
597         agent = None, 
598         sudo = False,
599         identity_file = None, 
600         nowait = False):
601     """
602     Kill a process spawned with remote_spawn.
603     
604     First tries a SIGTERM, and if the process does not end in 10 seconds,
605     it sends a SIGKILL.
606     
607     Parameters:
608         pid/ppid: pid and parent-pid of the spawned process. See remote_check_pid
609         
610         sudo: whether the command was run with sudo - careful killing like this.
611         
612         host/port/user/agent/identity_file: see rexec
613     
614     Returns:
615         
616         Nothing, should have killed the process
617     """
618     
619     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
620     cmd = """
621 SUBKILL="%(subkill)s" ;
622 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
623 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
624 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
625     sleep 0.2 
626     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
627         break
628     else
629         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
630         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
631     fi
632     sleep 1.8
633 done
634 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
635     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
636     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
637 fi
638 """
639     if nowait:
640         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
641
642     (out,err),proc = rexec(
643         cmd % {
644             'ppid' : ppid,
645             'pid' : pid,
646             'sudo' : 'sudo -S' if sudo else '',
647             'subkill' : subkill,
648         },
649         host = host,
650         port = port,
651         user = user,
652         agent = agent,
653         identity_file = identity_file
654         )
655     
656     # wait, don't leave zombies around
657     proc.wait()
658
659 # POSIX
660 def _communicate(self, input, timeout=None, err_on_timeout=True):
661     read_set = []
662     write_set = []
663     stdout = None # Return
664     stderr = None # Return
665     
666     killed = False
667     
668     if timeout is not None:
669         timelimit = time.time() + timeout
670         killtime = timelimit + 4
671         bailtime = timelimit + 4
672
673     if self.stdin:
674         # Flush stdio buffer.  This might block, if the user has
675         # been writing to .stdin in an uncontrolled fashion.
676         self.stdin.flush()
677         if input:
678             write_set.append(self.stdin)
679         else:
680             self.stdin.close()
681     if self.stdout:
682         read_set.append(self.stdout)
683         stdout = []
684     if self.stderr:
685         read_set.append(self.stderr)
686         stderr = []
687
688     input_offset = 0
689     while read_set or write_set:
690         if timeout is not None:
691             curtime = time.time()
692             if timeout is None or curtime > timelimit:
693                 if curtime > bailtime:
694                     break
695                 elif curtime > killtime:
696                     signum = signal.SIGKILL
697                 else:
698                     signum = signal.SIGTERM
699                 # Lets kill it
700                 os.kill(self.pid, signum)
701                 select_timeout = 0.5
702             else:
703                 select_timeout = timelimit - curtime + 0.1
704         else:
705             select_timeout = 1.0
706         
707         if select_timeout > 1.0:
708             select_timeout = 1.0
709             
710         try:
711             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
712         except select.error,e:
713             if e[0] != 4:
714                 raise
715             else:
716                 continue
717         
718         if not rlist and not wlist and not xlist and self.poll() is not None:
719             # timeout and process exited, say bye
720             break
721
722         if self.stdin in wlist:
723             # When select has indicated that the file is writable,
724             # we can write up to PIPE_BUF bytes without risk
725             # blocking.  POSIX defines PIPE_BUF >= 512
726             bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512))
727             input_offset += bytes_written
728             if input_offset >= len(input):
729                 self.stdin.close()
730                 write_set.remove(self.stdin)
731
732         if self.stdout in rlist:
733             data = os.read(self.stdout.fileno(), 1024)
734             if data == "":
735                 self.stdout.close()
736                 read_set.remove(self.stdout)
737             stdout.append(data)
738
739         if self.stderr in rlist:
740             data = os.read(self.stderr.fileno(), 1024)
741             if data == "":
742                 self.stderr.close()
743                 read_set.remove(self.stderr)
744             stderr.append(data)
745     
746     # All data exchanged.  Translate lists into strings.
747     if stdout is not None:
748         stdout = ''.join(stdout)
749     if stderr is not None:
750         stderr = ''.join(stderr)
751
752     # Translate newlines, if requested.  We cannot let the file
753     # object do the translation: It is based on stdio, which is
754     # impossible to combine with select (unless forcing no
755     # buffering).
756     if self.universal_newlines and hasattr(file, 'newlines'):
757         if stdout:
758             stdout = self._translate_newlines(stdout)
759         if stderr:
760             stderr = self._translate_newlines(stderr)
761
762     if killed and err_on_timeout:
763         errcode = self.poll()
764         raise RuntimeError, ("Operation timed out", errcode, stdout, stderr)
765     else:
766         if killed:
767             self.poll()
768         else:
769             self.wait()
770         return (stdout, stderr)
771