Lots of cross-connection fixes, TUN synchronization, etc
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
index 01f32af..8dff5a2 100644 (file)
@@ -5,6 +5,7 @@ import fcntl
 import os
 import os.path
 import select
+import signal
 
 import struct
 import ctypes
@@ -16,9 +17,6 @@ import functools
 import time
 import base64
 
-import time
-print >>sys.stderr, time.time()
-
 tun_name = 'tun0'
 tun_path = '/dev/net/tun'
 hostaddr = socket.gethostbyname(socket.gethostname())
@@ -254,6 +252,16 @@ def pl_vif_start(tun_path, tun_name):
     if out.strip():
         print >>sys.stderr, out
 
+def pl_vif_stop(tun_path, tun_name):
+    stdin = open("/vsys/vif_down.in","w")
+    stdout = open("/vsys/vif_down.out","r")
+    stdin.write(tun_name+"\n")
+    stdin.close()
+    out = stdout.read()
+    stdout.close()
+    if out.strip():
+        print >>sys.stderr, out
+
 
 def ipfmt(ip):
     ipbytes = map(ord,ip.decode("hex"))
@@ -403,9 +411,8 @@ def decrypt(packet, crypter):
     
     return packet
 
-abortme = False
 def tun_fwd(tun, remote):
-    global abortme
+    global TERMINATE
     
     # in PL mode, we cannot strip PI structs
     # so we'll have to handle them
@@ -438,7 +445,7 @@ def tun_fwd(tun, remote):
     # Which is needed, since /dev/net/tun is unbuffered
     fwbuf = ""
     bkbuf = ""
-    while not abortme:
+    while not TERMINATE:
         wset = []
         if packetReady(bkbuf, ether_mode):
             wset.append(tun)
@@ -512,12 +519,12 @@ MODEINFO = {
                   tunopen=tunopen, tunclose=tunclose,
                   dealloc=nop,
                   start=pl_vif_start,
-                  stop=nop),
+                  stop=pl_vif_stop),
     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
                   tunopen=tunopen, tunclose=tunclose,
                   dealloc=nop,
                   start=pl_vif_start,
-                  stop=nop),
+                  stop=pl_vif_stop),
 }
     
 tun_path = options.tun_path
@@ -539,7 +546,16 @@ except:
     raise
 
 
+# Trak SIGTERM, and set global termination flag instead of dying
+TERMINATE = False
+def _finalize(sig,frame):
+    global TERMINATE
+    TERMINATE = True
+signal.signal(signal.SIGTERM, _finalize)
+
 try:
+    tcpdump = None
+    
     if options.pass_fd:
         if options.pass_fd.startswith("base64:"):
             options.pass_fd = base64.b64decode(
@@ -552,24 +568,29 @@ try:
         import passfd
         
         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
-        try:
-            sock.connect(options.pass_fd)
-        except socket.error:
-            # wait a while, retry
-            print >>sys.stderr, "Could not connect. Retrying in a sec..."
-            time.sleep(1)
+        for i in xrange(30):
+            try:
+                sock.connect(options.pass_fd)
+                break
+            except socket.error:
+                # wait a while, retry
+                print >>sys.stderr, "Could not connect. Retrying in a sec..."
+                time.sleep(1)
+        else:
             sock.connect(options.pass_fd)
         passfd.sendfd(sock, tun.fileno(), '0')
         
+        # Launch a tcpdump subprocess, to capture and dump packets,
+        # we will not be able to capture them ourselves.
+        # Make sure to catch sigterm and kill the tcpdump as well
+        tcpdump = subprocess.Popen(
+            ["tcpdump","-l","-n","-i",tun_name])
+        
         # just wait forever
         def tun_fwd(tun, remote):
-            while True:
+            while not TERMINATE:
                 time.sleep(1)
         remote = None
-        
-        import time
-        print >>sys.stderr, time.time()
-        
     elif options.udp:
         # connect to remote endpoint
         if remaining_args and not remaining_args[0].startswith('-'):
@@ -587,7 +608,16 @@ try:
         if remaining_args and not remaining_args[0].startswith('-'):
             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            rsock.connect((remaining_args[0],options.port))
+            for i in xrange(30):
+                try:
+                    rsock.connect((remaining_args[0],options.port))
+                    break
+                except socket.error:
+                    # wait a while, retry
+                    print >>sys.stderr, "Could not connect. Retrying in a sec..."
+                    time.sleep(1)
+            else:
+                rsock.connect((remaining_args[0],options.port))
         else:
             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
@@ -599,6 +629,10 @@ try:
     print >>sys.stderr, "Connected"
 
     tun_fwd(tun, remote)
+
+    if tcpdump:
+        os.kill(tcpdump.pid, signal.SIGTERM)
+        tcpdump.wait()
 finally:
     try:
         print >>sys.stderr, "Shutting down..."
@@ -608,18 +642,19 @@ finally:
     
     # tidy shutdown in every case - swallow exceptions
     try:
-        modeinfo['tunclose'](tun_path, tun_name, tun)
+        modeinfo['stop'](tun_path, tun_name)
     except:
         pass
-        
+
     try:
-        modeinfo['stop'](tun_path, tun_name)
+        modeinfo['tunclose'](tun_path, tun_name, tun)
     except:
         pass
-
+        
     try:
         modeinfo['dealloc'](tun_path, tun_name)
     except:
         pass
-
+    
+    print >>sys.stderr, "TERMINATED GRACEFULLY"