Adding X11 forwarding tests for resources/linux/node.py
[nepi.git] / src / neco / util / sshfuncs.py
index 872143d..5690680 100644 (file)
@@ -13,7 +13,7 @@ import tempfile
 import hashlib
 
 OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyyyy_ssh_control_path"
+CONTROL_PATH = "yyy_ssh_ctrl_path"
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -111,13 +111,19 @@ def make_connkey(user, host, port):
         connkey = hashlib.sha1(connkey).hexdigest()
     return connkey
 
+def make_control_path(user, host, port):
+    connkey = make_connkey(user, host, port)
+    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
+
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
-        stdin = "", 
+        stdin = "",
         identity_file = None,
+        env = None,
         tty = False,
+        x11 = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
@@ -126,7 +132,6 @@ def rexec(command, host, user,
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    connkey = make_connkey(user, host, port)
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -140,9 +145,10 @@ def rexec(command, host, user,
             '-l', user, host]
 
     if persistent and openssh_has_persist():
+        control_path = make_control_path(user, host, port)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+            '-o', 'ControlPath=%s' % control_path,
             '-o', 'ControlPersist=60' ])
     if agent:
         args.append('-A')
@@ -154,9 +160,18 @@ def rexec(command, host, user,
         args.append('-t')
         if sudo:
             args.append('-t')
+    if x11:
+        args.append('-X')
+
+    if env:
+        export = ''
+        for envkey, envval in env.iteritems():
+            export += '%s=%s ' % (envkey, envval)
+        command = export + command
 
     if sudo:
         command = "sudo " + command
+
     args.append(command)
 
     for x in xrange(retry or 3):
@@ -183,7 +198,7 @@ def rexec(command, host, user,
         
     return ((out, err), proc)
 
-def rcopy(source, dest, host, user,
+def rcopy(source, dest,
         port = None, 
         agent = True, 
         recursive = False,
@@ -208,7 +223,6 @@ def rcopy(source, dest, host, user,
     
     if isinstance(source, file) and source.tell() == 0:
         source = source.name
-
     elif hasattr(source, 'read'):
         tmp = tempfile.NamedTemporaryFile()
         while True:
@@ -223,8 +237,17 @@ def rcopy(source, dest, host, user,
     if isinstance(source, file) or isinstance(dest, file) \
             or hasattr(source, 'read')  or hasattr(dest, 'write'):
         assert not recursive
+    
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        tmp_known_hosts = None
         
-        connkey = make_connkey(user,host,port)
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -236,10 +259,12 @@ def rcopy(source, dest, host, user,
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes',
                 host ]
+
         if openssh_has_persist():
+            control_path = make_control_path(user, host, port)
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+                '-o', 'ControlPath=%s' % control_path,
                 '-o', 'ControlPersist=60' ])
         if port:
             args.append('-P%d' % port)
@@ -247,12 +272,12 @@ def rcopy(source, dest, host, user,
             args.extend(('-i', identity_file))
         
         if isinstance(source, file) or hasattr(source, 'read'):
-            args.append('cat > %s' % dest)
+            args.append('cat > %s' % (shell_escape(path),))
         elif isinstance(dest, file) or hasattr(dest, 'write'):
-            args.append('cat %s' % dest)
+            args.append('cat %s' % (shell_escape(path),))
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
-        
+
         # connects to the remote host and starts a remote connection
         if isinstance(source, file):
             proc = subprocess.Popen(args, 
@@ -343,6 +368,15 @@ def rcopy(source, dest, host, user,
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
     else:
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+
         # plain scp
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
@@ -366,12 +400,13 @@ def rcopy(source, dest, host, user,
             args.extend(source)
         else:
             if openssh_has_persist():
-                connkey = make_connkey(user,host,port)
+                control_path = make_control_path(user, host, port)
                 args.extend([
                     '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
+                    '-o', 'ControlPath=%s' % control_path ])
             args.append(source)
-        args.append("%s@%s:%s" %(user, host, dest))
+
+        args.append(dest)
 
         # connects to the remote host and starts a remote connection
         proc = subprocess.Popen(args,