Added support for suspending and resuming traffic on PlanetLab TAP/TUN interfaces.
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Wed, 28 Mar 2012 20:16:07 +0000 (22:16 +0200)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Wed, 28 Mar 2012 20:16:07 +0000 (22:16 +0200)
src/nepi/testbeds/planetlab/execute.py
src/nepi/testbeds/planetlab/interfaces.py
src/nepi/testbeds/planetlab/scripts/tun_connect.py
src/nepi/testbeds/planetlab/tunproto.py
src/nepi/util/tunchannel.py
test/util/tunchannel.py [new file with mode: 0644]
tunbench.py

index 65ee9c4..a815158 100644 (file)
@@ -531,6 +531,12 @@ class TestbedController(testbed_impl.TestbedController):
         # TODO: take on account schedule time for the task
         element = self._elements[guid]
         if element:
+            if name == "up":
+                if value == True:
+                    element.if_up()
+                else:
+                    element.if_down()
+
             try:
                 setattr(element, name, value)
             except:
index 7823900..60295e2 100644 (file)
@@ -199,6 +199,14 @@ class TunIface(object):
         if self.peer_proto_impl:
             return self.peer_proto_impl.if_name
 
+    def if_up(self):
+        if self.peer_proto_impl:
+            return self.peer_proto_impl.if_up()
+
+    def if_down(self):
+        if self.peer_proto_impl:
+            return self.peer_proto_impl.if_down()
+
     def routes_here(self, route):
         """
         Returns True if the route should be attached to this interface
index df864af..0812d7c 100644 (file)
@@ -514,6 +514,7 @@ def pl_vif_stop(tun_path, tun_name):
 
 def tun_fwd(tun, remote, reconnect = None, accept_local = None, accept_remote = None, slowlocal = True, bwlimit = None):
     global TERMINATE
+    global SUSPEND
     
     tunqueue = options.vif_txqueuelen or 1000
     tunkqueue = 500
@@ -526,6 +527,7 @@ def tun_fwd(tun, remote, reconnect = None, accept_local = None, accept_remote =
         cipher_key = options.cipher_key,
         udp = options.protocol == 'udp',
         TERMINATE = TERMINATE,
+        SUSPEND = SUSPEND,
         stderr = None,
         reconnect = reconnect,
         tunqueue = tunqueue,
@@ -668,6 +670,20 @@ def _finalize(sig,frame):
     TERMINATE.append(None)
 signal.signal(signal.SIGTERM, _finalize)
 
+# SIGUSR1 suspends forwading, SIGUSR2 resumes forwarding
+SUSPEND = []
+def _suspend(sig,frame):
+    global SUSPEND
+    if not SUSPEND:
+        SUSPEND.append(None)
+signal.signal(signal.SIGUSR1, _suspend)
+
+def _resume(sig,frame):
+    global SUSPEND
+    if SUSPEND:
+        SUSPEND.remove(None)
+signal.signal(signal.SIGUSR2, _resume)
+
 try:
     tcpdump = None
     reconnect = None
index be09ee1..0874ac0 100644 (file)
@@ -35,6 +35,10 @@ class TunProtoBase(object):
         self._ppid = None
         self._if_name = None
 
+        self._pointopoint = None
+        self._netprefix = None
+        self._address = None
+
         # Logging
         self._logger = logging.getLogger('nepi.testbeds.planetlab')
     
@@ -193,18 +197,18 @@ class TunProtoBase(object):
         
         local_port = self.port
         local_cap  = local.capture
-        local_addr = local.address
-        local_mask = local.netprefix
+        self._address = local_addr = local.address
+        self._netprefix = local_mask = local.netprefix
         local_snat = local.snat
         local_txq  = local.txqueuelen
-        local_p2p  = local.pointopoint
+        self._pointopoint = local_p2p  = local.pointopoint
         local_cipher=local.tun_cipher
         local_mcast= local.multicast
         local_bwlim= local.bwlimit
         local_mcastfwd = local.multicast_forwarder
         
         if not local_p2p and hasattr(peer, 'address'):
-            local_p2p = peer.address
+            self._pointopoint = local_p2p = peer.address
 
         if check_proto != peer_proto:
             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (check_proto, peer_proto)
@@ -552,7 +556,46 @@ class TunProtoBase(object):
                         timeout = 60,
                         err_on_timeout = False
                         )
-                    proc.wait()    
+                    proc.wait()
+
+    def if_down(self):
+        # TODO!!! need to set the vif down with vsys/vif_down.in ... which 
+        # doesn't currently work.
+        local = self.local()
+        
+        if local:
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "sudo -S bash -c 'kill -s USR1 %d'" % (self._pid,),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key,
+                timeout = 60,
+                err_on_timeout = False
+                )
+            proc.wait()
+
+    def if_up(self):
+        # TODO!!! need to set the vif up with vsys/vif_up.in ... which 
+        # doesn't currently work.
+        local = self.local()
+        
+        if local:
+            (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
+                "sudo -S bash -c 'kill -s USR2 %d'" % (self._pid,),
+                host = local.node.hostname,
+                port = None,
+                user = local.node.slicename,
+                agent = None,
+                ident_key = local.node.ident_path,
+                server_key = local.node.server_key,
+                timeout = 60,
+                err_on_timeout = False
+                )
+            proc.wait()    
+
     _TRACEMAP = {
         # tracename : (remotename, localname)
         'packets' : ('capture','capture'),
index 735d08c..5e2e3ca 100644 (file)
@@ -26,6 +26,7 @@ tagtype = {
     '8864' : 'PPPoE',
     '86dd' : 'ipv6',
 }
+
 def etherProto(packet, len=len):
     if len(packet) > 14:
         if packet[12] == "\x81" and packet[13] == "\x00":
@@ -36,6 +37,7 @@ def etherProto(packet, len=len):
             return packet[12:14]
     # default: ip
     return "\x08\x00"
+
 def formatPacket(packet, ether_mode):
     if ether_mode:
         stripped_packet = etherStrip(packet)
@@ -194,9 +196,13 @@ def nonblock(fd):
         # Just ignore
         return False
 
-def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
-        cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None,
-        len=len, max=max, min=min, buffer=buffer, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
+def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, SUSPEND,
+        stderr = sys.stderr, reconnect = None, rwrite = None, rread = None,
+        tunqueue = 1000, tunkqueue = 1000, cipher = 'AES', accept_local = None, 
+        accept_remote = None, slowlocal = True, queueclass = None, 
+        bwlimit = None, len = len, max = max, min = min, buffer = buffer,
+        OSError = OSError, select = select.select, selecterror = select.error, 
+        os = os, socket = socket,
         retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
     crypto_mode = False
     crypter = None
@@ -343,6 +349,11 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
     
     
     while not TERMINATE:
+        # The SUSPEND flag has been set. This means we need to wait on
+        # the SUSPEND condition until it is released.
+        while SUSPEND:
+            time.sleep(0.5)
+
         wset = []
         if packetReady(bkbuf):
             wset.append(tun)
@@ -368,7 +379,12 @@ def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr
                 continue
             else:
                 traceback.print_exc(file=sys.stderr)
-                raise
+                # If the SUSPEND flag has been set, then the TUN will be in a bad
+                # state and the select error should be ignores.
+                if SUSPEND:
+                    continue
+                else:
+                    raise
 
         # check for errors
         if errs:
diff --git a/test/util/tunchannel.py b/test/util/tunchannel.py
new file mode 100644 (file)
index 0000000..7bd3176
--- /dev/null
@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+
+from nepi.util import tunchannel
+import socket
+import time
+import threading
+import unittest
+
+class TunnChannelTestCase(unittest.TestCase):
+    def test_send_suspend_terminate(self):
+        def tun_fwd(local, remote, TERMINATE, SUSPEND, STOPPED):
+            tunchannel.tun_fwd(local, remote, True, True, None, True,
+                TERMINATE, SUSPEND, None)
+            STOPPED.append(None)
+    
+        TERMINATE = []
+        SUSPEND = []
+        STOPPED = []
+    
+        s1, s2 = socket.socketpair()
+        s3, s4 = socket.socketpair()
+        s4.settimeout(2.0)
+
+        t = threading.Thread(target=tun_fwd, args=[s2, s3, TERMINATE, SUSPEND, STOPPED])
+        t.start()
+
+        txt = "0000|received"
+        s1.send(txt)
+        rtxt = s4.recv(len(txt))
+
+        self.assertTrue(rtxt == txt[4:])
+        
+        # Let's try to suspend execution now
+        cond = threading.Condition()
+        SUSPEND.insert(0, cond)
+
+        txt = "0000|suspended"
+        s1.send(txt)
+        
+        rtxt = "timeout"
+        try:
+            rtxt = s4.recv(len(txt))
+        except socket.timeout:
+            pass
+                    
+        self.assertTrue(rtxt == "timeout")
+
+        # Let's see if we can resume and receive the message
+        cond = SUSPEND[0]
+        SUSPEND.remove(cond)
+        cond.acquire()
+        cond.notify()
+        cond.release()
+
+        rtxt = s4.recv(len(txt))
+        self.assertTrue(rtxt == txt[4:])
+              
+        # Stop forwarding         
+        TERMINATE.append(None)
+
+        txt = "0000|never received"
+        s1.send(txt)
+        
+        rtxt = "timeout"
+        try:
+            rtxt = s4.recv(len(txt))
+        except socket.timeout:
+            pass
+                    
+        self.assertTrue(rtxt == "timeout")
+        self.assertTrue(STOPPED)
+
+if __name__ == '__main__':
+    unittest.main()
+
index 9a3e6e7..63d842f 100644 (file)
@@ -38,14 +38,18 @@ def test(cipher, passphrase, plr=None, queuemodule=None):
    else:
         queueclass = None
    TERMINATE = []
+   SUSPEND = []
+
    def stopme():
        time.sleep(100)
        TERMINATE.append(None)
+
    t = threading.Thread(target=stopme)
    t.start()
-   tunchannel.tun_fwd(tun, remote, True, True, passphrase, True, TERMINATE, None, tunkqueue=500,
-        rwrite = rwrite, rread = rread, cipher=cipher, queueclass=queueclass,
-        accept_local = accept, accept_remote = accept)
+   tunchannel.tun_fwd(tun, remote, True, True, passphrase, True, TERMINATE,
+            SUSPEND, None, tunkqueue=500, rwrite = rwrite, rread = rread, 
+            cipher=cipher, queueclass=queueclass, accept_local = accept,
+            accept_remote = accept)
 
 # Swallow exceptions on decryption
 def decrypt(packet, crypter, super=tunchannel.decrypt):