Merge TCP handshake stuff
[nepi.git] / src / nepi / util / tunchannel.py
index fc627df..61cb2e3 100644 (file)
@@ -119,15 +119,15 @@ def _pullPacket(buf, ether_mode=False, len=len):
             rv = buf.popleft()
         return rv
 
-def etherStrip(buf):
+def etherStrip(buf, buffer=buffer, len=len):
     if len(buf) < 14:
         return ""
     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
         # tagged ethernet frame
-        return buf[18:]
+        return buffer(buf, 18)
     elif buf[12:14] == '\x08\x00':
         # untagged ethernet frame
-        return buf[14:]
+        return buffer(buf, 14)
     else:
         return ""
 
@@ -576,7 +576,7 @@ def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
     retrydelay = 1.0
     # The peer has a firewall that prevents a response to the connect, we 
     # will be forever blocked in the connect, so we put a reasonable timeout.
-    sock.settimeout(10) 
+    rsock.settimeout(10) 
     # We wait for 
     for i in xrange(30):
         if stop:
@@ -596,6 +596,7 @@ def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
         rsock.connect((peer_addr, peer_port))
         sock = rsock
     if sock:
+        print >>sys.stderr, "tcp_connect: TCP sock connected to remote %s:%s" % (peer_addr, peer_port)
         sock.settimeout(0) 
     return sock
 
@@ -620,6 +621,7 @@ def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
     else:
         lsock.bind((local_addr, local_port))
 
+    print >>sys.stderr, "tcp_listen: TCP sock listening in local sock %s:%s" % (local_addr, local_port)
     # Now we wait until the other side connects. 
     # The other side might not be ready yet, so we also wait in a loop for timeouts.
     timeout = 1
@@ -632,109 +634,85 @@ def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
             break
         if lsock in rlist:
             sock,raddr = lsock.accept()
+            print >>sys.stderr, "tcp_listen: TCP connection accepted in local sock %s:%s" % (local_addr, local_port)
             break
         timeout += 5
     return sock
 
-def tcp_handshake(TERMINATE, sock, listen, dice):
+def tcp_handshake(rsock, listen, hand):
     # we are going to use a barrier algorithm to decide wich side listen.
     # each side will "roll a dice" and send the resulting value to the other 
     # side. 
-    result = None
-    sock.settimeout(5)
-    for i in xrange(100):
-        if TERMINATE:
-            raise OSError, "Killed"
-        try:
-            hand = dice.read()
-            sock.send(hand)
-            peer_hand = sock.recv(1)
-            if hand < peer_hand:
-                if listen:
-                    result = sock
-                break   
-            elif hand > peer_hand:
-                if not listen:
-                    result = sock
-                break
-            else:
-                dice.release()
-                dice.throw()
-        except socket.error:
-            dice.release()
-            break
-    if result:
-        sock.settimeout(0)
-    return result
+    win = False
+    rsock.settimeout(10)
+    try:
+        rsock.send(hand)
+        peer_hand = rsock.recv(1)
+        print >>sys.stderr, "tcp_handshake: hand %s, peer_hand %s" % (hand, peer_hand)
+        if hand < peer_hand:
+            if listen:
+                win = True
+        elif hand > peer_hand:
+            if not listen:
+                win = True
+    except socket.timeout:
+        pass
+    rsock.settimeout(0)
+    return win
 
 def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
-    def listen(stop, result, lsock, dice):
-        lsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
-        if lsock:
-            lsock = tcp_handshake(TERMINATE, lsock, True, dice)
-            if lsock:
-                stop[0] = True
-                result[0] = lsock
-
-    def connect(stop, result, rsock, dice):
+    def listen(stop, hand, lsock, lresult):
+        win = False
+        rsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
+        if rsock:
+            win = tcp_handshake(rsock, True, hand)
+            stop.append(True)
+        lresult.append((win, rsock))
+
+    def connect(stop, hand, rsock, rresult):
+        win = False
         rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
         if rsock:
-            rsock = tcp_handshake(TERMINATE, rsock, False, dice)
-            if sock:
-                stop[0] = True
-                result[0] = rsock
-    
-    dice = Dice()
-    dice.throw()
-    stop = []
-    result = []
-    lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-    rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-    connect_thread = threading.Thread(target=connect, args=(stop, result, rsock, dice))
-    listen_thread = threading.Thread(target=listen, args=(stop, result, lsock, dice))
-    connect_thread.start()
-    listen_thread.start()
-    connect_thread.join()
-    listen_thread.join()
-    if not result:
+            win = tcp_handshake(rsock, False, hand)
+            stop.append(True)
+        rresult.append((win, rsock))
+  
+    end = False
+    sock = None
+    while not end:
+        if TERMINATE:
+            raise OSError, "Killed"
+        hand = str(random.randint(1, 6))
+        stop = []
+        lresult = []
+        rresult = []
+        lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        listen_thread = threading.Thread(target=listen, args=(stop, hand, lsock, lresult))
+        connect_thread = threading.Thread(target=connect, args=(stop, hand, rsock, rresult))
+        connect_thread.start()
+        listen_thread.start()
+        connect_thread.join()
+        listen_thread.join()
+        (lwin, lrsock) = lresult[0]
+        (rwin, rrsock) = rresult[0]
+        if not lrsock or not rrsock:
+            if not lrsock:
+                sock = rrsock
+            if not rrsock:
+                sock = lrsock
+            end = True
+        # both socket are connected
+        else:
+           if lwin:
+                sock = lrsock
+                end = True
+           elif rwin: 
+                sock = rrsock
+                end = True
+
+    if not sock:
         raise OSError, "Error: tcp_establish could not establish connection."
-    sock = result[0]
-    if sock == lsock:
-        rsock.close()
-    else:
-        lsock.close()
     return sock
 
-class Dice(object):
-    def __init__(self):
-        self._condition = threading.Condition(threading.Lock())
-        self._readers = 0
-        self._value = None
-
-    def read(self):
-        self._condition.acquire()
-        try:
-            self._readers += 1
-        finally:
-            self._condition.release()
-        return self._value
-
-    def release(self):
-        self._condition.acquire()
-        try:
-            if self._readers > 0:
-                self._readers -= 1
-            if self._readers == 0:
-                self._condition.notifyAll()
-        finally:
-            self._condition.release()
-
-    def throw(self):
-        self._condition.acquire()
-        try:
-            while self._readers > 0:
-                self._condition.wait()
-            self._value = random.randint(1, 6)
-        finally:
-            self._condition.release()