Simplifying code of sshfuncs.py. Removing not used functionality
[nepi.git] / src / nepi / util / sshfuncs.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19 #         Claudio Freire <claudio-daniel.freire@inria.fr>
20
21 ## TODO: This code needs reviewing !!!
22
23 import base64
24 import errno
25 import hashlib
26 import logging
27 import os
28 import os.path
29 import re
30 import select
31 import signal
32 import socket
33 import subprocess
34 import threading
35 import time
36 import tempfile
37
38 logger = logging.getLogger("sshfuncs")
39
40 def log(msg, level, out = None, err = None):
41     if out:
42         msg += " - OUT: %s " % out
43
44     if err:
45         msg += " - ERROR: %s " % err
46
47     logger.log(level, msg)
48
49
50 if hasattr(os, "devnull"):
51     DEV_NULL = os.devnull
52 else:
53     DEV_NULL = "/dev/null"
54
55 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
56
57 class STDOUT: 
58     """
59     Special value that when given to rspawn in stderr causes stderr to 
60     redirect to whatever stdout was redirected to.
61     """
62
63 class ProcStatus:
64     """
65     Codes for status of remote spawned process
66     """
67     # Process is still running
68     RUNNING = 1
69
70     # Process is finished
71     FINISHED = 2
72     
73     # Process hasn't started running yet (this should be very rare)
74     NOT_STARTED = 3
75
76 hostbyname_cache = dict()
77 hostbyname_cache_lock = threading.Lock()
78
79 def gethostbyname(host):
80     global hostbyname_cache
81     global hostbyname_cache_lock
82     
83     hostbyname = hostbyname_cache.get(host)
84     if not hostbyname:
85         with hostbyname_cache_lock:
86             hostbyname = socket.gethostbyname(host)
87             hostbyname_cache[host] = hostbyname
88
89             msg = " Added hostbyname %s - %s " % (host, hostbyname)
90             log(msg, logging.DEBUG)
91
92     return hostbyname
93
94 OPENSSH_HAS_PERSIST = None
95
96 def openssh_has_persist():
97     """ The ssh_config options ControlMaster and ControlPersist allow to
98     reuse a same network connection for multiple ssh sessions. In this 
99     way limitations on number of open ssh connections can be bypassed.
100     However, older versions of openSSH do not support this feature.
101     This function is used to determine if ssh connection persist features
102     can be used.
103     """
104     global OPENSSH_HAS_PERSIST
105     if OPENSSH_HAS_PERSIST is None:
106         proc = subprocess.Popen(["ssh","-v"],
107             stdout = subprocess.PIPE,
108             stderr = subprocess.STDOUT,
109             stdin = open("/dev/null","r") )
110         out,err = proc.communicate()
111         proc.wait()
112         
113         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
114         OPENSSH_HAS_PERSIST = bool(vre.match(out))
115     return OPENSSH_HAS_PERSIST
116
117 def make_server_key_args(server_key, host, port):
118     """ Returns a reference to a temporary known_hosts file, to which 
119     the server key has been added. 
120     
121     Make sure to hold onto the temp file reference until the process is 
122     done with it
123
124     :param server_key: the server public key
125     :type server_key: str
126
127     :param host: the hostname
128     :type host: str
129
130     :param port: the ssh port
131     :type port: str
132
133     """
134     if port is not None:
135         host = '%s:%s' % (host, str(port))
136
137     # Create a temporary server key file
138     tmp_known_hosts = tempfile.NamedTemporaryFile()
139    
140     hostbyname = gethostbyname(host) 
141
142     # Add the intended host key
143     tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
144     
145     # If we're not in strict mode, add user-configured keys
146     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
147         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
148         if os.access(user_hosts_path, os.R_OK):
149             f = open(user_hosts_path, "r")
150             tmp_known_hosts.write(f.read())
151             f.close()
152         
153     tmp_known_hosts.flush()
154     
155     return tmp_known_hosts
156
157 def make_control_path(agent, forward_x11):
158     ctrl_path = "/tmp/nepi_ssh"
159
160     if agent:
161         ctrl_path +="_a"
162
163     if forward_x11:
164         ctrl_path +="_x"
165
166     ctrl_path += "-%r@%h:%p"
167
168     return ctrl_path
169
170 def shell_escape(s):
171     """ Escapes strings so that they are safe to use as command-line 
172     arguments """
173     if SHELL_SAFE.match(s):
174         # safe string - no escaping needed
175         return s
176     else:
177         # unsafe string - escape
178         def escp(c):
179             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
180                 return c
181             else:
182                 return "'$'\\x%02x''" % (ord(c),)
183         s = ''.join(map(escp,s))
184         return "'%s'" % (s,)
185
186 def eintr_retry(func):
187     """Retries a function invocation when a EINTR occurs"""
188     import functools
189     @functools.wraps(func)
190     def rv(*p, **kw):
191         retry = kw.pop("_retry", False)
192         for i in xrange(0 if retry else 4):
193             try:
194                 return func(*p, **kw)
195             except (select.error, socket.error), args:
196                 if args[0] == errno.EINTR:
197                     continue
198                 else:
199                     raise 
200             except OSError, e:
201                 if e.errno == errno.EINTR:
202                     continue
203                 else:
204                     raise
205         else:
206             return func(*p, **kw)
207     return rv
208
209 def rexec(command, host, user, 
210         port = None, 
211         agent = True,
212         sudo = False,
213         stdin = None,
214         identity = None,
215         server_key = None,
216         env = None,
217         tty = False,
218         connect_timeout = 30,
219         retry = 3,
220         persistent = True,
221         forward_x11 = False,
222         blocking = True,
223         strict_host_checking = True):
224     """
225     Executes a remote command, returns ((stdout,stderr),process)
226     """
227     
228     tmp_known_hosts = None
229     hostip = gethostbyname(host)
230
231     args = ['ssh', '-C',
232             # Don't bother with localhost. Makes test easier
233             '-o', 'NoHostAuthenticationForLocalhost=yes',
234             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
235             '-o', 'ConnectionAttempts=3',
236             '-o', 'ServerAliveInterval=30',
237             '-o', 'TCPKeepAlive=yes',
238             '-l', user, hostip or host]
239
240     if persistent and openssh_has_persist():
241         args.extend([
242             '-o', 'ControlMaster=auto',
243             '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
244             '-o', 'ControlPersist=60' ])
245
246     if not strict_host_checking:
247         # Do not check for Host key. Unsafe.
248         args.extend(['-o', 'StrictHostKeyChecking=no'])
249
250     if agent:
251         args.append('-A')
252
253     if port:
254         args.append('-p%d' % port)
255
256     if identity:
257         args.extend(('-i', identity))
258
259     if tty:
260         args.append('-t')
261         args.append('-t')
262
263     if forward_x11:
264         args.append('-X')
265
266     if server_key:
267         # Create a temporary server key file
268         tmp_known_hosts = make_server_key_args(server_key, host, port)
269         args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
270
271     if sudo:
272         command = "sudo " + command
273
274     args.append(command)
275     
276     log_msg = " rexec - host %s - command %s " % (host, " ".join(args))
277
278     return _retry_rexec(args, log_msg, env = env, retry = retry, 
279             tmp_known_hosts = tmp_known_hosts,
280             blocking = blocking)
281
282 def rcopy(source, dest,
283         port = None, 
284         agent = True, 
285         recursive = False,
286         identity = None,
287         server_key = None,
288         retry = 3,
289         strict_host_checking = True):
290     """
291     Copies from/to remote sites.
292     
293     Source and destination should have the user and host encoded
294     as per scp specs.
295     
296     Source can be a list of files to copy to a single destination,
297     in which case it is advised that the destination be a folder.
298     """
299     
300     # Parse destination as <user>@<server>:<path>
301     if isinstance(dest, basestring) and ':' in dest:
302         remspec, path = dest.split(':',1)
303     elif isinstance(source, basestring) and ':' in source:
304         remspec, path = source.split(':',1)
305     else:
306         raise ValueError, "Both endpoints cannot be local"
307     user,host = remspec.rsplit('@',1)
308     
309     # plain scp
310     tmp_known_hosts = None
311
312     args = ['scp', '-q', '-p', '-C',
313             # Speed up transfer using blowfish cypher specification which is 
314             # faster than the default one (3des)
315             '-c', 'blowfish',
316             # Don't bother with localhost. Makes test easier
317             '-o', 'NoHostAuthenticationForLocalhost=yes',
318             '-o', 'ConnectTimeout=60',
319             '-o', 'ConnectionAttempts=3',
320             '-o', 'ServerAliveInterval=30',
321             '-o', 'TCPKeepAlive=yes' ]
322             
323     if port:
324         args.append('-P%d' % port)
325
326     if recursive:
327         args.append('-r')
328
329     if identity:
330         args.extend(('-i', identity))
331
332     if server_key:
333         # Create a temporary server key file
334         tmp_known_hosts = make_server_key_args(server_key, host, port)
335         args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
336
337     if not strict_host_checking:
338         # Do not check for Host key. Unsafe.
339         args.extend(['-o', 'StrictHostKeyChecking=no'])
340
341     if isinstance(source,list):
342         args.extend(source)
343     else:
344         if openssh_has_persist():
345             args.extend([
346                 '-o', 'ControlMaster=auto',
347                 '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
348                 ])
349         args.append(source)
350
351     args.append(dest)
352
353     log_msg = " rcopy - host %s - command %s " % (host, " ".join(args))
354     
355     return _retry_rexec(args, log_msg, env = None, retry = retry, 
356             tmp_known_hosts = tmp_known_hosts,
357             blocking = True)
358
359 def rspawn(command, pidfile, 
360         stdout = '/dev/null', 
361         stderr = STDOUT, 
362         stdin = '/dev/null',
363         home = None, 
364         create_home = False, 
365         sudo = False,
366         host = None, 
367         port = None, 
368         user = None, 
369         agent = None, 
370         identity = None, 
371         server_key = None,
372         tty = False):
373     """
374     Spawn a remote command such that it will continue working asynchronously in 
375     background. 
376
377         :param command: The command to run, it should be a single line.
378         :type command: str
379
380         :param pidfile: Path to a file where to store the pid and ppid of the 
381                         spawned process
382         :type pidfile: str
383
384         :param stdout: Path to file to redirect standard output. 
385                        The default value is /dev/null
386         :type stdout: str
387
388         :param stderr: Path to file to redirect standard error.
389                        If the special STDOUT value is used, stderr will 
390                        be redirected to the same file as stdout
391         :type stderr: str
392
393         :param stdin: Path to a file with input to be piped into the command's standard input
394         :type stdin: str
395
396         :param home: Path to working directory folder. 
397                     It is assumed to exist unless the create_home flag is set.
398         :type home: str
399
400         :param create_home: Flag to force creation of the home folder before 
401                             running the command
402         :type create_home: bool
403  
404         :param sudo: Flag forcing execution with sudo user
405         :type sudo: bool
406         
407         :rtype: touple
408
409         (stdout, stderr), process
410         
411         Of the spawning process, which only captures errors at spawning time.
412         Usually only useful for diagnostics.
413     """
414     # Start process in a "daemonized" way, using nohup and heavy
415     # stdin/out redirection to avoid connection issues
416     if stderr is STDOUT:
417         stderr = '&1'
418     else:
419         stderr = ' ' + stderr
420     
421     daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
422         'command' : command,
423         'pidfile' : shell_escape(pidfile),
424         'stdout' : stdout,
425         'stderr' : stderr,
426         'stdin' : stdin,
427     }
428     
429     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
430             'command' : shell_escape(daemon_command),
431             'sudo' : 'sudo -S' if sudo else '',
432             'pidfile' : shell_escape(pidfile),
433             'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
434             'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
435         }
436
437     (out,err),proc = rexec(
438         cmd,
439         host = host,
440         port = port,
441         user = user,
442         agent = agent,
443         identity = identity,
444         server_key = server_key,
445         tty = tty ,
446         )
447     
448     if proc.wait():
449         raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
450
451     return ((out, err), proc)
452
453 @eintr_retry
454 def rgetpid(pidfile,
455         host = None, 
456         port = None, 
457         user = None, 
458         agent = None, 
459         identity = None,
460         server_key = None):
461     """
462     Returns the pid and ppid of a process from a remote file where the 
463     information was stored.
464
465         :param home: Path to directory where the pidfile is located
466         :type home: str
467
468         :param pidfile: Name of file containing the pid information
469         :type pidfile: str
470         
471         :rtype: int
472         
473         A (pid, ppid) tuple useful for calling rstatus and rkill,
474         or None if the pidfile isn't valid yet (can happen when process is staring up)
475
476     """
477     (out,err),proc = rexec(
478         "cat %(pidfile)s" % {
479             'pidfile' : pidfile,
480         },
481         host = host,
482         port = port,
483         user = user,
484         agent = agent,
485         identity = identity,
486         server_key = server_key
487         )
488         
489     if proc.wait():
490         return None
491     
492     if out:
493         try:
494             return map(int,out.strip().split(' ',1))
495         except:
496             # Ignore, many ways to fail that don't matter that much
497             return None
498
499 @eintr_retry
500 def rstatus(pid, ppid, 
501         host = None, 
502         port = None, 
503         user = None, 
504         agent = None, 
505         identity = None,
506         server_key = None):
507     """
508     Returns a code representing the the status of a remote process
509
510         :param pid: Process id of the process
511         :type pid: int
512
513         :param ppid: Parent process id of process
514         :type ppid: int
515     
516         :rtype: int (One of NOT_STARTED, RUNNING, FINISHED)
517     
518     """
519     (out,err),proc = rexec(
520         # Check only by pid. pid+ppid does not always work (especially with sudo) 
521         " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
522             'ppid' : ppid,
523             'pid' : pid,
524         },
525         host = host,
526         port = port,
527         user = user,
528         agent = agent,
529         identity = identity,
530         server_key = server_key
531         )
532     
533     if proc.wait():
534         return ProcStatus.NOT_STARTED
535     
536     status = False
537     if err:
538         if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
539             status = True
540     elif out:
541         status = (out.strip() == 'wait')
542     else:
543         return ProcStatus.NOT_STARTED
544     return ProcStatus.RUNNING if status else ProcStatus.FINISHED
545
546 @eintr_retry
547 def rkill(pid, ppid,
548         host = None, 
549         port = None, 
550         user = None, 
551         agent = None, 
552         sudo = False,
553         identity = None, 
554         server_key = None, 
555         nowait = False):
556     """
557     Sends a kill signal to a remote process.
558
559     First tries a SIGTERM, and if the process does not end in 10 seconds,
560     it sends a SIGKILL.
561  
562         :param pid: Process id of process to be killed
563         :type pid: int
564
565         :param ppid: Parent process id of process to be killed
566         :type ppid: int
567
568         :param sudo: Flag indicating if sudo should be used to kill the process
569         :type sudo: bool
570         
571     """
572     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
573     cmd = """
574 SUBKILL="%(subkill)s" ;
575 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
576 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
577 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
578     sleep 0.2 
579     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
580         break
581     else
582         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
583         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
584     fi
585     sleep 1.8
586 done
587 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
588     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
589     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
590 fi
591 """
592     if nowait:
593         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
594
595     (out,err),proc = rexec(
596         cmd % {
597             'ppid' : ppid,
598             'pid' : pid,
599             'sudo' : 'sudo -S' if sudo else '',
600             'subkill' : subkill,
601         },
602         host = host,
603         port = port,
604         user = user,
605         agent = agent,
606         identity = identity,
607         server_key = server_key
608         )
609     
610     # wait, don't leave zombies around
611     proc.wait()
612
613     return (out, err), proc
614
615 def _retry_rexec(args,
616         log_msg,
617         env = None,
618         retry = 3,
619         tmp_known_hosts = None,
620         blocking = True):
621
622     for x in xrange(retry):
623         # connects to the remote host and starts a remote connection
624         proc = subprocess.Popen(args,
625                 env = env,
626                 stdout = subprocess.PIPE,
627                 stdin = subprocess.PIPE, 
628                 stderr = subprocess.PIPE)
629         
630         # attach tempfile object to the process, to make sure the file stays
631         # alive until the process is finished with it
632         proc._known_hosts = tmp_known_hosts
633     
634         # The argument block == False forces to rexec to return immediately, 
635         # without blocking 
636         try:
637             if blocking:
638                 (out, err) = proc.communicate()
639             else:
640                 err = proc.stderr.read()
641                 out = proc.stdout.read()
642
643             log(log_msg, logging.DEBUG, out, err)
644
645             if proc.poll():
646                 skip = False
647
648                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
649                     # SSH error, can safely retry
650                     skip = True 
651                 elif retry:
652                     # Probably timed out or plain failed but can retry
653                     skip = True 
654                 
655                 if skip:
656                     t = x*2
657                     msg = "SLEEPING %d ... ATEMPT %d - command %s " % ( 
658                             t, x, " ".join(args))
659                     log(msg, logging.DEBUG)
660
661                     time.sleep(t)
662                     continue
663             break
664         except RuntimeError, e:
665             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
666             log(msg, logging.DEBUG, out, err)
667
668             if retry <= 0:
669                 raise
670             retry -= 1
671         
672     return ((out, err), proc)
673
674