TunChannel optimization: do not even format packets if there is no stderr output...
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
index c416557..1126649 100644 (file)
@@ -3,7 +3,9 @@ import sys
 import socket
 import fcntl
 import os
+import os.path
 import select
+import signal
 
 import struct
 import ctypes
@@ -12,6 +14,10 @@ import threading
 import subprocess
 import re
 import functools
+import time
+import base64
+
+import tunchannel
 
 tun_name = 'tun0'
 tun_path = '/dev/net/tun'
@@ -33,6 +39,13 @@ parser.add_option(
     "-p", "--port", dest="port", metavar="PORT", type="int",
     default = 15000,
     help = "Peering TCP port to connect or listen to.")
+parser.add_option(
+    "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
+    default = None,
+    help = "Path to a unix-domain socket to pass the TUN file descriptor to. "
+           "If given, all other connectivity options are ignored, tun_connect will "
+           "simply wait to be killed after passing the file descriptor, and it will be "
+           "the receiver's responsability to handle the tunneling.")
 
 parser.add_option(
     "-m", "--mode", dest="mode", metavar="MODE",
@@ -71,6 +84,29 @@ parser.add_option(
     default = None,
     help = 
         "See mode. This specifies the interface's transmission queue length. " )
+parser.add_option(
+    "-u", "--udp", dest="udp", metavar="PORT", type="int",
+    default = None,
+    help = 
+        "Bind to the specified UDP port locally, and send UDP datagrams to the "
+        "remote endpoint, creating a tunnel through UDP rather than TCP." )
+parser.add_option(
+    "-k", "--key", dest="cipher_key", metavar="KEY",
+    default = None,
+    help = 
+        "Specify a symmetric encryption key with which to protect packets across "
+        "the tunnel. python-crypto must be installed on the system." )
+parser.add_option(
+    "-N", "--no-capture", dest="no_capture", 
+    action = "store_true",
+    default = False,
+    help = "If specified, packets won't be logged to standard output "
+           "(default is to log them to standard output). " )
+parser.add_option(
+    "-c", "--pcap-capture", dest="pcap_capture", metavar="FILE",
+    default = None,
+    help = "If specified, packets won't be logged to standard output, "
+           "but dumped to a pcap-formatted trace in the specified file. " )
 
 (options, remaining_args) = parser.parse_args(sys.argv[1:])
 
@@ -88,6 +124,42 @@ IFNAMSIZ = 0x00000010
 IFREQ_SZ = 0x00000028
 FIONREAD = 0x0000541b
 
+class HostLock(object):
+    # This class is used as a lock to prevent concurrency issues with more
+    # than one instance of netns running in the same machine. Both in 
+    # different processes or different threads.
+    taken = False
+    processcond = threading.Condition()
+    
+    def __init__(self, lockfile):
+        processcond = self.__class__.processcond
+        
+        processcond.acquire()
+        try:
+            # It's not reentrant
+            while self.__class__.taken:
+                processcond.wait()
+            self.__class__.taken = True
+        finally:
+            processcond.release()
+        
+        self.lockfile = lockfile
+        fcntl.flock(self.lockfile, fcntl.LOCK_EX)
+    
+    def __del__(self):
+        processcond = self.__class__.processcond
+        
+        processcond.acquire()
+        try:
+            if not self.lockfile.closed:
+                fcntl.flock(self.lockfile, fcntl.LOCK_UN)
+            
+            # It's not reentrant
+            self.__class__.taken = False
+            processcond.notify()
+        finally:
+            processcond.release()
+
 def ifnam(x):
     return x+'\x00'*(IFNAMSIZ-len(x))
 
@@ -205,7 +277,7 @@ def vif_stop(tun_path, tun_name):
     
     
 def pl_tuntap_alloc(kind, tun_path, tun_name):
-    tunalloc_so = ctypes.cdll.LoadLibrary("./vsys-scripts/support/tunalloc.so")
+    tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
     kind = {"tun":IFF_TUN,
             "tap":IFF_TAP}[kind]
@@ -214,142 +286,81 @@ def pl_tuntap_alloc(kind, tun_path, tun_name):
     return str(fd), name
 
 def pl_vif_start(tun_path, tun_name):
+    out = []
+    def outreader():
+        stdout = open("/vsys/vif_up.out","r")
+        out.append(stdout.read())
+        stdout.close()
+        time.sleep(1)
+
+    # Serialize access to vsys
+    lockfile = open("/tmp/nepi-tun-connect.lock", "a")
+    lock = HostLock(lockfile)
+
     stdin = open("/vsys/vif_up.in","w")
-    stdout = open("/vsys/vif_up.out","r")
+
+    t = threading.Thread(target=outreader)
+    t.start()
+    
     stdin.write(tun_name+"\n")
     stdin.write(options.vif_addr+"\n")
     stdin.write(str(options.vif_mask)+"\n")
     if options.vif_snat:
         stdin.write("snat=1\n")
+    if options.vif_pointopoint:
+        stdin.write("pointopoint=%s\n" % (options.vif_pointopoint,))
     if options.vif_txqueuelen is not None:
         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
     stdin.close()
-    out = stdout.read()
-    stdout.close()
+    
+    t.join()
+    out = ''.join(out)
     if out.strip():
         print >>sys.stderr, out
+    
+    del lock, lockfile
 
+def pl_vif_stop(tun_path, tun_name):
+    out = []
+    def outreader():
+        stdout = open("/vsys/vif_down.out","r")
+        out.append(stdout.read())
+        stdout.close()
+        time.sleep(1)
 
-def ipfmt(ip):
-    ipbytes = map(ord,ip.decode("hex"))
-    return '.'.join(map(str,ipbytes))
-
-def formatPacket(packet):
-    packet = packet.encode("hex")
-    return '-'.join( (
-        packet[0:1], #version
-        packet[1:2], #header length
-        packet[2:4], #diffserv/ECN
-        packet[4:8], #total length
-        packet[8:12], #ident
-        packet[12:16], #flags/fragment offs
-        packet[16:18], #ttl
-        packet[18:20], #ip-proto
-        packet[20:24], #checksum
-        ipfmt(packet[24:32]), # src-ip
-        ipfmt(packet[32:40]), # dst-ip
-        packet[40:48] if (int(packet[1]) > 5) else "", # options
-        packet[48:] if (int(packet[1]) > 5) else packet[40:], # payload
-    ) )
-
-def packetReady(buf):
-    if len(buf) < 4:
-        return False
-    _,totallen = struct.unpack('HH',buf[:4])
-    totallen = socket.htons(totallen)
-    return len(buf) >= totallen
-
-def pullPacket(buf):
-    _,totallen = struct.unpack('HH',buf[:4])
-    totallen = socket.htons(totallen)
-    return buf[:totallen], buf[totallen:]
-
-def etherStrip(buf):
-    if len(buf) < 14:
-        return buf
-    if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
-        # tagged ethernet frame
-        return buf[18:]
-    elif buf[12:14] == '\x08\x00':
-        # untagged ethernet frame
-        return buf[14:]
-    else:
-        return buf
-
-def etherWrap(packet):
-    return (
-        "\x00"*6*2 # bogus src and dst mac
-        +"\x08\x00" # IPv4
-        +packet # payload
-        +"\x00"*4 # bogus crc
-    )
+    # Serialize access to vsys
+    lockfile = open("/tmp/nepi-tun-connect.lock", "a")
+    lock = HostLock(lockfile)
 
-def piStrip(buf):
-    if len(buf) < 4:
-        return buf
-    else:
-        return buf[4:]
+    stdin = open("/vsys/vif_down.in","w")
     
-def piWrap(buf):
-    return (
-        "\x00\x00\x08\x00" # PI: 16 bits flags, 16 bits proto
-        +buf
-    )
+    t = threading.Thread(target=outreader)
+    t.start()
+    
+    stdin.write(tun_name+"\n")
+    stdin.close()
+    
+    t.join()
+    out = ''.join(out)
+    if out.strip():
+        print >>sys.stderr, out
+    
+    del lock, lockfile
+
 
-abortme = False
 def tun_fwd(tun, remote):
-    global abortme
+    global TERMINATE
     
     # in PL mode, we cannot strip PI structs
     # so we'll have to handle them
-    with_pi = options.mode.startswith('pl-')
-    ether_mode = tun_name.startswith('tap')
-    
-    # Limited frame parsing, to preserve packet boundaries.
-    # Which is needed, since /dev/net/tun is unbuffered
-    fwbuf = ""
-    bkbuf = ""
-    while not abortme:
-        wset = []
-        if packetReady(bkbuf):
-            wset.append(tun)
-        if packetReady(fwbuf):
-            wset.append(remote)
-        rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
-        
-        # check for errors
-        if errs:
-            break
-        
-        # check to see if we can write
-        if remote in wrdy and packetReady(fwbuf):
-            packet, fwbuf = pullPacket(fwbuf)
-            os.write(remote.fileno(), packet)
-            print >>sys.stderr, '>', formatPacket(packet)
-            if ether_mode:
-                # strip ethernet crc
-                fwbuf = fwbuf[4:]
-        if tun in wrdy and packetReady(bkbuf):
-            packet, bkbuf = pullPacket(bkbuf)
-            formatted = formatPacket(packet)
-            if ether_mode:
-                packet = etherWrap(packet)
-            if with_pi:
-                packet = piWrap(packet)
-            os.write(tun.fileno(), packet)
-            print >>sys.stderr, '<', formatted
-        
-        # check incoming data packets
-        if tun in rdrdy:
-            packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
-            if with_pi:
-                packet = piStrip(packet)
-            fwbuf += packet
-            if ether_mode:
-                fwbuf = etherStrip(fwbuf)
-        if remote in rdrdy:
-            packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
-            bkbuf += packet
+    tunchannel.tun_fwd(tun, remote,
+        with_pi = options.mode.startswith('pl-'),
+        ether_mode = tun_name.startswith('tap'),
+        cipher_key = options.cipher_key,
+        udp = options.udp,
+        TERMINATE = TERMINATE,
+        stderr = None
+    )
 
 
 
@@ -374,12 +385,12 @@ MODEINFO = {
                   tunopen=tunopen, tunclose=tunclose,
                   dealloc=nop,
                   start=pl_vif_start,
-                  stop=nop),
+                  stop=pl_vif_stop),
     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
                   tunopen=tunopen, tunclose=tunclose,
                   dealloc=nop,
                   start=pl_vif_start,
-                  stop=nop),
+                  stop=pl_vif_stop),
 }
     
 tun_path = options.tun_path
@@ -401,23 +412,125 @@ except:
     raise
 
 
+# Trak SIGTERM, and set global termination flag instead of dying
+TERMINATE = []
+def _finalize(sig,frame):
+    global TERMINATE
+    TERMINATE.append(None)
+signal.signal(signal.SIGTERM, _finalize)
+
 try:
-    # connect to remote endpoint
-    if remaining_args and not remaining_args[0].startswith('-'):
-        print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
-        rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-        rsock.connect((remaining_args[0],options.port))
+    tcpdump = None
+    
+    if options.pass_fd:
+        if options.pass_fd.startswith("base64:"):
+            options.pass_fd = base64.b64decode(
+                options.pass_fd[len("base64:"):])
+            options.pass_fd = os.path.expandvars(options.pass_fd)
+        
+        print >>sys.stderr, "Sending FD to: %r" % (options.pass_fd,)
+        
+        # send FD to whoever wants it
+        import passfd
+        
+        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+        retrydelay = 1.0
+        for i in xrange(30):
+            try:
+                sock.connect(options.pass_fd)
+                break
+            except socket.error:
+                # wait a while, retry
+                print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
+                time.sleep(min(30.0,retrydelay))
+                retrydelay *= 1.1
+        else:
+            sock.connect(options.pass_fd)
+        passfd.sendfd(sock, tun.fileno(), '0')
+        
+        # just wait forever
+        def tun_fwd(tun, remote):
+            while not TERMINATE:
+                time.sleep(1)
+        remote = None
+    elif options.udp:
+        # connect to remote endpoint
+        if remaining_args and not remaining_args[0].startswith('-'):
+            print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
+            print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
+            rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+            retrydelay = 1.0
+            for i in xrange(30):
+                try:
+                    rsock.bind((hostaddr,options.udp))
+                    break
+                except socket.error:
+                    # wait a while, retry
+                    print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+                    time.sleep(min(30.0,retrydelay))
+                    retrydelay *= 1.1
+            else:
+                rsock.bind((hostaddr,options.udp))
+            rsock.connect((remaining_args[0],options.port))
+        else:
+            print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
+            raise AssertionError, "Error: need a remote endpoint in UDP mode"
+        remote = os.fdopen(rsock.fileno(), 'r+b', 0)
     else:
-        print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
-        lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-        lsock.bind((hostaddr,options.port))
-        lsock.listen(1)
-        rsock,raddr = lsock.accept()
-    remote = os.fdopen(rsock.fileno(), 'r+b', 0)
-
+        # connect to remote endpoint
+        if remaining_args and not remaining_args[0].startswith('-'):
+            print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
+            rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            retrydelay = 1.0
+            for i in xrange(30):
+                try:
+                    rsock.connect((remaining_args[0],options.port))
+                    break
+                except socket.error:
+                    # wait a while, retry
+                    print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
+                    time.sleep(min(30.0,retrydelay))
+                    retrydelay *= 1.1
+            else:
+                rsock.connect((remaining_args[0],options.port))
+        else:
+            print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
+            lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            retrydelay = 1.0
+            for i in xrange(30):
+                try:
+                    lsock.bind((hostaddr,options.port))
+                    break
+                except socket.error:
+                    # wait a while, retry
+                    print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+                    time.sleep(min(30.0,retrydelay))
+                    retrydelay *= 1.1
+            else:
+                lsock.bind((hostaddr,options.port))
+            lsock.listen(1)
+            rsock,raddr = lsock.accept()
+        remote = os.fdopen(rsock.fileno(), 'r+b', 0)
+
+    if not options.no_capture:
+        # Launch a tcpdump subprocess, to capture and dump packets.
+        # Make sure to catch sigterm and kill the tcpdump as well
+        tcpdump = subprocess.Popen(
+            ["tcpdump","-l","-n","-i",tun_name, "-s", "4096"]
+            + ["-w",options.pcap_capture,"-U"] * bool(options.pcap_capture) )
+    
     print >>sys.stderr, "Connected"
+    
+    # Try to give us high priority
+    try:
+        os.nice(-20)
+    except:
+        # Ignore errors, we might not have enough privileges,
+        # or perhaps there is no os.nice support in the system
+        pass
 
     tun_fwd(tun, remote)
+
 finally:
     try:
         print >>sys.stderr, "Shutting down..."
@@ -426,19 +539,28 @@ finally:
         pass
     
     # tidy shutdown in every case - swallow exceptions
+
     try:
-        modeinfo['tunclose'](tun_path, tun_name, tun)
+        if tcpdump:
+            os.kill(tcpdump.pid, signal.SIGTERM)
+            tcpdump.wait()
     except:
         pass
-        
+
     try:
         modeinfo['stop'](tun_path, tun_name)
     except:
         pass
 
+    try:
+        modeinfo['tunclose'](tun_path, tun_name, tun)
+    except:
+        pass
+        
     try:
         modeinfo['dealloc'](tun_path, tun_name)
     except:
         pass
-
+    
+    print >>sys.stderr, "TERMINATED GRACEFULLY"