Fix TUN shutdown: waitkill was not effective because of a faulty if_alive
[nepi.git] / src / nepi / testbeds / planetlab / tunproto.py
index 84e2d9c..84972a6 100644 (file)
@@ -26,13 +26,12 @@ class TunProtoBase(object):
         self.port = 15000
         self.mode = 'pl-tun'
         self.key = key
+        self.cross_slice = False
         
         self.home_path = home_path
-        
-        self._launcher = None
+       
         self._started = False
-        self._started_listening = False
-        self._starting = False
+
         self._pid = None
         self._ppid = None
         self._if_name = None
@@ -76,7 +75,6 @@ class TunProtoBase(object):
         
         if proc.wait():
             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
-        
     
     def _install_scripts(self):
         local = self.local()
@@ -182,12 +180,7 @@ class TunProtoBase(object):
         if proc.wait():
             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
         
-    def launch(self, check_proto, listen, extra_args=[]):
-        if self._starting:
-            raise AssertionError, "Double start"
-        
-        self._starting = True
-        
+    def launch(self, check_proto):
         peer = self.peer()
         local = self.local()
         
@@ -196,8 +189,8 @@ class TunProtoBase(object):
         
         peer_port = peer.tun_port
         peer_addr = peer.tun_addr
-        peer_proto= peer.tun_proto
-        peer_cipher=peer.tun_cipher
+        peer_proto = peer.tun_proto
+        peer_cipher = peer.tun_cipher
         
         local_port = self.port
         local_cap  = local.capture
@@ -209,6 +202,7 @@ class TunProtoBase(object):
         local_cipher=local.tun_cipher
         local_mcast= local.multicast
         local_bwlim= local.bwlimit
+        local_mcastfwd = local.multicast_forwarder
         
         if not local_p2p and hasattr(peer, 'address'):
             local_p2p = peer.address
@@ -219,12 +213,6 @@ class TunProtoBase(object):
         if local_cipher != peer_cipher:
             raise RuntimeError, "Peering cipher mismatch: %s != %s" % (local_cipher, peer_cipher)
         
-        if not listen and ((peer_proto != 'fd' and not peer_port) or not peer_addr):
-            raise RuntimeError, "Misconfigured peer: %s" % (peer,)
-        
-        if listen and ((peer_proto != 'fd' and not local_port) or not local_addr or not local_mask):
-            raise RuntimeError, "Misconfigured TUN: %s" % (local,)
-
         if check_proto == 'gre' and local_cipher.lower() != 'plain':
             raise RuntimeError, "Misconfigured TUN: %s - GRE tunnels do not support encryption. Got %s, you MUST use PLAIN" % (local, local_cipher,)
 
@@ -242,9 +230,11 @@ class TunProtoBase(object):
         
         args = ["python", "tun_connect.py", 
             "-m", str(self.mode),
+            "-t", str(check_proto),
             "-A", str(local_addr),
             "-M", str(local_mask),
-            "-C", str(local_cipher)]
+            "-C", str(local_cipher),
+            ]
         
         if check_proto == 'fd':
             passfd_arg = str(peer_addr)
@@ -257,37 +247,41 @@ class TunProtoBase(object):
                 "--pass-fd", passfd_arg
             ])
         elif check_proto == 'gre':
+            if self.cross_slice:
+                args.extend([
+                    "-K", str(self.key.strip('='))
+                ])
+
             args.extend([
-                "-K", str(min(local_port, peer_port))
+                "-a", str(peer_addr),
             ])
+        # both udp and tcp
         else:
             args.extend([
-                "-p", str(local_port if listen else peer_port),
+                "-P", str(local_port),
+                "-p", str(peer_port),
+                "-a", str(peer_addr),
                 "-k", str(self.key)
             ])
         
         if local_snat:
             args.append("-S")
         if local_p2p:
-            args.extend(("-P",str(local_p2p)))
+            args.extend(("-Z",str(local_p2p)))
         if local_txq:
             args.extend(("-Q",str(local_txq)))
         if not local_cap:
             args.append("-N")
         elif local_cap == 'pcap':
             args.extend(('-c','pcap'))
-        if local_mcast:
-            args.append("--multicast")
         if local_bwlim:
             args.extend(("-b",str(local_bwlim*1024)))
-        if extra_args:
-            args.extend(map(str,extra_args))
-        if not listen and check_proto != 'fd':
-            args.append(str(peer_addr))
         if filter_module:
             args.extend(("--filter", filter_module))
         if filter_args:
             args.extend(("--filter-args", filter_args))
+        if local_mcast and local_mcastfwd:
+            args.extend(("--multicast-forwarder", local_mcastfwd))
 
         self._logger.info("Starting %s", self)
         
@@ -316,35 +310,18 @@ class TunProtoBase(object):
         
         if proc.wait():
             raise RuntimeError, "Failed to set up TUN: %s %s" % (out,err,)
-        
+       
         self._started = True
     
     def recover(self):
         # Tunnel should be still running in its node
         # Just check its pidfile and we're done
         self._started = True
-        self._started_listening = True
         self.checkpid()
     
-    def _launch_and_wait(self, *p, **kw):
-        try:
-            self.__launch_and_wait(*p, **kw)
-        except:
-            if self._launcher:
-                import sys
-                self._launcher._exc.append(sys.exc_info())
-            else:
-                raise
-            
-    def __launch_and_wait(self, *p, **kw):
+    def wait(self):
         local = self.local()
         
-        self.launch(*p, **kw)
-        
-        # Wait for the process to be started
-        while self.status() == rspawn.NOT_STARTED:
-            time.sleep(1.0)
-        
         # Wait for the connection to be established
         retrytime = 2.0
         for spin in xrange(30):
@@ -354,7 +331,7 @@ class TunProtoBase(object):
             
             # Connected?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                "cd %(home)s ; grep -c Connected capture" % dict(
+                "cd %(home)s ; grep -a -c Connected capture" % dict(
                     home = server.shell_escape(self.home_path)),
                 host = local.node.hostname,
                 port = None,
@@ -372,7 +349,7 @@ class TunProtoBase(object):
 
             # At least listening?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                "cd %(home)s ; grep -c Listening capture" % dict(
+                "cd %(home)s ; grep -a -c Listening capture" % dict(
                     home = server.shell_escape(self.home_path)),
                 host = local.node.hostname,
                 port = None,
@@ -385,9 +362,6 @@ class TunProtoBase(object):
                 )
             proc.wait()
 
-            if out.strip() == '1':
-                self._started_listening = True
-            
             time.sleep(min(30.0, retrytime))
             retrytime *= 1.1
         else:
@@ -414,7 +388,7 @@ class TunProtoBase(object):
             # Inspect the trace to check the assigned iface
             local = self.local()
             if local:
-                cmd = "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
+                cmd = "cd %(home)s ; grep -a 'Using tun:' capture | head -1" % dict(
                             home = server.shell_escape(self.home_path))
                 for spin in xrange(30):
                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
@@ -444,7 +418,7 @@ class TunProtoBase(object):
                         self._logger.debug("if_name: %r does not match expected pattern from cmd %s", out, cmd)
                     else:
                         self._logger.debug("if_name: empty output from cmd %s", cmd)
-                    time.sleep(1)
+                    time.sleep(3)
                 else:
                     self._logger.warn("if_name: Could not get interface name")
         return self._if_name
@@ -455,7 +429,7 @@ class TunProtoBase(object):
             local = self.local()
             for i in xrange(30):
                 (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                    "ip show %s >/dev/null 2>&1 && echo ALIVE || echo DEAD" % (name,),
+                    "ip link show %s >/dev/null 2>&1 && echo ALIVE || echo DEAD" % (name,),
                     host = local.node.hostname,
                     port = None,
                     user = local.node.slicename,
@@ -476,38 +450,6 @@ class TunProtoBase(object):
                     return True
         return False
     
-    def async_launch(self, check_proto, listen, extra_args=[]):
-        if not self._started and not self._launcher:
-            self._launcher = threading.Thread(
-                target = self._launch_and_wait,
-                args = (check_proto, listen, extra_args))
-            self._launcher._exc = []
-            self._launcher.start()
-    
-    def async_launch_wait(self):
-        if self._launcher:
-            self._launcher.join()
-
-            if self._launcher._exc:
-                exctyp,exval,exctrace = self._launcher._exc[0]
-                raise exctyp,exval,exctrace
-            elif not self._started:
-                raise RuntimeError, "Failed to launch TUN forwarder"
-        elif not self._started:
-            self.launch()
-
-    def async_launch_wait_listening(self):
-        if self._launcher:
-            for x in xrange(180):
-                if self._launcher._exc:
-                    exctyp,exval,exctrace = self._launcher._exc[0]
-                    raise exctyp,exval,exctrace
-                elif self._started and self._started_listening:
-                    break
-                time.sleep(1)
-        elif not self._started:
-            self.launch()
-
     def checkpid(self):            
         local = self.local()
         
@@ -595,7 +537,23 @@ class TunProtoBase(object):
                     break
                 time.sleep(interval)
                 interval = min(30.0, interval * 1.1)
-    
+            else:
+                local = self.local()
+                
+                if local:
+                    # Forcibly shut down interface
+                    (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                        "sudo -S bash -c 'echo %s > /vsys/vif_down.in'" % (self.if_name,),
+                        host = local.node.hostname,
+                        port = None,
+                        user = local.node.slicename,
+                        agent = None,
+                        ident_key = local.node.ident_path,
+                        server_key = local.node.server_key,
+                        timeout = 60,
+                        err_on_timeout = False
+                        )
+                    proc.wait()    
     _TRACEMAP = {
         # tracename : (remotename, localname)
         'packets' : ('capture','capture'),
@@ -605,7 +563,6 @@ class TunProtoBase(object):
     def remote_trace_path(self, whichtrace, tracemap = None):
         tracemap = self._TRACEMAP if not tracemap else tracemap
         
-        
         if whichtrace not in tracemap:
             return None
         
@@ -650,153 +607,61 @@ class TunProtoBase(object):
         
         return local_path
         
-        
-    def prepare(self):
-        """
-        First-phase setup
-        
-        eg: set up listening ports
-        """
-        raise NotImplementedError
-    
-    def setup(self):
-        """
-        Second-phase setup
-        
-        eg: connect to peer
-        """
-        raise NotImplementedError
-    
     def shutdown(self):
-        """
-        Cleanup
-        """
-        raise NotImplementedError
+        self.kill()
     
     def destroy(self):
-        """
-        Second-phase cleanup
-        """
-        pass
-        
+        self.waitkill()
 
 class TunProtoUDP(TunProtoBase):
-    def __init__(self, local, peer, home_path, key, listening):
+    def __init__(self, local, peer, home_path, key):
         super(TunProtoUDP, self).__init__(local, peer, home_path, key)
-        self.listening = listening
-    
-    def prepare(self):
-        pass
     
-    def setup(self):
-        self.async_launch('udp', False, ("-u",str(self.port)))
-    
-    def shutdown(self):
-        self.kill()
-
-    def destroy(self):
-        self.waitkill()
-
-    def launch(self, check_proto='udp', listen=False, extra_args=None):
-        if extra_args is None:
-            extra_args = ("-u",str(self.port))
-        super(TunProtoUDP, self).launch(check_proto, listen, extra_args)
+    def launch(self):
+        super(TunProtoUDP, self).launch('udp')
 
 class TunProtoFD(TunProtoBase):
-    def __init__(self, local, peer, home_path, key, listening):
+    def __init__(self, local, peer, home_path, key):
         super(TunProtoFD, self).__init__(local, peer, home_path, key)
-        self.listening = listening
-    
-    def prepare(self):
-        pass
     
-    def setup(self):
-        self.async_launch('fd', False)
-    
-    def shutdown(self):
-        self.kill()
-
-    def destroy(self):
-        self.waitkill()
-
-    def launch(self, check_proto='fd', listen=False, extra_args=[]):
-        super(TunProtoFD, self).launch(check_proto, listen, extra_args)
+    def launch(self):
+        super(TunProtoFD, self).launch('fd')
 
 class TunProtoGRE(TunProtoBase):
-    def __init__(self, local, peer, home_path, key, listening):
+    def __init__(self, local, peer, home_path, key):
         super(TunProtoGRE, self).__init__(local, peer, home_path, key)
-        self.listening = listening
         self.mode = 'pl-gre-ip'
-    
-    def prepare(self):
-        pass
-    
-    def setup(self):
-        self.async_launch('gre', False)
-    
-    def shutdown(self):
-        self.kill()
 
-    def destroy(self):
-        self.waitkill()
-
-    def launch(self, check_proto='gre', listen=False, extra_args=[]):
-        super(TunProtoGRE, self).launch(check_proto, listen, extra_args)
+    def launch(self):
+        super(TunProtoGRE, self).launch('gre')
 
 class TunProtoTCP(TunProtoBase):
-    def __init__(self, local, peer, home_path, key, listening):
+    def __init__(self, local, peer, home_path, key):
         super(TunProtoTCP, self).__init__(local, peer, home_path, key)
-        self.listening = listening
-    
-    def prepare(self):
-        if self.listening:
-            self.async_launch('tcp', True)
     
-    def setup(self):
-        if not self.listening:
-            # make sure our peer is ready
-            peer = self.peer()
-            if peer and peer.peer_proto_impl:
-                peer.peer_proto_impl.async_launch_wait_listening()
-            
-            if not self._started:
-                self.async_launch('tcp', False)
-        
-        self.checkpid()
-    
-    def shutdown(self):
-        self.kill()
-
-    def destroy(self):
-        self.waitkill()
-
-    def launch(self, check_proto='tcp', listen=None, extra_args=[]):
-        if listen is None:
-            listen = self.listening
-        super(TunProtoTCP, self).launch(check_proto, listen, extra_args)
+    def launch(self):
+        super(TunProtoTCP, self).launch('tcp')
 
 class TapProtoUDP(TunProtoUDP):
-    def __init__(self, local, peer, home_path, key, listening):
-        super(TapProtoUDP, self).__init__(local, peer, home_path, key, listening)
+    def __init__(self, local, peer, home_path, key):
+        super(TapProtoUDP, self).__init__(local, peer, home_path, key)
         self.mode = 'pl-tap'
 
 class TapProtoTCP(TunProtoTCP):
-    def __init__(self, local, peer, home_path, key, listening):
-        super(TapProtoTCP, self).__init__(local, peer, home_path, key, listening)
+    def __init__(self, local, peer, home_path, key):
+        super(TapProtoTCP, self).__init__(local, peer, home_path, key)
         self.mode = 'pl-tap'
 
 class TapProtoFD(TunProtoFD):
-    def __init__(self, local, peer, home_path, key, listening):
-        super(TapProtoFD, self).__init__(local, peer, home_path, key, listening)
+    def __init__(self, local, peer, home_path, key):
+        super(TapProtoFD, self).__init__(local, peer, home_path, key)
         self.mode = 'pl-tap'
 
 class TapProtoGRE(TunProtoGRE):
-    def __init__(self, local, peer, home_path, key, listening):
-        super(TapProtoGRE, self).__init__(local, peer, home_path, key, listening)
+    def __init__(self, local, peer, home_path, key):
+        super(TapProtoGRE, self).__init__(local, peer, home_path, key)
         self.mode = 'pl-gre-eth'
 
-
-
 TUN_PROTO_MAP = {
     'tcp' : TunProtoTCP,
     'udp' : TunProtoUDP,
@@ -811,4 +676,3 @@ TAP_PROTO_MAP = {
     'gre' : TapProtoGRE,
 }
 
-