Initially working version of PlanetLab testbed implementation.
[nepi.git] / src / nepi / util / server.py
index 36e54dc..c1558d1 100644 (file)
@@ -13,6 +13,7 @@ import threading
 import time
 import traceback
 import signal
+import re
 
 CTRL_SOCK = "ctrl.sock"
 STD_ERR = "stderr.log"
@@ -28,6 +29,20 @@ if hasattr(os, "devnull"):
 else:
     DEV_NULL = "/dev/null"
 
+
+
+SHELL_SAFE = re.compile('[-a-zA-Z0-9_=+:.,/]*')
+
+def shell_escape(s):
+    """ Escapes strings so that they are safe to use as command-line arguments """
+    if SHELL_SAFE.match(s):
+        # safe string - no escaping needed
+        return s
+    else:
+        # unsafe string - escape
+        s = s.replace("'","\\'")
+        return "'%s'" % (s,)
+
 class Server(object):
     def __init__(self, root_dir = ".", log_level = ERROR_LEVEL):
         self._root_dir = root_dir
@@ -298,13 +313,9 @@ class Client(object):
         # will be able to connect to it
         helo = self._process.stderr.readline()
         if helo != 'READY.\n':
-            raise AssertionError, "Expected 'Ready.', got %r" % (helo,)
+            raise AssertionError, "Expected 'Ready.', got %r: %s" % (helo,
+                    helo + self._process.stderr.read())
         
-        if self._process.poll():
-            err = self._process.stderr.read()
-            raise RuntimeError("Client could not be executed: %s" % \
-                    err)
-
     def send_msg(self, msg):
         encoded = base64.b64encode(msg)
         data = "%s\n" % encoded
@@ -329,8 +340,178 @@ class Client(object):
         encoded = data.rstrip() 
         return base64.b64decode(encoded)
 
+def popen_ssh_command(command, host, port, user, agent, 
+            stdin="", 
+            ident_key = None):
+        """
+        Executes a remote commands, returns ((stdout,stderr),process)
+        """
+        args = ['ssh',
+                # Don't bother with localhost. Makes test easier
+                '-o', 'NoHostAuthenticationForLocalhost=yes',
+                '-l', user, host]
+        if agent:
+            args.append('-A')
+        if port:
+            args.append('-p%d' % port)
+        if ident_key:
+            args.extend(('-i', ident_key))
+        args.append(command)
+
+        # connects to the remote host and starts a remote connection
+        proc = subprocess.Popen(args, 
+                stdout = subprocess.PIPE,
+                stdin = subprocess.PIPE, 
+                stderr = subprocess.PIPE)
+        return (proc.communicate(stdin), proc)
+def popen_scp(source, dest, port, agent, 
+            recursive = False,
+            ident_key = None):
+        """
+        Copies from/to remote sites.
+        
+        Source and destination should have the user and host encoded
+        as per scp specs.
+        
+        If source is a file object, a special mode will be used to
+        create the remote file with the same contents.
+        
+        If dest is a file object, the remote file (source) will be
+        read and written into dest.
+        
+        In these modes, recursive cannot be True.
+        """
+        
+        if isinstance(source, file) or isinstance(dest, file) \
+                or hasattr(source, 'read')  or hasattr(dest, 'write'):
+            assert not resursive
+            
+            # Parse destination as <user>@<server>:<path>
+            tgtspec, path = dest.split(':',1)
+            user,host = tgtspec.rsplit('@',1)
+            
+            args = ['ssh', '-l', user, '-C',
+                    # Don't bother with localhost. Makes test easier
+                    '-o', 'NoHostAuthenticationForLocalhost=yes' ]
+            if port:
+                args.append('-P%d' % port)
+            if ident_key:
+                args.extend(('-i', ident_key))
+            
+            if isinstance(source, file) or hasattr(source, 'read'):
+                args.append('cat > %s' % (shell_escape(path),))
+            elif isinstance(dest, file) or hasattr(dest, 'write'):
+                args.append('cat %s' % (shell_escape(path),))
+            else:
+                raise AssertionError, "Unreachable code reached! :-Q"
+            
+            # connects to the remote host and starts a remote connection
+            if isinstance(source, file):
+                proc = subprocess.Popen(args, 
+                        stdout = open('/dev/null','w'),
+                        stderr = subprocess.PIPE,
+                        stdin = source)
+                err = proc.stderr.read()
+                proc.wait()
+                return ((None,err), proc)
+            elif isinstance(dest, file):
+                proc = subprocess.Popen(args, 
+                        stdout = open('/dev/null','w'),
+                        stderr = subprocess.PIPE,
+                        stdin = source)
+                err = proc.stderr.read()
+                proc.wait()
+                return ((None,err), proc)
+            elif hasattr(source, 'read'):
+                # file-like (but not file) source
+                proc = subprocess.Popen(args, 
+                        stdout = open('/dev/null','w'),
+                        stderr = subprocess.PIPE,
+                        stdin = source)
+                
+                buf = None
+                err = []
+                while True:
+                    if not buf:
+                        buf = source.read(4096)
+                    
+                    rdrdy, wrdy, broken = os.select(
+                        [proc.stderr],
+                        [proc.stdin],
+                        [proc.stderr,proc.stdin])
+                    
+                    if proc.stderr in rdrdy:
+                        # use os.read for fully unbuffered behavior
+                        err.append(os.read(proc.stderr.fileno(), 4096))
+                    
+                    if proc.stdin in wrdy:
+                        proc.stdin.write(buf)
+                        buf = None
+                    
+                    if broken:
+                        break
+                err.append(proc.stderr.read())
+                    
+                proc.wait()
+                return ((None,''.join(err)), proc)
+            elif hasattr(dest, 'write'):
+                # file-like (but not file) dest
+                proc = subprocess.Popen(args, 
+                        stdout = open('/dev/null','w'),
+                        stderr = subprocess.PIPE,
+                        stdin = source)
+                
+                buf = None
+                err = []
+                while True:
+                    rdrdy, wrdy, broken = os.select(
+                        [proc.stderr, proc.stdout],
+                        [],
+                        [proc.stderr, proc.stdout])
+                    
+                    if proc.stderr in rdrdy:
+                        # use os.read for fully unbuffered behavior
+                        err.append(os.read(proc.stderr.fileno(), 4096))
+                    
+                    if proc.stdout in rdrdy:
+                        # use os.read for fully unbuffered behavior
+                        dest.write(os.read(proc.stdout.fileno(), 4096))
+                    
+                    if broken:
+                        break
+                err.append(proc.stderr.read())
+                    
+                proc.wait()
+                return ((None,''.join(err)), proc)
+            else:
+                raise AssertionError, "Unreachable code reached! :-Q"
+        else:
+            # plain scp
+            args = ['scp', '-q', '-p', '-C',
+                    # Don't bother with localhost. Makes test easier
+                    '-o', 'NoHostAuthenticationForLocalhost=yes' ]
+            if port:
+                args.append('-P%d' % port)
+            if recursive:
+                args.append('-r')
+            if ident_key:
+                args.extend(('-i', ident_key))
+            args.append(source)
+            args.append(dest)
+
+            # connects to the remote host and starts a remote connection
+            proc = subprocess.Popen(args, 
+                    stdout = subprocess.PIPE,
+                    stdin = subprocess.PIPE, 
+                    stderr = subprocess.PIPE)
+            comm = proc.communicate()
+            proc.wait()
+            return (comm, proc)
 def popen_ssh_subprocess(python_code, host, port, user, agent, 
-        python_path = None):
+        python_path = None,
+        ident_key = None):
         if python_path:
             python_path.replace("'", r"'\''")
             cmd = """PYTHONPATH="$PYTHONPATH":'%s' """ % python_path
@@ -362,6 +543,8 @@ def popen_ssh_subprocess(python_code, host, port, user, agent,
             args.append('-A')
         if port:
             args.append('-p%d' % port)
+        if ident_key:
+            args.extend(('-i', ident_key))
         args.append(cmd)
 
         # connects to the remote host and starts a remote rpyc connection