Fixing UdpTunnel unit tests for PlanetLab
[nepi.git] / src / nepi / resources / linux / udptunnel.py
index d240416..01f6898 100644 (file)
 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
 
 from nepi.execution.attribute import Attribute, Flags, Types
-from nepi.execution.resource import ResourceManager, clsinit_copy, ResourceState, \
+from nepi.execution.resource import clsinit_copy, ResourceState, \
         reschedule_delay
 from nepi.resources.linux.application import LinuxApplication
+from nepi.util.sshfuncs import ProcStatus
 from nepi.util.timefuncs import tnow, tdiffsec
 
 import os
@@ -30,6 +31,41 @@ import time
 @clsinit_copy
 class UdpTunnel(LinuxApplication):
     _rtype = "UdpTunnel"
+    _help = "Constructs a tunnel between two Linux endpoints using a UDP connection "
+    _backend = "linux"
+
+    @classmethod
+    def _register_attributes(cls):
+        cipher = Attribute("cipher",
+               "Cipher to encript communication. "
+                "One of PLAIN, AES, Blowfish, DES, DES3. ",
+                default = None,
+                allowed = ["PLAIN", "AES", "Blowfish", "DES", "DES3"],
+                type = Types.Enumerate, 
+                flags = Flags.Design)
+
+        cipher_key = Attribute("cipherKey",
+                "Specify a symmetric encryption key with which to protect "
+                "packets across the tunnel. python-crypto must be installed "
+                "on the system." ,
+                flags = Flags.Design)
+
+        txqueuelen = Attribute("txQueueLen",
+                "Specifies the interface's transmission queue length. "
+                "Defaults to 1000. ", 
+                type = Types.Integer, 
+                flags = Flags.Design)
+
+        bwlimit = Attribute("bwLimit",
+                "Specifies the interface's emulated bandwidth in bytes "
+                "per second.",
+                type = Types.Integer, 
+                flags = Flags.Design)
+
+        cls._register_attribute(cipher)
+        cls._register_attribute(cipher_key)
+        cls._register_attribute(txqueuelen)
+        cls._register_attribute(bwlimit)
 
     def __init__(self, ec, guid):
         super(UdpTunnel, self).__init__(ec, guid)
@@ -81,9 +117,13 @@ class UdpTunnel(LinuxApplication):
                 "remote_port")
         ret_file = os.path.join(self.run_home(endpoint), 
                 "ret_file")
+        cipher = self.get("cipher")
+        cipher_key = self.get("cipherKey")
+        bwlimit = self.get("bwLimit")
+        txqueuelen = self.get("txQueueLen")
         udp_connect_command = endpoint.udp_connect_command(
                 remote_ip, local_port_file, remote_port_file,
-                ret_file)
+                ret_file, cipher, cipher_key, bwlimit, txqueuelen)
 
         # upload command to connect.sh script
         shfile = os.path.join(self.app_home(endpoint), "udp-connect.sh")
@@ -100,7 +140,6 @@ class UdpTunnel(LinuxApplication):
         msg = " Failed to connect endpoints "
         
         if proc.poll():
-            self.fail()
             self.error(msg, out, err)
             raise RuntimeError, msg
     
@@ -113,7 +152,6 @@ class UdpTunnel(LinuxApplication):
             (out, err), proc = endpoint.node.check_errors(self.run_home(endpoint))
             # Out is what was written in the stderr file
             if err:
-                self.fail()
                 msg = " Failed to start command '%s' " % command
                 self.error(msg, out, err)
                 raise RuntimeError, msg
@@ -122,7 +160,7 @@ class UdpTunnel(LinuxApplication):
         port = self.wait_local_port(endpoint)
         return (port, pid, ppid)
 
-    def provision(self):
+    def do_provision(self):
         # create run dir for tunnel on each node 
         self.endpoint1.node.mkdir(self.run_home(self.endpoint1))
         self.endpoint2.node.mkdir(self.run_home(self.endpoint2))
@@ -149,58 +187,33 @@ class UdpTunnel(LinuxApplication):
        
         self.info("Provisioning finished")
  
-        self.debug("----- READY ---- ")
-        self._provision_time = tnow()
-        self._state = ResourceState.PROVISIONED
+        self.set_provisioned()
 
-    def deploy(self):
+    def do_deploy(self):
         if (not self.endpoint1 or self.endpoint1.state < ResourceState.READY) or \
             (not self.endpoint2 or self.endpoint2.state < ResourceState.READY):
             self.ec.schedule(reschedule_delay, self.deploy)
         else:
-            try:
-                self.discover()
-                self.provision()
-            except:
-                self.fail()
-                raise
+            self.do_discover()
+            self.do_provision()
  
-            self.debug("----- READY ---- ")
-            self._ready_time = tnow()
-            self._state = ResourceState.READY
+            self.set_ready()
 
-    def start(self):
-        if self._state == ResourceState.READY:
+    def do_start(self):
+        if self.state == ResourceState.READY:
             command = self.get("command")
             self.info("Starting command '%s'" % command)
-
-            self._start_time = tnow()
-            self._state = ResourceState.STARTED
+            
+            self.set_started()
         else:
             msg = " Failed to execute command '%s'" % command
             self.error(msg, out, err)
-            self._state = ResourceState.FAILED
             raise RuntimeError, msg
 
-    def stop(self):
-        command = self.get('command') or ''
-        state = self.state
-        
-        if state == ResourceState.STARTED:
-            self.info("Stopping command '%s'" % command)
-
-            command = "bash %s" % os.path.join(self.app_home, "stop.sh")
-            (out, err), proc = self.execute_command(command,
-                    blocking = True)
-
-            self._stop_time = tnow()
-            self._state = ResourceState.STOPPED
-
-    def stop(self):
+    def do_stop(self):
         """ Stops application execution
         """
         if self.state == ResourceState.STARTED:
-            stopped = True
             self.info("Stopping tunnel")
     
             # Only try to kill the process if the pid and ppid
@@ -211,16 +224,13 @@ class UdpTunnel(LinuxApplication):
                 (out2, err2), proc2 = self.endpoint2.node.kill(self._pid2, 
                         self._ppid2, sudo = True) 
 
-                if err1 or err2 or pro1.poll() or proc2.poll():
+                if (proc1.poll() and err1) or (proc2.poll() and err2):
                     # check if execution errors occurred
                     msg = " Failed to STOP tunnel"
-                    self.error(msg, out, err)
-                    self.fail()
-                    stopped = False
+                    self.error(msg, err1, err2)
+                    raise RuntimeError, msg
 
-            if stopped:
-                self._stop_time = tnow()
-                self._state = ResourceState.STOPPED
+            self.set_stopped()
 
     @property
     def state(self):
@@ -232,27 +242,29 @@ class UdpTunnel(LinuxApplication):
             # requested every 'state_check_delay' seconds.
             state_check_delay = 0.5
             if tdiffsec(tnow(), self._last_state_check) > state_check_delay:
-                # check if execution errors occurred
-                (out1, err1), proc1 = self.endpoint1.node.check_errors(
-                        self.run_home(self.endpoint1))
-
-                (out2, err2), proc2 = self.endpoint2.node.check_errors(
-                        self.run_home(self.endpoint2))
-
-                if err1 or err2:
-                    msg = " Failed to connect endpoints "
-                    self.error(msg, err1, err2)
-                    self.fail()
-
-                elif self._pid1 and self._ppid1 and self._pid2 and self._ppid2:
+                if self._pid1 and self._ppid1 and self._pid2 and self._ppid2:
+                    # Make sure the process is still running in background
                     # No execution errors occurred. Make sure the background
                     # process with the recorded pid is still running.
-                    status1 = self.node.status(self._pid1, self._ppid1)
-                    status2 = self.node.status(self._pid2, self._ppid2)
+                    status1 = self.endpoint1.node.status(self._pid1, self._ppid1)
+                    status2 = self.endpoint2.node.status(self._pid2, self._ppid2)
 
                     if status1 == ProcStatus.FINISHED and \
-                            satus2 == ProcStatus.FINISHED:
-                        self._state = ResourceState.FINISHED
+                            status2 == ProcStatus.FINISHED:
+
+                        # check if execution errors occurred
+                        (out1, err1), proc1 = self.endpoint1.node.check_errors(
+                                self.run_home(self.endpoint1))
+
+                        (out2, err2), proc2 = self.endpoint2.node.check_errors(
+                                self.run_home(self.endpoint2))
+
+                        if err1 or err2: 
+                            msg = "Error occurred in tunnel"
+                            self.error(msg, err1, err2)
+                            self.fail()
+                        else:
+                            self.set_stopped()
 
                 self._last_state_check = tnow()
 
@@ -260,11 +272,15 @@ class UdpTunnel(LinuxApplication):
 
     def wait_local_port(self, endpoint):
         """ Waits until the local_port file for the endpoint is generated, 
-            and returns the port number """
+        and returns the port number 
+        
+        """
         return self.wait_file(endpoint, "local_port")
 
     def wait_result(self, endpoint):
-        """ Waits until the return code file for the endpoint is generated """ 
+        """ Waits until the return code file for the endpoint is generated 
+        
+        """ 
         return self.wait_file(endpoint, "ret_file")
  
     def wait_file(self, endpoint, filename):
@@ -272,7 +288,7 @@ class UdpTunnel(LinuxApplication):
         result = None
         delay = 1.0
 
-        for i in xrange(4):
+        for i in xrange(20):
             (out, err), proc = endpoint.node.check_output(
                     self.run_home(endpoint), filename)
 
@@ -285,7 +301,6 @@ class UdpTunnel(LinuxApplication):
         else:
             msg = "Couldn't retrieve %s" % filename
             self.error(msg, out, err)
-            self.fail()
             raise RuntimeError, msg
 
         return result