merge.
[nepi.git] / src / nepi / testbeds / planetlab / tunproto.py
index 00bfb83..84e2d9c 100644 (file)
@@ -31,6 +31,7 @@ class TunProtoBase(object):
         
         self._launcher = None
         self._started = False
+        self._started_listening = False
         self._starting = False
         self._pid = None
         self._ppid = None
@@ -58,7 +59,7 @@ class TunProtoBase(object):
         # they have to be created for deployment
         # Also remove pidfile, if there is one.
         # Old pidfiles from previous runs can be troublesome.
-        cmd = "mkdir -p %(home)s ; rm -f %(home)s/pid" % {
+        cmd = "mkdir -p %(home)s ; rm -f %(home)s/pid %(home)s/*.so" % {
             'home' : server.shell_escape(self.home_path)
         }
         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
@@ -68,7 +69,9 @@ class TunProtoBase(object):
             user = local.node.slicename,
             agent = None,
             ident_key = local.node.ident_path,
-            server_key = local.node.server_key
+            server_key = local.node.server_key,
+            timeout = 60,
+            retry = 3
             )
         
         if proc.wait():
@@ -85,11 +88,31 @@ class TunProtoBase(object):
         
         # Install the tun_connect script and tunalloc utility
         from nepi.util import tunchannel
+        from nepi.util import ipaddr2
         sources = [
             os.path.join(os.path.dirname(__file__), 'scripts', 'tun_connect.py'),
             os.path.join(os.path.dirname(__file__), 'scripts', 'tunalloc.c'),
             re.sub(r"([.]py)[co]$", r'\1', tunchannel.__file__, 1), # pyc/o files are version-specific
+            re.sub(r"([.]py)[co]$", r'\1', ipaddr2.__file__, 1), # pyc/o files are version-specific
         ]
+        if local.filter_module:
+            filter_sources = filter(bool,map(str.strip,local.filter_module.module.split()))
+            filter_module = filter_sources[0]
+            
+            # Translate paths to builtin sources
+            for i,source in enumerate(filter_sources):
+                if not os.path.exists(source):
+                    # Um... try the builtin folder
+                    source = os.path.join(os.path.dirname(__file__), "scripts", source)
+                    if os.path.exists(source):
+                        # Yep... replace
+                        filter_sources[i] = source
+
+            sources.extend(set(filter_sources))
+                
+        else:
+            filter_module = None
+            filter_sources = None
         dest = "%s@%s:%s" % (
             local.node.slicename, local.node.hostname, 
             os.path.join(self.home_path,'.'),)
@@ -118,6 +141,16 @@ class TunProtoBase(object):
             "python setup.py install --install-lib .. && "
             "cd .. "
             
+            + ( " && "
+                "gcc -fPIC -shared %(sources)s -o %(module)s.so " % {
+                   'module' : os.path.basename(filter_module).rsplit('.',1)[0],
+                   'sources' : ' '.join(map(os.path.basename,filter_sources))
+                }
+                
+                if filter_module is not None and filter_module.endswith('.c')
+                else ""
+            )
+            
             + ( " && "
                 "wget -q -c -O python-passfd-src.tar.gz %(passfd_url)s && "
                 "mkdir -p python-passfd && "
@@ -126,7 +159,8 @@ class TunProtoBase(object):
                 "python setup.py build && "
                 "python setup.py install --install-lib .. "
                 
-                if local.tun_proto == "fd" else ""
+                if local.tun_proto == "fd" 
+                else ""
             ) 
           )
         % {
@@ -141,7 +175,8 @@ class TunProtoBase(object):
             user = local.node.slicename,
             agent = None,
             ident_key = local.node.ident_path,
-            server_key = local.node.server_key
+            server_key = local.node.server_key,
+            timeout = 300
             )
         
         if proc.wait():
@@ -172,6 +207,8 @@ class TunProtoBase(object):
         local_txq  = local.txqueuelen
         local_p2p  = local.pointopoint
         local_cipher=local.tun_cipher
+        local_mcast= local.multicast
+        local_bwlim= local.bwlimit
         
         if not local_p2p and hasattr(peer, 'address'):
             local_p2p = peer.address
@@ -190,6 +227,18 @@ class TunProtoBase(object):
 
         if check_proto == 'gre' and local_cipher.lower() != 'plain':
             raise RuntimeError, "Misconfigured TUN: %s - GRE tunnels do not support encryption. Got %s, you MUST use PLAIN" % (local, local_cipher,)
+
+        if local.filter_module:
+            if check_proto not in ('udp', 'tcp'):
+                raise RuntimeError, "Miscofnigured TUN: %s - filtered tunnels only work with udp or tcp links" % (local,)
+            filter_module = filter(bool,map(str.strip,local.filter_module.module.split()))
+            filter_module = os.path.join('.',os.path.basename(filter_module[0]))
+            if filter_module.endswith('.c'):
+                filter_module = filter_module.rsplit('.',1)[0] + '.so'
+            filter_args = local.filter_module.args
+        else:
+            filter_module = None
+            filter_args = None
         
         args = ["python", "tun_connect.py", 
             "-m", str(self.mode),
@@ -227,10 +276,18 @@ class TunProtoBase(object):
             args.append("-N")
         elif local_cap == 'pcap':
             args.extend(('-c','pcap'))
+        if local_mcast:
+            args.append("--multicast")
+        if local_bwlim:
+            args.extend(("-b",str(local_bwlim*1024)))
         if extra_args:
             args.extend(map(str,extra_args))
         if not listen and check_proto != 'fd':
             args.append(str(peer_addr))
+        if filter_module:
+            args.extend(("--filter", filter_module))
+        if filter_args:
+            args.extend(("--filter-args", filter_args))
 
         self._logger.info("Starting %s", self)
         
@@ -266,6 +323,7 @@ class TunProtoBase(object):
         # Tunnel should be still running in its node
         # Just check its pidfile and we're done
         self._started = True
+        self._started_listening = True
         self.checkpid()
     
     def _launch_and_wait(self, *p, **kw):
@@ -288,11 +346,13 @@ class TunProtoBase(object):
             time.sleep(1.0)
         
         # Wait for the connection to be established
+        retrytime = 2.0
         for spin in xrange(30):
             if self.status() != rspawn.RUNNING:
                 self._logger.warn("FAILED TO CONNECT! %s", self)
                 break
             
+            # Connected?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
                 "cd %(home)s ; grep -c Connected capture" % dict(
                     home = server.shell_escape(self.home_path)),
@@ -301,18 +361,52 @@ class TunProtoBase(object):
                 user = local.node.slicename,
                 agent = None,
                 ident_key = local.node.ident_path,
-                server_key = local.node.server_key
+                server_key = local.node.server_key,
+                timeout = 60,
+                err_on_timeout = False
                 )
-            
-            if proc.wait():
-                break
-            
-            if out.strip() != '0':
+            proc.wait()
+
+            if out.strip() == '1':
                 break
+
+            # At least listening?
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "cd %(home)s ; grep -c Listening capture" % dict(
+                    home = server.shell_escape(self.home_path)),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key,
+                timeout = 60,
+                err_on_timeout = False
+                )
+            proc.wait()
+
+            if out.strip() == '1':
+                self._started_listening = True
             
-            time.sleep(1.0)
+            time.sleep(min(30.0, retrytime))
+            retrytime *= 1.1
         else:
-            self._logger.warn("FAILED TO CONNECT! %s", self)
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "cat %(home)s/capture" % dict(
+                    home = server.shell_escape(self.home_path)),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key,
+                timeout = 60,
+                retry = 3,
+                err_on_timeout = False
+                )
+            proc.wait()
+
+            raise RuntimeError, "FAILED TO CONNECT %s: %s%s" % (self,out,err)
     
     @property
     def if_name(self):
@@ -320,33 +414,68 @@ class TunProtoBase(object):
             # Inspect the trace to check the assigned iface
             local = self.local()
             if local:
+                cmd = "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
+                            home = server.shell_escape(self.home_path))
                 for spin in xrange(30):
                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                        "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
-                            home = server.shell_escape(self.home_path)),
+                        cmd,
                         host = local.node.hostname,
                         port = None,
                         user = local.node.slicename,
                         agent = None,
                         ident_key = local.node.ident_path,
-                        server_key = local.node.server_key
+                        server_key = local.node.server_key,
+                        timeout = 60,
+                        err_on_timeout = False
                         )
                     
                     if proc.wait():
-                        return
+                        self._logger.debug("if_name: failed cmd %s", cmd)
+                        time.sleep(1)
+                        continue
                     
                     out = out.strip()
                     
                     match = re.match(r"Using +tun: +([-a-zA-Z0-9]*).*",out)
                     if match:
                         self._if_name = match.group(1)
+                        break
                     elif out:
-                        self._logger.debug("if_name: %r does not match expected pattern", out)
-                        time.sleep(1)
+                        self._logger.debug("if_name: %r does not match expected pattern from cmd %s", out, cmd)
+                    else:
+                        self._logger.debug("if_name: empty output from cmd %s", cmd)
+                    time.sleep(1)
                 else:
                     self._logger.warn("if_name: Could not get interface name")
         return self._if_name
     
+    def if_alive(self):
+        name = self.if_name
+        if name:
+            local = self.local()
+            for i in xrange(30):
+                (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                    "ip show %s >/dev/null 2>&1 && echo ALIVE || echo DEAD" % (name,),
+                    host = local.node.hostname,
+                    port = None,
+                    user = local.node.slicename,
+                    agent = None,
+                    ident_key = local.node.ident_path,
+                    server_key = local.node.server_key,
+                    timeout = 60,
+                    err_on_timeout = False
+                    )
+                
+                if proc.wait():
+                    time.sleep(1)
+                    continue
+                
+                if out.strip() == 'DEAD':
+                    return False
+                elif out.strip() == 'ALIVE':
+                    return True
+        return False
+    
     def async_launch(self, check_proto, listen, extra_args=[]):
         if not self._started and not self._launcher:
             self._launcher = threading.Thread(
@@ -358,12 +487,24 @@ class TunProtoBase(object):
     def async_launch_wait(self):
         if self._launcher:
             self._launcher.join()
-            if not self._started:
+
+            if self._launcher._exc:
+                exctyp,exval,exctrace = self._launcher._exc[0]
+                raise exctyp,exval,exctrace
+            elif not self._started:
+                raise RuntimeError, "Failed to launch TUN forwarder"
+        elif not self._started:
+            self.launch()
+
+    def async_launch_wait_listening(self):
+        if self._launcher:
+            for x in xrange(180):
                 if self._launcher._exc:
                     exctyp,exval,exctrace = self._launcher._exc[0]
                     raise exctyp,exval,exctrace
-                else:
-                    raise RuntimeError, "Failed to launch TUN forwarder"
+                elif self._started and self._started_listening:
+                    break
+                time.sleep(1)
         elif not self._started:
             self.launch()
 
@@ -412,7 +553,7 @@ class TunProtoBase(object):
                 )
             return status
     
-    def kill(self):
+    def kill(self, nowait = True):
         local = self.local()
         
         if not local:
@@ -432,7 +573,7 @@ class TunProtoBase(object):
                 ident_key = local.node.ident_path,
                 server_key = local.node.server_key,
                 sudo = True,
-                nowait = True
+                nowait = nowait
                 )
     
     def waitkill(self):
@@ -444,6 +585,16 @@ class TunProtoBase(object):
                 break
             time.sleep(interval)
             interval = min(30.0, interval * 1.1)
+        else:
+            self.kill(nowait=False)
+
+        if self.if_name:
+            for i in xrange(30):
+                if not self.if_alive():
+                    self._logger.info("Device down %s", self)
+                    break
+                time.sleep(interval)
+                interval = min(30.0, interval * 1.1)
     
     _TRACEMAP = {
         # tracename : (remotename, localname)
@@ -451,16 +602,17 @@ class TunProtoBase(object):
         'pcap' : ('pcap','capture.pcap'),
     }
     
-    def remote_trace_path(self, whichtrace):
-        tracemap = self._TRACEMAP
+    def remote_trace_path(self, whichtrace, tracemap = None):
+        tracemap = self._TRACEMAP if not tracemap else tracemap
+        
         
         if whichtrace not in tracemap:
             return None
         
         return os.path.join(self.home_path, tracemap[whichtrace][1])
         
-    def sync_trace(self, local_dir, whichtrace):
-        tracemap = self._TRACEMAP
+    def sync_trace(self, local_dir, whichtrace, tracemap = None):
+        tracemap = self._TRACEMAP if not tracemap else tracemap
         
         if whichtrace not in tracemap:
             return None
@@ -605,7 +757,7 @@ class TunProtoTCP(TunProtoBase):
             # make sure our peer is ready
             peer = self.peer()
             if peer and peer.peer_proto_impl:
-                peer.peer_proto_impl.async_launch_wait()
+                peer.peer_proto_impl.async_launch_wait_listening()
             
             if not self._started:
                 self.async_launch('tcp', False)