import threading
import errno
import fcntl
+import random
import traceback
import functools
import collections
return crypter.encrypt(packet)
def decrypt(packet, crypter, ord=ord):
- # decrypt
- packet = crypter.decrypt(packet)
-
- # un-pad
- padding = ord(packet[-1])
- if not (0 < padding <= crypter.block_size):
- # wrong padding
- raise RuntimeError, "Truncated packet"
- packet = packet[:-padding]
+ if packet:
+ # decrypt
+ packet = crypter.decrypt(packet)
+
+ # un-pad
+ padding = ord(packet[-1])
+ if not (0 < padding <= crypter.block_size):
+ # wrong padding
+ raise RuntimeError, "Truncated packet"
+ packet = packet[:-padding]
return packet
return False
def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
- cipher='AES',
- len=len, max=max, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
+ cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None,
+ len=len, max=max, min=min, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
crypto_mode = False
+ crypter = None
+
try:
- if cipher_key:
+ if cipher_key and cipher:
import Crypto.Cipher
import hashlib
__import__('Crypto.Cipher.'+cipher)
if rread is None:
def rread(remote, maxlen, os_read=os.read):
return os_read(remote_fd, maxlen)
-
+
rnonblock = nonblock(remote)
tnonblock = nonblock(tun)
twrite = os.write
tread = os.read
- # Limited frame parsing, to preserve packet boundaries.
- # Which is needed, since /dev/net/tun is unbuffered
+ encrypt_ = encrypt
+ decrypt_ = decrypt
+ xrange_ = xrange
+
+ if accept_local is not None:
+ def tread(fd, maxlen, _tread=tread, accept=accept_local):
+ packet = _tread(fd, maxlen)
+ if accept(packet, 0):
+ return packet
+ else:
+ return None
+
+ if accept_remote is not None:
+ if crypto_mode:
+ def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
+ packet = decrypt_(packet, crypter)
+ if accept(packet, 1):
+ return packet
+ else:
+ return None
+ else:
+ def rread(fd, maxlen, _rread=rread, accept=accept_remote):
+ packet = _rread(fd, maxlen)
+ if accept(packet, 1):
+ return packet
+ else:
+ return None
+
maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
tunhurry = max(0,maxbkbuf/2)
- fwbuf = collections.deque()
- bkbuf = collections.deque()
+
+ if queueclass is None:
+ queueclass = collections.deque
+ maxbatch = 2000
+ maxtbatch = 50
+ else:
+ maxfwbuf = maxbkbuf = 2000000000
+ maxbatch = 50
+ maxtbatch = 30
+ tunhurry = 30
+
+ fwbuf = queueclass()
+ bkbuf = queueclass()
nfwbuf = 0
nbkbuf = 0
- if ether_mode:
+ if ether_mode or udp:
packetReady = bool
- pullPacket = collections.deque.popleft
- reschedule = collections.deque.appendleft
+ pullPacket = queueclass.popleft
+ reschedule = queueclass.appendleft
else:
packetReady = _packetReady
pullPacket = _pullPacket
- reschedule = collections.deque.appendleft
+ reschedule = queueclass.appendleft
tunfd = tun.fileno()
os_read = os.read
os_write = os.write
- encrypt_ = encrypt
- decrypt_ = decrypt
+
+ tget = time.time
+ maxbwfree = bwfree = 1500 * tunqueue
+ lastbwtime = tget()
+
+ remoteok = True
+
+
while not TERMINATE:
wset = []
if packetReady(bkbuf):
wset.append(tun)
- if packetReady(fwbuf):
+ if remoteok and packetReady(fwbuf) and (not bwlimit or bwfree > 0):
wset.append(remote)
rset = []
if len(fwbuf) < maxfwbuf:
rset.append(tun)
- if len(bkbuf) < maxbkbuf:
+ if remoteok and len(bkbuf) < maxbkbuf:
rset.append(remote)
+ if remoteok:
+ eset = (tun,remote)
+ else:
+ eset = (tun,)
+
try:
- rdrdy, wrdy, errs = select(rset,wset,(tun,remote),1)
+ rdrdy, wrdy, errs = select(rset,wset,eset,1)
except selecterror, e:
if e.args[0] == errno.EINTR:
# just retry
continue
+ else:
+ raise
# check for errors
if errs:
remote_fd = remote.fileno()
elif udp and remote in errs and tun not in errs:
# In UDP mode, those are always transient errors
- pass
+ # Usually, an error will imply a read-ready socket
+ # that will raise an "Connection refused" error, so
+ # disable read-readiness just for now, and retry
+ # the select
+ remoteok = False
+ continue
else:
break
+ else:
+ remoteok = True
# check to see if we can write
#rr = wr = rt = wt = 0
if remote in wrdy:
+ sent = 0
try:
try:
- while 1:
+ for x in xrange(maxbatch):
packet = pullPacket(fwbuf)
if crypto_mode:
packet = encrypt_(packet, crypter)
- rwrite(remote, packet)
+ sent += rwrite(remote, packet)
#wr += 1
if not rnonblock or not packetReady(fwbuf):
# in UDP mode, we ignore errors - packet loss man...
raise
#traceback.print_exc(file=sys.stderr)
+
+ if bwlimit:
+ bwfree -= sent
if tun in wrdy:
try:
- for x in xrange(50):
+ for x in xrange(maxtbatch):
packet = pullPacket(bkbuf)
twrite(tunfd, packet)
#wt += 1
# behind. TUN devices discard packets if their queue is full (tunkqueue), but they
# don't block either (they're always ready to write), so if we flood the device
# we'll have high packet loss.
- if not tnonblock or len(bkbuf) < tunhurry or not packetReady(bkbuf):
+ if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf):
break
else:
- # Give some time for the kernel to process the packets
- time.sleep(0)
+ if slowlocal:
+ # Give some time for the kernel to process the packets
+ time.sleep(0)
except OSError,e:
# This except handles the entire While block on PURPOSE
# as an optimization (setting a try/except block is expensive)
# check incoming data packets
if tun in rdrdy:
try:
- while 1:
+ for x in xrange(maxbatch):
packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
+ if not packet:
+ continue
#rt += 1
fwbuf.append(packet)
if remote in rdrdy:
try:
try:
- while 1:
+ for x in xrange(maxbatch):
packet = rread(remote,2000)
+
#rr += 1
if crypto_mode:
packet = decrypt_(packet, crypter)
+ if not packet:
+ continue
+ elif not packet:
+ if not udp and packet == "":
+ # Connection broken, try to reconnect (or just die)
+ raise RuntimeError, "Connection broken"
+ else:
+ continue
+
bkbuf.append(packet)
if not rnonblock or len(bkbuf) >= maxbkbuf:
elif not udp:
# in UDP mode, we ignore errors - packet loss man...
raise
- #traceback.print_exc(file=sys.stderr)
+ traceback.print_exc(file=sys.stderr)
+
+ if bwlimit:
+ tnow = tget()
+ delta = tnow - lastbwtime
+ if delta > 0.001:
+ delta = int(bwlimit * delta)
+ if delta > 0:
+ bwfree = min(bwfree+delta, maxbwfree)
+ lastbwtime = tnow
#print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
+def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+ rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+ retrydelay = 1.0
+ for i in xrange(30):
+ # TERMINATE is a array. An item can be added to TERMINATE, from
+ # outside this function to force termination of the loop
+ if TERMINATE:
+ raise OSError, "Killed"
+ try:
+ rsock.bind((local_addr, local_port))
+ break
+ except socket.error:
+ # wait a while, retry
+ print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+ time.sleep(min(30.0,retrydelay))
+ retrydelay *= 1.1
+ else:
+ rsock.bind((local_addr, local_port))
+ print >>sys.stderr, "Listening UDP at: %s:%d" % (local_addr, local_port)
+ print >>sys.stderr, "Connecting UDP to: %s:%d" % (peer_addr, peer_port)
+ rsock.connect((peer_addr, peer_port))
+ return rsock
+def udp_handshake(TERMINATE, rsock):
+ endme = False
+ def keepalive():
+ while not endme and not TERMINATE:
+ try:
+ rsock.send('')
+ except:
+ pass
+ time.sleep(1)
+ try:
+ rsock.send('')
+ except:
+ pass
+ keepalive_thread = threading.Thread(target=keepalive)
+ keepalive_thread.start()
+ retrydelay = 1.0
+ for i in xrange(30):
+ if TERMINATE:
+ raise OSError, "Killed"
+ try:
+ heartbeat = rsock.recv(10)
+ break
+ except:
+ time.sleep(min(30.0,retrydelay))
+ retrydelay *= 1.1
+ else:
+ heartbeat = rsock.recv(10)
+ endme = True
+ keepalive_thread.join()
+
+def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+ rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr,
+ peer_port)
+ udp_handshake(TERMINATE, rsock)
+ return rsock
+
+def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
+ sock = None
+ retrydelay = 1.0
+ # The peer has a firewall that prevents a response to the connect, we
+ # will be forever blocked in the connect, so we put a reasonable timeout.
+ sock.settimeout(10)
+ # We wait for
+ for i in xrange(30):
+ if stop:
+ break
+ if TERMINATE:
+ raise OSError, "Killed"
+ try:
+ rsock.connect((peer_addr, peer_port))
+ sock = rsock
+ break
+ except socket.error:
+ # wait a while, retry
+ print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
+ time.sleep(min(30.0,retrydelay))
+ retrydelay *= 1.1
+ else:
+ rsock.connect((peer_addr, peer_port))
+ sock = rsock
+ if sock:
+ sock.settimeout(0)
+ return sock
+
+def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
+ sock = None
+ retrydelay = 1.0
+ # We try to bind to the local virtual interface.
+ # It might not exist yet so we wait in a loop.
+ for i in xrange(30):
+ if stop:
+ break
+ if TERMINATE:
+ raise OSError, "Killed"
+ try:
+ lsock.bind((local_addr, local_port))
+ break
+ except socket.error:
+ # wait a while, retry
+ print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+ time.sleep(min(30.0,retrydelay))
+ retrydelay *= 1.1
+ else:
+ lsock.bind((local_addr, local_port))
+
+ # Now we wait until the other side connects.
+ # The other side might not be ready yet, so we also wait in a loop for timeouts.
+ timeout = 1
+ lsock.listen(1)
+ for i in xrange(30):
+ if TERMINATE:
+ raise OSError, "Killed"
+ rlist, wlist, xlist = select.select([lsock], [], [], timeout)
+ if stop:
+ break
+ if lsock in rlist:
+ sock,raddr = lsock.accept()
+ break
+ timeout += 5
+ return sock
+
+def tcp_handshake(TERMINATE, sock, listen, dice):
+ # we are going to use a barrier algorithm to decide wich side listen.
+ # each side will "roll a dice" and send the resulting value to the other
+ # side.
+ result = None
+ sock.settimeout(5)
+ for i in xrange(100):
+ if TERMINATE:
+ raise OSError, "Killed"
+ try:
+ hand = dice.read()
+ sock.send(hand)
+ peer_hand = sock.recv(1)
+ if hand < peer_hand:
+ if listen:
+ result = sock
+ break
+ elif hand > peer_hand:
+ if not listen:
+ result = sock
+ break
+ else:
+ dice.release()
+ dice.throw()
+ except socket.error:
+ dice.release()
+ break
+ if result:
+ sock.settimeout(0)
+ return result
+
+def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+ def listen(stop, result, lsock, dice):
+ lsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
+ if lsock:
+ lsock = tcp_handshake(TERMINATE, lsock, True, dice)
+ if lsock:
+ stop[0] = True
+ result[0] = lsock
+
+ def connect(stop, result, rsock, dice):
+ rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
+ if rsock:
+ rsock = tcp_handshake(TERMINATE, rsock, False, dice)
+ if sock:
+ stop[0] = True
+ result[0] = rsock
+
+ dice = Dice()
+ dice.throw()
+ stop = []
+ result = []
+ lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ connect_thread = threading.Thread(target=connect, args=(stop, result, rsock, dice))
+ listen_thread = threading.Thread(target=listen, args=(stop, result, lsock, dice))
+ connect_thread.start()
+ listen_thread.start()
+ connect_thread.join()
+ listen_thread.join()
+ if not result:
+ raise OSError, "Error: tcp_establish could not establish connection."
+ sock = result[0]
+ if sock == lsock:
+ rsock.close()
+ else:
+ lsock.close()
+ return sock
+
+class Dice(object):
+ def __init__(self):
+ self._condition = threading.Condition(threading.Lock())
+ self._readers = 0
+ self._value = None
+
+ def read(self):
+ self._condition.acquire()
+ try:
+ self._readers += 1
+ finally:
+ self._condition.release()
+ return self._value
+
+ def release(self):
+ self._condition.acquire()
+ try:
+ if self._readers > 0:
+ self._readers -= 1
+ if self._readers == 0:
+ self._condition.notifyAll()
+ finally:
+ self._condition.release()
+
+ def throw(self):
+ self._condition.acquire()
+ try:
+ while self._readers > 0:
+ self._condition.wait()
+ self._value = random.randint(1, 6)
+ finally:
+ self._condition.release()