Merge TCP handshake stuff
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
index cf55e0d..f015bad 100644 (file)
@@ -19,7 +19,12 @@ import base64
 import traceback
 
 import tunchannel
-import ipaddr2
+
+try:
+    import iovec
+    HAS_IOVEC = True
+except:
+    HAS_IOVEC = False
 
 tun_name = 'tun0'
 tun_path = '/dev/net/tun'
@@ -113,8 +118,8 @@ parser.add_option(
         "Specify a symmetric encryption key with which to protect packets across "
         "the tunnel. python-crypto must be installed on the system." )
 parser.add_option(
-    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="int",
-    default = None,
+    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="string",
+    default = "true",
     help = 
         "Specify a demultiplexing 32-bit numeric key for GRE." )
 parser.add_option(
@@ -133,14 +138,12 @@ parser.add_option(
     help = "If specified, packets won't be logged to standard output, "
            "but dumped to a pcap-formatted trace in the specified file. " )
 parser.add_option(
-    "--multicast", dest="multicast", 
-    action = "store_true",
-    default = False,
-    help = "If specified, multicast packets will be forwarded and IGMP "
-           "join/leave packets will be generated. Routing information "
-           "must be sent to the mroute unix socket, in a format identical "
-           "to that of the kernel's MRT ioctls, prefixed with 32-bit IOCTL "
-           "code and 32-bit data length." )
+    "--multicast-forwarder", dest="multicast_fwd", 
+    default = None,
+    help = "If specified, multicast packets will be forwarded to "
+           "the specified unix-domain socket. If the device uses ethernet "
+           "frames, ethernet headers will be stripped and IP packets "
+           "will be forwarded, prefixed with the interface's address." )
 parser.add_option(
     "--filter", dest="filter_module", metavar="PATH",
     default = None,
@@ -222,65 +225,6 @@ IFNAMSIZ = 0x00000010
 IFREQ_SZ = 0x00000028
 FIONREAD = 0x0000541b
 
-class MulticastThread(threading.Thread):
-    def __init__(self, *p, **kw):
-        super(MulticastThread, self).__init__(*p, **kw)
-        self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
-        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
-            socket.inet_aton(options.vif_addr) )
-        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
-        self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
-        self._stop = False
-        self.setDaemon(True)
-    
-    def run(self):
-        devnull = open('/dev/null','r+b')
-        maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
-        cur_maddr = set()
-        lastfullrefresh = time.time()
-        while not self._stop:
-            # Get current subscriptions
-            proc = subprocess.Popen(['ip','maddr','show',tun_name],
-                stdout = subprocess.PIPE,
-                stderr = subprocess.STDOUT,
-                stdin = devnull)
-            new_maddr = set()
-            for line in proc.stdout:
-                match = maddr_re.match(line)
-                if match:
-                    new_maddr.add(match.group(1))
-            proc.wait()
-            
-            # Every now and then, send a full report
-            now = time.time()
-            report_new = new_maddr
-            if (now - lastfullrefresh) <= 30.0:
-                report_new = report_new - cur_maddr
-            else:
-                lastfullrefresh = now
-            
-            # Report subscriptions
-            for grp in report_new:
-                igmpp = ipaddr2.ipigmp(
-                    options.vif_addr, '224.0.0.2', 1, 0x16, 0, grp, 
-                    noipcksum=True)
-                self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
-
-            # Notify group leave
-            for grp in cur_maddr - new_maddr:
-                igmpp = ipaddr2.ipigmp(
-                    options.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
-                    noipcksum=True)
-                self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
-
-            cur_maddr = new_maddr
-            
-            time.sleep(1)
-    
-    def stop(self):
-        self._stop = True
-        self.join(5)
-
 class HostLock(object):
     # This class is used as a lock to prevent concurrency issues with more
     # than one instance of netns running in the same machine. Both in 
@@ -517,7 +461,7 @@ def pl_vif_start(tun_path, tun_name):
     if options.vif_txqueuelen is not None:
         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
     if options.mode.startswith('pl-gre'):
-        stdin.write("gre=%d\n" % (options.gre_key,))
+        stdin.write("gre=%s\n" % (options.gre_key,))
         stdin.write("remote=%s\n" % (options.peer_addr,))
     stdin.close()
     
@@ -694,6 +638,12 @@ else:
     filter_close = None
     queueclass = None
 
+# install multicast forwarding hook
+if options.multicast_fwd:
+    print >>sys.stderr, "Connecting to mcast filter"
+    mcfwd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    tunchannel.nonblock(mcfwd_sock.fileno())
+
 # be careful to roll back stuff on exceptions
 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
 try:
@@ -719,6 +669,61 @@ try:
     tcpdump = None
     reconnect = None
     mcastthread = None
+
+    # install multicast forwarding hook
+    if options.multicast_fwd:
+        print >>sys.stderr, "Installing mcast filter"
+        
+        if HAS_IOVEC:
+            writev = iovec.writev
+        else:
+            os_write = os.write
+            map_ = map
+            str_ = str
+            def writev(fileno, *stuff):
+                os_write(''.join(map_(str_,stuff)))
+        
+        def accept_packet(packet, direction, 
+                _up_accept=accept_packet, 
+                sock=mcfwd_sock, 
+                sockno=mcfwd_sock.fileno(),
+                etherProto=tunchannel.etherProto,
+                etherStrip=tunchannel.etherStrip,
+                etherMode=tun_name.startswith('tap'),
+                multicast_fwd = options.multicast_fwd,
+                vif_addr = socket.inet_aton(options.vif_addr),
+                connected = [], writev=writev,
+                len=len, ord=ord):
+            if _up_accept:
+                rv = _up_accept(packet, direction)
+                if not rv:
+                    return rv
+
+            if direction == 1:
+                # Incoming... what?
+                if etherMode:
+                    if etherProto(packet)=='\x08\x00':
+                        fwd = etherStrip(packet)
+                    else:
+                        fwd = None
+                else:
+                    fwd = packet
+                if fwd is not None and len(fwd) >= 20:
+                    if (ord(fwd[16]) & 0xf0) == 0xe0:
+                        # Forward it
+                        if not connected:
+                            try:
+                                sock.connect(multicast_fwd)
+                                connected.append(None)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+                        if connected:
+                            try:
+                                writev(sockno, vif_addr,fwd)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+            return 1
+
     
     if options.protocol == 'fd':
         if accept_packet or filter_init:
@@ -825,11 +830,6 @@ try:
         # or perhaps there is no os.nice support in the system
         pass
     
-    if options.multicast:
-        # Start multicast forwarding daemon
-        mcastthread = MulticastThread()
-        mcastthread.start()
-
     if not filter_init:
         tun_fwd(tun, remote,
             reconnect = reconnect,