Ticket #45: spanning tree deployment
[nepi.git] / src / nepi / testbeds / planetlab / tunproto.py
index 3b762e6..6e518f4 100644 (file)
@@ -29,8 +29,10 @@ class TunProtoBase(object):
         
         self._launcher = None
         self._started = False
+        self._starting = False
         self._pid = None
         self._ppid = None
+        self._if_name = None
 
     def _make_home(self):
         local = self.local()
@@ -42,7 +44,11 @@ class TunProtoBase(object):
         
         # Make sure all the paths are created where 
         # they have to be created for deployment
-        cmd = "mkdir -p %s" % (server.shell_escape(self.home_path),)
+        # Also remove pidfile, if there is one.
+        # Old pidfiles from previous runs can be troublesome.
+        cmd = "mkdir -p %(home)s ; rm -f %(home)s/pid" % {
+            'home' : server.shell_escape(self.home_path)
+        }
         (out,err),proc = server.popen_ssh_command(
             cmd,
             host = local.node.hostname,
@@ -118,6 +124,11 @@ class TunProtoBase(object):
             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
         
     def launch(self, check_proto, listen, extra_args=[]):
+        if self._starting:
+            raise AssertionError, "Double start"
+        
+        self._starting = True
+        
         peer = self.peer()
         local = self.local()
         
@@ -134,7 +145,11 @@ class TunProtoBase(object):
         local_mask = local.netprefix
         local_snat = local.snat
         local_txq  = local.txqueuelen
+        local_p2p  = local.pointopoint
         
+        if not local_p2p and hasattr(peer, 'address'):
+            local_p2p = peer.address
+
         if check_proto != peer_proto:
             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (check_proto, peer_proto)
         
@@ -167,8 +182,12 @@ class TunProtoBase(object):
         
         if local_snat:
             args.append("-S")
+        if local_p2p:
+            args.extend(("-P",str(local_p2p)))
         if local_txq:
             args.extend(("-Q",str(local_txq)))
+        if not local_cap:
+            args.append("-N")
         if extra_args:
             args.extend(map(str,extra_args))
         if not listen and check_proto != 'fd':
@@ -185,7 +204,7 @@ class TunProtoBase(object):
             pidfile = './pid',
             home = self.home_path,
             stdin = '/dev/null',
-            stdout = 'capture' if local_cap else '/dev/null',
+            stdout = 'capture',
             stderr = rspawn.STDOUT,
             sudo = True,
             
@@ -203,6 +222,16 @@ class TunProtoBase(object):
         self._started = True
     
     def _launch_and_wait(self, *p, **kw):
+        try:
+            self.__launch_and_wait(*p, **kw)
+        except:
+            if self._launcher:
+                import sys
+                self._launcher._exc.append(sys.exc_info())
+            else:
+                raise
+            
+    def __launch_and_wait(self, *p, **kw):
         local = self.local()
         
         self.launch(*p, **kw)
@@ -212,42 +241,74 @@ class TunProtoBase(object):
             time.sleep(1.0)
         
         # Wait for the connection to be established
-        if local.capture:
-            for spin in xrange(30):
-                if self.status() != rspawn.RUNNING:
-                    break
-                
-                (out,err),proc = server.popen_ssh_command(
-                    "cd %(home)s ; grep -c Connected capture" % dict(
-                        home = server.shell_escape(self.home_path)),
-                    host = local.node.hostname,
-                    port = None,
-                    user = local.node.slicename,
-                    agent = None,
-                    ident_key = local.node.ident_path,
-                    server_key = local.node.server_key
-                    )
-                
-                if proc.wait():
-                    break
-                
-                if out.strip() != '0':
-                    break
-                
-                time.sleep(1.0)
+        for spin in xrange(30):
+            if self.status() != rspawn.RUNNING:
+                break
+            
+            (out,err),proc = server.popen_ssh_command(
+                "cd %(home)s ; grep -c Connected capture" % dict(
+                    home = server.shell_escape(self.home_path)),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key
+                )
+            
+            if proc.wait():
+                break
+            
+            if out.strip() != '0':
+                break
+            
+            time.sleep(1.0)
+    
+    @property
+    def if_name(self):
+        if not self._if_name:
+            # Inspect the trace to check the assigned iface
+            local = self.local()
+            if local:
+                for spin in xrange(30):
+                    (out,err),proc = server.popen_ssh_command(
+                        "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
+                            home = server.shell_escape(self.home_path)),
+                        host = local.node.hostname,
+                        port = None,
+                        user = local.node.slicename,
+                        agent = None,
+                        ident_key = local.node.ident_path,
+                        server_key = local.node.server_key
+                        )
+                    
+                    if proc.wait():
+                        return
+                    
+                    out = out.strip()
+                    
+                    match = re.match(r"Using +tun: +([-a-zA-Z0-9]*) +.*",out)
+                    if match:
+                        self._if_name = match.group(1)
+        return self._if_name
     
     def async_launch(self, check_proto, listen, extra_args=[]):
         if not self._launcher:
             self._launcher = threading.Thread(
                 target = self._launch_and_wait,
                 args = (check_proto, listen, extra_args))
+            self._launcher._exc = []
             self._launcher.start()
     
     def async_launch_wait(self):
         if self._launcher:
             self._launcher.join()
             if not self._started:
-                raise RuntimeError, "Failed to launch TUN forwarder"
+                if self._launcher._exc:
+                    exctyp,exval,exctrace = self._launcher._exc[0]
+                    raise exctyp,exval,exctrace
+                else:
+                    raise RuntimeError, "Failed to launch TUN forwarder"
         elif not self._started:
             self.launch()
 
@@ -312,8 +373,18 @@ class TunProtoBase(object):
                 agent = None,
                 ident_key = local.node.ident_path,
                 server_key = local.node.server_key,
-                sudo = True
+                sudo = True,
+                nowait = True
                 )
+    
+    def waitkill(self):
+        interval = 1.0
+        for i in xrange(30):
+            status = self.status()
+            if status != rspawn.RUNNING:
+                break
+            time.sleep(interval)
+            interval = min(30.0, interval * 1.1)
         
     def sync_trace(self, local_dir, whichtrace):
         if whichtrace != 'packets':
@@ -373,6 +444,12 @@ class TunProtoBase(object):
         Cleanup
         """
         raise NotImplementedError
+    
+    def destroy(self):
+        """
+        Second-phase cleanup
+        """
+        pass
         
 
 class TunProtoUDP(TunProtoBase):
@@ -389,6 +466,14 @@ class TunProtoUDP(TunProtoBase):
     def shutdown(self):
         self.kill()
 
+    def destroy(self):
+        self.waitkill()
+
+    def launch(self, check_proto='udp', listen=False, extra_args=None):
+        if extra_args is None:
+            extra_args = ("-u",str(self.port))
+        super(TunProtoUDP, self).launch(check_proto, listen, extra_args)
+
 class TunProtoFD(TunProtoBase):
     def __init__(self, local, peer, home_path, key, listening):
         super(TunProtoFD, self).__init__(local, peer, home_path, key)
@@ -403,6 +488,12 @@ class TunProtoFD(TunProtoBase):
     def shutdown(self):
         self.kill()
 
+    def destroy(self):
+        self.waitkill()
+
+    def launch(self, check_proto='fd', listen=False, extra_args=[]):
+        super(TunProtoFD, self).launch(check_proto, listen, extra_args)
+
 class TunProtoTCP(TunProtoBase):
     def __init__(self, local, peer, home_path, key, listening):
         super(TunProtoTCP, self).__init__(local, peer, home_path, key)
@@ -420,16 +511,21 @@ class TunProtoTCP(TunProtoBase):
                 peer.peer_proto_impl.async_launch_wait()
             
             if not self._started:
-                self.launch('tcp', False)
-        else:
-            # make sure WE are ready
-            self.async_launch_wait()
+                self.async_launch('tcp', False)
         
         self.checkpid()
     
     def shutdown(self):
         self.kill()
 
+    def destroy(self):
+        self.waitkill()
+
+    def launch(self, check_proto='tcp', listen=None, extra_args=[]):
+        if listen is None:
+            listen = self.listening
+        super(TunProtoTCP, self).launch(check_proto, listen, extra_args)
+
 class TapProtoUDP(TunProtoUDP):
     def __init__(self, local, peer, home_path, key, listening):
         super(TapProtoUDP, self).__init__(local, peer, home_path, key, listening)