Adding X11 forwarding tests for resources/linux/node.py
[nepi.git] / src / neco / util / sshfuncs.py
index 77698ca..5690680 100644 (file)
@@ -13,7 +13,7 @@ import tempfile
 import hashlib
 
 OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyyyy_ssh_control_path"
+CONTROL_PATH = "yyy_ssh_ctrl_path"
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -111,14 +111,19 @@ def make_connkey(user, host, port):
         connkey = hashlib.sha1(connkey).hexdigest()
     return connkey
 
+def make_control_path(user, host, port):
+    connkey = make_connkey(user, host, port)
+    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
+
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
-        stdin = "", 
+        stdin = "",
         identity_file = None,
+        env = None,
         tty = False,
-        tty2 = False,
+        x11 = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
@@ -127,7 +132,6 @@ def rexec(command, host, user,
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    connkey = make_connkey(user, host, port)
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -141,9 +145,10 @@ def rexec(command, host, user,
             '-l', user, host]
 
     if persistent and openssh_has_persist():
+        control_path = make_control_path(user, host, port)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+            '-o', 'ControlPath=%s' % control_path,
             '-o', 'ControlPersist=60' ])
     if agent:
         args.append('-A')
@@ -153,14 +158,21 @@ def rexec(command, host, user,
         args.extend(('-i', identity_file))
     if tty:
         args.append('-t')
-    elif tty2:
-        args.append('-t')
-        args.append('-t')
+        if sudo:
+            args.append('-t')
+    if x11:
+        args.append('-X')
+
+    if env:
+        export = ''
+        for envkey, envval in env.iteritems():
+            export += '%s=%s ' % (envkey, envval)
+        command = export + command
+
     if sudo:
         command = "sudo " + command
-    args.append(command)
 
-    print " ".join(args)
+    args.append(command)
 
     for x in xrange(retry or 3):
         # connects to the remote host and starts a remote connection
@@ -186,7 +198,7 @@ def rexec(command, host, user,
         
     return ((out, err), proc)
 
-def rcopy(source, dest, host, user,
+def rcopy(source, dest,
         port = None, 
         agent = True, 
         recursive = False,
@@ -211,7 +223,6 @@ def rcopy(source, dest, host, user,
     
     if isinstance(source, file) and source.tell() == 0:
         source = source.name
-
     elif hasattr(source, 'read'):
         tmp = tempfile.NamedTemporaryFile()
         while True:
@@ -226,8 +237,17 @@ def rcopy(source, dest, host, user,
     if isinstance(source, file) or isinstance(dest, file) \
             or hasattr(source, 'read')  or hasattr(dest, 'write'):
         assert not recursive
+    
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        tmp_known_hosts = None
         
-        connkey = make_connkey(user,host,port)
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -239,10 +259,12 @@ def rcopy(source, dest, host, user,
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes',
                 host ]
+
         if openssh_has_persist():
+            control_path = make_control_path(user, host, port)
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+                '-o', 'ControlPath=%s' % control_path,
                 '-o', 'ControlPersist=60' ])
         if port:
             args.append('-P%d' % port)
@@ -250,12 +272,12 @@ def rcopy(source, dest, host, user,
             args.extend(('-i', identity_file))
         
         if isinstance(source, file) or hasattr(source, 'read'):
-            args.append('cat > %s' % (shell_escape(dest),))
+            args.append('cat > %s' % (shell_escape(path),))
         elif isinstance(dest, file) or hasattr(dest, 'write'):
-            args.append('cat %s' % (shell_escape(dest),))
+            args.append('cat %s' % (shell_escape(path),))
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
-        
+
         # connects to the remote host and starts a remote connection
         if isinstance(source, file):
             proc = subprocess.Popen(args, 
@@ -346,6 +368,15 @@ def rcopy(source, dest, host, user,
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
     else:
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+
         # plain scp
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
@@ -369,12 +400,13 @@ def rcopy(source, dest, host, user,
             args.extend(source)
         else:
             if openssh_has_persist():
-                connkey = make_connkey(user,host,port)
+                control_path = make_control_path(user, host, port)
                 args.extend([
                     '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
+                    '-o', 'ControlPath=%s' % control_path ])
             args.append(source)
-        args.append("%s@%s:%s" %(user, host, dest))
+
+        args.append(dest)
 
         # connects to the remote host and starts a remote connection
         proc = subprocess.Popen(args, 
@@ -436,21 +468,21 @@ def rspawn(command, pidfile,
     
     daemon_command = '{ { %(command)s  > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % {
         'command' : command,
-        'pidfile' : shell_escape(pidfile),
+        'pidfile' : pidfile,
         
         'stdout' : stdout,
         'stderr' : stderr,
         'stdin' : stdin,
     }
     
-    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c %(command)s " % {
-            'command' : shell_escape(daemon_command),
+    cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % {
+            'command' : daemon_command,
             
             'sudo' : 'sudo -S' if sudo else '',
             
-            'pidfile' : shell_escape(pidfile),
-            'gohome' : 'cd %s ; ' % (shell_escape(home),) if home else '',
-            'create' : 'mkdir -p %s ; ' % (shell_escape,) if create_home else '',
+            'pidfile' : pidfile,
+            'gohome' : 'cd %s ; ' % home if home else '',
+            'create' : 'mkdir -p %s ; ' % home if create_home else '',
         }
 
     (out,err), proc = rexec(
@@ -584,10 +616,7 @@ def rkill(pid, ppid,
         Nothing, should have killed the process
     """
     
-    if sudo:
-        subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
-    else:
-        subkill = ""
+    subkill = "$(ps --ppid %(pid)d -o pid h)" % { 'pid' : pid }
     cmd = """
 SUBKILL="%(subkill)s" ;
 %(sudo)s kill -- -%(pid)d $SUBKILL || /bin/true