NS3Client: replacing socat for ssh
[nepi.git] / src / nepi / util / sshfuncs.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19 #         Claudio Freire <claudio-daniel.freire@inria.fr>
20
21 ## TODO: This code needs reviewing !!!
22
23 import base64
24 import errno
25 import hashlib
26 import logging
27 import os
28 import os.path
29 import re
30 import select
31 import signal
32 import socket
33 import subprocess
34 import threading
35 import time
36 import tempfile
37
38 logger = logging.getLogger("sshfuncs")
39
40 def log(msg, level, out = None, err = None):
41     if out:
42         msg += " - OUT: %s " % out
43
44     if err:
45         msg += " - ERROR: %s " % err
46
47     logger.log(level, msg)
48
49
50 if hasattr(os, "devnull"):
51     DEV_NULL = os.devnull
52 else:
53     DEV_NULL = "/dev/null"
54
55 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
56
57 class STDOUT: 
58     """
59     Special value that when given to rspawn in stderr causes stderr to 
60     redirect to whatever stdout was redirected to.
61     """
62
63 class ProcStatus:
64     """
65     Codes for status of remote spawned process
66     """
67     # Process is still running
68     RUNNING = 1
69
70     # Process is finished
71     FINISHED = 2
72     
73     # Process hasn't started running yet (this should be very rare)
74     NOT_STARTED = 3
75
76 hostbyname_cache = dict()
77 hostbyname_cache_lock = threading.Lock()
78
79 def gethostbyname(host):
80     global hostbyname_cache
81     global hostbyname_cache_lock
82     
83     hostbyname = hostbyname_cache.get(host)
84     if not hostbyname:
85         with hostbyname_cache_lock:
86             hostbyname = socket.gethostbyname(host)
87             hostbyname_cache[host] = hostbyname
88
89             msg = " Added hostbyname %s - %s " % (host, hostbyname)
90             log(msg, logging.DEBUG)
91
92     return hostbyname
93
94 OPENSSH_HAS_PERSIST = None
95
96 def openssh_has_persist():
97     """ The ssh_config options ControlMaster and ControlPersist allow to
98     reuse a same network connection for multiple ssh sessions. In this 
99     way limitations on number of open ssh connections can be bypassed.
100     However, older versions of openSSH do not support this feature.
101     This function is used to determine if ssh connection persist features
102     can be used.
103     """
104     global OPENSSH_HAS_PERSIST
105     if OPENSSH_HAS_PERSIST is None:
106         proc = subprocess.Popen(["ssh","-v"],
107             stdout = subprocess.PIPE,
108             stderr = subprocess.STDOUT,
109             stdin = open("/dev/null","r") )
110         out,err = proc.communicate()
111         proc.wait()
112         
113         vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
114         OPENSSH_HAS_PERSIST = bool(vre.match(out))
115     return OPENSSH_HAS_PERSIST
116
117 def make_server_key_args(server_key, host, port):
118     """ Returns a reference to a temporary known_hosts file, to which 
119     the server key has been added. 
120     
121     Make sure to hold onto the temp file reference until the process is 
122     done with it
123
124     :param server_key: the server public key
125     :type server_key: str
126
127     :param host: the hostname
128     :type host: str
129
130     :param port: the ssh port
131     :type port: str
132
133     """
134     if port is not None:
135         host = '%s:%s' % (host, str(port))
136
137     # Create a temporary server key file
138     tmp_known_hosts = tempfile.NamedTemporaryFile()
139    
140     hostbyname = gethostbyname(host) 
141
142     # Add the intended host key
143     tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
144     
145     # If we're not in strict mode, add user-configured keys
146     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
147         user_hosts_path = '%s/.ssh/known_hosts' % (os.environ.get('HOME',""),)
148         if os.access(user_hosts_path, os.R_OK):
149             f = open(user_hosts_path, "r")
150             tmp_known_hosts.write(f.read())
151             f.close()
152         
153     tmp_known_hosts.flush()
154     
155     return tmp_known_hosts
156
157 def make_control_path(agent, forward_x11):
158     ctrl_path = "/tmp/nepi_ssh"
159
160     if agent:
161         ctrl_path +="_a"
162
163     if forward_x11:
164         ctrl_path +="_x"
165
166     ctrl_path += "-%r@%h:%p"
167
168     return ctrl_path
169
170 def shell_escape(s):
171     """ Escapes strings so that they are safe to use as command-line 
172     arguments """
173     if SHELL_SAFE.match(s):
174         # safe string - no escaping needed
175         return s
176     else:
177         # unsafe string - escape
178         def escp(c):
179             if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
180                 return c
181             else:
182                 return "'$'\\x%02x''" % (ord(c),)
183         s = ''.join(map(escp,s))
184         return "'%s'" % (s,)
185
186 def eintr_retry(func):
187     """Retries a function invocation when a EINTR occurs"""
188     import functools
189     @functools.wraps(func)
190     def rv(*p, **kw):
191         retry = kw.pop("_retry", False)
192         for i in xrange(0 if retry else 4):
193             try:
194                 return func(*p, **kw)
195             except (select.error, socket.error), args:
196                 if args[0] == errno.EINTR:
197                     continue
198                 else:
199                     raise 
200             except OSError, e:
201                 if e.errno == errno.EINTR:
202                     continue
203                 else:
204                     raise
205         else:
206             return func(*p, **kw)
207     return rv
208
209 def rexec(command, host, user, 
210         port = None, 
211         agent = True,
212         sudo = False,
213         identity = None,
214         server_key = None,
215         env = None,
216         tty = False,
217         connect_timeout = 30,
218         retry = 3,
219         persistent = True,
220         forward_x11 = False,
221         blocking = True,
222         strict_host_checking = True):
223     """
224     Executes a remote command, returns ((stdout,stderr),process)
225     """
226     
227     tmp_known_hosts = None
228     hostip = gethostbyname(host)
229
230     args = ['ssh', '-C',
231             # Don't bother with localhost. Makes test easier
232             '-o', 'NoHostAuthenticationForLocalhost=yes',
233             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
234             '-o', 'ConnectionAttempts=3',
235             '-o', 'ServerAliveInterval=30',
236             '-o', 'TCPKeepAlive=yes',
237             '-l', user, hostip or host]
238
239     if persistent and openssh_has_persist():
240         args.extend([
241             '-o', 'ControlMaster=auto',
242             '-o', 'ControlPath=%s' % (make_control_path(agent, forward_x11),),
243             '-o', 'ControlPersist=60' ])
244
245     if not strict_host_checking:
246         # Do not check for Host key. Unsafe.
247         args.extend(['-o', 'StrictHostKeyChecking=no'])
248
249     if agent:
250         args.append('-A')
251
252     if port:
253         args.append('-p%d' % port)
254
255     if identity:
256         args.extend(('-i', identity))
257
258     if tty:
259         args.append('-t')
260         args.append('-t')
261
262     if forward_x11:
263         args.append('-X')
264
265     if server_key:
266         # Create a temporary server key file
267         tmp_known_hosts = make_server_key_args(server_key, host, port)
268         args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
269
270     if sudo:
271         command = "sudo " + command
272
273     args.append(command)
274     
275     log_msg = " rexec - host %s - command %s " % (host, " ".join(args))
276
277     stdout = stderr = stdin = subprocess.PIPE
278     if forward_x11:
279         stdout = stderr = stdin = None
280
281     return _retry_rexec(args, log_msg, 
282             stderr = stderr,
283             stdin = stdin,
284             stdout = stdout,
285             env = env, 
286             retry = retry, 
287             tmp_known_hosts = tmp_known_hosts,
288             blocking = blocking)
289
290 def rcopy(source, dest,
291         port = None, 
292         agent = True, 
293         recursive = False,
294         identity = None,
295         server_key = None,
296         retry = 3,
297         strict_host_checking = True):
298     """
299     Copies from/to remote sites.
300     
301     Source and destination should have the user and host encoded
302     as per scp specs.
303     
304     Source can be a list of files to copy to a single destination,
305     in which case it is advised that the destination be a folder.
306     """
307     
308     # Parse destination as <user>@<server>:<path>
309     if isinstance(dest, basestring) and ':' in dest:
310         remspec, path = dest.split(':',1)
311     elif isinstance(source, basestring) and ':' in source:
312         remspec, path = source.split(':',1)
313     else:
314         raise ValueError, "Both endpoints cannot be local"
315     user,host = remspec.rsplit('@',1)
316     
317     # plain scp
318     tmp_known_hosts = None
319
320     args = ['scp', '-q', '-p', '-C',
321             # Speed up transfer using blowfish cypher specification which is 
322             # faster than the default one (3des)
323             '-c', 'blowfish',
324             # Don't bother with localhost. Makes test easier
325             '-o', 'NoHostAuthenticationForLocalhost=yes',
326             '-o', 'ConnectTimeout=60',
327             '-o', 'ConnectionAttempts=3',
328             '-o', 'ServerAliveInterval=30',
329             '-o', 'TCPKeepAlive=yes' ]
330             
331     if port:
332         args.append('-P%d' % port)
333
334     if recursive:
335         args.append('-r')
336
337     if identity:
338         args.extend(('-i', identity))
339
340     if server_key:
341         # Create a temporary server key file
342         tmp_known_hosts = make_server_key_args(server_key, host, port)
343         args.extend(['-o', 'UserKnownHostsFile=%s' % (tmp_known_hosts.name,)])
344
345     if not strict_host_checking:
346         # Do not check for Host key. Unsafe.
347         args.extend(['-o', 'StrictHostKeyChecking=no'])
348
349     if isinstance(source, list):
350         args.extend(source)
351     else:
352         if openssh_has_persist():
353             args.extend([
354                 '-o', 'ControlMaster=auto',
355                 '-o', 'ControlPath=%s' % (make_control_path(agent, False),)
356                 ])
357         args.append(source)
358
359     args.append(dest)
360
361     log_msg = " rcopy - host %s - command %s " % (host, " ".join(args))
362     
363     return _retry_rexec(args, log_msg, env = None, retry = retry, 
364             tmp_known_hosts = tmp_known_hosts,
365             blocking = True)
366
367 def rspawn(command, pidfile, 
368         stdout = '/dev/null', 
369         stderr = STDOUT, 
370         stdin = '/dev/null',
371         home = None, 
372         create_home = False, 
373         sudo = False,
374         host = None, 
375         port = None, 
376         user = None, 
377         agent = None, 
378         identity = None, 
379         server_key = None,
380         tty = False):
381     """
382     Spawn a remote command such that it will continue working asynchronously in 
383     background. 
384
385         :param command: The command to run, it should be a single line.
386         :type command: str
387
388         :param pidfile: Path to a file where to store the pid and ppid of the 
389                         spawned process
390         :type pidfile: str
391
392         :param stdout: Path to file to redirect standard output. 
393                        The default value is /dev/null
394         :type stdout: str
395
396         :param stderr: Path to file to redirect standard error.
397                        If the special STDOUT value is used, stderr will 
398                        be redirected to the same file as stdout
399         :type stderr: str
400
401         :param stdin: Path to a file with input to be piped into the command's standard input
402         :type stdin: str
403
404         :param home: Path to working directory folder. 
405                     It is assumed to exist unless the create_home flag is set.
406         :type home: str
407
408         :param create_home: Flag to force creation of the home folder before 
409                             running the command
410         :type create_home: bool
411  
412         :param sudo: Flag forcing execution with sudo user
413         :type sudo: bool
414         
415         :rtype: touple
416
417         (stdout, stderr), process
418         
419         Of the spawning process, which only captures errors at spawning time.
420         Usually only useful for diagnostics.
421     """
422     # Start process in a "daemonized" way, using nohup and heavy
423     # stdin/out redirection to avoid connection issues
424     if stderr is STDOUT:
425         stderr = '&1'
426     else:
427         stderr = ' ' + stderr
428     
429     daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
430         'command' : command,
431         'pidfile' : shell_escape(pidfile),
432         'stdout' : stdout,
433         'stderr' : stderr,
434         'stdin' : stdin,
435     }
436     
437     cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
438             'command' : shell_escape(daemon_command),
439             'sudo' : 'sudo -S' if sudo else '',
440             'pidfile' : shell_escape(pidfile),
441             'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
442             'create' : 'mkdir -p %s ; ' % (shell_escape(home),) if create_home and home else '',
443         }
444
445     (out,err),proc = rexec(
446         cmd,
447         host = host,
448         port = port,
449         user = user,
450         agent = agent,
451         identity = identity,
452         server_key = server_key,
453         tty = tty ,
454         )
455     
456     if proc.wait():
457         raise RuntimeError, "Failed to set up application on host %s: %s %s" % (host, out,err,)
458
459     return ((out, err), proc)
460
461 @eintr_retry
462 def rgetpid(pidfile,
463         host = None, 
464         port = None, 
465         user = None, 
466         agent = None, 
467         identity = None,
468         server_key = None):
469     """
470     Returns the pid and ppid of a process from a remote file where the 
471     information was stored.
472
473         :param home: Path to directory where the pidfile is located
474         :type home: str
475
476         :param pidfile: Name of file containing the pid information
477         :type pidfile: str
478         
479         :rtype: int
480         
481         A (pid, ppid) tuple useful for calling rstatus and rkill,
482         or None if the pidfile isn't valid yet (can happen when process is staring up)
483
484     """
485     (out,err),proc = rexec(
486         "cat %(pidfile)s" % {
487             'pidfile' : pidfile,
488         },
489         host = host,
490         port = port,
491         user = user,
492         agent = agent,
493         identity = identity,
494         server_key = server_key
495         )
496         
497     if proc.wait():
498         return None
499     
500     if out:
501         try:
502             return map(int,out.strip().split(' ',1))
503         except:
504             # Ignore, many ways to fail that don't matter that much
505             return None
506
507 @eintr_retry
508 def rstatus(pid, ppid, 
509         host = None, 
510         port = None, 
511         user = None, 
512         agent = None, 
513         identity = None,
514         server_key = None):
515     """
516     Returns a code representing the the status of a remote process
517
518         :param pid: Process id of the process
519         :type pid: int
520
521         :param ppid: Parent process id of process
522         :type ppid: int
523     
524         :rtype: int (One of NOT_STARTED, RUNNING, FINISHED)
525     
526     """
527     (out,err),proc = rexec(
528         # Check only by pid. pid+ppid does not always work (especially with sudo) 
529         " (( ps --pid %(pid)d -o pid | grep -c %(pid)d && echo 'wait')  || echo 'done' ) | tail -n 1" % {
530             'ppid' : ppid,
531             'pid' : pid,
532         },
533         host = host,
534         port = port,
535         user = user,
536         agent = agent,
537         identity = identity,
538         server_key = server_key
539         )
540     
541     if proc.wait():
542         return ProcStatus.NOT_STARTED
543     
544     status = False
545     if err:
546         if err.strip().find("Error, do this: mount -t proc none /proc") >= 0:
547             status = True
548     elif out:
549         status = (out.strip() == 'wait')
550     else:
551         return ProcStatus.NOT_STARTED
552     return ProcStatus.RUNNING if status else ProcStatus.FINISHED
553
554 @eintr_retry
555 def rkill(pid, ppid,
556         host = None, 
557         port = None, 
558         user = None, 
559         agent = None, 
560         sudo = False,
561         identity = None, 
562         server_key = None, 
563         nowait = False):
564     """
565     Sends a kill signal to a remote process.
566
567     First tries a SIGTERM, and if the process does not end in 10 seconds,
568     it sends a SIGKILL.
569  
570         :param pid: Process id of process to be killed
571         :type pid: int
572
573         :param ppid: Parent process id of process to be killed
574         :type ppid: int
575
576         :param sudo: Flag indicating if sudo should be used to kill the process
577         :type sudo: bool
578         
579     """
580     subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
581     cmd = """
582 SUBKILL="%(subkill)s" ;
583 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
584 %(sudo)s kill %(pid)d $SUBKILL || /bin/true
585 for x in 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 ; do 
586     sleep 0.2 
587     if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` == '0' ]; then
588         break
589     else
590         %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true
591         %(sudo)s kill %(pid)d $SUBKILL || /bin/true
592     fi
593     sleep 1.8
594 done
595 if [ `ps --pid %(pid)d -o pid | grep -c %(pid)d` != '0' ]; then
596     %(sudo)s kill -9 -- -%(pid)d $SUBKILL || /bin/true
597     %(sudo)s kill -9 %(pid)d $SUBKILL || /bin/true
598 fi
599 """
600     if nowait:
601         cmd = "( %s ) >/dev/null 2>/dev/null </dev/null &" % (cmd,)
602
603     (out,err),proc = rexec(
604         cmd % {
605             'ppid' : ppid,
606             'pid' : pid,
607             'sudo' : 'sudo -S' if sudo else '',
608             'subkill' : subkill,
609         },
610         host = host,
611         port = port,
612         user = user,
613         agent = agent,
614         identity = identity,
615         server_key = server_key
616         )
617     
618     # wait, don't leave zombies around
619     proc.wait()
620
621     return (out, err), proc
622
623 def _retry_rexec(args,
624         log_msg,
625         stdout = subprocess.PIPE,
626         stdin = subprocess.PIPE, 
627         stderr = subprocess.PIPE,
628         env = None,
629         retry = 3,
630         tmp_known_hosts = None,
631         blocking = True):
632
633     for x in xrange(retry):
634         # connects to the remote host and starts a remote connection
635         proc = subprocess.Popen(args,
636                 env = env,
637                 stdout = stdout,
638                 stdin = stdin, 
639                 stderr = stderr)
640         
641         # attach tempfile object to the process, to make sure the file stays
642         # alive until the process is finished with it
643         proc._known_hosts = tmp_known_hosts
644     
645         # The argument block == False forces to rexec to return immediately, 
646         # without blocking 
647         try:
648             err = out = " "
649             if blocking:
650                 (out, err) = proc.communicate()
651             elif stdout:
652                 out = proc.stdout.read()
653                 if proc.poll() and stderr:
654                     err = proc.stderr.read()
655
656             log(log_msg, logging.DEBUG, out, err)
657
658             if proc.poll():
659                 skip = False
660
661                 if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
662                     # SSH error, can safely retry
663                     skip = True 
664                 elif retry:
665                     # Probably timed out or plain failed but can retry
666                     skip = True 
667                 
668                 if skip:
669                     t = x*2
670                     msg = "SLEEPING %d ... ATEMPT %d - command %s " % ( 
671                             t, x, " ".join(args))
672                     log(msg, logging.DEBUG)
673
674                     time.sleep(t)
675                     continue
676             break
677         except RuntimeError, e:
678             msg = " rexec EXCEPTION - TIMEOUT -> %s \n %s" % ( e.args, log_msg )
679             log(msg, logging.DEBUG, out, err)
680
681             if retry <= 0:
682                 raise
683             retry -= 1
684         
685     return ((out, err), proc)
686
687