From: Claudio-Daniel Freire Date: Wed, 3 Aug 2011 11:47:39 +0000 (+0200) Subject: TUN/TAP optimizations, generizations, and a benchmark X-Git-Tag: nepi-3.0.0~323 X-Git-Url: http://git.onelab.eu/?a=commitdiff_plain;h=d7f2fcb17b385962eabfbc24eca3cda987134c76;p=nepi.git TUN/TAP optimizations, generizations, and a benchmark --- diff --git a/src/nepi/testbeds/planetlab/scripts/tun_connect.py b/src/nepi/testbeds/planetlab/scripts/tun_connect.py index 11266498..8f794309 100644 --- a/src/nepi/testbeds/planetlab/scripts/tun_connect.py +++ b/src/nepi/testbeds/planetlab/scripts/tun_connect.py @@ -348,9 +348,12 @@ def pl_vif_stop(tun_path, tun_name): del lock, lockfile -def tun_fwd(tun, remote): +def tun_fwd(tun, remote, reconnect = None): global TERMINATE + tunqueue = options.vif_txqueuelen or 1000 + tunkqueue = 500 + # in PL mode, we cannot strip PI structs # so we'll have to handle them tunchannel.tun_fwd(tun, remote, @@ -359,7 +362,10 @@ def tun_fwd(tun, remote): cipher_key = options.cipher_key, udp = options.udp, TERMINATE = TERMINATE, - stderr = None + stderr = None, + reconnect = reconnect, + tunqueue = tunqueue, + tunkqueue = tunkqueue ) @@ -421,6 +427,7 @@ signal.signal(signal.SIGTERM, _finalize) try: tcpdump = None + reconnect = None if options.pass_fd: if options.pass_fd.startswith("base64:"): @@ -529,7 +536,8 @@ try: # or perhaps there is no os.nice support in the system pass - tun_fwd(tun, remote) + tun_fwd(tun, remote, + reconnect = reconnect) finally: try: diff --git a/src/nepi/util/server.py b/src/nepi/util/server.py index 73fa84c0..833d16b2 100644 --- a/src/nepi/util/server.py +++ b/src/nepi/util/server.py @@ -580,6 +580,19 @@ def popen_scp(source, dest, if TRACE: print "scp", source, dest + if isinstance(source, file) and source.tell() == 0: + source = source.name + elif hasattr(source, 'read'): + tmp = tempfile.NamedTemporaryFile() + while True: + buf = source.read(65536) + if buf: + tmp.write(buf) + else: + break + tmp.seek(0) + source = tmp.name + if isinstance(source, file) or isinstance(dest, file) \ or hasattr(source, 'read') or hasattr(dest, 'write'): assert not recursive diff --git a/src/nepi/util/tunchannel.py b/src/nepi/util/tunchannel.py index 2ec2d115..f36211e3 100644 --- a/src/nepi/util/tunchannel.py +++ b/src/nepi/util/tunchannel.py @@ -5,6 +5,10 @@ import struct import socket import threading import errno +import fcntl +import traceback +import functools +import collections def ipfmt(ip): ipbytes = map(ord,ip.decode("hex")) @@ -18,8 +22,7 @@ tagtype = { '8864' : 'PPPoE', '86dd' : 'ipv6', } -def etherProto(packet): - packet = packet.encode("hex") +def etherProto(packet, len=len): if len(packet) > 14: if packet[12:14] == "\x81\x00": # tagged @@ -78,23 +81,40 @@ def formatPacket(packet, ether_mode): packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload ) ) ) -def packetReady(buf, ether_mode): - if len(buf) < 4: +def _packetReady(buf, ether_mode=False, len=len): + if not buf: return False - elif ether_mode: - return True - else: - _,totallen = struct.unpack('HH',buf[:4]) - totallen = socket.htons(totallen) - return len(buf) >= totallen + + rv = False + while not rv: + if len(buf[0]) < 4: + rv = False + elif ether_mode: + rv = True + else: + _,totallen = struct.unpack('HH',buf[0][:4]) + totallen = socket.htons(totallen) + rv = len(buf[0]) >= totallen + if not rv and len(buf) > 1: + nbuf = ''.join(buf) + buf.clear() + buf.append(nbuf) + else: + return rv + return rv -def pullPacket(buf, ether_mode): +def _pullPacket(buf, ether_mode=False, len=len): if ether_mode: - return buf, "" + return buf.popleft() else: - _,totallen = struct.unpack('HH',buf[:4]) + _,totallen = struct.unpack('HH',buf[0][:4]) totallen = socket.htons(totallen) - return buf[:totallen], buf[totallen:] + if len(buf[0]) < totallen: + rv = buf[0][:totallen] + buf[0] = buf[0][totallen:] + else: + rv = buf.popleft() + return rv def etherStrip(buf): if len(buf) < 14: @@ -109,39 +129,42 @@ def etherStrip(buf): return "" def etherWrap(packet): - return ( + return ''.join(( "\x00"*6*2 # bogus src and dst mac - +"\x08\x00" # IPv4 - +packet # payload - +"\x00"*4 # bogus crc - ) + +"\x08\x00", # IPv4 + packet, # payload + "\x00"*4, # bogus crc + )) -def piStrip(buf): +def piStrip(buf, len=len): if len(buf) < 4: return buf else: return buf[4:] -def piWrap(buf, ether_mode): +def piWrap(buf, ether_mode, etherProto=etherProto): if ether_mode: proto = etherProto(buf) else: proto = "\x08\x00" - return ( - "\x00\x00" # PI: 16 bits flags - +proto # 16 bits proto - +buf - ) + return ''.join(( + "\x00\x00", # PI: 16 bits flags + proto, # 16 bits proto + buf, + )) + +_padmap = [ chr(padding) * padding for padding in xrange(127) ] +del padding -def encrypt(packet, crypter): +def encrypt(packet, crypter, len=len, padmap=_padmap): # pad padding = crypter.block_size - len(packet) % crypter.block_size - packet += chr(padding) * padding + packet += padmap[padding] # encrypt return crypter.encrypt(packet) -def decrypt(packet, crypter): +def decrypt(packet, crypter, ord=ord): # decrypt packet = crypter.decrypt(packet) @@ -154,22 +177,39 @@ def decrypt(packet, crypter): return packet +def nonblock(fd): + try: + fl = fcntl.fcntl(fd, fcntl.F_GETFL) + fl |= os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, fl) + return True + except: + traceback.print_exc(file=sys.stderr) + # Just ignore + return False -def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr): +def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000, + len=len, max=max, OSError=OSError, cipher='AES'): crypto_mode = False try: if cipher_key: - import Crypto.Cipher.AES + import Crypto.Cipher import hashlib + __import__('Crypto.Cipher.'+cipher) + ciphername = cipher + cipher = getattr(Crypto.Cipher, cipher) hashed_key = hashlib.sha256(cipher_key).digest() - crypter = Crypto.Cipher.AES.new( + if getattr(cipher, 'key_size'): + hashed_key = hashed_key[:cipher.key_size] + elif ciphername == 'DES3': + hashed_key = hashed_key[:24] + crypter = cipher.new( hashed_key, - Crypto.Cipher.AES.MODE_ECB) + cipher.MODE_ECB) crypto_mode = True except: - import traceback - traceback.print_exc() + traceback.print_exc(file=sys.stderr) crypto_mode = False crypter = None @@ -179,69 +219,195 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr else: print >>stderr, "Packets are transmitted in PLAINTEXT" + if hasattr(remote, 'fileno'): + remote_fd = remote.fileno() + if rwrite is None: + def rwrite(remote, packet, os_write=os.write): + return os_write(remote_fd, packet) + if rread is None: + def rread(remote, maxlen, os_read=os.read): + return os_read(remote_fd, maxlen) + + rnonblock = nonblock(remote) + tnonblock = nonblock(tun) + # Limited frame parsing, to preserve packet boundaries. # Which is needed, since /dev/net/tun is unbuffered - fwbuf = "" - bkbuf = "" + maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue) + tunhurry = max(0,maxbkbuf/2) + fwbuf = collections.deque() + bkbuf = collections.deque() + if ether_mode: + packetReady = bool + pullPacket = collections.deque.popleft + else: + packetReady = _packetReady + pullPacket = _pullPacket + tunfd = tun.fileno() + os_read = os.read + os_write = os.write while not TERMINATE: wset = [] - if packetReady(bkbuf, ether_mode): + if packetReady(bkbuf): wset.append(tun) - if packetReady(fwbuf, ether_mode): + if packetReady(fwbuf): wset.append(remote) + rset = [] + if len(fwbuf) < maxfwbuf: + rset.append(tun) + if len(bkbuf) < maxbkbuf: + rset.append(remote) + try: - rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1) + rdrdy, wrdy, errs = select.select(rset,wset,(tun,remote),1) except select.error, e: if e.args[0] == errno.EINTR: # just retry continue - + # check for errors if errs: - break + if reconnect is not None and remote in errs and tun not in errs: + remote = reconnect() + if hasattr(remote, 'fileno'): + remote_fd = remote.fileno() + elif udp and remote in errs and tun not in errs: + # In UDP mode, those are always transient errors + pass + else: + break # check to see if we can write - if remote in wrdy and packetReady(fwbuf, ether_mode): - packet, fwbuf = pullPacket(fwbuf, ether_mode) + #rr = wr = rt = wt = 0 + if remote in wrdy: try: - if crypto_mode: - enpacket = encrypt(packet, crypter) - else: - enpacket = packet - os.write(remote.fileno(), enpacket) + try: + while True: + packet = pullPacket(fwbuf) + + if crypto_mode: + enpacket = encrypt(packet, crypter) + else: + enpacket = packet + + # try twice - sometimes it barks the first time, + # due to ICMP Port Unreachable packets from previous writes + try: + rwrite(remote, enpacket) + except socket.error: + rwrite(remote, enpacket) + #wr += 1 + + if stderr is not None: + print >>stderr, '>', formatPacket(packet, ether_mode) + + if not rnonblock or not packetReady(fwbuf): + break + except OSError,e: + # This except handles the entire While block on PURPOSE + # as an optimization (setting a try/except block is expensive) + # The only operation that can raise this exception is rwrite + if e.errno == os.errno.EWOULDBLOCK: + # re-schedule packet + fwbuf.insert(0, packet) + else: + raise except: - if not udp: + if reconnect is not None: + # in UDP mode, sometimes connected sockets can return a connection refused. + # Give the caller a chance to reconnect + remote = reconnect() + if hasattr(remote, 'fileno'): + remote_fd = remote.fileno() + elif not udp: # in UDP mode, we ignore errors - packet loss man... raise - if stderr is not None: - print >>stderr, '>', formatPacket(packet, ether_mode) - if tun in wrdy and packetReady(bkbuf, ether_mode): - packet, bkbuf = pullPacket(bkbuf, ether_mode) - if stderr is not None: - formatted = formatPacket(packet, ether_mode) - if with_pi: - packet = piWrap(packet, ether_mode) - os.write(tun.fileno(), packet) - if stderr is not None: - print >>stderr, '<', formatted + traceback.print_exc(file=sys.stderr) + if tun in wrdy: + try: + while True: + packet = pullPacket(bkbuf) + if stderr is not None: + formatted = formatPacket(packet, ether_mode) + if with_pi: + packet = piWrap(packet, ether_mode) + os_write(tunfd, packet) + #wt += 1 + if stderr is not None: + print >>stderr, '<', formatted + + # Do not inject packets into the TUN faster than they arrive, unless we're falling + # behind. TUN devices discard packets if their queue is full (tunkqueue), but they + # don't block either (they're always ready to write), so if we flood the device + # we'll have high packet loss. + if not tnonblock or len(bkbuf) < tunhurry or not packetReady(bkbuf): + break + except OSError,e: + # This except handles the entire While block on PURPOSE + # as an optimization (setting a try/except block is expensive) + # The only operation that can raise this exception is os_write + if e.errno == os.errno.EWOULDBLOCK: + # re-schedule packet + bkbuf.insert(0, packet) + else: + raise # check incoming data packets if tun in rdrdy: - packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k! - if with_pi: - packet = piStrip(packet) - fwbuf += packet + try: + while True: + packet = os_read(tunfd,2000) # tun.read blocks until it gets 2k! + #rt += 1 + if with_pi: + packet = piStrip(packet) + fwbuf.append(packet) + + if not tnonblock or len(fwbuf) >= maxfwbuf: + break + except OSError,e: + # This except handles the entire While block on PURPOSE + # as an optimization (setting a try/except block is expensive) + # The only operation that can raise this exception is os_read + if e.errno != os.errno.EWOULDBLOCK: + raise if remote in rdrdy: try: - packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k! - if crypto_mode: - packet = decrypt(packet, crypter) - except: - if not udp: + try: + while True: + # Try twice, sometimes it barks the first time, + # due to ICMP Port Unreachable packets from previous writes + try: + packet = rread(remote,2000) + except socket.error: + packet = rread(remote,2000) + #rr += 1 + + if crypto_mode: + packet = decrypt(packet, crypter) + bkbuf.append(packet) + + if not rnonblock or len(bkbuf) >= maxbkbuf: + break + except OSError,e: + # This except handles the entire While block on PURPOSE + # as an optimization (setting a try/except block is expensive) + # The only operation that can raise this exception is rread + if e.errno != os.errno.EWOULDBLOCK: + raise + except Exception, e: + if reconnect is not None: + # in UDP mode, sometimes connected sockets can return a connection refused + # on read. Give the caller a chance to reconnect + remote = reconnect() + if hasattr(remote, 'fileno'): + remote_fd = remote.fileno() + elif not udp: # in UDP mode, we ignore errors - packet loss man... raise - bkbuf += packet + traceback.print_exc(file=sys.stderr) + + #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt) diff --git a/tunbench.py b/tunbench.py new file mode 100644 index 00000000..b8ac97c8 --- /dev/null +++ b/tunbench.py @@ -0,0 +1,55 @@ +import os +import sys +import threading +import time +import cProfile +import pstats + +from nepi.util import tunchannel + +remote = open("/dev/zero","r+b") +tun = open("/dev/zero","r+b") + +def rwrite(remote, packet, remote_fd = remote.fileno(), os_write=os.write, len=len): + global bytes + bytes += len(packet) + return os_write(remote_fd, packet) + +def rread(remote, maxlen, remote_fd = remote.fileno(), os_read=os.read): + global bytes + rv = os_read(remote_fd, maxlen) + bytes += len(rv) + return rv + +def test(cipher, passphrase): + TERMINATE = [] + def stopme(): + time.sleep(100) + TERMINATE.append(None) + t = threading.Thread(target=stopme) + t.start() + tunchannel.tun_fwd(tun, remote, True, True, passphrase, True, TERMINATE, None, tunkqueue=500, + rwrite = rwrite, rread = rread, cipher=cipher) + +# Swallow exceptions on decryption +def decrypt(packet, crypter, super=tunchannel.decrypt): + try: + return super(packet, crypter) + except: + return packet +tunchannel.decrypt = decrypt + +for cipher in (None, 'AES', 'Blowfish', 'DES', 'DES3'): + if cipher is None: + passphrase = None + else: + passphrase = 'Abracadabra' + bytes = 0 + cProfile.runctx('test(%r,%r)' % (cipher, passphrase),globals(),locals(),'tunchannel.%s.profile' % (cipher,)) + + print "Profile (%s):" % ( cipher, ) + pstats.Stats('tunchannel.%s.profile' % cipher).strip_dirs().sort_stats('time').print_stats() + + print "Bandwidth (%s): %.4fMb/s" % ( cipher, bytes / 200.0 * 8 / 2**20, ) + +