Adding ICN PlanetLab large experiment scenarios
[nepi.git] / src / nepi / util / server.py
index fb4bab8..ed9bccd 100644 (file)
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
 from nepi.util.constants import DeploymentConfiguration as DC
@@ -9,30 +8,32 @@ import os
 import os.path
 import resource
 import select
-import socket
+import shutil
 import signal
+import socket
 import sys
 import subprocess
 import threading
 import time
 import traceback
-import signal
 import re
 import tempfile
 import defer
 import functools
 import collections
+import hashlib
 
 CTRL_SOCK = "ctrl.sock"
+CTRL_PID = "ctrl.pid"
 STD_ERR = "stderr.log"
 MAX_FD = 1024
 
 STOP_MSG = "STOP"
 
-ERROR_LEVEL = 0
-DEBUG_LEVEL = 1
 TRACE = os.environ.get("NEPI_TRACE", "false").lower() in ("true", "1", "on")
 
+OPENSSH_HAS_PERSIST = None
+
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
 else:
@@ -40,6 +41,29 @@ else:
 
 SHELL_SAFE = re.compile('^[-a-zA-Z0-9_=+:.,/]*$')
 
+hostbyname_cache = dict()
+
+def gethostbyname(host):
+    hostbyname = hostbyname_cache.get(host)
+    if not hostbyname:
+        hostbyname = socket.gethostbyname(host)
+        hostbyname_cache[host] = hostbyname
+    return hostbyname
+
+def openssh_has_persist():
+    global OPENSSH_HAS_PERSIST
+    if OPENSSH_HAS_PERSIST is None:
+        proc = subprocess.Popen(["ssh","-v"],
+            stdout = subprocess.PIPE,
+            stderr = subprocess.STDOUT,
+            stdin = open("/dev/null","r") )
+        out,err = proc.communicate()
+        proc.wait()
+        
+        vre = re.compile(r'OpenSSH_(?:[6-9]|5[.][8-9]|5[.][1-9][0-9]|[1-9][0-9]).*', re.I)
+        OPENSSH_HAS_PERSIST = bool(vre.match(out))
+    return OPENSSH_HAS_PERSIST
+
 def shell_escape(s):
     """ Escapes strings so that they are safe to use as command-line arguments """
     if SHELL_SAFE.match(s):
@@ -48,7 +72,7 @@ def shell_escape(s):
     else:
         # unsafe string - escape
         def escp(c):
-            if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",):
+            if (32 <= ord(c) < 127 or c in ('\r','\n','\t')) and c not in ("'",'"'):
                 return c
             else:
                 return "'$'\\x%02x''" % (ord(c),)
@@ -78,8 +102,10 @@ def eintr_retry(func):
     return rv
 
 class Server(object):
-    def __init__(self, root_dir = ".", log_level = ERROR_LEVEL, environment_setup = ""):
+    def __init__(self, root_dir = ".", log_level = DC.ERROR_LEVEL, 
+            environment_setup = "", clean_root = False):
         self._root_dir = root_dir
+        self._clean_root = clean_root
         self._stop = False
         self._ctrl_sock = None
         self._log_level = log_level
@@ -112,6 +138,9 @@ class Server(object):
         
         # build root folder
         root = os.path.normpath(self._root_dir)
+        if self._root_dir not in [".", ""] and os.path.exists(root) \
+                and self._clean_root:
+            shutil.rmtree(root)
         if not os.path.exists(root):
             os.makedirs(root, 0755)
 
@@ -196,8 +225,35 @@ class Server(object):
 
         # create control socket
         self._ctrl_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        self._ctrl_sock.bind(CTRL_SOCK)
+        try:
+            self._ctrl_sock.bind(CTRL_SOCK)
+        except socket.error:
+            # Address in use, check pidfile
+            pid = None
+            try:
+                pidfile = open(CTRL_PID, "r")
+                pid = pidfile.read()
+                pidfile.close()
+                pid = int(pid)
+            except:
+                # no pidfile
+                pass
+            
+            if pid is not None:
+                # Check process liveliness
+                if not os.path.exists("/proc/%d" % (pid,)):
+                    # Ok, it's dead, clean the socket
+                    os.remove(CTRL_SOCK)
+            
+            # try again
+            self._ctrl_sock.bind(CTRL_SOCK)
+            
         self._ctrl_sock.listen(0)
+        
+        # Save pidfile
+        pidfile = open(CTRL_PID, "w")
+        pidfile.write(str(os.getpid()))
+        pidfile.close()
 
         # let the parent process know that the daemonization is finished
         os.write(w, "\n")
@@ -205,6 +261,7 @@ class Server(object):
         return 1
 
     def post_daemonize(self):
+        os.environ["NEPI_CONTROLLER_LOGLEVEL"] = self._log_level
         # QT, for some strange reason, redefines the SIGCHILD handler to write
         # a \0 to a fd (lets say fileno 'x'), when ever a SIGCHILD is received.
         # Server dameonization closes all file descriptors from fileno '3',
@@ -227,8 +284,8 @@ class Server(object):
                 try:
                     msg = self.recv_msg(conn)
                 except socket.timeout, e:
-                    self.log_error()
-                    break
+                    #self.log_error("SERVER recv_msg: connection timedout ")
+                    continue
                 
                 if not msg:
                     self.log_error("CONNECTION LOST")
@@ -303,7 +360,7 @@ class Server(object):
         return text
 
     def log_debug(self, text):
-        if self._log_level == DEBUG_LEVEL:
+        if self._log_level == DC.DEBUG_LEVEL:
             date = time.strftime("%Y-%m-%d %H:%M:%S")
             sys.stderr.write("DEBUG: %s\n%s\n" % (date, text))
 
@@ -430,10 +487,15 @@ class Client(object):
                
         # Wait for the forwarder to be ready, otherwise nobody
         # will be able to connect to it
-        helo = self._process.stderr.readline()
-        if helo != 'FORWARDER_READY.\n':
-            raise AssertionError, "Expected 'FORWARDER_READY.', got %r: %s" % (helo,
-                    helo + self._process.stderr.read())
+        err = []
+        helo = "nope"
+        while helo:
+            helo = self._process.stderr.readline()
+            if helo == 'FORWARDER_READY.\n':
+                break
+            err.append(helo)
+        else:
+            raise AssertionError, "Expected 'FORWARDER_READY.', got: %s" % (''.join(err),)
         
     def send_msg(self, msg):
         encoded = base64.b64encode(msg)
@@ -516,9 +578,11 @@ def _make_server_key_args(server_key, host, port, args):
         host = '%s:%s' % (host,port)
     # Create a temporary server key file
     tmp_known_hosts = tempfile.NamedTemporaryFile()
-    
+   
+    hostbyname = gethostbyname(host) 
+
     # Add the intended host key
-    tmp_known_hosts.write('%s,%s %s\n' % (host, socket.gethostbyname(host), server_key))
+    tmp_known_hosts.write('%s,%s %s\n' % (host, hostbyname, server_key))
     
     # If we're not in strict mode, add user-configured keys
     if os.environ.get('NEPI_STRICT_AUTH_MODE',"").lower() not in ('1','true','on'):
@@ -541,18 +605,28 @@ def popen_ssh_command(command, host, port, user, agent,
         tty = False,
         timeout = None,
         retry = 0,
-        err_on_timeout = True):
+        err_on_timeout = True,
+        connect_timeout = 60,
+        persistent = True,
+        hostip = None):
     """
     Executes a remote commands, returns ((stdout,stderr),process)
     """
-    if TRACE:
-        print "ssh", host, command
-    
+   
     tmp_known_hosts = None
-    args = ['ssh',
+    args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
-            '-l', user, host]
+            '-o', 'ConnectTimeout=%d' % (int(connect_timeout),),
+            '-o', 'ConnectionAttempts=3',
+            '-o', 'ServerAliveInterval=30',
+            '-o', 'TCPKeepAlive=yes',
+            '-l', user, hostip or host]
+    if persistent and openssh_has_persist():
+        args.extend([
+            '-o', 'ControlMaster=auto',
+            '-o', 'ControlPath=/tmp/nepi_ssh-%r@%h:%p',
+            '-o', 'ControlPersist=60' ])
     if agent:
         args.append('-A')
     if port:
@@ -561,13 +635,14 @@ def popen_ssh_command(command, host, port, user, agent,
         args.extend(('-i', ident_key))
     if tty:
         args.append('-t')
+        args.append('-t')
     if server_key:
         # Create a temporary server key file
         tmp_known_hosts = _make_server_key_args(
             server_key, host, port, args)
     args.append(command)
 
-    while 1:
+    for x in xrange(retry or 3):
         # connects to the remote host and starts a remote connection
         proc = subprocess.Popen(args, 
                 stdout = subprocess.PIPE,
@@ -577,20 +652,34 @@ def popen_ssh_command(command, host, port, user, agent,
         # attach tempfile object to the process, to make sure the file stays
         # alive until the process is finished with it
         proc._known_hosts = tmp_known_hosts
-        
+    
         try:
             out, err = _communicate(proc, stdin, timeout, err_on_timeout)
+            if TRACE:
+                print "COMMAND host %s, command %s, out %s, error %s" % (host, " ".join(args), out, err)
+
+            if proc.poll():
+                if err.strip().startswith('ssh: ') or err.strip().startswith('mux_client_hello_exchange: '):
+                    # SSH error, can safely retry
+                    continue
+                elif :
+                    ControlSocket /tmp/nepi_ssh-inria_alina@planetlab04.cnds.unibe.ch:22 already exists, disabling multiplexing
+                    # SSH error, can safely retry (but need to delete controlpath file)
+                    # TODO: delete file
+                    continue
+                elif retry:
+                    # Probably timed out or plain failed but can retry
+                    continue
             break
         except RuntimeError,e:
+            if TRACE:
+                print "EXCEPTION host %s, command %s, out %s, error %s, exception TIMEOUT ->  %s" % (
+                        host, " ".join(args), out, err, e.args)
+
             if retry <= 0:
                 raise
-            if TRACE:
-                print " timedout -> ", e.args
             retry -= 1
         
-    if TRACE:
-        print " -> ", out, err
-
     return ((out, err), proc)
 
 def popen_scp(source, dest, 
@@ -650,7 +739,16 @@ def popen_scp(source, dest,
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
                 host ]
+        if openssh_has_persist():
+            args.extend([
+                '-o', 'ControlMaster=auto',
+                '-o', 'ControlPath=/tmp/nepi_ssh-%r@%h:%p',
+                '-o', 'ControlPersist=60' ])
         if port:
             args.append('-P%d' % port)
         if ident_key:
@@ -774,7 +872,12 @@ def popen_scp(source, dest,
         tmp_known_hosts = None
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
-                '-o', 'NoHostAuthenticationForLocalhost=yes' ]
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectTimeout=60',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes' ]
+                
         if port:
             args.append('-P%d' % port)
         if recursive:
@@ -788,6 +891,10 @@ def popen_scp(source, dest,
         if isinstance(source,list):
             args.extend(source)
         else:
+            if openssh_has_persist():
+                args.extend([
+                    '-o', 'ControlMaster=auto',
+                    '-o', 'ControlPath=/tmp/nepi_ssh-%r@%h:%p'])
             args.append(source)
         args.append(dest)
 
@@ -810,7 +917,13 @@ def decode_and_execute():
     import base64, os
     cmd = ""
     while True:
-        cmd += os.read(0, 1)# one byte from stdin
+        try:
+            cmd += os.read(0, 1)# one byte from stdin
+        except OSError, e:            
+            if e.errno == errno.EINTR:
+                continue
+            else:
+                raise
         if cmd[-1] == "\n": 
             break
     cmd = base64.b64decode(cmd)
@@ -832,10 +945,7 @@ def popen_python(python_code,
         sudo = False, 
         environment_setup = ""):
 
-    shell = False
     cmd = ""
-    if sudo:
-        cmd +="sudo "
     if python_path:
         python_path.replace("'", r"'\''")
         cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
@@ -849,15 +959,24 @@ def popen_python(python_code,
     #cmd += "$CMD "
     #cmd += "strace -f -tt -s 200 -o strace$$.out "
     import nepi
-    cmd += "python -c 'import sys; sys.path.append(%s); from nepi.util import server; server.decode_and_execute()'" % (
+    cmd += "python -c 'import sys; sys.path.insert(0,%s); from nepi.util import server; server.decode_and_execute()'" % (
         repr(os.path.dirname(os.path.dirname(nepi.__file__))).replace("'",'"'),
     )
 
+    if sudo:
+        if ';' in cmd:
+            cmd = "sudo bash -c " + shell_escape(cmd)
+        else:
+            cmd = "sudo " + cmd
+
     if communication == DC.ACCESS_SSH:
         tmp_known_hosts = None
-        args = ['ssh',
+        args = ['ssh', '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-o', 'ConnectionAttempts=3',
+                '-o', 'ServerAliveInterval=30',
+                '-o', 'TCPKeepAlive=yes',
                 '-l', user, host]
         if agent:
             args.append('-A')
@@ -873,12 +992,11 @@ def popen_python(python_code,
                 server_key, host, port, args)
         args.append(cmd)
     else:
-        args = [cmd]
-        shell = True
+        args = [ "/bin/bash", "-c", cmd ]
 
     # connects to the remote host and starts a remote
     proc = subprocess.Popen(args,
-            shell = shell
+            shell = False
             stdout = subprocess.PIPE,
             stdin = subprocess.PIPE, 
             stderr = subprocess.PIPE)
@@ -889,7 +1007,17 @@ def popen_python(python_code,
     # send the command to execute
     os.write(proc.stdin.fileno(),
             base64.b64encode(python_code) + "\n")
-    msg = os.read(proc.stdout.fileno(), 3)
+    while True: 
+        try:
+            msg = os.read(proc.stdout.fileno(), 3)
+            break
+        except OSError, e:            
+            if e.errno == errno.EINTR:
+                continue
+            else:
+                raise
+    
     if msg != "OK\n":
         raise RuntimeError, "Failed to start remote python interpreter: \nout:\n%s%s\nerr:\n%s" % (
             msg, proc.stdout.read(), proc.stderr.read())
@@ -942,7 +1070,10 @@ def _communicate(self, input, timeout=None, err_on_timeout=True):
             else:
                 select_timeout = timelimit - curtime + 0.1
         else:
-            select_timeout = None
+            select_timeout = 1.0
+        
+        if select_timeout > 1.0:
+            select_timeout = 1.0
             
         try:
             rlist, wlist, xlist = select.select(read_set, write_set, [], select_timeout)
@@ -951,6 +1082,10 @@ def _communicate(self, input, timeout=None, err_on_timeout=True):
                 raise
             else:
                 continue
+        
+        if not rlist and not wlist and not xlist and self.poll() is not None:
+            # timeout and process exited, say bye
+            break
 
         if self.stdin in wlist:
             # When select has indicated that the file is writable,