Merge TCP handshake stuff
[nepi.git] / src / nepi / util / tunchannel.py
index d464b8c..61cb2e3 100644 (file)
@@ -6,6 +6,7 @@ import socket
 import threading
 import errno
 import fcntl
+import random
 import traceback
 import functools
 import collections
@@ -197,6 +198,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
     crypto_mode = False
     crypter = None
+
     try:
         if cipher_key and cipher:
             import Crypto.Cipher
@@ -233,7 +235,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         if rread is None:
             def rread(remote, maxlen, os_read=os.read):
                 return os_read(remote_fd, maxlen)
-    
     rnonblock = nonblock(remote)
     tnonblock = nonblock(tun)
     
@@ -326,6 +328,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
     
     remoteok = True
     
+    
     while not TERMINATE:
         wset = []
         if packetReady(bkbuf):
@@ -350,6 +353,8 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
             if e.args[0] == errno.EINTR:
                 # just retry
                 continue
+            else:
+                raise
 
         # check for errors
         if errs:
@@ -460,6 +465,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                 try:
                     for x in xrange(maxbatch):
                         packet = rread(remote,2000)
+                        
                         #rr += 1
                         
                         if crypto_mode:
@@ -506,6 +512,28 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         
         #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
 
+def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+    retrydelay = 1.0
+    for i in xrange(30):
+        # TERMINATE is a array. An item can be added to TERMINATE, from
+        # outside this function to force termination of the loop
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            rsock.bind((local_addr, local_port))
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        rsock.bind((local_addr, local_port))
+    print >>sys.stderr, "Listening UDP at: %s:%d" % (local_addr, local_port)
+    print >>sys.stderr, "Connecting UDP to: %s:%d" % (peer_addr, peer_port)
+    rsock.connect((peer_addr, peer_port))
+    return rsock
 
 def udp_handshake(TERMINATE, rsock):
     endme = False
@@ -537,3 +565,154 @@ def udp_handshake(TERMINATE, rsock):
     endme = True
     keepalive_thread.join()
 
+def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr,
+            peer_port)
+    udp_handshake(TERMINATE, rsock)
+    return rsock 
+
+def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
+    sock = None
+    retrydelay = 1.0
+    # The peer has a firewall that prevents a response to the connect, we 
+    # will be forever blocked in the connect, so we put a reasonable timeout.
+    rsock.settimeout(10) 
+    # We wait for 
+    for i in xrange(30):
+        if stop:
+            break
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            rsock.connect((peer_addr, peer_port))
+            sock = rsock
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        rsock.connect((peer_addr, peer_port))
+        sock = rsock
+    if sock:
+        print >>sys.stderr, "tcp_connect: TCP sock connected to remote %s:%s" % (peer_addr, peer_port)
+        sock.settimeout(0) 
+    return sock
+
+def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
+    sock = None
+    retrydelay = 1.0
+    # We try to bind to the local virtual interface. 
+    # It might not exist yet so we wait in a loop.
+    for i in xrange(30):
+        if stop:
+            break
+        if TERMINATE:
+            raise OSError, "Killed"
+        try:
+            lsock.bind((local_addr, local_port))
+            break
+        except socket.error:
+            # wait a while, retry
+            print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
+            time.sleep(min(30.0,retrydelay))
+            retrydelay *= 1.1
+    else:
+        lsock.bind((local_addr, local_port))
+
+    print >>sys.stderr, "tcp_listen: TCP sock listening in local sock %s:%s" % (local_addr, local_port)
+    # Now we wait until the other side connects. 
+    # The other side might not be ready yet, so we also wait in a loop for timeouts.
+    timeout = 1
+    lsock.listen(1)
+    for i in xrange(30):
+        if TERMINATE:
+            raise OSError, "Killed"
+        rlist, wlist, xlist = select.select([lsock], [], [], timeout)
+        if stop:
+            break
+        if lsock in rlist:
+            sock,raddr = lsock.accept()
+            print >>sys.stderr, "tcp_listen: TCP connection accepted in local sock %s:%s" % (local_addr, local_port)
+            break
+        timeout += 5
+    return sock
+
+def tcp_handshake(rsock, listen, hand):
+    # we are going to use a barrier algorithm to decide wich side listen.
+    # each side will "roll a dice" and send the resulting value to the other 
+    # side. 
+    win = False
+    rsock.settimeout(10)
+    try:
+        rsock.send(hand)
+        peer_hand = rsock.recv(1)
+        print >>sys.stderr, "tcp_handshake: hand %s, peer_hand %s" % (hand, peer_hand)
+        if hand < peer_hand:
+            if listen:
+                win = True
+        elif hand > peer_hand:
+            if not listen:
+                win = True
+    except socket.timeout:
+        pass
+    rsock.settimeout(0)
+    return win
+
+def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
+    def listen(stop, hand, lsock, lresult):
+        win = False
+        rsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
+        if rsock:
+            win = tcp_handshake(rsock, True, hand)
+            stop.append(True)
+        lresult.append((win, rsock))
+
+    def connect(stop, hand, rsock, rresult):
+        win = False
+        rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
+        if rsock:
+            win = tcp_handshake(rsock, False, hand)
+            stop.append(True)
+        rresult.append((win, rsock))
+  
+    end = False
+    sock = None
+    while not end:
+        if TERMINATE:
+            raise OSError, "Killed"
+        hand = str(random.randint(1, 6))
+        stop = []
+        lresult = []
+        rresult = []
+        lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        listen_thread = threading.Thread(target=listen, args=(stop, hand, lsock, lresult))
+        connect_thread = threading.Thread(target=connect, args=(stop, hand, rsock, rresult))
+        connect_thread.start()
+        listen_thread.start()
+        connect_thread.join()
+        listen_thread.join()
+        (lwin, lrsock) = lresult[0]
+        (rwin, rrsock) = rresult[0]
+        if not lrsock or not rrsock:
+            if not lrsock:
+                sock = rrsock
+            if not rrsock:
+                sock = lrsock
+            end = True
+        # both socket are connected
+        else:
+           if lwin:
+                sock = lrsock
+                end = True
+           elif rwin: 
+                sock = rrsock
+                end = True
+
+    if not sock:
+        raise OSError, "Error: tcp_establish could not establish connection."
+    return sock
+
+