Tunchannel fixes galore:
authorClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Wed, 17 Aug 2011 13:11:08 +0000 (15:11 +0200)
committerClaudio-Daniel Freire <claudio-daniel.freire@inria.fr>
Wed, 17 Aug 2011 13:11:08 +0000 (15:11 +0200)
 * Fix lockups at shutdown that would keep the tun_connect.py script at 100% forever
 * Make sure UDP tunnels connect (at least at first), by performing a small handshake
 * Fix tunproto.py to effectively raise exceptions on misconnection (was swallowing them)
 * Fix tunproto.py to effectively wait for listening/connected endpoints
 * Filter fixes: ingress filters should be applied after decryption

src/nepi/testbeds/planetlab/scripts/tun_connect.py
src/nepi/testbeds/planetlab/tunproto.py
src/nepi/util/tunchannel.py

index 3e987ec..28356f7 100644 (file)
@@ -251,11 +251,11 @@ def tunopen(tun_path, tun_name):
     return tun
 
 def tunclose(tun_path, tun_name, tun):
-    if tun_path.isdigit():
+    if tun_path and tun_path.isdigit():
         # close TUN fd
         os.close(int(tun_path))
         tun.close()
-    else:
+    elif tun:
         # close TUN object
         tun.close()
 
@@ -359,7 +359,7 @@ def pl_tuntap_namealloc(kind, tun_path, tun_name):
     global _name_reservation
     # Serialize access
     lockfile = open("/tmp/nepi-tun-connect.lock", "a")
-    _name_reservation = lock = HostLock(lockfile)
+    lock = HostLock(lockfile)
     
     # We need to do this, fd_tuntap is the only one who can
     # tell us our slice id (this script runs as root, so no uid),
@@ -382,6 +382,8 @@ def pl_tuntap_namealloc(kind, tun_path, tun_name):
     else:
         raise RuntimeError, "Could not assign interface name"
     
+    _name_reservation = lock
+    
     return None, name
 
 def pl_vif_start(tun_path, tun_name):
@@ -389,7 +391,6 @@ def pl_vif_start(tun_path, tun_name):
 
     out = []
     def outreader():
-        stdout = open("/vsys/vif_up.out","r")
         out.append(stdout.read())
         stdout.close()
         time.sleep(1)
@@ -400,6 +401,7 @@ def pl_vif_start(tun_path, tun_name):
     _name_reservation = None
     
     stdin = open("/vsys/vif_up.in","w")
+    stdout = open("/vsys/vif_up.out","r")
 
     t = threading.Thread(target=outreader)
     t.start()
@@ -428,7 +430,6 @@ def pl_vif_start(tun_path, tun_name):
 def pl_vif_stop(tun_path, tun_name):
     out = []
     def outreader():
-        stdout = open("/vsys/vif_down.out","r")
         out.append(stdout.read())
         stdout.close()
         
@@ -449,6 +450,7 @@ def pl_vif_stop(tun_path, tun_name):
     lock = HostLock(lockfile)
 
     stdin = open("/vsys/vif_down.in","w")
+    stdout = open("/vsys/vif_down.out","r")
     
     t = threading.Thread(target=outreader)
     t.start()
@@ -610,6 +612,8 @@ try:
         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
         retrydelay = 1.0
         for i in xrange(30):
+            if TERMINATE:
+                raise OSError, "Killed"
             try:
                 sock.connect(options.pass_fd)
                 break
@@ -648,6 +652,8 @@ try:
             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             retrydelay = 1.0
             for i in xrange(30):
+                if TERMINATE:
+                    raise OSError, "Killed"
                 try:
                     rsock.bind((hostaddr,options.udp))
                     break
@@ -662,6 +668,37 @@ try:
         else:
             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
             raise AssertionError, "Error: need a remote endpoint in UDP mode"
+        
+        # Wait for other peer
+        endme = False
+        def keepalive():
+            while not endme and not TERMINATE:
+                try:
+                    rsock.send('')
+                except:
+                    pass
+                time.sleep(1)
+            try:
+                rsock.send('')
+            except:
+                pass
+        keepalive_thread = threading.Thread(target=keepalive)
+        keepalive_thread.start()
+        retrydelay = 1.0
+        for i in xrange(30):
+            if TERMINATE:
+                raise OSError, "Killed"
+            try:
+                heartbeat = rsock.recv(10)
+                break
+            except:
+                time.sleep(min(30.0,retrydelay))
+                retrydelay *= 1.1
+        else:
+            heartbeat = rsock.recv(10)
+        endme = True
+        keepalive_thread.join()
+        
         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
     else:
         # connect to remote endpoint
@@ -670,6 +707,8 @@ try:
             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
             retrydelay = 1.0
             for i in xrange(30):
+                if TERMINATE:
+                    raise OSError, "Killed"
                 try:
                     rsock.connect((remaining_args[0],options.port))
                     break
@@ -685,6 +724,8 @@ try:
             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
             retrydelay = 1.0
             for i in xrange(30):
+                if TERMINATE:
+                    raise OSError, "Killed"
                 try:
                     lsock.bind((hostaddr,options.port))
                     break
@@ -699,13 +740,6 @@ try:
             rsock,raddr = lsock.accept()
         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
 
-    if not options.no_capture:
-        # Launch a tcpdump subprocess, to capture and dump packets.
-        # Make sure to catch sigterm and kill the tcpdump as well
-        tcpdump = subprocess.Popen(
-            ["tcpdump","-l","-n","-i",tun_name, "-s", "4096"]
-            + ["-w",options.pcap_capture,"-U"] * bool(options.pcap_capture) )
-    
     if filter_init:
         filter_local, filter_remote = filter_init()
         
@@ -723,6 +757,13 @@ try:
         filter_thread.start()
     
     print >>sys.stderr, "Connected"
+
+    if not options.no_capture:
+        # Launch a tcpdump subprocess, to capture and dump packets.
+        # Make sure to catch sigterm and kill the tcpdump as well
+        tcpdump = subprocess.Popen(
+            ["tcpdump","-l","-n","-i",tun_name, "-s", "4096"]
+            + ["-w",options.pcap_capture,"-U"] * bool(options.pcap_capture) )
     
     # Try to give us high priority
     try:
@@ -769,6 +810,12 @@ try:
                 accept_remote = accept_packet,
                 slowlocal = False)
         
+        localthread = threading.Thread(target=localside)
+        remotethread = threading.Thread(target=remoteside)
+        localthread.start()
+        remotethread.start()
+        localthread.join()
+        remotethread.join()
 
 finally:
     try:
index 3f003ba..ae0b069 100644 (file)
@@ -31,6 +31,7 @@ class TunProtoBase(object):
         
         self._launcher = None
         self._started = False
+        self._started_listening = False
         self._starting = False
         self._pid = None
         self._ppid = None
@@ -296,6 +297,7 @@ class TunProtoBase(object):
         # Tunnel should be still running in its node
         # Just check its pidfile and we're done
         self._started = True
+        self._started_listening = True
         self.checkpid()
     
     def _launch_and_wait(self, *p, **kw):
@@ -318,11 +320,13 @@ class TunProtoBase(object):
             time.sleep(1.0)
         
         # Wait for the connection to be established
+        retrytime = 2.0
         for spin in xrange(30):
             if self.status() != rspawn.RUNNING:
                 self._logger.warn("FAILED TO CONNECT! %s", self)
                 break
             
+            # Connected?
             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
                 "cd %(home)s ; grep -c Connected capture" % dict(
                     home = server.shell_escape(self.home_path)),
@@ -333,16 +337,43 @@ class TunProtoBase(object):
                 ident_key = local.node.ident_path,
                 server_key = local.node.server_key
                 )
-            
-            if proc.wait():
-                break
-            
-            if out.strip() != '0':
+            proc.wait()
+
+            if out.strip() == '1':
                 break
+
+            # At least listening?
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "cd %(home)s ; grep -c Listening capture" % dict(
+                    home = server.shell_escape(self.home_path)),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key
+                )
+            proc.wait()
+
+            if out.strip() == '1':
+                self._started_listening = True
             
-            time.sleep(1.0)
+            time.sleep(min(30.0, retrytime))
+            retrytime *= 1.1
         else:
-            self._logger.warn("FAILED TO CONNECT! %s", self)
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "cat %(home)s/capture" % dict(
+                    home = server.shell_escape(self.home_path)),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key
+                )
+            proc.wait()
+
+            raise RuntimeError, "FAILED TO CONNECT %s: %s%s" % (self,out,err)
     
     @property
     def if_name(self):
@@ -419,12 +450,24 @@ class TunProtoBase(object):
     def async_launch_wait(self):
         if self._launcher:
             self._launcher.join()
-            if not self._started:
+
+            if self._launcher._exc:
+                exctyp,exval,exctrace = self._launcher._exc[0]
+                raise exctyp,exval,exctrace
+            elif not self._started:
+                raise RuntimeError, "Failed to launch TUN forwarder"
+        elif not self._started:
+            self.launch()
+
+    def async_launch_wait_listening(self):
+        if self._launcher:
+            for x in xrange(180):
                 if self._launcher._exc:
                     exctyp,exval,exctrace = self._launcher._exc[0]
                     raise exctyp,exval,exctrace
-                else:
-                    raise RuntimeError, "Failed to launch TUN forwarder"
+                elif self._started and self._started_listening:
+                    break
+                time.sleep(1)
         elif not self._started:
             self.launch()
 
@@ -473,7 +516,7 @@ class TunProtoBase(object):
                 )
             return status
     
-    def kill(self):
+    def kill(self, nowait = True):
         local = self.local()
         
         if not local:
@@ -493,7 +536,7 @@ class TunProtoBase(object):
                 ident_key = local.node.ident_path,
                 server_key = local.node.server_key,
                 sudo = True,
-                nowait = True
+                nowait = nowait
                 )
     
     def waitkill(self):
@@ -505,6 +548,8 @@ class TunProtoBase(object):
                 break
             time.sleep(interval)
             interval = min(30.0, interval * 1.1)
+        else:
+            self.kill(nowait=False)
 
         if self.if_name:
             for i in xrange(30):
@@ -674,7 +719,7 @@ class TunProtoTCP(TunProtoBase):
             # make sure our peer is ready
             peer = self.peer()
             if peer and peer.peer_proto_impl:
-                peer.peer_proto_impl.async_launch_wait()
+                peer.peer_proto_impl.async_launch_wait_listening()
             
             if not self._started:
                 self.async_launch('tcp', False)
index 5804832..d622170 100644 (file)
@@ -167,15 +167,16 @@ def encrypt(packet, crypter, len=len, padmap=_padmap):
     return crypter.encrypt(packet)
 
 def decrypt(packet, crypter, ord=ord):
-    # decrypt
-    packet = crypter.decrypt(packet)
-    
-    # un-pad
-    padding = ord(packet[-1])
-    if not (0 < padding <= crypter.block_size):
-        # wrong padding
-        raise RuntimeError, "Truncated packet"
-    packet = packet[:-padding]
+    if packet:
+        # decrypt
+        packet = crypter.decrypt(packet)
+        
+        # un-pad
+        padding = ord(packet[-1])
+        if not (0 < padding <= crypter.block_size):
+            # wrong padding
+            raise RuntimeError, "Truncated packet"
+        packet = packet[:-padding]
     
     return packet
 
@@ -261,6 +262,10 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         twrite = os.write
         tread = os.read
     
+    encrypt_ = encrypt
+    decrypt_ = decrypt
+    xrange_ = xrange
+
     if accept_local is not None:
         def tread(fd, maxlen, _tread=tread, accept=accept_local):
             packet = _tread(fd, maxlen)
@@ -270,12 +275,20 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                 return None
 
     if accept_remote is not None:
-        def rread(fd, maxlen, _rread=rread, accept=accept_remote):
-            packet = _rread(fd, maxlen)
-            if accept(packet, 1):
-                return packet
-            else:
-                return None
+        if crypto_mode:
+            def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
+                packet = decrypt_(packet, crypter)
+                if accept(packet, 1):
+                    return packet
+                else:
+                    return None
+        else:
+            def rread(fd, maxlen, _rread=rread, accept=accept_remote):
+                packet = _rread(fd, maxlen)
+                if accept(packet, 1):
+                    return packet
+                else:
+                    return None
     
     # Limited frame parsing, to preserve packet boundaries.
     # Which is needed, since /dev/net/tun is unbuffered
@@ -296,23 +309,29 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
     tunfd = tun.fileno()
     os_read = os.read
     os_write = os.write
-    encrypt_ = encrypt
-    decrypt_ = decrypt
+    
+    remoteok = True
+    
     while not TERMINATE:
         wset = []
         if packetReady(bkbuf):
             wset.append(tun)
-        if packetReady(fwbuf):
+        if remoteok and packetReady(fwbuf):
             wset.append(remote)
         
         rset = []
         if len(fwbuf) < maxfwbuf:
             rset.append(tun)
-        if len(bkbuf) < maxbkbuf:
+        if remoteok and len(bkbuf) < maxbkbuf:
             rset.append(remote)
         
+        if remoteok:
+            eset = (tun,remote)
+        else:
+            eset = (tun,)
+        
         try:
-            rdrdy, wrdy, errs = select(rset,wset,(tun,remote),1)
+            rdrdy, wrdy, errs = select(rset,wset,eset,1)
         except selecterror, e:
             if e.args[0] == errno.EINTR:
                 # just retry
@@ -326,16 +345,23 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                     remote_fd = remote.fileno()
             elif udp and remote in errs and tun not in errs:
                 # In UDP mode, those are always transient errors
-                pass
+                # Usually, an error will imply a read-ready socket
+                # that will raise an "Connection refused" error, so
+                # disable read-readiness just for now, and retry
+                # the select
+                remoteok = False
+                continue
             else:
                 break
+        else:
+            remoteok = True
         
         # check to see if we can write
         #rr = wr = rt = wt = 0
         if remote in wrdy:
             try:
                 try:
-                    while 1:
+                    for x in xrange(2000):
                         packet = pullPacket(fwbuf)
 
                         if crypto_mode:
@@ -396,9 +422,9 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         # check incoming data packets
         if tun in rdrdy:
             try:
-                while 1:
+                for x in xrange(2000):
                     packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
-                    if packet is None:
+                    if not packet:
                         continue
                     #rt += 1
                     fwbuf.append(packet)
@@ -414,14 +440,21 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
         if remote in rdrdy:
             try:
                 try:
-                    while 1:
+                    for x in xrange(2000):
                         packet = rread(remote,2000)
-                        if packet is None:
-                            continue
                         #rr += 1
                         
                         if crypto_mode:
                             packet = decrypt_(packet, crypter)
+                            if not packet:
+                                continue
+                        elif not packet:
+                            if not udp and packet == "":
+                                # Connection broken, try to reconnect (or just die)
+                                raise RuntimeError, "Connection broken"
+                            else:
+                                continue
+
                         bkbuf.append(packet)
                         
                         if not rnonblock or len(bkbuf) >= maxbkbuf:
@@ -442,7 +475,7 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                 elif not udp:
                     # in UDP mode, we ignore errors - packet loss man...
                     raise
-                #traceback.print_exc(file=sys.stderr)
+                traceback.print_exc(file=sys.stderr)
         
         #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)