Capture packets on FD-based tuns, by spawning a tcpdump in the background
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import os.path
7 import select
8 import signal
9
10 import struct
11 import ctypes
12 import optparse
13 import threading
14 import subprocess
15 import re
16 import functools
17 import time
18 import base64
19
20 tun_name = 'tun0'
21 tun_path = '/dev/net/tun'
22 hostaddr = socket.gethostbyname(socket.gethostname())
23
24 usage = "usage: %prog [options] <remote-endpoint>"
25
26 parser = optparse.OptionParser(usage=usage)
27
28 parser.add_option(
29     "-i", "--iface", dest="tun_name", metavar="DEVICE",
30     default = "tun0",
31     help = "TUN/TAP interface to tap into")
32 parser.add_option(
33     "-d", "--tun-path", dest="tun_path", metavar="PATH",
34     default = "/dev/net/tun",
35     help = "TUN/TAP device file path or file descriptor number")
36 parser.add_option(
37     "-p", "--port", dest="port", metavar="PORT", type="int",
38     default = 15000,
39     help = "Peering TCP port to connect or listen to.")
40 parser.add_option(
41     "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
42     default = None,
43     help = "Path to a unix-domain socket to pass the TUN file descriptor to. "
44            "If given, all other connectivity options are ignored, tun_connect will "
45            "simply wait to be killed after passing the file descriptor, and it will be "
46            "the receiver's responsability to handle the tunneling.")
47
48 parser.add_option(
49     "-m", "--mode", dest="mode", metavar="MODE",
50     default = "none",
51     help = 
52         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
53         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
54         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
55         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
56 parser.add_option(
57     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
58     default = None,
59     help = 
60         "See mode. This specifies the VIF_ADDRESS, "
61         "the IP address of the virtual interface.")
62 parser.add_option(
63     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
64     default = None,
65     help = 
66         "See mode. This specifies the VIF_MASK, "
67         "a number indicating the network type (ie: 24 for a C-class network).")
68 parser.add_option(
69     "-S", "--vif-snat", dest="vif_snat", 
70     action = "store_true",
71     default = False,
72     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
73 parser.add_option(
74     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
75     default = None,
76     help = 
77         "See mode. This specifies the remote endpoint's virtual address, "
78         "for point-to-point routing configuration. "
79         "Not supported by PlanetLab" )
80 parser.add_option(
81     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
82     default = None,
83     help = 
84         "See mode. This specifies the interface's transmission queue length. " )
85 parser.add_option(
86     "-u", "--udp", dest="udp", metavar="PORT", type="int",
87     default = None,
88     help = 
89         "Bind to the specified UDP port locally, and send UDP datagrams to the "
90         "remote endpoint, creating a tunnel through UDP rather than TCP." )
91 parser.add_option(
92     "-k", "--key", dest="cipher_key", metavar="KEY",
93     default = None,
94     help = 
95         "Specify a symmetric encryption key with which to protect packets across "
96         "the tunnel. python-crypto must be installed on the system." )
97
98 (options, remaining_args) = parser.parse_args(sys.argv[1:])
99
100
101 ETH_P_ALL = 0x00000003
102 ETH_P_IP = 0x00000800
103 TUNSETIFF = 0x400454ca
104 IFF_NO_PI = 0x00001000
105 IFF_TAP = 0x00000002
106 IFF_TUN = 0x00000001
107 IFF_VNET_HDR = 0x00004000
108 TUN_PKT_STRIP = 0x00000001
109 IFHWADDRLEN = 0x00000006
110 IFNAMSIZ = 0x00000010
111 IFREQ_SZ = 0x00000028
112 FIONREAD = 0x0000541b
113
114 def ifnam(x):
115     return x+'\x00'*(IFNAMSIZ-len(x))
116
117 def ifreq(iface, flags):
118     # ifreq contains:
119     #   char[IFNAMSIZ] : interface name
120     #   short : flags
121     #   <padding>
122     ifreq = ifnam(iface)+struct.pack("H",flags);
123     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
124     return ifreq
125
126 def tunopen(tun_path, tun_name):
127     if tun_path.isdigit():
128         # open TUN fd
129         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
130         tun = os.fdopen(int(tun_path), 'r+b', 0)
131     else:
132         # open TUN path
133         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
134         tun = open(tun_path, 'r+b', 0)
135
136         # bind file descriptor to the interface
137         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
138     
139     return tun
140
141 def tunclose(tun_path, tun_name, tun):
142     if tun_path.isdigit():
143         # close TUN fd
144         os.close(int(tun_path))
145         tun.close()
146     else:
147         # close TUN object
148         tun.close()
149
150 def tuntap_alloc(kind, tun_path, tun_name):
151     args = ["tunctl"]
152     if kind == "tun":
153         args.append("-n")
154     if tun_name:
155         args.append("-t")
156         args.append(tun_name)
157     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
158     out,err = proc.communicate()
159     if proc.wait():
160         raise RuntimeError, "Could not allocate %s device" % (kind,)
161         
162     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
163     if not match:
164         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
165     
166     tun_name = match.group("dev")
167     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
168     
169     return tun_path, tun_name
170
171 def tuntap_dealloc(tun_path, tun_name):
172     args = ["tunctl", "-d", tun_name]
173     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
174     out,err = proc.communicate()
175     if proc.wait():
176         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
177
178 def nmask_to_dot_notation(mask):
179     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
180     mask = mask[2:] # strip 0x
181     mask = mask.decode("hex") # to bytes
182     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
183     return mask
184
185 def vif_start(tun_path, tun_name):
186     args = ["ifconfig", tun_name, options.vif_addr, 
187             "netmask", nmask_to_dot_notation(options.vif_mask),
188             "-arp" ]
189     if options.vif_pointopoint:
190         args.extend(["pointopoint",options.vif_pointopoint])
191     if options.vif_txqueuelen is not None:
192         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
193     args.append("up")
194     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
195     out,err = proc.communicate()
196     if proc.wait():
197         raise RuntimeError, "Error starting virtual interface"
198     
199     if options.vif_snat:
200         # set up SNAT using iptables
201         # TODO: stop vif on error. 
202         #   Not so necessary since deallocating the tun/tap device
203         #   will forcibly stop it, but it would be tidier
204         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
205                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
206                  "-j", "SNAT",
207                  "--to-source", hostaddr, "--random" ]
208         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
209         out,err = proc.communicate()
210         if proc.wait():
211             raise RuntimeError, "Error setting up SNAT"
212
213 def vif_stop(tun_path, tun_name):
214     if options.vif_snat:
215         # set up SNAT using iptables
216         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
217                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
218                  "-j", "SNAT",
219                  "--to-source", hostaddr, "--random" ]
220         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
221         out,err = proc.communicate()
222     
223     args = ["ifconfig", tun_name, "down"]
224     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
225     out,err = proc.communicate()
226     if proc.wait():
227         print >>sys.stderr, "WARNING: error stopping virtual interface"
228     
229     
230 def pl_tuntap_alloc(kind, tun_path, tun_name):
231     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
232     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
233     kind = {"tun":IFF_TUN,
234             "tap":IFF_TAP}[kind]
235     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
236     name = c_tun_name.value
237     return str(fd), name
238
239 def pl_vif_start(tun_path, tun_name):
240     stdin = open("/vsys/vif_up.in","w")
241     stdout = open("/vsys/vif_up.out","r")
242     stdin.write(tun_name+"\n")
243     stdin.write(options.vif_addr+"\n")
244     stdin.write(str(options.vif_mask)+"\n")
245     if options.vif_snat:
246         stdin.write("snat=1\n")
247     if options.vif_txqueuelen is not None:
248         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
249     stdin.close()
250     out = stdout.read()
251     stdout.close()
252     if out.strip():
253         print >>sys.stderr, out
254
255
256 def ipfmt(ip):
257     ipbytes = map(ord,ip.decode("hex"))
258     return '.'.join(map(str,ipbytes))
259
260 tagtype = {
261     '0806' : 'arp',
262     '0800' : 'ipv4',
263     '8870' : 'jumbo',
264     '8863' : 'PPPoE discover',
265     '8864' : 'PPPoE',
266     '86dd' : 'ipv6',
267 }
268 def etherProto(packet):
269     packet = packet.encode("hex")
270     if len(packet) > 14:
271         if packet[12:14] == "\x81\x00":
272             # tagged
273             return packet[16:18]
274         else:
275             # untagged
276             return packet[12:14]
277     # default: ip
278     return "\x08\x00"
279 def formatPacket(packet, ether_mode):
280     if ether_mode:
281         stripped_packet = etherStrip(packet)
282         if not stripped_packet:
283             packet = packet.encode("hex")
284             if len(packet) < 28:
285                 return "malformed eth " + packet.encode("hex")
286             else:
287                 if packet[24:28] == "8100":
288                     # tagged
289                     ethertype = tagtype.get(packet[32:36], 'eth')
290                     return ethertype + " " + ( '-'.join( (
291                         packet[0:12], # MAC dest
292                         packet[12:24], # MAC src
293                         packet[24:32], # VLAN tag
294                         packet[32:36], # Ethertype/len
295                         packet[36:], # Payload
296                     ) ) )
297                 else:
298                     # untagged
299                     ethertype = tagtype.get(packet[24:28], 'eth')
300                     return ethertype + " " + ( '-'.join( (
301                         packet[0:12], # MAC dest
302                         packet[12:24], # MAC src
303                         packet[24:28], # Ethertype/len
304                         packet[28:], # Payload
305                     ) ) )
306         else:
307             packet = stripped_packet
308     packet = packet.encode("hex")
309     if len(packet) < 48:
310         return "malformed ip " + packet
311     else:
312         return "ip " + ( '-'.join( (
313             packet[0:1], #version
314             packet[1:2], #header length
315             packet[2:4], #diffserv/ECN
316             packet[4:8], #total length
317             packet[8:12], #ident
318             packet[12:16], #flags/fragment offs
319             packet[16:18], #ttl
320             packet[18:20], #ip-proto
321             packet[20:24], #checksum
322             ipfmt(packet[24:32]), # src-ip
323             ipfmt(packet[32:40]), # dst-ip
324             packet[40:48] if (int(packet[1],16) > 5) else "", # options
325             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
326         ) ) )
327
328 def packetReady(buf, ether_mode):
329     if len(buf) < 4:
330         return False
331     elif ether_mode:
332         return True
333     else:
334         _,totallen = struct.unpack('HH',buf[:4])
335         totallen = socket.htons(totallen)
336         return len(buf) >= totallen
337
338 def pullPacket(buf, ether_mode):
339     if ether_mode:
340         return buf, ""
341     else:
342         _,totallen = struct.unpack('HH',buf[:4])
343         totallen = socket.htons(totallen)
344         return buf[:totallen], buf[totallen:]
345
346 def etherStrip(buf):
347     if len(buf) < 14:
348         return ""
349     if buf[12:14] == '\x08\x10' and buf[16:18] in '\x08\x00':
350         # tagged ethernet frame
351         return buf[18:]
352     elif buf[12:14] == '\x08\x00':
353         # untagged ethernet frame
354         return buf[14:]
355     else:
356         return ""
357
358 def etherWrap(packet):
359     return (
360         "\x00"*6*2 # bogus src and dst mac
361         +"\x08\x00" # IPv4
362         +packet # payload
363         +"\x00"*4 # bogus crc
364     )
365
366 def piStrip(buf):
367     if len(buf) < 4:
368         return buf
369     else:
370         return buf[4:]
371     
372 def piWrap(buf, ether_mode):
373     if ether_mode:
374         proto = etherProto(buf)
375     else:
376         proto = "\x08\x00"
377     return (
378         "\x00\x00" # PI: 16 bits flags
379         +proto # 16 bits proto
380         +buf
381     )
382
383 def encrypt(packet, crypter):
384     # pad
385     padding = crypter.block_size - len(packet) % crypter.block_size
386     packet += chr(padding) * padding
387     
388     # encrypt
389     return crypter.encrypt(packet)
390
391 def decrypt(packet, crypter):
392     # decrypt
393     packet = crypter.decrypt(packet)
394     
395     # un-pad
396     padding = ord(packet[-1])
397     if not (0 < padding <= crypter.block_size):
398         # wrong padding
399         raise RuntimeError, "Truncated packet"
400     packet = packet[:-padding]
401     
402     return packet
403
404 abortme = False
405 def tun_fwd(tun, remote):
406     global abortme
407     
408     # in PL mode, we cannot strip PI structs
409     # so we'll have to handle them
410     with_pi = options.mode.startswith('pl-')
411     ether_mode = tun_name.startswith('tap')
412     
413     crypto_mode = False
414     try:
415         if options.cipher_key:
416             import Crypto.Cipher.AES
417             import hashlib
418             
419             hashed_key = hashlib.sha256(options.cipher_key).digest()
420             crypter = Crypto.Cipher.AES.new(
421                 hashed_key, 
422                 Crypto.Cipher.AES.MODE_ECB)
423             crypto_mode = True
424     except:
425         import traceback
426         traceback.print_exc()
427         crypto_mode = False
428         crypter = None
429
430     if crypto_mode:
431         print >>sys.stderr, "Packets are transmitted in CIPHER"
432     else:
433         print >>sys.stderr, "Packets are transmitted in PLAINTEXT"
434     
435     # Limited frame parsing, to preserve packet boundaries.
436     # Which is needed, since /dev/net/tun is unbuffered
437     fwbuf = ""
438     bkbuf = ""
439     while not abortme:
440         wset = []
441         if packetReady(bkbuf, ether_mode):
442             wset.append(tun)
443         if packetReady(fwbuf, ether_mode):
444             wset.append(remote)
445         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
446         
447         # check for errors
448         if errs:
449             break
450         
451         # check to see if we can write
452         if remote in wrdy and packetReady(fwbuf, ether_mode):
453             packet, fwbuf = pullPacket(fwbuf, ether_mode)
454             try:
455                 if crypto_mode:
456                     enpacket = encrypt(packet, crypter)
457                 else:
458                     enpacket = packet
459                 os.write(remote.fileno(), enpacket)
460             except:
461                 if not options.udp:
462                     # in UDP mode, we ignore errors - packet loss man...
463                     raise
464             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
465         if tun in wrdy and packetReady(bkbuf, ether_mode):
466             packet, bkbuf = pullPacket(bkbuf, ether_mode)
467             formatted = formatPacket(packet, ether_mode)
468             if with_pi:
469                 packet = piWrap(packet, ether_mode)
470             os.write(tun.fileno(), packet)
471             print >>sys.stderr, '<', formatted
472         
473         # check incoming data packets
474         if tun in rdrdy:
475             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
476             if with_pi:
477                 packet = piStrip(packet)
478             fwbuf += packet
479         if remote in rdrdy:
480             try:
481                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
482                 if crypto_mode:
483                     packet = decrypt(packet, crypter)
484             except:
485                 if not options.udp:
486                     # in UDP mode, we ignore errors - packet loss man...
487                     raise
488             bkbuf += packet
489
490
491
492 nop = lambda tun_path, tun_name : (tun_path, tun_name)
493 MODEINFO = {
494     'none' : dict(alloc=nop,
495                   tunopen=tunopen, tunclose=tunclose,
496                   dealloc=nop,
497                   start=nop,
498                   stop=nop),
499     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
500                   tunopen=tunopen, tunclose=tunclose,
501                   dealloc=tuntap_dealloc,
502                   start=vif_start,
503                   stop=vif_stop),
504     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
505                   tunopen=tunopen, tunclose=tunclose,
506                   dealloc=tuntap_dealloc,
507                   start=vif_start,
508                   stop=vif_stop),
509     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
510                   tunopen=tunopen, tunclose=tunclose,
511                   dealloc=nop,
512                   start=pl_vif_start,
513                   stop=nop),
514     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
515                   tunopen=tunopen, tunclose=tunclose,
516                   dealloc=nop,
517                   start=pl_vif_start,
518                   stop=nop),
519 }
520     
521 tun_path = options.tun_path
522 tun_name = options.tun_name
523
524 modeinfo = MODEINFO[options.mode]
525
526 # be careful to roll back stuff on exceptions
527 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
528 try:
529     modeinfo['start'](tun_path, tun_name)
530     try:
531         tun = modeinfo['tunopen'](tun_path, tun_name)
532     except:
533         modeinfo['stop'](tun_path, tun_name)
534         raise
535 except:
536     modeinfo['dealloc'](tun_path, tun_name)
537     raise
538
539
540 try:
541     tcpdump = None
542     
543     if options.pass_fd:
544         if options.pass_fd.startswith("base64:"):
545             options.pass_fd = base64.b64decode(
546                 options.pass_fd[len("base64:"):])
547             options.pass_fd = os.path.expandvars(options.pass_fd)
548         
549         print >>sys.stderr, "Sending FD to: %r" % (options.pass_fd,)
550         
551         # send FD to whoever wants it
552         import passfd
553         
554         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
555         try:
556             sock.connect(options.pass_fd)
557         except socket.error:
558             # wait a while, retry
559             print >>sys.stderr, "Could not connect. Retrying in a sec..."
560             time.sleep(1)
561             sock.connect(options.pass_fd)
562         passfd.sendfd(sock, tun.fileno(), '0')
563         
564         # Launch a tcpdump subprocess, to capture and dump packets,
565         # we will not be able to capture them ourselves.
566         # Make sure to catch sigterm and kill the tcpdump as well
567         tcpdump = subprocess.Popen(
568             ["tcpdump","-l","-n","-i",tun_name])
569         
570         def _finalize(sig,frame):
571             os.kill(tcpdump.pid, signal.SIGTERM)
572             tcpdump.wait()
573             if callable(_oldterm):
574                 _oldterm(sig,frame)
575             else:
576                 sys.exit(0)
577         _oldterm = signal.signal(signal.SIGTERM, _finalize)
578             
579         # just wait forever
580         def tun_fwd(tun, remote):
581             while True:
582                 time.sleep(1)
583         remote = None
584     elif options.udp:
585         # connect to remote endpoint
586         if remaining_args and not remaining_args[0].startswith('-'):
587             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
588             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
589             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
590             rsock.bind((hostaddr,options.udp))
591             rsock.connect((remaining_args[0],options.port))
592         else:
593             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
594             raise AssertionError, "Error: need a remote endpoint in UDP mode"
595         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
596     else:
597         # connect to remote endpoint
598         if remaining_args and not remaining_args[0].startswith('-'):
599             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
600             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
601             rsock.connect((remaining_args[0],options.port))
602         else:
603             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
604             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
605             lsock.bind((hostaddr,options.port))
606             lsock.listen(1)
607             rsock,raddr = lsock.accept()
608         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
609
610     print >>sys.stderr, "Connected"
611
612     tun_fwd(tun, remote)
613
614     if tcpdump:
615         os.kill(tcpdump.pid, signal.SIGTERM)
616         proc.wait()
617 finally:
618     try:
619         print >>sys.stderr, "Shutting down..."
620     except:
621         # In case sys.stderr is broken
622         pass
623     
624     # tidy shutdown in every case - swallow exceptions
625     try:
626         modeinfo['tunclose'](tun_path, tun_name, tun)
627     except:
628         pass
629         
630     try:
631         modeinfo['stop'](tun_path, tun_name)
632     except:
633         pass
634
635     try:
636         modeinfo['dealloc'](tun_path, tun_name)
637     except:
638         pass
639
640