X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;ds=inline;f=src%2Fnepi%2Fresources%2Flinux%2Fgretunnel.py;h=2d866c2290bbe68af34513257f8ae7390c477565;hb=47bfadde39e0d22c3df7e2bd1cd4d52f07ad8c0a;hp=856208868a3ffb7490a49e470d4dd99b836f8e08;hpb=0fb87d99cce02d9806a4557459e279d96a975b08;p=nepi.git diff --git a/src/nepi/resources/linux/gretunnel.py b/src/nepi/resources/linux/gretunnel.py index 85620886..2d866c22 100644 --- a/src/nepi/resources/linux/gretunnel.py +++ b/src/nepi/resources/linux/gretunnel.py @@ -20,7 +20,7 @@ from nepi.execution.attribute import Attribute, Flags, Types from nepi.execution.resource import clsinit_copy, ResourceState, \ reschedule_delay -from nepi.resources.linux.application import LinuxApplication +from nepi.resources.linux.tunnel import LinuxTunnel from nepi.util.sshfuncs import ProcStatus from nepi.util.timefuncs import tnow, tdiffsec @@ -29,24 +29,11 @@ import socket import time @clsinit_copy -class LinuxGRETunnel(LinuxApplication): +class LinuxGRETunnel(LinuxTunnel): _rtype = "LinuxGRETunnel" _help = "Constructs a tunnel between two Linux endpoints using a UDP connection " _backend = "linux" - @classmethod - def _register_attributes(cls): - bwlimit = Attribute("bwLimit", - "Specifies the interface's emulated bandwidth in bytes " - "per second.", - type = Types.Integer, - flags = Flags.Design) - - cls._register_attribute(bwlimit) - - def __init__(self, ec, guid): - super(LinuxGRETunnel, self).__init__(ec, guid) - def log_message(self, msg): return " guid %d - GRE tunnel %s - %s - %s " % (self.guid, self.endpoint1.node.get("hostname"), @@ -59,46 +46,19 @@ class LinuxGRETunnel(LinuxApplication): connected = [] for guid in self.connections: rm = self.ec.get_resource(guid) - if hasattr(rm, "udp_connect_command"): + if hasattr(rm, "gre_connect_command"): connected.append(rm) return connected - @property - def endpoint1(self): - endpoints = self.get_endpoints() - if endpoints: return endpoints[0] - return None - - @property - def endpoint2(self): - endpoints = self.get_endpoints() - if endpoints and len(endpoints) > 1: return endpoints[1] - return None - - def app_home(self, endpoint): - return os.path.join(endpoint.node.exp_home, self._home) - - def run_home(self, endpoint): - return os.path.join(self.app_home(endpoint), self.ec.run_id) - - def udp_connect(self, endpoint, remote_ip): - # Get udp connect command - local_port_file = os.path.join(self.run_home(endpoint), - "local_port") - remote_port_file = os.path.join(self.run_home(endpoint), - "remote_port") - ret_file = os.path.join(self.run_home(endpoint), - "ret_file") - cipher = self.get("cipher") - cipher_key = self.get("cipherKey") - bwlimit = self.get("bwLimit") - txqueuelen = self.get("txQueueLen") - udp_connect_command = endpoint.udp_connect_command( - remote_ip, local_port_file, remote_port_file, - ret_file, cipher, cipher_key, bwlimit, txqueuelen) + def initiate_connection(self, endpoint, remote_endpoint): + # Return the command to execute to initiate the connection to the + # other endpoint + connection_run_home = self.run_home(endpoint) + gre_connect_command = endpoint.gre_connect_command( + remote_endpoint, connection_run_home) # upload command to connect.sh script - shfile = os.path.join(self.app_home(endpoint), "udp-connect.sh") + shfile = os.path.join(self.app_home(endpoint), "gre-connect.sh") endpoint.node.upload(udp_connect_command, shfile, text = True, @@ -128,162 +88,21 @@ class LinuxGRETunnel(LinuxApplication): self.error(msg, out, err) raise RuntimeError, msg - # wait until port is written to file - port = self.wait_local_port(endpoint) - return (port, pid, ppid) - - def do_provision(self): - # create run dir for tunnel on each node - self.endpoint1.node.mkdir(self.run_home(self.endpoint1)) - self.endpoint2.node.mkdir(self.run_home(self.endpoint2)) - - # Invoke connect script in endpoint 1 - remote_ip1 = socket.gethostbyname(self.endpoint2.node.get("hostname")) - (port1, self._pid1, self._ppid1) = self.udp_connect(self.endpoint1, - remote_ip1) - - # Invoke connect script in endpoint 2 - remote_ip2 = socket.gethostbyname(self.endpoint1.node.get("hostname")) - (port2, self._pid2, self._ppid2) = self.udp_connect(self.endpoint2, - remote_ip2) - - # upload file with port 2 to endpoint 1 - self.upload_remote_port(self.endpoint1, port2) - - # upload file with port 1 to endpoint 2 - self.upload_remote_port(self.endpoint2, port1) - - # check if connection was successful on both sides - self.wait_result(self.endpoint1) - self.wait_result(self.endpoint2) - - self.info("Provisioning finished") - - self.set_provisioned() - - def do_deploy(self): - if (not self.endpoint1 or self.endpoint1.state < ResourceState.READY) or \ - (not self.endpoint2 or self.endpoint2.state < ResourceState.READY): - self.ec.schedule(reschedule_delay, self.deploy) - else: - self.do_discover() - self.do_provision() - - self.set_ready() - - def do_start(self): - if self.state == ResourceState.READY: - command = self.get("command") - self.info("Starting command '%s'" % command) - - self.set_started() - else: - msg = " Failed to execute command '%s'" % command - self.error(msg, out, err) - raise RuntimeError, msg - - def do_stop(self): - """ Stops application execution - """ - if self.state == ResourceState.STARTED: - self.info("Stopping tunnel") - - # Only try to kill the process if the pid and ppid - # were retrieved - if self._pid1 and self._ppid1 and self._pid2 and self._ppid2: - (out1, err1), proc1 = self.endpoint1.node.kill(self._pid1, - self._ppid1, sudo = True) - (out2, err2), proc2 = self.endpoint2.node.kill(self._pid2, - self._ppid2, sudo = True) - - if (proc1.poll() and err1) or (proc2.poll() and err2): - # check if execution errors occurred - msg = " Failed to STOP tunnel" - self.error(msg, err1, err2) - raise RuntimeError, msg - - self.set_stopped() - - @property - def state(self): - """ Returns the state of the application - """ - if self._state == ResourceState.STARTED: - # In order to avoid overwhelming the remote host and - # the local processor with too many ssh queries, the state is only - # requested every 'state_check_delay' seconds. - state_check_delay = 0.5 - if tdiffsec(tnow(), self._last_state_check) > state_check_delay: - if self._pid1 and self._ppid1 and self._pid2 and self._ppid2: - # Make sure the process is still running in background - # No execution errors occurred. Make sure the background - # process with the recorded pid is still running. - status1 = self.endpoint1.node.status(self._pid1, self._ppid1) - status2 = self.endpoint2.node.status(self._pid2, self._ppid2) - - if status1 == ProcStatus.FINISHED and \ - status2 == ProcStatus.FINISHED: - - # check if execution errors occurred - (out1, err1), proc1 = self.endpoint1.node.check_errors( - self.run_home(self.endpoint1)) - - (out2, err2), proc2 = self.endpoint2.node.check_errors( - self.run_home(self.endpoint2)) - - if err1 or err2: - msg = "Error occurred in tunnel" - self.error(msg, err1, err2) - self.fail() - else: - self.set_stopped() - - self._last_state_check = tnow() - - return self._state - - def wait_local_port(self, endpoint): - """ Waits until the local_port file for the endpoint is generated, - and returns the port number - - """ - return self.wait_file(endpoint, "local_port") - - def wait_result(self, endpoint): - """ Waits until the return code file for the endpoint is generated - - """ - return self.wait_file(endpoint, "ret_file") - - def wait_file(self, endpoint, filename): - """ Waits until file on endpoint is generated """ - result = None - delay = 1.0 + # Wait if name + return True - for i in xrange(20): - (out, err), proc = endpoint.node.check_output( - self.run_home(endpoint), filename) + def establish_connection(self, endpoint, remote_endpoint, data): + pass - if out: - result = out.strip() - break - else: - time.sleep(delay) - delay = delay * 1.5 - else: - msg = "Couldn't retrieve %s" % filename - self.error(msg, out, err) - raise RuntimeError, msg + def verify_connection(self, endpoint, remote_endpoint): + # Execute a ping from both sides to verify that the tunnel works + pass - return result + def terminate_connection(self, endpoint, remote_endpoint): + pass - def upload_remote_port(self, endpoint, port): - # upload remote port number to file - port = "%s\n" % port - endpoint.node.upload(port, - os.path.join(self.run_home(endpoint), "remote_port"), - text = True, - overwrite = False) + def check_state_connection(self, endpoint, remote_endpoint): + raise NotImplementedError def valid_connection(self, guid): # TODO: Validate!