X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=src%2Fnepi%2Futil%2Ftunchannel.py;h=fc627df4cbc572d5643934aa084aa9947b4c8c80;hb=4bea9f9369373d6f6be4b4b781de6d615bfaf610;hp=9e34b45f5c72ddc124c905e673ed51c9b87d2894;hpb=e1cea30060f499111e893b290557faa42908cf78;p=nepi.git diff --git a/src/nepi/util/tunchannel.py b/src/nepi/util/tunchannel.py index 9e34b45f..fc627df4 100644 --- a/src/nepi/util/tunchannel.py +++ b/src/nepi/util/tunchannel.py @@ -6,6 +6,7 @@ import socket import threading import errno import fcntl +import random import traceback import functools import collections @@ -167,15 +168,16 @@ def encrypt(packet, crypter, len=len, padmap=_padmap): return crypter.encrypt(packet) def decrypt(packet, crypter, ord=ord): - # decrypt - packet = crypter.decrypt(packet) - - # un-pad - padding = ord(packet[-1]) - if not (0 < padding <= crypter.block_size): - # wrong padding - raise RuntimeError, "Truncated packet" - packet = packet[:-padding] + if packet: + # decrypt + packet = crypter.decrypt(packet) + + # un-pad + padding = ord(packet[-1]) + if not (0 < padding <= crypter.block_size): + # wrong padding + raise RuntimeError, "Truncated packet" + packet = packet[:-padding] return packet @@ -191,12 +193,14 @@ def nonblock(fd): return False def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000, - cipher='AES', - len=len, max=max, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket, + cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None, + len=len, max=max, min=min, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket, retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ): crypto_mode = False + crypter = None + try: - if cipher_key: + if cipher_key and cipher: import Crypto.Cipher import hashlib __import__('Crypto.Cipher.'+cipher) @@ -231,7 +235,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr if rread is None: def rread(remote, maxlen, os_read=os.read): return os_read(remote_fd, maxlen) - + rnonblock = nonblock(remote) tnonblock = nonblock(tun) @@ -261,46 +265,96 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr twrite = os.write tread = os.read - # Limited frame parsing, to preserve packet boundaries. - # Which is needed, since /dev/net/tun is unbuffered + encrypt_ = encrypt + decrypt_ = decrypt + xrange_ = xrange + + if accept_local is not None: + def tread(fd, maxlen, _tread=tread, accept=accept_local): + packet = _tread(fd, maxlen) + if accept(packet, 0): + return packet + else: + return None + + if accept_remote is not None: + if crypto_mode: + def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote): + packet = decrypt_(packet, crypter) + if accept(packet, 1): + return packet + else: + return None + else: + def rread(fd, maxlen, _rread=rread, accept=accept_remote): + packet = _rread(fd, maxlen) + if accept(packet, 1): + return packet + else: + return None + maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue) tunhurry = max(0,maxbkbuf/2) - fwbuf = collections.deque() - bkbuf = collections.deque() + + if queueclass is None: + queueclass = collections.deque + maxbatch = 2000 + maxtbatch = 50 + else: + maxfwbuf = maxbkbuf = 2000000000 + maxbatch = 50 + maxtbatch = 30 + tunhurry = 30 + + fwbuf = queueclass() + bkbuf = queueclass() nfwbuf = 0 nbkbuf = 0 - if ether_mode: + if ether_mode or udp: packetReady = bool - pullPacket = collections.deque.popleft - reschedule = collections.deque.appendleft + pullPacket = queueclass.popleft + reschedule = queueclass.appendleft else: packetReady = _packetReady pullPacket = _pullPacket - reschedule = collections.deque.appendleft + reschedule = queueclass.appendleft tunfd = tun.fileno() os_read = os.read os_write = os.write - encrypt_ = encrypt - decrypt_ = decrypt + + tget = time.time + maxbwfree = bwfree = 1500 * tunqueue + lastbwtime = tget() + + remoteok = True + + while not TERMINATE: wset = [] if packetReady(bkbuf): wset.append(tun) - if packetReady(fwbuf): + if remoteok and packetReady(fwbuf) and (not bwlimit or bwfree > 0): wset.append(remote) rset = [] if len(fwbuf) < maxfwbuf: rset.append(tun) - if len(bkbuf) < maxbkbuf: + if remoteok and len(bkbuf) < maxbkbuf: rset.append(remote) + if remoteok: + eset = (tun,remote) + else: + eset = (tun,) + try: - rdrdy, wrdy, errs = select(rset,wset,(tun,remote),1) + rdrdy, wrdy, errs = select(rset,wset,eset,1) except selecterror, e: if e.args[0] == errno.EINTR: # just retry continue + else: + raise # check for errors if errs: @@ -310,22 +364,30 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr remote_fd = remote.fileno() elif udp and remote in errs and tun not in errs: # In UDP mode, those are always transient errors - pass + # Usually, an error will imply a read-ready socket + # that will raise an "Connection refused" error, so + # disable read-readiness just for now, and retry + # the select + remoteok = False + continue else: break + else: + remoteok = True # check to see if we can write #rr = wr = rt = wt = 0 if remote in wrdy: + sent = 0 try: try: - while 1: + for x in xrange(maxbatch): packet = pullPacket(fwbuf) if crypto_mode: packet = encrypt_(packet, crypter) - rwrite(remote, packet) + sent += rwrite(remote, packet) #wr += 1 if not rnonblock or not packetReady(fwbuf): @@ -350,9 +412,12 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr # in UDP mode, we ignore errors - packet loss man... raise #traceback.print_exc(file=sys.stderr) + + if bwlimit: + bwfree -= sent if tun in wrdy: try: - for x in xrange(50): + for x in xrange(maxtbatch): packet = pullPacket(bkbuf) twrite(tunfd, packet) #wt += 1 @@ -361,11 +426,12 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr # behind. TUN devices discard packets if their queue is full (tunkqueue), but they # don't block either (they're always ready to write), so if we flood the device # we'll have high packet loss. - if not tnonblock or len(bkbuf) < tunhurry or not packetReady(bkbuf): + if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf): break else: - # Give some time for the kernel to process the packets - time.sleep(0) + if slowlocal: + # Give some time for the kernel to process the packets + time.sleep(0) except OSError,e: # This except handles the entire While block on PURPOSE # as an optimization (setting a try/except block is expensive) @@ -379,8 +445,10 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr # check incoming data packets if tun in rdrdy: try: - while 1: + for x in xrange(maxbatch): packet = tread(tunfd,2000) # tun.read blocks until it gets 2k! + if not packet: + continue #rt += 1 fwbuf.append(packet) @@ -395,12 +463,22 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr if remote in rdrdy: try: try: - while 1: + for x in xrange(maxbatch): packet = rread(remote,2000) + #rr += 1 if crypto_mode: packet = decrypt_(packet, crypter) + if not packet: + continue + elif not packet: + if not udp and packet == "": + # Connection broken, try to reconnect (or just die) + raise RuntimeError, "Connection broken" + else: + continue + bkbuf.append(packet) if not rnonblock or len(bkbuf) >= maxbkbuf: @@ -421,9 +499,242 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr elif not udp: # in UDP mode, we ignore errors - packet loss man... raise - #traceback.print_exc(file=sys.stderr) + traceback.print_exc(file=sys.stderr) + + if bwlimit: + tnow = tget() + delta = tnow - lastbwtime + if delta > 0.001: + delta = int(bwlimit * delta) + if delta > 0: + bwfree = min(bwfree+delta, maxbwfree) + lastbwtime = tnow #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt) +def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port): + rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) + retrydelay = 1.0 + for i in xrange(30): + # TERMINATE is a array. An item can be added to TERMINATE, from + # outside this function to force termination of the loop + if TERMINATE: + raise OSError, "Killed" + try: + rsock.bind((local_addr, local_port)) + break + except socket.error: + # wait a while, retry + print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),) + time.sleep(min(30.0,retrydelay)) + retrydelay *= 1.1 + else: + rsock.bind((local_addr, local_port)) + print >>sys.stderr, "Listening UDP at: %s:%d" % (local_addr, local_port) + print >>sys.stderr, "Connecting UDP to: %s:%d" % (peer_addr, peer_port) + rsock.connect((peer_addr, peer_port)) + return rsock +def udp_handshake(TERMINATE, rsock): + endme = False + def keepalive(): + while not endme and not TERMINATE: + try: + rsock.send('') + except: + pass + time.sleep(1) + try: + rsock.send('') + except: + pass + keepalive_thread = threading.Thread(target=keepalive) + keepalive_thread.start() + retrydelay = 1.0 + for i in xrange(30): + if TERMINATE: + raise OSError, "Killed" + try: + heartbeat = rsock.recv(10) + break + except: + time.sleep(min(30.0,retrydelay)) + retrydelay *= 1.1 + else: + heartbeat = rsock.recv(10) + endme = True + keepalive_thread.join() + +def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port): + rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr, + peer_port) + udp_handshake(TERMINATE, rsock) + return rsock + +def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port): + sock = None + retrydelay = 1.0 + # The peer has a firewall that prevents a response to the connect, we + # will be forever blocked in the connect, so we put a reasonable timeout. + sock.settimeout(10) + # We wait for + for i in xrange(30): + if stop: + break + if TERMINATE: + raise OSError, "Killed" + try: + rsock.connect((peer_addr, peer_port)) + sock = rsock + break + except socket.error: + # wait a while, retry + print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),) + time.sleep(min(30.0,retrydelay)) + retrydelay *= 1.1 + else: + rsock.connect((peer_addr, peer_port)) + sock = rsock + if sock: + sock.settimeout(0) + return sock + +def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port): + sock = None + retrydelay = 1.0 + # We try to bind to the local virtual interface. + # It might not exist yet so we wait in a loop. + for i in xrange(30): + if stop: + break + if TERMINATE: + raise OSError, "Killed" + try: + lsock.bind((local_addr, local_port)) + break + except socket.error: + # wait a while, retry + print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),) + time.sleep(min(30.0,retrydelay)) + retrydelay *= 1.1 + else: + lsock.bind((local_addr, local_port)) + + # Now we wait until the other side connects. + # The other side might not be ready yet, so we also wait in a loop for timeouts. + timeout = 1 + lsock.listen(1) + for i in xrange(30): + if TERMINATE: + raise OSError, "Killed" + rlist, wlist, xlist = select.select([lsock], [], [], timeout) + if stop: + break + if lsock in rlist: + sock,raddr = lsock.accept() + break + timeout += 5 + return sock + +def tcp_handshake(TERMINATE, sock, listen, dice): + # we are going to use a barrier algorithm to decide wich side listen. + # each side will "roll a dice" and send the resulting value to the other + # side. + result = None + sock.settimeout(5) + for i in xrange(100): + if TERMINATE: + raise OSError, "Killed" + try: + hand = dice.read() + sock.send(hand) + peer_hand = sock.recv(1) + if hand < peer_hand: + if listen: + result = sock + break + elif hand > peer_hand: + if not listen: + result = sock + break + else: + dice.release() + dice.throw() + except socket.error: + dice.release() + break + if result: + sock.settimeout(0) + return result + +def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port): + def listen(stop, result, lsock, dice): + lsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port) + if lsock: + lsock = tcp_handshake(TERMINATE, lsock, True, dice) + if lsock: + stop[0] = True + result[0] = lsock + + def connect(stop, result, rsock, dice): + rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port) + if rsock: + rsock = tcp_handshake(TERMINATE, rsock, False, dice) + if sock: + stop[0] = True + result[0] = rsock + + dice = Dice() + dice.throw() + stop = [] + result = [] + lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + connect_thread = threading.Thread(target=connect, args=(stop, result, rsock, dice)) + listen_thread = threading.Thread(target=listen, args=(stop, result, lsock, dice)) + connect_thread.start() + listen_thread.start() + connect_thread.join() + listen_thread.join() + if not result: + raise OSError, "Error: tcp_establish could not establish connection." + sock = result[0] + if sock == lsock: + rsock.close() + else: + lsock.close() + return sock + +class Dice(object): + def __init__(self): + self._condition = threading.Condition(threading.Lock()) + self._readers = 0 + self._value = None + + def read(self): + self._condition.acquire() + try: + self._readers += 1 + finally: + self._condition.release() + return self._value + + def release(self): + self._condition.acquire() + try: + if self._readers > 0: + self._readers -= 1 + if self._readers == 0: + self._condition.notifyAll() + finally: + self._condition.release() + + def throw(self): + self._condition.acquire() + try: + while self._readers > 0: + self._condition.wait() + self._value = random.randint(1, 6) + finally: + self._condition.release()