Added TCP-handshake for TunChannel and tun_connect.py
[nepi.git] / src / nepi / util / tunchannel_impl.py
index 1a38aef..c358429 100644 (file)
@@ -7,7 +7,7 @@ import select
 import weakref
 import time
 
-from tunchannel import tun_fwd, udp_handshake
+from tunchannel import tun_fwd, udp_establish, tcp_establish
 
 class TunChannel(object):
     """
@@ -30,19 +30,12 @@ class TunChannel(object):
         
         tun_key: the agreed upon encryption key.
         
-        listen: if set to True (and in TCP mode), it marks a
-            listening endpoint. Be certain that any TCP connection
-            is made between a listening and a non-listening
-            endpoint, or it won't work.
-        
         with_pi: set if the incoming packet stream (see tun_socket)
             contains PI headers - if so, they will be stripped.
         
         ethernet_mode: set if the incoming packet stream is
             composed of ethernet frames (as opposed of IP packets).
         
-        udp: set to use UDP datagrams instead of TCP connections.
-        
         tun_socket: a socket or file object that can be read
             from and written to. Packets will be read when available,
             remote packets will be forwarded as writes.
@@ -57,7 +50,6 @@ class TunChannel(object):
     
     def __init__(self):
         # Some operational attributes
-        self.listen = False
         self.ethernet_mode = True
         self.with_pi = False
         
@@ -85,7 +77,7 @@ class TunChannel(object):
         self._exc = [] # exception store, to relay exceptions from the forwarder thread
         self._connected = threading.Event()
         self._forwarder_thread = None
-        
+       
         # trace to stderr
         self.stderr = sys.stderr
         
@@ -109,23 +101,20 @@ class TunChannel(object):
             self.tun_cipher,
         )
 
-    def Prepare(self):
-        if self.tun_proto:
-            udp = self.tun_proto == "udp"
-            if not udp and self.listen and not self._forwarder_thread:
-                if self.listen or (self.peer_addr and self.peer_port and self.peer_proto):
-                    self._launch()
-    
-    def Setup(self):
+    def launch(self):
+        # self.tun_proto is only set if the channel is connected
+        # launch has to be a no-op in unconnected channels because
+        # it is called at configuration time, which for cross connections
+        # happens before connection.
         if self.tun_proto:
             if not self._forwarder_thread:
                 self._launch()
     
-    def Cleanup(self):
+    def cleanup(self):
         if self._forwarder_thread:
-            self.Kill()
+            self.kill()
 
-    def Wait(self):
+    def wait(self):
         if self._forwarder_thread:
             self._connected.wait()
             for exc in self._exc:
@@ -133,7 +122,7 @@ class TunChannel(object):
                 eTyp, eVal, eLoc = exc
                 raise eTyp, eVal, eLoc
 
-    def Kill(self):    
+    def kill(self):    
         if self._forwarder_thread:
             if not self._terminate:
                 self._terminate.append(None)
@@ -186,66 +175,31 @@ class TunChannel(object):
         if local_cipher != peer_cipher:
             raise RuntimeError, "Peering cipher mismatch: %s != %s" % (local_cipher, peer_cipher)
         
-        udp = local_proto == 'udp'
-        listen = self.listen
-
-        if (udp or not listen) and (not peer_port or not peer_addr):
+        if not peer_port or not peer_addr:
             raise RuntimeError, "Misconfigured peer for: %s" % (self,)
 
-        if (udp or listen) and (not local_port or not local_addr):
+        if not local_port or not local_addr:
             raise RuntimeError, "Misconfigured TUN: %s" % (self,)
         
         TERMINATE = self._terminate
         cipher_key = self.tun_key
         tun = self.tun_socket
-        
+        udp = local_proto == 'udp'
+
         if not tun:
             raise RuntimeError, "Unconnected TUN channel %s" % (self,)
-        
-        if udp:
-            # listen on udp port
-            rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            for i in xrange(30):
-                try:
-                    rsock.bind((local_addr,local_port))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    time.sleep(1)
-            else:
-                rsock.bind((local_addr,local_port))
-            rsock.connect((peer_addr,peer_port))
-            udp_handshake(TERMINATE, rsock)
+
+        if local_proto == 'udp':
+            rsock = udp_establish(TERMINATE, local_addr, local_port, 
+                    peer_addr, peer_port)
             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
-        elif listen:
-            # accept tcp connections
-            lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            for i in xrange(30):
-                try:
-                    lsock.bind((local_addr,local_port))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    time.sleep(1)
-            else:
-                lsock.bind((local_addr,local_port))
-            lsock.listen(1)
-            rsock,raddr = lsock.accept()
+        elif local_proto == 'tcp':
+            rsock = tcp_establish(TERMINATE, local_addr, local_port,
+                    peer_addr, peer_port)
             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
         else:
-            # connect to tcp server
-            rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            for i in xrange(30):
-                try:
-                    rsock.connect((peer_addr,peer_port))
-                    break
-                except socket.error:
-                    # wait a while, retry
-                    time.sleep(1)
-            else:
-                rsock.connect((peer_addr,peer_port))
-            remote = os.fdopen(rsock.fileno(), 'r+b', 0)
-        
+            raise RuntimeError, "Bad protocol for %s: %r" % (self,local_proto)
+
         # notify that we're ready
         self._connected.set()
         
@@ -305,20 +259,6 @@ def preconfigure_tunchannel(testbed_instance, guid):
     if not element.tun_port and element.tun_addr:
         element.tun_port = 15000 + int(guid)
 
-    # First-phase setup
-    if element.peer_proto:
-        # cross tun
-        if not element.tun_addr or not element.tun_port:
-            listening = True
-        elif not element.peer_addr or not element.peer_port:
-            listening = True
-        else:
-            # both have addresses...
-            # ...the one with the lesser address listens
-            listening = element.tun_addr < element.peer_addr
-        element.listen = listening
-        element.Prepare()
-
 def postconfigure_tunchannel(testbed_instance, guid):
     """
     TunChannel preconfiguration.
@@ -329,10 +269,8 @@ def postconfigure_tunchannel(testbed_instance, guid):
     Should be adequate for most implementations.
     """
     element = testbed_instance._elements[guid]
-    
-    # Second-phase setup
-    element.Setup()
-
+   
+    element.launch()
 
 def crossconnect_tunchannel_peer_init(proto, testbed_instance, tun_guid, peer_data,
         preconfigure_tunchannel = preconfigure_tunchannel):
@@ -353,7 +291,7 @@ def crossconnect_tunchannel_peer_init(proto, testbed_instance, tun_guid, peer_da
     tun.peer_cipher = peer_data.get("tun_cipher")
     tun.tun_key = min(tun.tun_key, peer_data.get("tun_key"))
     tun.tun_proto = proto
-    
+  
     preconfigure_tunchannel(testbed_instance, tun_guid)
 
 def crossconnect_tunchannel_peer_compl(proto, testbed_instance, tun_guid, peer_data,
@@ -374,12 +312,10 @@ def crossconnect_tunchannel_peer_compl(proto, testbed_instance, tun_guid, peer_d
     tun.peer_proto = peer_data.get("tun_proto") or proto
     tun.peer_port = peer_data.get("tun_port")
     tun.peer_cipher = peer_data.get("tun_cipher")
-    
+   
     postconfigure_tunchannel(testbed_instance, tun_guid)
 
-    
-
-def wait_tunchannel(testbed_instance, guid):
+def prestart_tunchannel(testbed_instance, guid):
     """
     Wait for the channel forwarder to be up and running.
     
@@ -387,5 +323,5 @@ def wait_tunchannel(testbed_instance, guid):
     be certain to start TunChannels before applications that might require them.
     """
     element = testbed_instance.elements[guid]
-    element.Wait()
+    element.wait()