Added TCP-handshake for TunChannel and tun_connect.py
[nepi.git] / src / nepi / util / tunchannel.py
index 9e34b45..fc627df 100644 (file)
@@ -6,6 +6,7 @@ import socket
 import threading
 import errno
 import fcntl
+import random
 import traceback
 import functools
 import collections
@@ -167,15 +168,16 @@ def encrypt(packet, crypter, len=len, padmap=_padmap):
     return crypter.encrypt(packet)
 
 def decrypt(packet, crypter, ord=ord):
-    # decrypt
-    packet = crypter.decrypt(packet)
-    
-    # un-pad
-    padding = ord(packet[-1])
-    if not (0 < padding <= crypter.block_size):
-        # wrong padding
-        raise RuntimeError, "Truncated packet"
-    packet = packet[:-padding]
+    if packet:
+        # decrypt
+        packet = crypter.decrypt(packet)
+        
+        # un-pad
+        padding = ord(packet[-1])
+        if not (0 < padding <= crypter.block_size):
+            # wrong padding
+            raise RuntimeError, "Truncated packet"
+        packet = packet[:-padding]
     
     return packet
 
@@ -191,12 +193,14 @@ def nonblock(fd):
         return False
 
 def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
-        cipher='AES',
-        len=len, max=max, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
+        cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None,
+        len=len, max=max, min=min, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
         retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
     crypto_mode = False
+    crypter = None
+
     try:
-        if cipher_key:
+        if cipher_key and cipher:
             import Crypto.Cipher
             import hashlib
             __import__('Crypto.Cipher.'+cipher)
@@ -231,7 +235,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         if rread is None:
             def rread(remote, maxlen, os_read=os.read):
                 return os_read(remote_fd, maxlen)
-    
     rnonblock = nonblock(remote)
     tnonblock = nonblock(tun)
     
@@ -261,46 +265,96 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         twrite = os.write
         tread = os.read
     
-    # Limited frame parsing, to preserve packet boundaries.
-    # Which is needed, since /dev/net/tun is unbuffered
+    encrypt_ = encrypt
+    decrypt_ = decrypt
+    xrange_ = xrange
+
+    if accept_local is not None:
+        def tread(fd, maxlen, _tread=tread, accept=accept_local):
+            packet = _tread(fd, maxlen)
+            if accept(packet, 0):
+                return packet
+            else:
+                return None
+
+    if accept_remote is not None:
+        if crypto_mode:
+            def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
+                packet = decrypt_(packet, crypter)
+                if accept(packet, 1):
+                    return packet
+                else:
+                    return None
+        else:
+            def rread(fd, maxlen, _rread=rread, accept=accept_remote):
+                packet = _rread(fd, maxlen)
+                if accept(packet, 1):
+                    return packet
+                else:
+                    return None
+    
     maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
     tunhurry = max(0,maxbkbuf/2)
-    fwbuf = collections.deque()
-    bkbuf = collections.deque()
+    
+    if queueclass is None:
+        queueclass = collections.deque
+        maxbatch = 2000
+        maxtbatch = 50
+    else:
+        maxfwbuf = maxbkbuf = 2000000000
+        maxbatch = 50
+        maxtbatch = 30
+        tunhurry = 30
+    
+    fwbuf = queueclass()
+    bkbuf = queueclass()
     nfwbuf = 0
     nbkbuf = 0
-    if ether_mode:
+    if ether_mode or udp:
         packetReady = bool
-        pullPacket = collections.deque.popleft
-        reschedule = collections.deque.appendleft
+        pullPacket = queueclass.popleft
+        reschedule = queueclass.appendleft
     else:
         packetReady = _packetReady
         pullPacket = _pullPacket
-        reschedule = collections.deque.appendleft
+        reschedule = queueclass.appendleft
     tunfd = tun.fileno()
     os_read = os.read
     os_write = os.write
-    encrypt_ = encrypt
-    decrypt_ = decrypt
+    
+    tget = time.time
+    maxbwfree = bwfree = 1500 * tunqueue
+    lastbwtime = tget()
+    
+    remoteok = True
+    
+    
     while not TERMINATE:
         wset = []
         if packetReady(bkbuf):
             wset.append(tun)
-        if packetReady(fwbuf):
+        if remoteok and packetReady(fwbuf) and (not bwlimit or bwfree > 0):
             wset.append(remote)
         
         rset = []
         if len(fwbuf) < maxfwbuf:
             rset.append(tun)
-        if len(bkbuf) < maxbkbuf:
+        if remoteok and len(bkbuf) < maxbkbuf:
             rset.append(remote)
         
+        if remoteok:
+            eset = (tun,remote)
+        else:
+            eset = (tun,)
+        
         try:
-            rdrdy, wrdy, errs = select(rset,wset,(tun,remote),1)
+            rdrdy, wrdy, errs = select(rset,wset,eset,1)
         except selecterror, e:
             if e.args[0] == errno.EINTR:
                 # just retry
                 continue
+            else:
+                raise
 
         # check for errors
         if errs:
@@ -310,22 +364,30 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                     remote_fd = remote.fileno()
             elif udp and remote in errs and tun not in errs:
                 # In UDP mode, those are always transient errors
-                pass
+                # Usually, an error will imply a read-ready socket
+                # that will raise an "Connection refused" error, so
+                # disable read-readiness just for now, and retry
+                # the select
+                remoteok = False
+                continue
             else:
                 break
+        else:
+            remoteok = True
         
         # check to see if we can write
         #rr = wr = rt = wt = 0
         if remote in wrdy:
+            sent = 0
             try:
                 try:
-                    while 1:
+                    for x in xrange(maxbatch):
                         packet = pullPacket(fwbuf)
 
                         if crypto_mode:
                             packet = encrypt_(packet, crypter)
                         
-                        rwrite(remote, packet)
+                        sent += rwrite(remote, packet)
                         #wr += 1
                         
                         if not rnonblock or not packetReady(fwbuf):
@@ -350,9 +412,12 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                     # in UDP mode, we ignore errors - packet loss man...
                     raise
                 #traceback.print_exc(file=sys.stderr)
+            
+            if bwlimit:
+                bwfree -= sent
         if tun in wrdy:
             try:
-                for x in xrange(50):
+                for x in xrange(maxtbatch):
                     packet = pullPacket(bkbuf)
                     twrite(tunfd, packet)
                     #wt += 1
@@ -361,11 +426,12 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                     # behind. TUN devices discard packets if their queue is full (tunkqueue), but they
                     # don't block either (they're always ready to write), so if we flood the device 
                     # we'll have high packet loss.
-                    if not tnonblock or len(bkbuf) < tunhurry or not packetReady(bkbuf):
+                    if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf):
                         break
                 else:
-                    # Give some time for the kernel to process the packets
-                    time.sleep(0)
+                    if slowlocal:
+                        # Give some time for the kernel to process the packets
+                        time.sleep(0)
             except OSError,e:
                 # This except handles the entire While block on PURPOSE
                 # as an optimization (setting a try/except block is expensive)
@@ -379,8 +445,10 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         # check incoming data packets
         if tun in rdrdy:
             try:
-                while 1:
+                for x in xrange(maxbatch):
                     packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
+                    if not packet:
+                        continue
                     #rt += 1
                     fwbuf.append(packet)
                     
@@ -395,12 +463,22 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         if remote in rdrdy:
             try:
                 try:
-                    while 1:
+                    for x in xrange(maxbatch):
                         packet = rread(remote,2000)
+                        
                         #rr += 1
                         
                         if crypto_mode:
                             packet = decrypt_(packet, crypter)
+                            if not packet:
+                                continue
+                        elif not packet:
+                            if not udp and packet == "":
+                                # Connection broken, try to reconnect (or just die)
+                                raise RuntimeError, "Connection broken"
+                            else:
+                                continue
+
                         bkbuf.append(packet)
                         
                         if not rnonblock or len(bkbuf) >= maxbkbuf:
@@ -421,9 +499,242 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                 elif not udp:
                     # in UDP mode, we ignore errors - packet loss man...
                     raise
-                #traceback.print_exc(file=sys.stderr)
+                traceback.print_exc(file=sys.stderr)
+
+        if bwlimit:
+            tnow = tget()
+            delta = tnow - lastbwtime
+            if delta > 0.001:
+                delta = int(bwlimit * delta)
+                if delta > 0:
+                    bwfree = min(bwfree+delta, maxbwfree)
+                    lastbwtime = tnow
         
         #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
 
+def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+    retrydelay = 1.0
+    for i in xrange(30):
+        # TERMINATE is a array. An item can be added to TERMINATE, from
+        # outside this function to force termination of the loop
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            rsock.bind((local_addr, local_port))
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        rsock.bind((local_addr, local_port))
+    print >>sys.stderr, "Listening UDP at: %s:%d" % (local_addr, local_port)
+    print >>sys.stderr, "Connecting UDP to: %s:%d" % (peer_addr, peer_port)
+    rsock.connect((peer_addr, peer_port))
+    return rsock
 
+def udp_handshake(TERMINATE, rsock):
+    endme = False
+    def keepalive():
+        while not endme and not TERMINATE:
+            try:
+                rsock.send('')
+            except:
+                pass
+            time.sleep(1)
+        try:
+            rsock.send('')
+        except:
+            pass
+    keepalive_thread = threading.Thread(target=keepalive)
+    keepalive_thread.start()
+    retrydelay = 1.0
+    for i in xrange(30):
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            heartbeat = rsock.recv(10)
+            break
+        except:
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        heartbeat = rsock.recv(10)
+    endme = True
+    keepalive_thread.join()
+
+def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr,
+            peer_port)
+    udp_handshake(TERMINATE, rsock)
+    return rsock 
+
+def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
+    sock = None
+    retrydelay = 1.0
+    # The peer has a firewall that prevents a response to the connect, we 
+    # will be forever blocked in the connect, so we put a reasonable timeout.
+    sock.settimeout(10) 
+    # We wait for 
+    for i in xrange(30):
+        if stop:
+            break
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            rsock.connect((peer_addr, peer_port))
+            sock = rsock
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        rsock.connect((peer_addr, peer_port))
+        sock = rsock
+    if sock:
+        sock.settimeout(0) 
+    return sock
+
+def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
+    sock = None
+    retrydelay = 1.0
+    # We try to bind to the local virtual interface. 
+    # It might not exist yet so we wait in a loop.
+    for i in xrange(30):
+        if stop:
+            break
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            lsock.bind((local_addr, local_port))
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        lsock.bind((local_addr, local_port))
+
+    # Now we wait until the other side connects. 
+    # The other side might not be ready yet, so we also wait in a loop for timeouts.
+    timeout = 1
+    lsock.listen(1)
+    for i in xrange(30):
+        if TERMINATE:
+            raise OSError, "Killed"
+        rlist, wlist, xlist = select.select([lsock], [], [], timeout)
+        if stop:
+            break
+        if lsock in rlist:
+            sock,raddr = lsock.accept()
+            break
+        timeout += 5
+    return sock
+
+def tcp_handshake(TERMINATE, sock, listen, dice):
+    # we are going to use a barrier algorithm to decide wich side listen.
+    # each side will "roll a dice" and send the resulting value to the other 
+    # side. 
+    result = None
+    sock.settimeout(5)
+    for i in xrange(100):
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            hand = dice.read()
+            sock.send(hand)
+            peer_hand = sock.recv(1)
+            if hand < peer_hand:
+                if listen:
+                    result = sock
+                break   
+            elif hand > peer_hand:
+                if not listen:
+                    result = sock
+                break
+            else:
+                dice.release()
+                dice.throw()
+        except socket.error:
+            dice.release()
+            break
+    if result:
+        sock.settimeout(0)
+    return result
+
+def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    def listen(stop, result, lsock, dice):
+        lsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
+        if lsock:
+            lsock = tcp_handshake(TERMINATE, lsock, True, dice)
+            if lsock:
+                stop[0] = True
+                result[0] = lsock
+
+    def connect(stop, result, rsock, dice):
+        rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
+        if rsock:
+            rsock = tcp_handshake(TERMINATE, rsock, False, dice)
+            if sock:
+                stop[0] = True
+                result[0] = rsock
+    
+    dice = Dice()
+    dice.throw()
+    stop = []
+    result = []
+    lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+    rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+    connect_thread = threading.Thread(target=connect, args=(stop, result, rsock, dice))
+    listen_thread = threading.Thread(target=listen, args=(stop, result, lsock, dice))
+    connect_thread.start()
+    listen_thread.start()
+    connect_thread.join()
+    listen_thread.join()
+    if not result:
+        raise OSError, "Error: tcp_establish could not establish connection."
+    sock = result[0]
+    if sock == lsock:
+        rsock.close()
+    else:
+        lsock.close()
+    return sock
+
+class Dice(object):
+    def __init__(self):
+        self._condition = threading.Condition(threading.Lock())
+        self._readers = 0
+        self._value = None
+
+    def read(self):
+        self._condition.acquire()
+        try:
+            self._readers += 1
+        finally:
+            self._condition.release()
+        return self._value
+
+    def release(self):
+        self._condition.acquire()
+        try:
+            if self._readers > 0:
+                self._readers -= 1
+            if self._readers == 0:
+                self._condition.notifyAll()
+        finally:
+            self._condition.release()
+
+    def throw(self):
+        self._condition.acquire()
+        try:
+            while self._readers > 0:
+                self._condition.wait()
+            self._value = random.randint(1, 6)
+        finally:
+            self._condition.release()