Merge TCP handshake stuff
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
index 98fc311..f015bad 100644 (file)
@@ -20,11 +20,17 @@ import traceback
 
 import tunchannel
 
+try:
+    import iovec
+    HAS_IOVEC = True
+except:
+    HAS_IOVEC = False
+
 tun_name = 'tun0'
 tun_path = '/dev/net/tun'
 hostaddr = socket.gethostbyname(socket.gethostname())
 
-usage = "usage: %prog [options] <remote-endpoint>"
+usage = "usage: %prog [options]"
 
 parser = optparse.OptionParser(usage=usage)
 
@@ -37,9 +43,9 @@ parser.add_option(
     default = "/dev/net/tun",
     help = "TUN/TAP device file path or file descriptor number")
 parser.add_option(
-    "-p", "--port", dest="port", metavar="PORT", type="int",
+    "-p", "--peer-port", dest="peer_port", metavar="PEER_PORT", type="int",
     default = 15000,
-    help = "Peering TCP port to connect or listen to.")
+    help = "Remote TCP/UDP port to connect to.")
 parser.add_option(
     "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
     default = None,
@@ -47,7 +53,6 @@ parser.add_option(
            "If given, all other connectivity options are ignored, tun_connect will "
            "simply wait to be killed after passing the file descriptor, and it will be "
            "the receiver's responsability to handle the tunneling.")
-
 parser.add_option(
     "-m", "--mode", dest="mode", metavar="MODE",
     default = "none",
@@ -56,6 +61,11 @@ parser.add_option(
         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
+parser.add_option(
+    "-t", "--protocol", dest="protocol", metavar="PROTOCOL",
+    default = None,
+    help = 
+        "Set protocol. One of tcp, udp, fd, gre. In any mode except none, a TUN/TAP will be created.")
 parser.add_option(
     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
     default = None,
@@ -68,13 +78,18 @@ parser.add_option(
     help = 
         "See mode. This specifies the VIF_MASK, "
         "a number indicating the network type (ie: 24 for a C-class network).")
+parser.add_option(
+    "-P", "--port", dest="port", type="int", metavar="PORT", 
+    default = None,
+    help = 
+        "This specifies the LOCAL_PORT. This will be the local bind port for UDP/TCP.")
 parser.add_option(
     "-S", "--vif-snat", dest="vif_snat", 
     action = "store_true",
     default = False,
     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
 parser.add_option(
-    "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
+    "-Z", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
     default = None,
     help = 
         "See mode. This specifies the remote endpoint's virtual address, "
@@ -86,11 +101,16 @@ parser.add_option(
     help = 
         "See mode. This specifies the interface's transmission queue length. " )
 parser.add_option(
-    "-u", "--udp", dest="udp", metavar="PORT", type="int",
+    "-b", "--bwlimit", dest="bwlimit", metavar="BYTESPERSECOND", type="int",
     default = None,
     help = 
-        "Bind to the specified UDP port locally, and send UDP datagrams to the "
-        "remote endpoint, creating a tunnel through UDP rather than TCP." )
+        "This specifies the interface's emulated bandwidth in bytes per second." )
+parser.add_option(
+    "-a", "--peer-address", dest="peer_addr", metavar="PEER_ADDRESS",
+    default = None,
+    help = 
+        "This specifies the PEER_ADDRESS, "
+        "the IP address of the remote interface.")
 parser.add_option(
     "-k", "--key", dest="cipher_key", metavar="KEY",
     default = None,
@@ -98,8 +118,8 @@ parser.add_option(
         "Specify a symmetric encryption key with which to protect packets across "
         "the tunnel. python-crypto must be installed on the system." )
 parser.add_option(
-    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="int",
-    default = None,
+    "-K", "--gre-key", dest="gre_key", metavar="KEY", type="string",
+    default = "true",
     help = 
         "Specify a demultiplexing 32-bit numeric key for GRE." )
 parser.add_option(
@@ -117,8 +137,72 @@ parser.add_option(
     default = None,
     help = "If specified, packets won't be logged to standard output, "
            "but dumped to a pcap-formatted trace in the specified file. " )
+parser.add_option(
+    "--multicast-forwarder", dest="multicast_fwd", 
+    default = None,
+    help = "If specified, multicast packets will be forwarded to "
+           "the specified unix-domain socket. If the device uses ethernet "
+           "frames, ethernet headers will be stripped and IP packets "
+           "will be forwarded, prefixed with the interface's address." )
+parser.add_option(
+    "--filter", dest="filter_module", metavar="PATH",
+    default = None,
+    help = "If specified, it should be either a .py or .so module. "
+           "It will be loaded, and all incoming and outgoing packets "
+           "will be routed through it. The filter will not be responsible "
+           "for buffering, packet queueing is performed in tun_connect "
+           "already, so it should not concern itself with it. It should "
+           "not, however, block in one direction if the other is congested.\n"
+           "\n"
+           "Modules are expected to have the following methods:\n"
+           "\tinit(**args)\n"
+           "\t\tIf arguments are given, this method will be called with the\n"
+           "\t\tgiven arguments (as keyword args in python modules, or a single\n"
+           "\t\tstring in c modules).\n"
+           "\taccept_packet(packet, direction):\n"
+           "\t\tDecide whether to drop the packet. Direction is 0 for packets "
+               "coming from the local side to the remote, and 1 is for packets "
+               "coming from the remote side to the local. Return a boolean, "
+               "true if the packet is not to be dropped.\n"
+           "\tfilter_init():\n"
+           "\t\tInitializes a filtering pipe (filter_run). It should "
+               "return two file descriptors to use as a bidirectional "
+               "pipe: local and remote. 'local' is where packets from the "
+               "local side will be written to. After filtering, those packets "
+               "should be written to 'remote', where tun_connect will read "
+               "from, and it will forward them to the remote peer. "
+               "Packets from the remote peer will be written to 'remote', "
+               "where the filter is expected to read from, and eventually "
+               "forward them to the local side. If the file descriptors are "
+               "not nonblocking, they will be set to nonblocking. So it's "
+               "better to set them from the start like that.\n"
+           "\tfilter_run(local, remote):\n"
+           "\t\tIf filter_init is provided, it will be called repeatedly, "
+               "in a separate thread until the process is killed. It should "
+               "sleep at most for a second.\n"
+           "\tfilter_close(local, remote):\n"
+           "\t\tCalled then the process is killed, if filter_init was provided. "
+               "It should, among other things, close the file descriptors.\n"
+           "\n"
+           "Python modules are expected to return a tuple in filter_init, "
+           "either of file descriptors or file objects, while native ones "
+           "will receive two int*.\n"
+           "\n"
+           "Python modules can additionally contain a custom queue class "
+           "that will replace the FIFO used by default. The class should "
+           "be named 'queueclass' and contain an interface compatible with "
+           "collections.deque. That is, indexing (especiall for q[0]), "
+           "bool(q), popleft, appendleft, pop (right), append (right), "
+           "len(q) and clear. When using a custom queue, queue size will "
+           "have no effect, pass an effective queue size to the module "
+           "by using filter_args" )
+parser.add_option(
+    "--filter-args", dest="filter_args", metavar="FILE",
+    default = None,
+    help = "If specified, packets won't be logged to standard output, "
+           "but dumped to a pcap-formatted trace in the specified file. " )
 
-(options, remaining_args) = parser.parse_args(sys.argv[1:])
+(options,args) = parser.parse_args(sys.argv[1:])
 
 options.cipher = {
     'aes' : 'AES',
@@ -212,11 +296,11 @@ def tunopen(tun_path, tun_name):
     return tun
 
 def tunclose(tun_path, tun_name, tun):
-    if tun_path.isdigit():
+    if tun_path and tun_path.isdigit():
         # close TUN fd
         os.close(int(tun_path))
         tun.close()
-    else:
+    elif tun:
         # close TUN object
         tun.close()
 
@@ -320,7 +404,7 @@ def pl_tuntap_namealloc(kind, tun_path, tun_name):
     global _name_reservation
     # Serialize access
     lockfile = open("/tmp/nepi-tun-connect.lock", "a")
-    _name_reservation = lock = HostLock(lockfile)
+    lock = HostLock(lockfile)
     
     # We need to do this, fd_tuntap is the only one who can
     # tell us our slice id (this script runs as root, so no uid),
@@ -343,6 +427,8 @@ def pl_tuntap_namealloc(kind, tun_path, tun_name):
     else:
         raise RuntimeError, "Could not assign interface name"
     
+    _name_reservation = lock
+    
     return None, name
 
 def pl_vif_start(tun_path, tun_name):
@@ -350,7 +436,6 @@ def pl_vif_start(tun_path, tun_name):
 
     out = []
     def outreader():
-        stdout = open("/vsys/vif_up.out","r")
         out.append(stdout.read())
         stdout.close()
         time.sleep(1)
@@ -361,6 +446,7 @@ def pl_vif_start(tun_path, tun_name):
     _name_reservation = None
     
     stdin = open("/vsys/vif_up.in","w")
+    stdout = open("/vsys/vif_up.out","r")
 
     t = threading.Thread(target=outreader)
     t.start()
@@ -375,8 +461,8 @@ def pl_vif_start(tun_path, tun_name):
     if options.vif_txqueuelen is not None:
         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
     if options.mode.startswith('pl-gre'):
-        stdin.write("gre=%d\n" % (options.gre_key,))
-        stdin.write("remote=%s\n" % (remaining_args[0],))
+        stdin.write("gre=%s\n" % (options.gre_key,))
+        stdin.write("remote=%s\n" % (options.peer_addr,))
     stdin.close()
     
     t.join()
@@ -389,7 +475,6 @@ def pl_vif_start(tun_path, tun_name):
 def pl_vif_stop(tun_path, tun_name):
     out = []
     def outreader():
-        stdout = open("/vsys/vif_down.out","r")
         out.append(stdout.read())
         stdout.close()
         
@@ -410,6 +495,7 @@ def pl_vif_stop(tun_path, tun_name):
     lock = HostLock(lockfile)
 
     stdin = open("/vsys/vif_down.in","w")
+    stdout = open("/vsys/vif_down.out","r")
     
     t = threading.Thread(target=outreader)
     t.start()
@@ -425,7 +511,7 @@ def pl_vif_stop(tun_path, tun_name):
     del lock, lockfile
 
 
-def tun_fwd(tun, remote, reconnect = None):
+def tun_fwd(tun, remote, reconnect = None, accept_local = None, accept_remote = None, slowlocal = True, bwlimit = None):
     global TERMINATE
     
     tunqueue = options.vif_txqueuelen or 1000
@@ -437,13 +523,18 @@ def tun_fwd(tun, remote, reconnect = None):
         with_pi = options.mode.startswith('pl-'),
         ether_mode = tun_name.startswith('tap'),
         cipher_key = options.cipher_key,
-        udp = options.udp,
+        udp = options.protocol == 'udp',
         TERMINATE = TERMINATE,
         stderr = None,
         reconnect = reconnect,
         tunqueue = tunqueue,
         tunkqueue = tunkqueue,
-        cipher = options.cipher
+        cipher = options.cipher,
+        accept_local = accept_local,
+        accept_remote = accept_remote,
+        queueclass = queueclass,
+        slowlocal = slowlocal,
+        bwlimit = bwlimit
     )
 
 
@@ -492,6 +583,67 @@ tun_name = options.tun_name
 
 modeinfo = MODEINFO[options.mode]
 
+# Try to load filter module
+filter_thread = None
+if options.filter_module:
+    print >>sys.stderr, "Loading module", options.filter_module, "with args", options.filter_args
+    if options.filter_module.endswith('.py'):
+        sys.path.append(os.path.dirname(options.filter_module))
+        filter_module = __import__(os.path.basename(options.filter_module).rsplit('.',1)[0])
+        if options.filter_args:
+            try:
+                filter_args = dict(map(lambda x:x.split('=',1),options.filter_args.split(',')))
+                filter_module.init(**filter_args)
+            except:
+                pass
+    elif options.filter_module.endswith('.so'):
+        filter_module = ctypes.cdll.LoadLibrary(options.filter_module)
+        if options.filter_args:
+            try:
+                filter_module.init(options.filter_args)
+            except:
+                pass
+    try:
+        accept_packet = filter_module.accept_packet
+        print >>sys.stderr, "Installing packet filter (accept_packet)"
+    except:
+        accept_packet = None
+    
+    try:
+        queueclass = filter_module.queueclass
+        print >>sys.stderr, "Installing custom queue"
+    except:
+        queueclass = None
+    
+    try:
+        _filter_init = filter_module.filter_init
+        filter_run = filter_module.filter_run
+        filter_close = filter_module.filter_close
+        
+        def filter_init():
+            filter_local = ctypes.c_int(0)
+            filter_remote = ctypes.c_int(0)
+            _filter_init(filter_local, filter_remote)
+            return filter_local, filter_remote
+
+        print >>sys.stderr, "Installing packet filter (stream filter)"
+    except:
+        filter_init = None
+        filter_run = None
+        filter_close = None
+else:
+    accept_packet = None
+    filter_init = None
+    filter_run = None
+    filter_close = None
+    queueclass = None
+
+# install multicast forwarding hook
+if options.multicast_fwd:
+    print >>sys.stderr, "Connecting to mcast filter"
+    mcfwd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    tunchannel.nonblock(mcfwd_sock.fileno())
+
 # be careful to roll back stuff on exceptions
 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
 try:
@@ -516,8 +668,67 @@ signal.signal(signal.SIGTERM, _finalize)
 try:
     tcpdump = None
     reconnect = None
+    mcastthread = None
+
+    # install multicast forwarding hook
+    if options.multicast_fwd:
+        print >>sys.stderr, "Installing mcast filter"
+        
+        if HAS_IOVEC:
+            writev = iovec.writev
+        else:
+            os_write = os.write
+            map_ = map
+            str_ = str
+            def writev(fileno, *stuff):
+                os_write(''.join(map_(str_,stuff)))
+        
+        def accept_packet(packet, direction, 
+                _up_accept=accept_packet, 
+                sock=mcfwd_sock, 
+                sockno=mcfwd_sock.fileno(),
+                etherProto=tunchannel.etherProto,
+                etherStrip=tunchannel.etherStrip,
+                etherMode=tun_name.startswith('tap'),
+                multicast_fwd = options.multicast_fwd,
+                vif_addr = socket.inet_aton(options.vif_addr),
+                connected = [], writev=writev,
+                len=len, ord=ord):
+            if _up_accept:
+                rv = _up_accept(packet, direction)
+                if not rv:
+                    return rv
+
+            if direction == 1:
+                # Incoming... what?
+                if etherMode:
+                    if etherProto(packet)=='\x08\x00':
+                        fwd = etherStrip(packet)
+                    else:
+                        fwd = None
+                else:
+                    fwd = packet
+                if fwd is not None and len(fwd) >= 20:
+                    if (ord(fwd[16]) & 0xf0) == 0xe0:
+                        # Forward it
+                        if not connected:
+                            try:
+                                sock.connect(multicast_fwd)
+                                connected.append(None)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+                        if connected:
+                            try:
+                                writev(sockno, vif_addr,fwd)
+                            except:
+                                traceback.print_exc(file=sys.stderr)
+            return 1
+
     
-    if options.pass_fd:
+    if options.protocol == 'fd':
+        if accept_packet or filter_init:
+            raise NotImplementedError, "--pass-fd and --filter are not compatible"
+        
         if options.pass_fd.startswith("base64:"):
             options.pass_fd = base64.b64decode(
                 options.pass_fd[len("base64:"):])
@@ -531,6 +742,8 @@ try:
         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
         retrydelay = 1.0
         for i in xrange(30):
+            if TERMINATE:
+                raise OSError, "Killed"
             try:
                 sock.connect(options.pass_fd)
                 break
@@ -545,73 +758,62 @@ try:
         
         # just wait forever
         def tun_fwd(tun, remote, **kw):
-            while not TERMINATE:
+            global TERMINATE
+            TERM = TERMINATE
+            while not TERM:
                 time.sleep(1)
         remote = None
-    elif options.mode.startswith('pl-gre'):
+    elif options.protocol == "gre":
+        if accept_packet or filter_init:
+            raise NotImplementedError, "--mode %s and --filter are not compatible" % (options.mode,)
+        
         # just wait forever
         def tun_fwd(tun, remote, **kw):
-            while not TERMINATE:
+            global TERMINATE
+            TERM = TERMINATE
+            while not TERM:
                 time.sleep(1)
-        remote = remaining_args[0]
-    elif options.udp:
+        remote = options.peer_addr
+    elif options.protocol == "udp":
         # connect to remote endpoint
-        if remaining_args and not remaining_args[0].startswith('-'):
-            print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
-            print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
-            rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            retrydelay = 1.0
-            for i in xrange(30):
-                try:
-                    rsock.bind((hostaddr,options.udp))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
-                    time.sleep(min(30.0,retrydelay))
-                    retrydelay *= 1.1
-            else:
-                rsock.bind((hostaddr,options.udp))
-            rsock.connect((remaining_args[0],options.port))
+        if options.peer_addr and options.peer_port:
+            rsock = tunchannel.udp_establish(TERMINATE, hostaddr, options.port, 
+                    options.peer_addr, options.peer_port)
+            remote = os.fdopen(rsock.fileno(), 'r+b', 0)
         else:
             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
             raise AssertionError, "Error: need a remote endpoint in UDP mode"
-        remote = os.fdopen(rsock.fileno(), 'r+b', 0)
-    else:
+    elif options.protocol == "tcp":
         # connect to remote endpoint
-        if remaining_args and not remaining_args[0].startswith('-'):
-            print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
-            rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            retrydelay = 1.0
-            for i in xrange(30):
-                try:
-                    rsock.connect((remaining_args[0],options.port))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
-                    time.sleep(min(30.0,retrydelay))
-                    retrydelay *= 1.1
-            else:
-                rsock.connect((remaining_args[0],options.port))
+        if options.peer_addr and options.peer_port:
+            rsock = tunchannel.tcp_establish(TERMINATE, hostaddr, options.port,
+                    options.peer_addr, options.peer_port)
+            remote = os.fdopen(rsock.fileno(), 'r+b', 0)
         else:
-            print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
-            lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            retrydelay = 1.0
-            for i in xrange(30):
-                try:
-                    lsock.bind((hostaddr,options.port))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
-                    time.sleep(min(30.0,retrydelay))
-                    retrydelay *= 1.1
-            else:
-                lsock.bind((hostaddr,options.port))
-            lsock.listen(1)
-            rsock,raddr = lsock.accept()
-        remote = os.fdopen(rsock.fileno(), 'r+b', 0)
+            print >>sys.stderr, "Error: need a remote endpoint in TCP mode"
+            raise AssertionError, "Error: need a remote endpoint in TCP mode"
+    else:
+        msg = "Error: Invalid protocol %s" % options.protocol
+        print >>sys.stderr, msg 
+        raise AssertionError, msg
+
+    if filter_init:
+        filter_local, filter_remote = filter_init()
+        
+        def filter_loop():
+            global TERMINATE
+            TERM = TERMINATE
+            run = filter_run
+            local = filter_local
+            remote = filter_remote
+            while not TERM:
+                run(local, remote)
+            filter_close(local, remote)
+            
+        filter_thread = threading.Thread(target=filter_loop)
+        filter_thread.start()
+    
+    print >>sys.stderr, "Connected"
 
     if not options.no_capture:
         # Launch a tcpdump subprocess, to capture and dump packets.
@@ -620,8 +822,6 @@ try:
             ["tcpdump","-l","-n","-i",tun_name, "-s", "4096"]
             + ["-w",options.pcap_capture,"-U"] * bool(options.pcap_capture) )
     
-    print >>sys.stderr, "Connected"
-    
     # Try to give us high priority
     try:
         os.nice(-20)
@@ -629,9 +829,52 @@ try:
         # Ignore errors, we might not have enough privileges,
         # or perhaps there is no os.nice support in the system
         pass
+    
+    if not filter_init:
+        tun_fwd(tun, remote,
+            reconnect = reconnect,
+            accept_local = accept_packet,
+            accept_remote = accept_packet,
+            bwlimit = options.bwlimit,
+            slowlocal = True)
+    else:
+        # Hm...
+        # ...ok, we need to:
+        #  1. Forward packets from tun to filter
+        #  2. Forward packets from remote to filter
+        #
+        # 1. needs TUN rate-limiting, while 
+        # 2. needs reconnection
+        #
+        # 1. needs ONLY TUN-side acceptance checks, while
+        # 2. needs ONLY remote-side acceptance checks
+        if isinstance(filter_local, ctypes.c_int):
+            filter_local_fd = filter_local.value
+        else:
+            filter_local_fd = filter_local
+        if isinstance(filter_remote, ctypes.c_int):
+            filter_remote_fd = filter_remote.value
+        else:
+            filter_remote_fd = filter_remote
 
-    tun_fwd(tun, remote,
-        reconnect = reconnect)
+        def localside():
+            tun_fwd(tun, filter_local_fd,
+                accept_local = accept_packet,
+                slowlocal = True)
+        
+        def remoteside():
+            tun_fwd(filter_remote_fd, remote,
+                reconnect = reconnect,
+                accept_remote = accept_packet,
+                bwlimit = options.bwlimit,
+                slowlocal = False)
+        
+        localthread = threading.Thread(target=localside)
+        remotethread = threading.Thread(target=remoteside)
+        localthread.start()
+        remotethread.start()
+        localthread.join()
+        remotethread.join()
 
 finally:
     try:
@@ -641,6 +884,19 @@ finally:
         pass
     
     # tidy shutdown in every case - swallow exceptions
+    TERMINATE.append(None)
+    
+    if mcastthread:
+        try:
+            mcastthread.stop()
+        except:
+            pass
+    
+    if filter_thread:
+        try:
+            filter_thread.join()
+        except:
+            pass
 
     try:
         if tcpdump: