Added unit tests for SFA PlanetLab.
[nepi.git] / src / nepi / util / server.py
index 2d15b38..6ce1176 100644 (file)
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
 from nepi.util.constants import DeploymentConfiguration as DC
@@ -22,6 +21,7 @@ import tempfile
 import defer
 import functools
 import collections
+import hashlib
 
 CTRL_SOCK = "ctrl.sock"
 CTRL_PID = "ctrl.pid"
@@ -32,6 +32,8 @@ STOP_MSG = "STOP"
 
 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
 
+OPENSSH_HAS_PERSIST = None
+
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
 else:
@@ -39,6 +41,20 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
+def openssh_has_persist():
+    global OPENSSH_HAS_PERSIST
+    if OPENSSH_HAS_PERSIST is None:
+        proc = subprocess.Popen(["ssh","-v"],
+            stdout = subprocess.PIPE,
+            stderr = subprocess.STDOUT,
+            stdin = open("/dev/null","r") )
+        out,err = proc.communicate()
+        proc.wait()
+        
+        vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
+        OPENSSH_HAS_PERSIST = bool(vre.match(out))
+    return OPENSSH_HAS_PERSIST
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line arguments """
     if SHELL_SAFE.match(s):
@@ -571,6 +587,12 @@ def _make_server_key_args(server_key, host, port, args):
     
     return tmp_known_hosts
 
+def make_connkey(user, host, port):
+    connkey = repr((user,host,port)).encode("base64").strip().replace('/','.')
+    if len(connkey) > 60:
+        connkey = hashlib.sha1(connkey).hexdigest()
+    return connkey
+
 def popen_ssh_command(command, host, port, user, agent, 
         stdin="", 
         ident_key = None,
@@ -579,7 +601,9 @@ def popen_ssh_command(command, host, port, user, agent,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
-        connect_timeout = 30):
+        connect_timeout = 30,
+        persistent = True,
+        hostip = None):
     """
     Executes a remote commands, returns ((stdout,stderr),process)
     """
@@ -587,11 +611,20 @@ def popen_ssh_command(command, host, port, user, agent,
         print "ssh", host, command
     
     tmp_known_hosts = None
-    args = ['ssh',
+    connkey = make_connkey(user,host,port)
+    args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
             '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
-            '-l', user, host]
+            '-o', 'ConnectionAttempts=3',
+            '-o', 'ServerAliveInterval=30',
+            '-o', 'TCPKeepAlive=yes',
+            '-l', user, hostip or host]
+    if persistent and openssh_has_persist():
+        args.extend([
+            '-o', 'ControlMaster=auto',
+            '-o', 'ControlPath=/tmp/nepi_ssh_pl_%s' % ( connkey, ),
+            '-o', 'ControlPersist=60' ])
     if agent:
         args.append('-A')
     if port:
@@ -619,9 +652,13 @@ def popen_ssh_command(command, host, port, user, agent,
         
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
-            if proc.poll() and err.strip().startswith('ssh: '):
-                # SSH error, can safely retry
-                continue
+            if proc.poll():
+                if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
+                    # SSH error, can safely retry
+                    continue
+                elif retry:
+                    # Probably timed out or plain failed but can retry
+                    continue
             break
         except RuntimeError,e:
             if retry <= 0:
@@ -689,10 +726,20 @@ def popen_scp(source, dest,
         user,host = remspec.rsplit('@',1)
         tmp_known_hosts = None
         
+        connkey = make_connkey(user,host,port)
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
                 host ]
+        if openssh_has_persist():
+            args.extend([
+                '-o', 'ControlMaster=auto',
+                '-o', 'ControlPath=/tmp/nepi_ssh_pl_%s' % ( connkey, ),
+                '-o', 'ControlPersist=60' ])
         if port:
             args.append('-P%d' % port)
         if ident_key:
@@ -816,7 +863,12 @@ def popen_scp(source, dest,
         tmp_known_hosts = None
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
-                '-o', 'NoHostAuthenticationForLocalhost=yes' ]
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=30',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes' ]
+                
         if port:
             args.append('-P%d' % port)
         if recursive:
@@ -830,6 +882,11 @@ def popen_scp(source, dest,
         if isinstance(source,list):
             args.extend(source)
         else:
+            if openssh_has_persist():
+                connkey = make_connkey(user,host,port)
+                args.extend([
+                    '-o', 'ControlMaster=no',
+                    '-o', 'ControlPath=/tmp/nepi_ssh_pl_%s' % ( connkey, ) ])
             args.append(source)
         args.append(dest)
 
@@ -906,9 +963,12 @@ def popen_python(python_code,
 
     if communication == DC.ACCESS_SSH:
         tmp_known_hosts = None
-        args = ['ssh',
+        args = ['ssh', '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
                 '-l', user, host]
         if agent:
             args.append('-A')
@@ -1002,7 +1062,10 @@ def _communicate(self, input, timeout=None, err_on_timeout=True):
             else:
                 select_timeout = timelimit - curtime + 0.1
         else:
-            select_timeout = None
+            select_timeout = 1.0
+        
+        if select_timeout > 1.0:
+            select_timeout = 1.0
             
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
@@ -1011,6 +1074,10 @@ def _communicate(self, input, timeout=None, err_on_timeout=True):
                 raise
             else:
                 continue
+        
+        if not rlist and not wlist and not xlist and self.poll() is not None:
+            # timeout and process exited, say bye
+            break
 
         if self.stdin in wlist:
             # When select has indicated that the file is writable,