TUN/TAP optimizations, generizations, and a benchmark
authorClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Wed, 3 Aug 2011 11:47:39 +0000 (13:47 +0200)
committerClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Wed, 3 Aug 2011 11:47:39 +0000 (13:47 +0200)
src/nepi/testbeds/planetlab/scripts/tun_connect.py
src/nepi/util/server.py
src/nepi/util/tunchannel.py
tunbench.py [new file with mode: 0644]

index 1126649..8f79430 100644 (file)
@@ -348,9 +348,12 @@ def pl_vif_stop(tun_path, tun_name):
     del lock, lockfile
 
 
-def tun_fwd(tun, remote):
+def tun_fwd(tun, remote, reconnect = None):
     global TERMINATE
     
+    tunqueue = options.vif_txqueuelen or 1000
+    tunkqueue = 500
+    
     # in PL mode, we cannot strip PI structs
     # so we'll have to handle them
     tunchannel.tun_fwd(tun, remote,
@@ -359,7 +362,10 @@ def tun_fwd(tun, remote):
         cipher_key = options.cipher_key,
         udp = options.udp,
         TERMINATE = TERMINATE,
-        stderr = None
+        stderr = None,
+        reconnect = reconnect,
+        tunqueue = tunqueue,
+        tunkqueue = tunkqueue
     )
 
 
@@ -421,6 +427,7 @@ signal.signal(signal.SIGTERM, _finalize)
 
 try:
     tcpdump = None
+    reconnect = None
     
     if options.pass_fd:
         if options.pass_fd.startswith("base64:"):
@@ -529,7 +536,8 @@ try:
         # or perhaps there is no os.nice support in the system
         pass
 
-    tun_fwd(tun, remote)
+    tun_fwd(tun, remote,
+        reconnect = reconnect)
 
 finally:
     try:
index 73fa84c..833d16b 100644 (file)
@@ -580,6 +580,19 @@ def popen_scp(source, dest,
         if TRACE:
             print "scp", source, dest
         
+        if isinstance(source, file) and source.tell() == 0:
+            source = source.name
+        elif hasattr(source, 'read'):
+            tmp = tempfile.NamedTemporaryFile()
+            while True:
+                buf = source.read(65536)
+                if buf:
+                    tmp.write(buf)
+                else:
+                    break
+            tmp.seek(0)
+            source = tmp.name
+        
         if isinstance(source, file) or isinstance(dest, file) \
                 or hasattr(source, 'read')  or hasattr(dest, 'write'):
             assert not recursive
index 2ec2d11..f36211e 100644 (file)
@@ -5,6 +5,10 @@ import struct
 import socket
 import threading
 import errno
+import fcntl
+import traceback
+import functools
+import collections
 
 def ipfmt(ip):
     ipbytes = map(ord,ip.decode("hex"))
@@ -18,8 +22,7 @@ tagtype = {
     '8864' : 'PPPoE',
     '86dd' : 'ipv6',
 }
-def etherProto(packet):
-    packet = packet.encode("hex")
+def etherProto(packet, len=len):
     if len(packet) > 14:
         if packet[12:14] == "\x81\x00":
             # tagged
@@ -78,23 +81,40 @@ def formatPacket(packet, ether_mode):
             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
         ) ) )
 
-def packetReady(buf, ether_mode):
-    if len(buf) < 4:
+def _packetReady(buf, ether_mode=False, len=len):
+    if not buf:
         return False
-    elif ether_mode:
-        return True
-    else:
-        _,totallen = struct.unpack('HH',buf[:4])
-        totallen = socket.htons(totallen)
-        return len(buf) >= totallen
+        
+    rv = False
+    while not rv:
+        if len(buf[0]) < 4:
+            rv = False
+        elif ether_mode:
+            rv = True
+        else:
+            _,totallen = struct.unpack('HH',buf[0][:4])
+            totallen = socket.htons(totallen)
+            rv = len(buf[0]) >= totallen
+        if not rv and len(buf) > 1:
+            nbuf = ''.join(buf)
+            buf.clear()
+            buf.append(nbuf)
+        else:
+            return rv
+    return rv
 
-def pullPacket(buf, ether_mode):
+def _pullPacket(buf, ether_mode=False, len=len):
     if ether_mode:
-        return buf, ""
+        return buf.popleft()
     else:
-        _,totallen = struct.unpack('HH',buf[:4])
+        _,totallen = struct.unpack('HH',buf[0][:4])
         totallen = socket.htons(totallen)
-        return buf[:totallen], buf[totallen:]
+        if len(buf[0]) < totallen:
+            rv = buf[0][:totallen]
+            buf[0] = buf[0][totallen:]
+        else:
+            rv = buf.popleft()
+        return rv
 
 def etherStrip(buf):
     if len(buf) < 14:
@@ -109,39 +129,42 @@ def etherStrip(buf):
         return ""
 
 def etherWrap(packet):
-    return (
+    return ''.join((
         "\x00"*6*2 # bogus src and dst mac
-        +"\x08\x00" # IPv4
-        +packet # payload
-        +"\x00"*4 # bogus crc
-    )
+        +"\x08\x00", # IPv4
+        packet, # payload
+        "\x00"*4, # bogus crc
+    ))
 
-def piStrip(buf):
+def piStrip(buf, len=len):
     if len(buf) < 4:
         return buf
     else:
         return buf[4:]
     
-def piWrap(buf, ether_mode):
+def piWrap(buf, ether_mode, etherProto=etherProto):
     if ether_mode:
         proto = etherProto(buf)
     else:
         proto = "\x08\x00"
-    return (
-        "\x00\x00" # PI: 16 bits flags
-        +proto # 16 bits proto
-        +buf
-    )
+    return ''.join((
+        "\x00\x00", # PI: 16 bits flags
+        proto, # 16 bits proto
+        buf,
+    ))
+
+_padmap = [ chr(padding) * padding for padding in xrange(127) ]
+del padding
 
-def encrypt(packet, crypter):
+def encrypt(packet, crypter, len=len, padmap=_padmap):
     # pad
     padding = crypter.block_size - len(packet) % crypter.block_size
-    packet += chr(padding) * padding
+    packet += padmap[padding]
     
     # encrypt
     return crypter.encrypt(packet)
 
-def decrypt(packet, crypter):
+def decrypt(packet, crypter, ord=ord):
     # decrypt
     packet = crypter.decrypt(packet)
     
@@ -154,22 +177,39 @@ def decrypt(packet, crypter):
     
     return packet
 
+def nonblock(fd):
+    try:
+        fl = fcntl.fcntl(fd, fcntl.F_GETFL)
+        fl |= os.O_NONBLOCK
+        fcntl.fcntl(fd, fcntl.F_SETFL, fl)
+        return True
+    except:
+        traceback.print_exc(file=sys.stderr)
+        # Just ignore
+        return False
 
-def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr):
+def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
+        len=len, max=max, OSError=OSError, cipher='AES'):
     crypto_mode = False
     try:
         if cipher_key:
-            import Crypto.Cipher.AES
+            import Crypto.Cipher
             import hashlib
+            __import__('Crypto.Cipher.'+cipher)
             
+            ciphername = cipher
+            cipher = getattr(Crypto.Cipher, cipher)
             hashed_key = hashlib.sha256(cipher_key).digest()
-            crypter = Crypto.Cipher.AES.new(
+            if getattr(cipher, 'key_size'):
+                hashed_key = hashed_key[:cipher.key_size]
+            elif ciphername == 'DES3':
+                hashed_key = hashed_key[:24]
+            crypter = cipher.new(
                 hashed_key, 
-                Crypto.Cipher.AES.MODE_ECB)
+                cipher.MODE_ECB)
             crypto_mode = True
     except:
-        import traceback
-        traceback.print_exc()
+        traceback.print_exc(file=sys.stderr)
         crypto_mode = False
         crypter = None
 
@@ -179,69 +219,195 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         else:
             print >>stderr, "Packets are transmitted in PLAINTEXT"
     
+    if hasattr(remote, 'fileno'):
+        remote_fd = remote.fileno()
+        if rwrite is None:
+            def rwrite(remote, packet, os_write=os.write):
+                return os_write(remote_fd, packet)
+        if rread is None:
+            def rread(remote, maxlen, os_read=os.read):
+                return os_read(remote_fd, maxlen)
+    
+    rnonblock = nonblock(remote)
+    tnonblock = nonblock(tun)
+    
     # Limited frame parsing, to preserve packet boundaries.
     # Which is needed, since /dev/net/tun is unbuffered
-    fwbuf = ""
-    bkbuf = ""
+    maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
+    tunhurry = max(0,maxbkbuf/2)
+    fwbuf = collections.deque()
+    bkbuf = collections.deque()
+    if ether_mode:
+        packetReady = bool
+        pullPacket = collections.deque.popleft
+    else:
+        packetReady = _packetReady
+        pullPacket = _pullPacket
+    tunfd = tun.fileno()
+    os_read = os.read
+    os_write = os.write
     while not TERMINATE:
         wset = []
-        if packetReady(bkbuf, ether_mode):
+        if packetReady(bkbuf):
             wset.append(tun)
-        if packetReady(fwbuf, ether_mode):
+        if packetReady(fwbuf):
             wset.append(remote)
         
+        rset = []
+        if len(fwbuf) < maxfwbuf:
+            rset.append(tun)
+        if len(bkbuf) < maxbkbuf:
+            rset.append(remote)
+        
         try:
-            rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
+            rdrdy, wrdy, errs = select.select(rset,wset,(tun,remote),1)
         except select.error, e:
             if e.args[0] == errno.EINTR:
                 # just retry
                 continue
-        
+
         # check for errors
         if errs:
-            break
+            if reconnect is not None and remote in errs and tun not in errs:
+                remote = reconnect()
+                if hasattr(remote, 'fileno'):
+                    remote_fd = remote.fileno()
+            elif udp and remote in errs and tun not in errs:
+                # In UDP mode, those are always transient errors
+                pass
+            else:
+                break
         
         # check to see if we can write
-        if remote in wrdy and packetReady(fwbuf, ether_mode):
-            packet, fwbuf = pullPacket(fwbuf, ether_mode)
+        #rr = wr = rt = wt = 0
+        if remote in wrdy:
             try:
-                if crypto_mode:
-                    enpacket = encrypt(packet, crypter)
-                else:
-                    enpacket = packet
-                os.write(remote.fileno(), enpacket)
+                try:
+                    while True:
+                        packet = pullPacket(fwbuf)
+
+                        if crypto_mode:
+                            enpacket = encrypt(packet, crypter)
+                        else:
+                            enpacket = packet
+                        
+                        # try twice - sometimes it barks the first time,
+                        # due to ICMP Port Unreachable packets from previous writes
+                        try:
+                            rwrite(remote, enpacket)
+                        except socket.error:
+                            rwrite(remote, enpacket)
+                        #wr += 1
+                        
+                        if stderr is not None:
+                            print >>stderr, '>', formatPacket(packet, ether_mode)
+                        
+                        if not rnonblock or not packetReady(fwbuf):
+                            break
+                except OSError,e:
+                    # This except handles the entire While block on PURPOSE
+                    # as an optimization (setting a try/except block is expensive)
+                    # The only operation that can raise this exception is rwrite
+                    if e.errno == os.errno.EWOULDBLOCK:
+                        # re-schedule packet
+                        fwbuf.insert(0, packet)
+                    else:
+                        raise
             except:
-                if not udp:
+                if reconnect is not None:
+                    # in UDP mode, sometimes connected sockets can return a connection refused.
+                    # Give the caller a chance to reconnect
+                    remote = reconnect()
+                    if hasattr(remote, 'fileno'):
+                        remote_fd = remote.fileno()
+                elif not udp:
                     # in UDP mode, we ignore errors - packet loss man...
                     raise
-            if stderr is not None:
-                print >>stderr, '>', formatPacket(packet, ether_mode)
-        if tun in wrdy and packetReady(bkbuf, ether_mode):
-            packet, bkbuf = pullPacket(bkbuf, ether_mode)
-            if stderr is not None:
-                formatted = formatPacket(packet, ether_mode)
-            if with_pi:
-                packet = piWrap(packet, ether_mode)
-            os.write(tun.fileno(), packet)
-            if stderr is not None:
-                print >>stderr, '<', formatted
+                traceback.print_exc(file=sys.stderr)
+        if tun in wrdy:
+            try:
+                while True:
+                    packet = pullPacket(bkbuf)
+                    if stderr is not None:
+                        formatted = formatPacket(packet, ether_mode)
+                    if with_pi:
+                        packet = piWrap(packet, ether_mode)
+                    os_write(tunfd, packet)
+                    #wt += 1
+                    if stderr is not None:
+                        print >>stderr, '<', formatted
+                    
+                    # Do not inject packets into the TUN faster than they arrive, unless we're falling
+                    # behind. TUN devices discard packets if their queue is full (tunkqueue), but they
+                    # don't block either (they're always ready to write), so if we flood the device 
+                    # we'll have high packet loss.
+                    if not tnonblock or len(bkbuf) < tunhurry or not packetReady(bkbuf):
+                        break
+            except OSError,e:
+                # This except handles the entire While block on PURPOSE
+                # as an optimization (setting a try/except block is expensive)
+                # The only operation that can raise this exception is os_write
+                if e.errno == os.errno.EWOULDBLOCK:
+                    # re-schedule packet
+                    bkbuf.insert(0, packet)
+                else:
+                    raise
         
         # check incoming data packets
         if tun in rdrdy:
-            packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
-            if with_pi:
-                packet = piStrip(packet)
-            fwbuf += packet
+            try:
+                while True:
+                    packet = os_read(tunfd,2000) # tun.read blocks until it gets 2k!
+                    #rt += 1
+                    if with_pi:
+                        packet = piStrip(packet)
+                    fwbuf.append(packet)
+                    
+                    if not tnonblock or len(fwbuf) >= maxfwbuf:
+                        break
+            except OSError,e:
+                # This except handles the entire While block on PURPOSE
+                # as an optimization (setting a try/except block is expensive)
+                # The only operation that can raise this exception is os_read
+                if e.errno != os.errno.EWOULDBLOCK:
+                    raise
         if remote in rdrdy:
             try:
-                packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
-                if crypto_mode:
-                    packet = decrypt(packet, crypter)
-            except:
-                if not udp:
+                try:
+                    while True:
+                        # Try twice, sometimes it barks the first time, 
+                        # due to ICMP Port Unreachable packets from previous writes
+                        try:
+                            packet = rread(remote,2000)
+                        except socket.error:
+                            packet = rread(remote,2000)
+                        #rr += 1
+                        
+                        if crypto_mode:
+                            packet = decrypt(packet, crypter)
+                        bkbuf.append(packet)
+                        
+                        if not rnonblock or len(bkbuf) >= maxbkbuf:
+                            break
+                except OSError,e:
+                    # This except handles the entire While block on PURPOSE
+                    # as an optimization (setting a try/except block is expensive)
+                    # The only operation that can raise this exception is rread
+                    if e.errno != os.errno.EWOULDBLOCK:
+                        raise
+            except Exception, e:
+                if reconnect is not None:
+                    # in UDP mode, sometimes connected sockets can return a connection refused
+                    # on read. Give the caller a chance to reconnect
+                    remote = reconnect()
+                    if hasattr(remote, 'fileno'):
+                        remote_fd = remote.fileno()
+                elif not udp:
                     # in UDP mode, we ignore errors - packet loss man...
                     raise
-            bkbuf += packet
+                traceback.print_exc(file=sys.stderr)
+        
+        #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
 
 
 
diff --git a/tunbench.py b/tunbench.py
new file mode 100644 (file)
index 0000000..b8ac97c
--- /dev/null
@@ -0,0 +1,55 @@
+import os
+import sys
+import threading
+import time
+import cProfile
+import pstats
+
+from nepi.util import tunchannel
+
+remote = open("/dev/zero","r+b")
+tun = open("/dev/zero","r+b")
+
+def rwrite(remote, packet, remote_fd = remote.fileno(), os_write=os.write, len=len):
+    global bytes
+    bytes += len(packet)
+    return os_write(remote_fd, packet)
+
+def rread(remote, maxlen, remote_fd = remote.fileno(), os_read=os.read):
+    global bytes
+    rv = os_read(remote_fd, maxlen)
+    bytes += len(rv)
+    return rv
+
+def test(cipher, passphrase):
+   TERMINATE = []
+   def stopme():
+       time.sleep(100)
+       TERMINATE.append(None)
+   t = threading.Thread(target=stopme)
+   t.start()
+   tunchannel.tun_fwd(tun, remote, True, True, passphrase, True, TERMINATE, None, tunkqueue=500,
+        rwrite = rwrite, rread = rread, cipher=cipher)
+
+# Swallow exceptions on decryption
+def decrypt(packet, crypter, super=tunchannel.decrypt):
+    try:
+        return super(packet, crypter)
+    except:
+        return packet
+tunchannel.decrypt = decrypt
+
+for cipher in (None, 'AES', 'Blowfish', 'DES', 'DES3'):
+    if cipher is None:
+        passphrase = None
+    else:
+        passphrase = 'Abracadabra'
+    bytes = 0
+    cProfile.runctx('test(%r,%r)' % (cipher, passphrase),globals(),locals(),'tunchannel.%s.profile' % (cipher,))
+    
+    print "Profile (%s):" % ( cipher, )
+    pstats.Stats('tunchannel.%s.profile' % cipher).strip_dirs().sort_stats('time').print_stats()
+    
+    print "Bandwidth (%s): %.4fMb/s" % ( cipher, bytes / 200.0 * 8 / 2**20, )
+
+