19b0897885ecba7da832175745ba906864fe41b7
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import select
7
8 import struct
9 import ctypes
10 import optparse
11 import threading
12 import subprocess
13 import re
14 import functools
15 import time
16
17 tun_name = 'tun0'
18 tun_path = '/dev/net/tun'
19 hostaddr = socket.gethostbyname(socket.gethostname())
20
21 usage = "usage: %prog [options] <remote-endpoint>"
22
23 parser = optparse.OptionParser(usage=usage)
24
25 parser.add_option(
26     "-i", "--iface", dest="tun_name", metavar="DEVICE",
27     default = "tun0",
28     help = "TUN/TAP interface to tap into")
29 parser.add_option(
30     "-d", "--tun-path", dest="tun_path", metavar="PATH",
31     default = "/dev/net/tun",
32     help = "TUN/TAP device file path or file descriptor number")
33 parser.add_option(
34     "-p", "--port", dest="port", metavar="PORT", type="int",
35     default = 15000,
36     help = "Peering TCP port to connect or listen to.")
37 parser.add_option(
38     "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
39     default = None,
40     help = "Path to a unix-domain socket to pass the TUN file descriptor to. "
41            "If given, all other connectivity options are ignored, tun_connect will "
42            "simply wait to be killed after passing the file descriptor, and it will be "
43            "the receiver's responsability to handle the tunneling.")
44
45 parser.add_option(
46     "-m", "--mode", dest="mode", metavar="MODE",
47     default = "none",
48     help = 
49         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
50         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
51         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
52         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
53 parser.add_option(
54     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
55     default = None,
56     help = 
57         "See mode. This specifies the VIF_ADDRESS, "
58         "the IP address of the virtual interface.")
59 parser.add_option(
60     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
61     default = None,
62     help = 
63         "See mode. This specifies the VIF_MASK, "
64         "a number indicating the network type (ie: 24 for a C-class network).")
65 parser.add_option(
66     "-S", "--vif-snat", dest="vif_snat", 
67     action = "store_true",
68     default = False,
69     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
70 parser.add_option(
71     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
72     default = None,
73     help = 
74         "See mode. This specifies the remote endpoint's virtual address, "
75         "for point-to-point routing configuration. "
76         "Not supported by PlanetLab" )
77 parser.add_option(
78     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
79     default = None,
80     help = 
81         "See mode. This specifies the interface's transmission queue length. " )
82 parser.add_option(
83     "-u", "--udp", dest="udp", metavar="PORT", type="int",
84     default = None,
85     help = 
86         "Bind to the specified UDP port locally, and send UDP datagrams to the "
87         "remote endpoint, creating a tunnel through UDP rather than TCP." )
88 parser.add_option(
89     "-k", "--key", dest="cipher_key", metavar="KEY",
90     default = None,
91     help = 
92         "Specify a symmetric encryption key with which to protect packets across "
93         "the tunnel. python-crypto must be installed on the system." )
94
95 (options, remaining_args) = parser.parse_args(sys.argv[1:])
96
97
98 ETH_P_ALL = 0x00000003
99 ETH_P_IP = 0x00000800
100 TUNSETIFF = 0x400454ca
101 IFF_NO_PI = 0x00001000
102 IFF_TAP = 0x00000002
103 IFF_TUN = 0x00000001
104 IFF_VNET_HDR = 0x00004000
105 TUN_PKT_STRIP = 0x00000001
106 IFHWADDRLEN = 0x00000006
107 IFNAMSIZ = 0x00000010
108 IFREQ_SZ = 0x00000028
109 FIONREAD = 0x0000541b
110
111 def ifnam(x):
112     return x+'\x00'*(IFNAMSIZ-len(x))
113
114 def ifreq(iface, flags):
115     # ifreq contains:
116     #   char[IFNAMSIZ] : interface name
117     #   short : flags
118     #   <padding>
119     ifreq = ifnam(iface)+struct.pack("H",flags);
120     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
121     return ifreq
122
123 def tunopen(tun_path, tun_name):
124     if tun_path.isdigit():
125         # open TUN fd
126         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
127         tun = os.fdopen(int(tun_path), 'r+b', 0)
128     else:
129         # open TUN path
130         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
131         tun = open(tun_path, 'r+b', 0)
132
133         # bind file descriptor to the interface
134         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
135     
136     return tun
137
138 def tunclose(tun_path, tun_name, tun):
139     if tun_path.isdigit():
140         # close TUN fd
141         os.close(int(tun_path))
142         tun.close()
143     else:
144         # close TUN object
145         tun.close()
146
147 def tuntap_alloc(kind, tun_path, tun_name):
148     args = ["tunctl"]
149     if kind == "tun":
150         args.append("-n")
151     if tun_name:
152         args.append("-t")
153         args.append(tun_name)
154     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
155     out,err = proc.communicate()
156     if proc.wait():
157         raise RuntimeError, "Could not allocate %s device" % (kind,)
158         
159     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
160     if not match:
161         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
162     
163     tun_name = match.group("dev")
164     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
165     
166     return tun_path, tun_name
167
168 def tuntap_dealloc(tun_path, tun_name):
169     args = ["tunctl", "-d", tun_name]
170     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
171     out,err = proc.communicate()
172     if proc.wait():
173         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
174
175 def nmask_to_dot_notation(mask):
176     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
177     mask = mask[2:] # strip 0x
178     mask = mask.decode("hex") # to bytes
179     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
180     return mask
181
182 def vif_start(tun_path, tun_name):
183     args = ["ifconfig", tun_name, options.vif_addr, 
184             "netmask", nmask_to_dot_notation(options.vif_mask),
185             "-arp" ]
186     if options.vif_pointopoint:
187         args.extend(["pointopoint",options.vif_pointopoint])
188     if options.vif_txqueuelen is not None:
189         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
190     args.append("up")
191     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
192     out,err = proc.communicate()
193     if proc.wait():
194         raise RuntimeError, "Error starting virtual interface"
195     
196     if options.vif_snat:
197         # set up SNAT using iptables
198         # TODO: stop vif on error. 
199         #   Not so necessary since deallocating the tun/tap device
200         #   will forcibly stop it, but it would be tidier
201         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
202                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
203                  "-j", "SNAT",
204                  "--to-source", hostaddr, "--random" ]
205         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
206         out,err = proc.communicate()
207         if proc.wait():
208             raise RuntimeError, "Error setting up SNAT"
209
210 def vif_stop(tun_path, tun_name):
211     if options.vif_snat:
212         # set up SNAT using iptables
213         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
214                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
215                  "-j", "SNAT",
216                  "--to-source", hostaddr, "--random" ]
217         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
218         out,err = proc.communicate()
219     
220     args = ["ifconfig", tun_name, "down"]
221     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
222     out,err = proc.communicate()
223     if proc.wait():
224         print >>sys.stderr, "WARNING: error stopping virtual interface"
225     
226     
227 def pl_tuntap_alloc(kind, tun_path, tun_name):
228     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
229     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
230     kind = {"tun":IFF_TUN,
231             "tap":IFF_TAP}[kind]
232     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
233     name = c_tun_name.value
234     return str(fd), name
235
236 def pl_vif_start(tun_path, tun_name):
237     stdin = open("/vsys/vif_up.in","w")
238     stdout = open("/vsys/vif_up.out","r")
239     stdin.write(tun_name+"\n")
240     stdin.write(options.vif_addr+"\n")
241     stdin.write(str(options.vif_mask)+"\n")
242     if options.vif_snat:
243         stdin.write("snat=1\n")
244     if options.vif_txqueuelen is not None:
245         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
246     stdin.close()
247     out = stdout.read()
248     stdout.close()
249     if out.strip():
250         print >>sys.stderr, out
251
252
253 def ipfmt(ip):
254     ipbytes = map(ord,ip.decode("hex"))
255     return '.'.join(map(str,ipbytes))
256
257 tagtype = {
258     '0806' : 'arp ',
259     '0800' : 'ipv4 ',
260     '8870' : 'jumbo ',
261     '8863' : 'PPPoE discover ',
262     '8864' : 'PPPoE ',
263 }
264 def etherProto(packet):
265     packet = packet.encode("hex")
266     if len(packet) > 14:
267         if packet[12:14] == "\x81\x00":
268             # tagged
269             return packet[16:18]
270         else:
271             # untagged
272             return packet[12:14]
273     # default: ip
274     return "\x08\x00"
275 def formatPacket(packet, ether_mode):
276     if ether_mode:
277         stripped_packet = etherStrip(packet)
278         if not stripped_packet:
279             packet = packet.encode("hex")
280             if len(packet) < 28:
281                 return "malformed eth " + packet.encode("hex")
282             else:
283                 if packet[24:28] == "8100":
284                     # tagged
285                     ethertype = tagtype.get(packet[32:36], 'eth')
286                     return ethertype + " " + ( '-'.join( (
287                         packet[0:12], # MAC dest
288                         packet[12:24], # MAC src
289                         packet[24:32], # VLAN tag
290                         packet[32:36], # Ethertype/len
291                         packet[36:], # Payload
292                     ) ) )
293                 else:
294                     # untagged
295                     ethertype = tagtype.get(packet[24:28], 'eth')
296                     return ethertype + " " + ( '-'.join( (
297                         packet[0:12], # MAC dest
298                         packet[12:24], # MAC src
299                         packet[24:28], # Ethertype/len
300                         packet[28:], # Payload
301                     ) ) )
302         else:
303             packet = stripped_packet
304     packet = packet.encode("hex")
305     if len(packet) < 48:
306         return "malformed ip " + packet
307     else:
308         return "ip " + ( '-'.join( (
309             packet[0:1], #version
310             packet[1:2], #header length
311             packet[2:4], #diffserv/ECN
312             packet[4:8], #total length
313             packet[8:12], #ident
314             packet[12:16], #flags/fragment offs
315             packet[16:18], #ttl
316             packet[18:20], #ip-proto
317             packet[20:24], #checksum
318             ipfmt(packet[24:32]), # src-ip
319             ipfmt(packet[32:40]), # dst-ip
320             packet[40:48] if (int(packet[1],16) > 5) else "", # options
321             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
322         ) ) )
323
324 def packetReady(buf, ether_mode):
325     if len(buf) < 4:
326         return False
327     elif ether_mode:
328         return True
329     else:
330         _,totallen = struct.unpack('HH',buf[:4])
331         totallen = socket.htons(totallen)
332         return len(buf) >= totallen
333
334 def pullPacket(buf, ether_mode):
335     if ether_mode:
336         return buf, ""
337     else:
338         _,totallen = struct.unpack('HH',buf[:4])
339         totallen = socket.htons(totallen)
340         return buf[:totallen], buf[totallen:]
341
342 def etherStrip(buf):
343     if len(buf) < 14:
344         return ""
345     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
346         # tagged ethernet frame
347         return buf[18:-4]
348     elif buf[12:14] == '\x08\x00':
349         # untagged ethernet frame
350         return buf[14:-4]
351     else:
352         return ""
353
354 def etherWrap(packet):
355     return (
356         "\x00"*6*2 # bogus src and dst mac
357         +"\x08\x00" # IPv4
358         +packet # payload
359         +"\x00"*4 # bogus crc
360     )
361
362 def piStrip(buf):
363     if len(buf) < 4:
364         return buf
365     else:
366         return buf[4:]
367     
368 def piWrap(buf, ether_mode):
369     if ether_mode:
370         proto = etherProto(buf)
371     else:
372         proto = "\x08\x00"
373     return (
374         "\x00\x00" # PI: 16 bits flags
375         +proto # 16 bits proto
376         +buf
377     )
378
379 def encrypt(packet, crypter):
380     # pad
381     padding = crypter.block_size - len(packet) % crypter.block_size
382     packet += chr(padding) * padding
383     
384     # encrypt
385     return crypter.encrypt(packet)
386
387 def decrypt(packet, crypter):
388     # decrypt
389     packet = crypter.decrypt(packet)
390     
391     # un-pad
392     padding = ord(packet[-1])
393     if not (0 < padding <= crypter.block_size):
394         # wrong padding
395         raise RuntimeError, "Truncated packet"
396     packet = packet[:-padding]
397     
398     return packet
399
400 abortme = False
401 def tun_fwd(tun, remote):
402     global abortme
403     
404     # in PL mode, we cannot strip PI structs
405     # so we'll have to handle them
406     with_pi = options.mode.startswith('pl-')
407     ether_mode = tun_name.startswith('tap')
408     
409     crypto_mode = False
410     try:
411         if options.cipher_key:
412             import Crypto.Cipher.AES
413             import hashlib
414             
415             hashed_key = hashlib.sha256(options.cipher_key).digest()
416             crypter = Crypto.Cipher.AES.new(
417                 hashed_key, 
418                 Crypto.Cipher.AES.MODE_ECB)
419             crypto_mode = True
420     except:
421         import traceback
422         traceback.print_exc()
423         crypto_mode = False
424         crypter = None
425
426     if crypto_mode:
427         print >>sys.stderr, "Packets are transmitted in CIPHER"
428     else:
429         print >>sys.stderr, "Packets are transmitted in PLAINTEXT"
430     
431     # Limited frame parsing, to preserve packet boundaries.
432     # Which is needed, since /dev/net/tun is unbuffered
433     fwbuf = ""
434     bkbuf = ""
435     while not abortme:
436         wset = []
437         if packetReady(bkbuf, ether_mode):
438             wset.append(tun)
439         if packetReady(fwbuf, ether_mode):
440             wset.append(remote)
441         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
442         
443         # check for errors
444         if errs:
445             break
446         
447         # check to see if we can write
448         if remote in wrdy and packetReady(fwbuf, ether_mode):
449             packet, fwbuf = pullPacket(fwbuf, ether_mode)
450             try:
451                 if crypto_mode:
452                     enpacket = encrypt(packet, crypter)
453                 else:
454                     enpacket = packet
455                 os.write(remote.fileno(), enpacket)
456             except:
457                 if not options.udp:
458                     # in UDP mode, we ignore errors - packet loss man...
459                     raise
460             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
461         if tun in wrdy and packetReady(bkbuf, ether_mode):
462             packet, bkbuf = pullPacket(bkbuf, ether_mode)
463             formatted = formatPacket(packet, ether_mode)
464             if with_pi:
465                 packet = piWrap(packet, ether_mode)
466             os.write(tun.fileno(), packet)
467             print >>sys.stderr, '<', formatted
468         
469         # check incoming data packets
470         if tun in rdrdy:
471             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
472             if with_pi:
473                 packet = piStrip(packet)
474             fwbuf += packet
475         if remote in rdrdy:
476             try:
477                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
478                 if crypto_mode:
479                     packet = decrypt(packet, crypter)
480             except:
481                 if not options.udp:
482                     # in UDP mode, we ignore errors - packet loss man...
483                     raise
484             bkbuf += packet
485
486
487
488 nop = lambda tun_path, tun_name : (tun_path, tun_name)
489 MODEINFO = {
490     'none' : dict(alloc=nop,
491                   tunopen=tunopen, tunclose=tunclose,
492                   dealloc=nop,
493                   start=nop,
494                   stop=nop),
495     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
496                   tunopen=tunopen, tunclose=tunclose,
497                   dealloc=tuntap_dealloc,
498                   start=vif_start,
499                   stop=vif_stop),
500     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
501                   tunopen=tunopen, tunclose=tunclose,
502                   dealloc=tuntap_dealloc,
503                   start=vif_start,
504                   stop=vif_stop),
505     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
506                   tunopen=tunopen, tunclose=tunclose,
507                   dealloc=nop,
508                   start=pl_vif_start,
509                   stop=nop),
510     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
511                   tunopen=tunopen, tunclose=tunclose,
512                   dealloc=nop,
513                   start=pl_vif_start,
514                   stop=nop),
515 }
516     
517 tun_path = options.tun_path
518 tun_name = options.tun_name
519
520 modeinfo = MODEINFO[options.mode]
521
522 # be careful to roll back stuff on exceptions
523 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
524 try:
525     modeinfo['start'](tun_path, tun_name)
526     try:
527         tun = modeinfo['tunopen'](tun_path, tun_name)
528     except:
529         modeinfo['stop'](tun_path, tun_name)
530         raise
531 except:
532     modeinfo['dealloc'](tun_path, tun_name)
533     raise
534
535
536 try:
537     if options.pass_fd:
538         # send FD to whoever wants it
539         import passfd
540         
541         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
542         sock.connect(options.pass_fd)
543         passfd.sendfd(sock, tun.fileno(), '0')
544         
545         # just wait forever
546         def tun_fwd(tun, remote):
547             while True:
548                 time.sleep(1)
549     elif options.udp:
550         # connect to remote endpoint
551         if remaining_args and not remaining_args[0].startswith('-'):
552             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
553             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
554             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
555             rsock.bind((hostaddr,options.udp))
556             rsock.connect((remaining_args[0],options.port))
557         else:
558             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
559             raise AssertionError, "Error: need a remote endpoint in UDP mode"
560         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
561     else:
562         # connect to remote endpoint
563         if remaining_args and not remaining_args[0].startswith('-'):
564             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
565             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
566             rsock.connect((remaining_args[0],options.port))
567         else:
568             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
569             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
570             lsock.bind((hostaddr,options.port))
571             lsock.listen(1)
572             rsock,raddr = lsock.accept()
573         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
574
575     print >>sys.stderr, "Connected"
576
577     tun_fwd(tun, remote)
578 finally:
579     try:
580         print >>sys.stderr, "Shutting down..."
581     except:
582         # In case sys.stderr is broken
583         pass
584     
585     # tidy shutdown in every case - swallow exceptions
586     try:
587         modeinfo['tunclose'](tun_path, tun_name, tun)
588     except:
589         pass
590         
591     try:
592         modeinfo['stop'](tun_path, tun_name)
593     except:
594         pass
595
596     try:
597         modeinfo['dealloc'](tun_path, tun_name)
598     except:
599         pass
600
601