Adding X11 forwarding tests for resources/linux/node.py
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Tue, 13 Nov 2012 14:41:36 +0000 (15:41 +0100)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Tue, 13 Nov 2012 14:41:36 +0000 (15:41 +0100)
src/neco/resources/linux/node.py
src/neco/util/sshfuncs.py
test/resources/linux/node.py

index 3aa063c..7eddd02 100644 (file)
@@ -1,6 +1,6 @@
 from neco.execution.resource import Resource
 from neco.util.sshfuncs import eintr_retry, rexec, rcopy, \
-        rspawn, rcheck_pid, rstatus, rkill, RUNNING 
+        rspawn, rcheck_pid, rstatus, rkill, make_control_path, RUNNING 
 
 import cStringIO
 import logging
@@ -14,6 +14,9 @@ class LinuxNode(Resource):
         self.user = None
         self.port = None
         self.identity_file = None
+        self.enable_x11 = False
+        self.forward_agent = True
+
         # packet management system - either yum or apt for now...
         self._pm = None
        
@@ -23,6 +26,20 @@ class LinuxNode(Resource):
                 self.box.guid)
         self._logger.setLevel(getattr(logging, loglevel.upper()))
 
+        # For ssh connections we use the ControlMaster option which 
+        # allows us to decrease the number of open ssh network connections.
+        # Subsequent ssh connections will reuse a same master connection.
+        # This might pose a problem when using X11 and ssh-agent, since
+        # display and agent forwarded will be those of the first connection,
+        # which created the master. 
+        # To avoid reusing a master created by a previous LinuxNode instance,
+        # we explicitly erase the ControlPath socket.
+        control_path = make_control_path(self.user, self.host, self.port)
+        try:
+            os.remove(control_path)
+        except:
+            pass
+
     @property
     def pm(self):
         if self._pm:
@@ -66,10 +83,26 @@ class LinuxNode(Resource):
         if not os.path.isfile(src):
             src = cStringIO.StringIO(src)
 
+        # Build destination as <user>@<server>:<path>
+        dst = "%s@%s:%s" % (self.user, self.host or self.ip, dst)
+
+        (out, err), proc = eintr_retry(rcopy)(
+            src, dst, 
+            port = self.port,
+            identity_file = self.identity_file)
+
+        if proc.wait():
+            msg = "Error uploading to %s got:\n%s%s" %\
+                    (self.host or self.ip, out, err)
+            self._logger.error(msg)
+            raise RuntimeError(msg)
+
+    def download(self, src, dst):
+        # Build destination as <user>@<server>:<path>
+        src = "%s@%s:%s" % (self.user, self.host or self.ip, src)
+
         (out, err), proc = eintr_retry(rcopy)(
             src, dst, 
-            self.host or self.ip, 
-            self.user,
             port = self.port,
             identity_file = self.identity_file)
 
@@ -85,7 +118,9 @@ class LinuxNode(Resource):
                 self.host or self.ip, 
                 self.user,
                 port = self.port, 
+                agent = self.forward_agent,
                 identity_file = self.identity_file,
+                x11 = self.enable_x11,
                 timeout = 60,
                 err_on_timeout = False,
                 persistent = False)
@@ -119,10 +154,10 @@ class LinuxNode(Resource):
             )
 
     def execute(self, command,
-            agent = True,
             sudo = False,
             stdin = "", 
             tty = False,
+            env = None,
             timeout = None,
             retry = 0,
             err_on_timeout = True,
@@ -136,11 +171,13 @@ class LinuxNode(Resource):
                 self.host or self.ip, 
                 self.user,
                 port = self.port, 
-                agent = agent,
+                agent = self.forward_agent,
                 sudo = sudo,
                 stdin = stdin, 
                 identity_file = self.identity_file,
                 tty = tty,
+                x11 = self.enable_x11,
+                env = env,
                 timeout = timeout,
                 retry = retry,
                 err_on_timeout = err_on_timeout,
@@ -175,6 +212,7 @@ class LinuxNode(Resource):
             host = self.host,
             user = self.user,
             port = self.port,
+            agent = self.forward_agent,
             identity_file = self.identity_file
             )
         
@@ -189,6 +227,7 @@ class LinuxNode(Resource):
             host = self.host,
             user = self.user,
             port = self.port,
+            agent = self.forward_agent,
             identity_file = self.identity_file
             )
         
@@ -200,6 +239,7 @@ class LinuxNode(Resource):
                 host = self.host,
                 user = self.user,
                 port = self.port,
+                agent = self.forward_agent,
                 identity_file = self.identity_file
                 )
            
@@ -214,6 +254,7 @@ class LinuxNode(Resource):
                 host = self.host,
                 user = self.user,
                 port = self.port,
+                agent = self.forward_agent,
                 sudo = sudo,
                 identity_file = self.identity_file
                 )
index 872143d..5690680 100644 (file)
@@ -13,7 +13,7 @@ import tempfile
 import hashlib
 
 OPENSSH_HAS_PERSIST = None
-CONTROL_PATH = "yyyyy_ssh_control_path"
+CONTROL_PATH = "yyy_ssh_ctrl_path"
 
 if hasattr(os, "devnull"):
     DEV_NULL = os.devnull
@@ -111,13 +111,19 @@ def make_connkey(user, host, port):
         connkey = hashlib.sha1(connkey).hexdigest()
     return connkey
 
+def make_control_path(user, host, port):
+    connkey = make_connkey(user, host, port)
+    return '/tmp/%s_%s' % ( CONTROL_PATH, connkey, )
+
 def rexec(command, host, user, 
         port = None, 
         agent = True,
         sudo = False,
-        stdin = "", 
+        stdin = "",
         identity_file = None,
+        env = None,
         tty = False,
+        x11 = False,
         timeout = None,
         retry = 0,
         err_on_timeout = True,
@@ -126,7 +132,6 @@ def rexec(command, host, user,
     """
     Executes a remote command, returns ((stdout,stderr),process)
     """
-    connkey = make_connkey(user, host, port)
     args = ['ssh', '-C',
             # Don't bother with localhost. Makes test easier
             '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -140,9 +145,10 @@ def rexec(command, host, user,
             '-l', user, host]
 
     if persistent and openssh_has_persist():
+        control_path = make_control_path(user, host, port)
         args.extend([
             '-o', 'ControlMaster=auto',
-            '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+            '-o', 'ControlPath=%s' % control_path,
             '-o', 'ControlPersist=60' ])
     if agent:
         args.append('-A')
@@ -154,9 +160,18 @@ def rexec(command, host, user,
         args.append('-t')
         if sudo:
             args.append('-t')
+    if x11:
+        args.append('-X')
+
+    if env:
+        export = ''
+        for envkey, envval in env.iteritems():
+            export += '%s=%s ' % (envkey, envval)
+        command = export + command
 
     if sudo:
         command = "sudo " + command
+
     args.append(command)
 
     for x in xrange(retry or 3):
@@ -183,7 +198,7 @@ def rexec(command, host, user,
         
     return ((out, err), proc)
 
-def rcopy(source, dest, host, user,
+def rcopy(source, dest,
         port = None, 
         agent = True, 
         recursive = False,
@@ -208,7 +223,6 @@ def rcopy(source, dest, host, user,
     
     if isinstance(source, file) and source.tell() == 0:
         source = source.name
-
     elif hasattr(source, 'read'):
         tmp = tempfile.NamedTemporaryFile()
         while True:
@@ -223,8 +237,17 @@ def rcopy(source, dest, host, user,
     if isinstance(source, file) or isinstance(dest, file) \
             or hasattr(source, 'read')  or hasattr(dest, 'write'):
         assert not recursive
+    
+        # Parse source/destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+        tmp_known_hosts = None
         
-        connkey = make_connkey(user,host,port)
         args = ['ssh', '-l', user, '-C',
                 # Don't bother with localhost. Makes test easier
                 '-o', 'NoHostAuthenticationForLocalhost=yes',
@@ -236,10 +259,12 @@ def rcopy(source, dest, host, user,
                 '-o', 'ServerAliveInterval=30',
                 '-o', 'TCPKeepAlive=yes',
                 host ]
+
         if openssh_has_persist():
+            control_path = make_control_path(user, host, port)
             args.extend([
                 '-o', 'ControlMaster=auto',
-                '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, ),
+                '-o', 'ControlPath=%s' % control_path,
                 '-o', 'ControlPersist=60' ])
         if port:
             args.append('-P%d' % port)
@@ -247,12 +272,12 @@ def rcopy(source, dest, host, user,
             args.extend(('-i', identity_file))
         
         if isinstance(source, file) or hasattr(source, 'read'):
-            args.append('cat > %s' % dest)
+            args.append('cat > %s' % (shell_escape(path),))
         elif isinstance(dest, file) or hasattr(dest, 'write'):
-            args.append('cat %s' % dest)
+            args.append('cat %s' % (shell_escape(path),))
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
-        
+
         # connects to the remote host and starts a remote connection
         if isinstance(source, file):
             proc = subprocess.Popen(args, 
@@ -343,6 +368,15 @@ def rcopy(source, dest, host, user,
         else:
             raise AssertionError, "Unreachable code reached! :-Q"
     else:
+        # Parse destination as <user>@<server>:<path>
+        if isinstance(dest, basestring) and ':' in dest:
+            remspec, path = dest.split(':',1)
+        elif isinstance(source, basestring) and ':' in source:
+            remspec, path = source.split(':',1)
+        else:
+            raise ValueError, "Both endpoints cannot be local"
+        user,host = remspec.rsplit('@',1)
+
         # plain scp
         args = ['scp', '-q', '-p', '-C',
                 # Don't bother with localhost. Makes test easier
@@ -366,12 +400,13 @@ def rcopy(source, dest, host, user,
             args.extend(source)
         else:
             if openssh_has_persist():
-                connkey = make_connkey(user,host,port)
+                control_path = make_control_path(user, host, port)
                 args.extend([
                     '-o', 'ControlMaster=no',
-                    '-o', 'ControlPath=/tmp/%s_%s' % ( CONTROL_PATH, connkey, )])
+                    '-o', 'ControlPath=%s' % control_path ])
             args.append(source)
-        args.append("%s@%s:%s" %(user, host, dest))
+
+        args.append(dest)
 
         # connects to the remote host and starts a remote connection
         proc = subprocess.Popen(args, 
index acb86d6..8512378 100755 (executable)
@@ -34,6 +34,21 @@ class LinuxBoxTestCase(unittest.TestCase):
 
         return node
 
+    def t_xterm(self, node, target):
+        if not node.is_alive():
+            print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
+            return 
+
+        node.enable_x11 = True
+
+        node.install('xterm')
+
+        out = node.execute('xterm')
+
+        node.uninstall('xterm')
+
+        self.assertEquals(out, "")
+
     def t_execute(self, node, target):
         if not node.is_alive():
             print "*** WARNING: Skipping test: Node %s is not alive\n" % (node.host)
@@ -97,11 +112,11 @@ main (void)
         command = "%s/hello" % self.home
         out = node.execute(command)
 
-        self.assertEquals(out, "Hello, world!\n")
-
         node.uninstall('gcc')
         node.rmdir(self.home)
 
+        self.assertEquals(out, "Hello, world!\n")
+
     def test_execute_fedora(self):
         self.t_execute(self.node_fedora, self.target)
 
@@ -120,6 +135,16 @@ main (void)
     def test_install_ubuntu(self):
         self.t_install(self.node_ubuntu, self.target)
 
+    def xtest_xterm_fedora(self):
+        """ PlanetLab doesn't currently support X11 forwarding.
+        Interactive test. Should not run automatically """
+        self.t_xterm(self.node_fedora, self.target)
+
+    def xtest_xterm_ubuntu(self):
+        """ Interactive test. Should not run automatically """
+        self.t_xterm(self.node_ubuntu, self.target)
+
+
 if __name__ == '__main__':
     unittest.main()