Multicast forwarding KINDA working
authorClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Tue, 30 Aug 2011 13:48:33 +0000 (15:48 +0200)
committerClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Tue, 30 Aug 2011 13:48:33 +0000 (15:48 +0200)
src/nepi/testbeds/planetlab/interfaces.py
src/nepi/testbeds/planetlab/scripts/mcastfwd.py [new file with mode: 0644]
src/nepi/testbeds/planetlab/scripts/tun_connect.py
src/nepi/testbeds/planetlab/tunproto.py
src/nepi/util/tunchannel.py

index c853a37..9699167 100644 (file)
@@ -139,6 +139,7 @@ class TunIface(object):
         
         # These get initialized when the iface is connected to any filter
         self.filter_module = None
+        self.multicast_forwarder = None
         
         # These get initialized when the iface is configured
         self.external_iface = None
@@ -242,6 +243,7 @@ class TunIface(object):
         impl = self._PROTO_MAP[self.peer_proto](
             self, self.peer_iface, home_path, self.tun_key, listening)
         impl.port = self.tun_port
+        impl.cross_slice = not self.peer_iface or isinstance(self.peer_iface, _CrossIface)
         return impl
     
     def recover(self):
diff --git a/src/nepi/testbeds/planetlab/scripts/mcastfwd.py b/src/nepi/testbeds/planetlab/scripts/mcastfwd.py
new file mode 100644 (file)
index 0000000..0db1f5f
--- /dev/null
@@ -0,0 +1,399 @@
+import sys
+
+import signal
+import socket
+import struct
+import optparse
+import threading
+import subprocess
+import re
+import time
+import collections
+import os
+import traceback
+
+import ipaddr2
+
+usage = "usage: %prog [options] <enabled-addresses>"
+
+parser = optparse.OptionParser(usage=usage)
+
+parser.add_option(
+    "-d", "--poll-delay", dest="poll_delay", metavar="SECONDS", type="float",
+    default = 1.0,
+    help = "Multicast subscription polling interval")
+parser.add_option(
+    "-D", "--refresh-delay", dest="refresh_delay", metavar="SECONDS", type="float",
+    default = 30.0,
+    help = "Full-refresh interval - time between full IGMP reports")
+parser.add_option(
+    "-p", "--fwd-path", dest="fwd_path", metavar="PATH", 
+    default = "/var/run/mcastfwd",
+    help = "Path of the unix socket in which the program will listen for packets")
+parser.add_option(
+    "-r", "--router-path", dest="mrt_path", metavar="PATH", 
+    default = "/var/run/mcastrt",
+    help = "Path of the unix socket in which the program will listen for routing changes")
+
+(options, remaining_args) = parser.parse_args(sys.argv[1:])
+
+ETH_P_ALL = 0x00000003
+ETH_P_IP = 0x00000800
+TUNSETIFF = 0x400454ca
+IFF_NO_PI = 0x00001000
+IFF_TAP = 0x00000002
+IFF_TUN = 0x00000001
+IFF_VNET_HDR = 0x00004000
+TUN_PKT_STRIP = 0x00000001
+IFHWADDRLEN = 0x00000006
+IFNAMSIZ = 0x00000010
+IFREQ_SZ = 0x00000028
+FIONREAD = 0x0000541b
+
+class IGMPThread(threading.Thread):
+    def __init__(self, vif_addr, *p, **kw):
+        super(IGMPThread, self).__init__(*p, **kw)
+        
+        vif_addr = vif_addr.strip()
+        self.vif_addr = vif_addr
+        self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
+        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
+            socket.inet_aton(self.vif_addr) )
+        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
+        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
+        self._stop = False
+        self.setDaemon(True)
+        
+        # Find tun name
+        proc = subprocess.Popen(['ip','addr','show'],
+            stdout = subprocess.PIPE,
+            stderr = subprocess.STDOUT,
+            stdin = open('/dev/null','r+b') )
+        tun_name = None
+        heading = re.compile(r"\d+:\s*(\w+):.*")
+        addr = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}).*")
+        for line in proc.stdout:
+            match = heading.match(line)
+            if match:
+                tun_name = match.group(1)
+            else:
+                match = addr.match(line)
+                if match and match.group(1) == vif_addr:
+                    self.tun_name = tun_name
+                    break
+        else:
+            raise RuntimeError, "Could not find iterface for", vif_addr
+    
+    def run(self):
+        devnull = open('/dev/null','r+b')
+        maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
+        cur_maddr = set()
+        lastfullrefresh = time.time()
+        while not self._stop:
+            # Get current subscriptions
+            proc = subprocess.Popen(['ip','maddr','show',self.tun_name],
+                stdout = subprocess.PIPE,
+                stderr = subprocess.STDOUT,
+                stdin = devnull)
+            new_maddr = set()
+            for line in proc.stdout:
+                match = maddr_re.match(line)
+                if match:
+                    new_maddr.add(match.group(1))
+            proc.wait()
+            
+            # Every now and then, send a full report
+            now = time.time()
+            report_new = new_maddr
+            if (now - lastfullrefresh) <= options.refresh_delay:
+                report_new = report_new - cur_maddr
+            else:
+                lastfullrefresh = now
+            
+            # Report subscriptions
+            for grp in report_new:
+                print >>sys.stderr, "JOINING", grp
+                igmpp = ipaddr2.ipigmp(
+                    self.vif_addr, '224.0.0.2', 1, 0x16, 0, grp, 
+                    noipcksum=True)
+                try:
+                    self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
+                except:
+                    traceback.print_exc(file=sys.stderr)
+
+            # Notify group leave
+            for grp in cur_maddr - new_maddr:
+                print >>sys.stderr, "LEAVING", grp
+                igmpp = ipaddr2.ipigmp(
+                    self.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
+                    noipcksum=True)
+                try:
+                    self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
+                except:
+                    traceback.print_exc(file=sys.stderr)
+
+            cur_maddr = new_maddr
+            
+            time.sleep(options.poll_delay)
+    
+    def stop(self):
+        self._stop = True
+        self.join(1+5*options.poll_delay)
+
+
+class FWDThread(threading.Thread):
+    def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
+        super(FWDThread, self).__init__(*p, **kw)
+        
+        self.in_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+        self.in_socket.bind(options.fwd_path)
+        
+        self.pending = collections.deque()
+        self.maxpending = 1000
+        self.rt_cache = rt_cache
+        self.router_socket = router_socket
+        self.vifs = vifs
+        self.fwd_sockets = {}
+        for fwd_target in remaining_args:
+            fwd_target = socket.inet_aton(fwd_target)
+            fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
+            fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
+            fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
+            fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
+            self.fwd_sockets[fwd_target] = fwd_socket
+        
+        self._stop = False
+        self.setDaemon(True)
+    
+    def run(self):
+        in_socket = self.in_socket
+        rt_cache = self.rt_cache
+        vifs = self.vifs
+        router_socket = self.router_socket
+        len_ = len
+        ord_ = ord
+        pending = self.pending
+        in_socket.settimeout(options.poll_delay)
+        buffer_ = buffer
+        enumerate_ = enumerate
+        fwd_sockets = self.fwd_sockets
+        npending = 0
+        
+        while not self._stop:
+            # Get packet
+            try:
+                if pending and npending:
+                    packet = pending.pop()
+                    npending -= 1
+                else:
+                    packet = in_socket.recv(2000)
+            except socket.timeout, e:
+                if pending and not npending:
+                    npending = len_(pending)
+                continue
+            if not packet or len_(packet) < 24:
+                continue
+            
+            fullpacket = packet
+            parent = buffer_(packet,0,4)
+            packet = buffer_(packet,4)
+            
+            if packet[9] == '\x02':
+                # IGMP packet? It's for mrouted
+                self.router_socket.send(packet)
+            elif packet[9] == '\x00':
+                # LOOPING packet, discard
+                continue
+            
+            # To-Do: PIM asserts
+            
+            # Get route
+            addrinfo = buffer_(packet,14,8)
+            fwd_targets = rt_cache.get(parent+addrinfo)
+            
+            if fwd_targets is not None:
+                # Forward
+                ttl = ord_(packet[8])
+                tgt_group = (addrinfo[4:],0)
+                print >>sys.stderr, socket.inet_ntoa(tgt_group), "->",
+                for vifi, ttl in enumerate_(fwd_targets):
+                    ttl_thresh = ord_(ttl)
+                    if ttl_thresh > 0 and ttl > ttl_thresh and vifi in vifs:
+                        vifi = vifs[vifi]
+                        if vifi[4] in fwd_sockets:
+                            print >>sys.stderr, socket.inet_ntoa(vifi[4]),
+                            fwd_socket = fwd_sockets[vifi[4]]
+                            fwd_socket.sendto(packet, 0, tgt_group)
+                print >>sys.stderr, "."
+            else:
+                # Mark pending
+                if len_(pending) < self.maxpending:
+                    tgt_group = addrinfo[4:]
+                    print >>sys.stderr, socket.inet_ntoa(tgt_group), "-> ?"
+                    
+                    pending.append(fullpacket)
+                    
+                    # Notify mrouted by forwarding it with protocol 0
+                    router_socket.send(''.join(
+                        (packet[:9],'\x00',packet[10:]) ))
+    
+    def stop(self):
+        self._stop = True
+        self.join(1+5*options.poll_delay)
+
+
+class RouterThread(threading.Thread):
+    def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
+        super(RouterThread, self).__init__(*p, **kw)
+        
+        self.rt_cache = rt_cache
+        self.vifs = vifs
+        self.router_socket = router_socket
+
+        self._stop = False
+        self.setDaemon(True)
+    
+    def run(self):
+        rt_cache = self.rt_cache
+        vifs = self.vifs
+        addr_vifs = {}
+        router_socket = self.router_socket
+        router_socket.settimeout(options.poll_delay)
+        len_ = len
+        buffer_ = buffer
+        
+        buf = ""
+        
+        MRT_BASE       = 200
+        MRT_ADD_VIF    = MRT_BASE+2    # Add a virtual interface               
+        MRT_DEL_VIF    = MRT_BASE+3    # Delete a virtual interface            
+        MRT_ADD_MFC    = MRT_BASE+4    # Add a multicast forwarding entry      
+        MRT_DEL_MFC = MRT_BASE+5       # Delete a multicast forwarding entry   
+        
+        def cmdhdr(cmd, unpack=struct.unpack, buffer=buffer):
+            op,dlen = unpack('II', buffer(cmd,0,8))
+            cmd = buffer(cmd,8)
+            return op,dlen,cmd
+        def vifctl(data, unpack=struct.unpack):
+            #vifi, flags,threshold,rate_limit,lcl_addr,rmt_addr = unpack('HBBI4s4s', data)
+            return unpack('HBBI4s4s', data)
+        def mfcctl(data, unpack=struct.unpack):
+            #origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = unpack('4s4sH10sIIIi', data)
+            return unpack('4s4sH10sIIIi', data)
+        
+        
+        def add_vif(cmd):
+            vifi = vifctl(cmd)
+            vifs[vifi[0]] = vifi
+            addr_vifs[vifi[4]] = vifi[0]
+            print >>sys.stderr, "Added VIF", vifi
+        def del_vif(cmd):
+            vifi = vifctl(cmd)
+            vifi = vifs[vifi[0]]
+            del addr_vifs[vifi[4]]
+            del vifs[vifi[0]]
+            print >>sys.stderr, "Removed VIF", vifi
+        def add_mfc(cmd):
+            origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
+            addrinfo = ''.join(vifs[parent][4],origin,mcastgrp)
+            rt_cache[addrinfo] = ttls
+            print >>sys.stderr, "Added RT", '-'.join(map(socket.inet_ntoa((vifs[parent][4],origin,mcastgrp))))
+        def del_mfc(cmd):
+            origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
+            addrinfo = ''.join(vifs[parent][4],origin,mcastgrp)
+            del rt_cache[addrinfo]
+            print >>sys.stderr, "Removed RT", '-'.join(map(socket.inet_ntoa((vifs[parent][4],origin,mcastgrp))))
+        
+        commands = {
+            MRT_ADD_VIF : add_vif,
+            MRT_DEL_VIF : del_vif,
+            MRT_ADD_MFC : add_mfc,
+            MRT_DEL_MFC : del_mfc,
+        }
+
+        while not self._stop:
+            if len_(buf) < 8 or len_(buf) < (cmdhdr(buf)[1]+8):
+                # Get cmd
+                try:
+                    cmd = router_socket.recv(2000)
+                except socket.timeout, e:
+                    continue
+                if not cmd:
+                    print >>sys.stderr, "PLRT CONNECTION BROKEN"
+                    TERMINATE.append(None)
+                    break
+            
+            if buf:
+                buf += cmd
+                cmd = buf
+            
+            if len_(cmd) < 8:
+                continue
+            
+            op,dlen,data = cmdhdr(cmd)
+            if len_(data) < dlen:
+                continue
+            
+            buf = buffer_(data, dlen)
+            data = buffer_(data, 0, dlen)
+            
+            print >>sys.stderr, "COMMAND", op, "DATA", dlen
+            
+            if op in commands:
+                try:
+                    commands[op](data)
+                except:
+                    traceback.print_exc(file=sys.stderr)
+            else:
+                print >>sys.stderr, "IGNORING UNKNOWN COMMAND", op
+    
+    def stop(self):
+        self._stop = True
+        self.join(1+5*options.poll_delay)
+
+
+
+igmp_threads = []
+for vif_addr in remaining_args:
+    igmp_threads.append(IGMPThread(vif_addr))
+
+rt_cache = {}
+vifs = {}
+
+TERMINATE = []
+TERMINATE = []
+def _finalize(sig,frame):
+    global TERMINATE
+    TERMINATE.append(None)
+signal.signal(signal.SIGTERM, _finalize)
+
+
+try:
+    router_socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
+    router_socket.bind(options.mrt_path)
+    router_socket.listen(0)
+    router_remote_socket, router_remote_addr = router_socket.accept()
+
+    fwd_thread = FWDThread(rt_cache, router_remote_socket, vifs)
+    router_thread = RouterThread(rt_cache, router_remote_socket, vifs)
+
+    for thread in igmp_threads:
+        thread.start()
+    fwd_thread.start()
+    router_thread.start()
+
+    while not TERMINATE:
+        time.sleep(30)
+finally:
+    if os.path.exists(options.mrt_path):
+        try:
+            os.remove(options.mrt_path)
+        except:
+            pass
+    if os.path.exists(options.fwd_path):
+        try:
+            os.remove(options.fwd_path)    
+        except:
+            pass
+
+
index dce367b..bd74454 100644 (file)
@@ -21,6 +21,12 @@ import traceback
 import tunchannel
 import ipaddr2
 
+try:
+    import iovec
+    HAS_IOVEC = True
+except:
+    HAS_IOVEC = False
+
 tun_name = 'tun0'
 tun_path = '/dev/net/tun'
 hostaddr = socket.gethostbyname(socket.gethostname())
@@ -104,8 +110,8 @@ parser.add_option(
         "Specify a symmetric encryption key with which to protect packets across "
         "the tunnel. python-crypto must be installed on the system." )
 parser.add_option(
-    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="int",
-    default = None,
+    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="string",
+    default = "true",
     help = 
         "Specify a demultiplexing 32-bit numeric key for GRE." )
 parser.add_option(
@@ -132,6 +138,13 @@ parser.add_option(
            "must be sent to the mroute unix socket, in a format identical "
            "to that of the kernel's MRT ioctls, prefixed with 32-bit IOCTL "
            "code and 32-bit data length." )
+parser.add_option(
+    "--multicast-forwarder", dest="multicast_fwd", 
+    default = None,
+    help = "If specified, multicast packets will be forwarded to "
+           "the specified unix-domain socket. If the device uses ethernet "
+           "frames, ethernet headers will be stripped and IP packets "
+           "will be forwarded." )
 parser.add_option(
     "--filter", dest="filter_module", metavar="PATH",
     default = None,
@@ -508,7 +521,7 @@ def pl_vif_start(tun_path, tun_name):
     if options.vif_txqueuelen is not None:
         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
     if options.mode.startswith('pl-gre'):
-        stdin.write("gre=%d\n" % (options.gre_key,))
+        stdin.write("gre=%s\n" % (options.gre_key,))
         stdin.write("remote=%s\n" % (remaining_args[0],))
     stdin.close()
     
@@ -685,6 +698,12 @@ else:
     filter_close = None
     queueclass = None
 
+# install multicast forwarding hook
+if options.multicast_fwd:
+    print >>sys.stderr, "Connecting to mcast filter"
+    mcfwd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    tunchannel.nonblock(mcfwd_sock.fileno())
+
 # be careful to roll back stuff on exceptions
 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
 try:
@@ -710,6 +729,61 @@ try:
     tcpdump = None
     reconnect = None
     mcastthread = None
+
+    # install multicast forwarding hook
+    if options.multicast_fwd:
+        print >>sys.stderr, "Installing mcast filter"
+        
+        if HAS_IOVEC:
+            writev = iovec.writev
+        else:
+            os_write = os.write
+            map_ = map
+            str_ = str
+            def writev(fileno, *stuff):
+                os_write(''.join(map_(str_,stuff)))
+        
+        def accept_packet(packet, direction, 
+                _up_accept=accept_packet, 
+                sock=mcfwd_sock, 
+                sockno=mcfwd_sock.fileno(),
+                etherProto=tunchannel.etherProto,
+                etherStrip=tunchannel.etherStrip,
+                etherMode=tun_name.startswith('tap'),
+                multicast_fwd = options.multicast_fwd,
+                vif_addr = socket.inet_aton(options.vif_addr),
+                connected = [], writev=writev,
+                len=len, ord=ord):
+            if _up_accept:
+                rv = _up_accept(packet, direction)
+                if not rv:
+                    return rv
+
+            if direction == 1:
+                # Incoming... what?
+                if etherMode:
+                    if etherProto(packet)=='\x08\x00':
+                        fwd = etherStrip(packet)
+                    else:
+                        fwd = None
+                else:
+                    fwd = packet
+                if fwd is not None and len(fwd) >= 20:
+                    if (ord(fwd[16]) & 0xf0) == 0xe0:
+                        # Forward it
+                        if not connected:
+                            try:
+                                sock.connect(multicast_fwd)
+                                connected.append(None)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+                        if connected:
+                            try:
+                                writev(sockno, vif_addr,fwd)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+            return 1
+
     
     if options.pass_fd:
         if accept_packet or filter_init:
index 84e2d9c..3522d05 100644 (file)
@@ -26,6 +26,7 @@ class TunProtoBase(object):
         self.port = 15000
         self.mode = 'pl-tun'
         self.key = key
+        self.cross_slice = False
         
         self.home_path = home_path
         
@@ -209,6 +210,7 @@ class TunProtoBase(object):
         local_cipher=local.tun_cipher
         local_mcast= local.multicast
         local_bwlim= local.bwlimit
+        local_mcastfwd = local.multicast_forwarder
         
         if not local_p2p and hasattr(peer, 'address'):
             local_p2p = peer.address
@@ -244,7 +246,8 @@ class TunProtoBase(object):
             "-m", str(self.mode),
             "-A", str(local_addr),
             "-M", str(local_mask),
-            "-C", str(local_cipher)]
+            "-C", str(local_cipher),
+            ]
         
         if check_proto == 'fd':
             passfd_arg = str(peer_addr)
@@ -257,9 +260,10 @@ class TunProtoBase(object):
                 "--pass-fd", passfd_arg
             ])
         elif check_proto == 'gre':
-            args.extend([
-                "-K", str(min(local_port, peer_port))
-            ])
+            if self.cross_slice:
+                args.extend([
+                    "-K", str(self.key.strip('='))
+                ])
         else:
             args.extend([
                 "-p", str(local_port if listen else peer_port),
@@ -288,6 +292,8 @@ class TunProtoBase(object):
             args.extend(("--filter", filter_module))
         if filter_args:
             args.extend(("--filter-args", filter_args))
+        if local_mcastfwd:
+            args.extend(("--multicast-forwarder", local_mcastfwd))
 
         self._logger.info("Starting %s", self)
         
@@ -354,7 +360,7 @@ class TunProtoBase(object):
             
             # Connected?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                "cd %(home)s ; grep -c Connected capture" % dict(
+                "cd %(home)s ; grep -a -c Connected capture" % dict(
                     home = server.shell_escape(self.home_path)),
                 host = local.node.hostname,
                 port = None,
@@ -372,7 +378,7 @@ class TunProtoBase(object):
 
             # At least listening?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
-                "cd %(home)s ; grep -c Listening capture" % dict(
+                "cd %(home)s ; grep -a -c Listening capture" % dict(
                     home = server.shell_escape(self.home_path)),
                 host = local.node.hostname,
                 port = None,
@@ -414,7 +420,7 @@ class TunProtoBase(object):
             # Inspect the trace to check the assigned iface
             local = self.local()
             if local:
-                cmd = "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
+                cmd = "cd %(home)s ; grep -a 'Using tun:' capture | head -1" % dict(
                             home = server.shell_escape(self.home_path))
                 for spin in xrange(30):
                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
index c44803f..d464b8c 100644 (file)
@@ -118,15 +118,15 @@ def _pullPacket(buf, ether_mode=False, len=len):
             rv = buf.popleft()
         return rv
 
-def etherStrip(buf):
+def etherStrip(buf, buffer=buffer, len=len):
     if len(buf) < 14:
         return ""
     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
         # tagged ethernet frame
-        return buf[18:]
+        return buffer(buf, 18)
     elif buf[12:14] == '\x08\x00':
         # untagged ethernet frame
-        return buf[14:]
+        return buffer(buf, 14)
     else:
         return ""