9918fc31c3726d630a4999181c44c88e0b86ba80
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import os.path
7 import select
8
9 import struct
10 import ctypes
11 import optparse
12 import threading
13 import subprocess
14 import re
15 import functools
16 import time
17 import base64
18
19 tun_name = 'tun0'
20 tun_path = '/dev/net/tun'
21 hostaddr = socket.gethostbyname(socket.gethostname())
22
23 usage = "usage: %prog [options] <remote-endpoint>"
24
25 parser = optparse.OptionParser(usage=usage)
26
27 parser.add_option(
28     "-i", "--iface", dest="tun_name", metavar="DEVICE",
29     default = "tun0",
30     help = "TUN/TAP interface to tap into")
31 parser.add_option(
32     "-d", "--tun-path", dest="tun_path", metavar="PATH",
33     default = "/dev/net/tun",
34     help = "TUN/TAP device file path or file descriptor number")
35 parser.add_option(
36     "-p", "--port", dest="port", metavar="PORT", type="int",
37     default = 15000,
38     help = "Peering TCP port to connect or listen to.")
39 parser.add_option(
40     "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
41     default = None,
42     help = "Path to a unix-domain socket to pass the TUN file descriptor to. "
43            "If given, all other connectivity options are ignored, tun_connect will "
44            "simply wait to be killed after passing the file descriptor, and it will be "
45            "the receiver's responsability to handle the tunneling.")
46
47 parser.add_option(
48     "-m", "--mode", dest="mode", metavar="MODE",
49     default = "none",
50     help = 
51         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
52         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
53         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
54         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
55 parser.add_option(
56     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
57     default = None,
58     help = 
59         "See mode. This specifies the VIF_ADDRESS, "
60         "the IP address of the virtual interface.")
61 parser.add_option(
62     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
63     default = None,
64     help = 
65         "See mode. This specifies the VIF_MASK, "
66         "a number indicating the network type (ie: 24 for a C-class network).")
67 parser.add_option(
68     "-S", "--vif-snat", dest="vif_snat", 
69     action = "store_true",
70     default = False,
71     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
72 parser.add_option(
73     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
74     default = None,
75     help = 
76         "See mode. This specifies the remote endpoint's virtual address, "
77         "for point-to-point routing configuration. "
78         "Not supported by PlanetLab" )
79 parser.add_option(
80     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
81     default = None,
82     help = 
83         "See mode. This specifies the interface's transmission queue length. " )
84 parser.add_option(
85     "-u", "--udp", dest="udp", metavar="PORT", type="int",
86     default = None,
87     help = 
88         "Bind to the specified UDP port locally, and send UDP datagrams to the "
89         "remote endpoint, creating a tunnel through UDP rather than TCP." )
90 parser.add_option(
91     "-k", "--key", dest="cipher_key", metavar="KEY",
92     default = None,
93     help = 
94         "Specify a symmetric encryption key with which to protect packets across "
95         "the tunnel. python-crypto must be installed on the system." )
96
97 (options, remaining_args) = parser.parse_args(sys.argv[1:])
98
99
100 ETH_P_ALL = 0x00000003
101 ETH_P_IP = 0x00000800
102 TUNSETIFF = 0x400454ca
103 IFF_NO_PI = 0x00001000
104 IFF_TAP = 0x00000002
105 IFF_TUN = 0x00000001
106 IFF_VNET_HDR = 0x00004000
107 TUN_PKT_STRIP = 0x00000001
108 IFHWADDRLEN = 0x00000006
109 IFNAMSIZ = 0x00000010
110 IFREQ_SZ = 0x00000028
111 FIONREAD = 0x0000541b
112
113 def ifnam(x):
114     return x+'\x00'*(IFNAMSIZ-len(x))
115
116 def ifreq(iface, flags):
117     # ifreq contains:
118     #   char[IFNAMSIZ] : interface name
119     #   short : flags
120     #   <padding>
121     ifreq = ifnam(iface)+struct.pack("H",flags);
122     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
123     return ifreq
124
125 def tunopen(tun_path, tun_name):
126     if tun_path.isdigit():
127         # open TUN fd
128         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
129         tun = os.fdopen(int(tun_path), 'r+b', 0)
130     else:
131         # open TUN path
132         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
133         tun = open(tun_path, 'r+b', 0)
134
135         # bind file descriptor to the interface
136         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
137     
138     return tun
139
140 def tunclose(tun_path, tun_name, tun):
141     if tun_path.isdigit():
142         # close TUN fd
143         os.close(int(tun_path))
144         tun.close()
145     else:
146         # close TUN object
147         tun.close()
148
149 def tuntap_alloc(kind, tun_path, tun_name):
150     args = ["tunctl"]
151     if kind == "tun":
152         args.append("-n")
153     if tun_name:
154         args.append("-t")
155         args.append(tun_name)
156     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
157     out,err = proc.communicate()
158     if proc.wait():
159         raise RuntimeError, "Could not allocate %s device" % (kind,)
160         
161     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
162     if not match:
163         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
164     
165     tun_name = match.group("dev")
166     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
167     
168     return tun_path, tun_name
169
170 def tuntap_dealloc(tun_path, tun_name):
171     args = ["tunctl", "-d", tun_name]
172     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
173     out,err = proc.communicate()
174     if proc.wait():
175         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
176
177 def nmask_to_dot_notation(mask):
178     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
179     mask = mask[2:] # strip 0x
180     mask = mask.decode("hex") # to bytes
181     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
182     return mask
183
184 def vif_start(tun_path, tun_name):
185     args = ["ifconfig", tun_name, options.vif_addr, 
186             "netmask", nmask_to_dot_notation(options.vif_mask),
187             "-arp" ]
188     if options.vif_pointopoint:
189         args.extend(["pointopoint",options.vif_pointopoint])
190     if options.vif_txqueuelen is not None:
191         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
192     args.append("up")
193     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
194     out,err = proc.communicate()
195     if proc.wait():
196         raise RuntimeError, "Error starting virtual interface"
197     
198     if options.vif_snat:
199         # set up SNAT using iptables
200         # TODO: stop vif on error. 
201         #   Not so necessary since deallocating the tun/tap device
202         #   will forcibly stop it, but it would be tidier
203         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
204                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
205                  "-j", "SNAT",
206                  "--to-source", hostaddr, "--random" ]
207         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
208         out,err = proc.communicate()
209         if proc.wait():
210             raise RuntimeError, "Error setting up SNAT"
211
212 def vif_stop(tun_path, tun_name):
213     if options.vif_snat:
214         # set up SNAT using iptables
215         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
216                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
217                  "-j", "SNAT",
218                  "--to-source", hostaddr, "--random" ]
219         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
220         out,err = proc.communicate()
221     
222     args = ["ifconfig", tun_name, "down"]
223     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
224     out,err = proc.communicate()
225     if proc.wait():
226         print >>sys.stderr, "WARNING: error stopping virtual interface"
227     
228     
229 def pl_tuntap_alloc(kind, tun_path, tun_name):
230     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
231     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
232     kind = {"tun":IFF_TUN,
233             "tap":IFF_TAP}[kind]
234     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
235     name = c_tun_name.value
236     return str(fd), name
237
238 def pl_vif_start(tun_path, tun_name):
239     stdin = open("/vsys/vif_up.in","w")
240     stdout = open("/vsys/vif_up.out","r")
241     stdin.write(tun_name+"\n")
242     stdin.write(options.vif_addr+"\n")
243     stdin.write(str(options.vif_mask)+"\n")
244     if options.vif_snat:
245         stdin.write("snat=1\n")
246     if options.vif_txqueuelen is not None:
247         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
248     stdin.close()
249     out = stdout.read()
250     stdout.close()
251     if out.strip():
252         print >>sys.stderr, out
253
254
255 def ipfmt(ip):
256     ipbytes = map(ord,ip.decode("hex"))
257     return '.'.join(map(str,ipbytes))
258
259 tagtype = {
260     '0806' : 'arp',
261     '0800' : 'ipv4',
262     '8870' : 'jumbo',
263     '8863' : 'PPPoE discover',
264     '8864' : 'PPPoE',
265     '86dd' : 'ipv6',
266 }
267 def etherProto(packet):
268     packet = packet.encode("hex")
269     if len(packet) > 14:
270         if packet[12:14] == "\x81\x00":
271             # tagged
272             return packet[16:18]
273         else:
274             # untagged
275             return packet[12:14]
276     # default: ip
277     return "\x08\x00"
278 def formatPacket(packet, ether_mode):
279     if ether_mode:
280         stripped_packet = etherStrip(packet)
281         if not stripped_packet:
282             packet = packet.encode("hex")
283             if len(packet) < 28:
284                 return "malformed eth " + packet.encode("hex")
285             else:
286                 if packet[24:28] == "8100":
287                     # tagged
288                     ethertype = tagtype.get(packet[32:36], 'eth')
289                     return ethertype + " " + ( '-'.join( (
290                         packet[0:12], # MAC dest
291                         packet[12:24], # MAC src
292                         packet[24:32], # VLAN tag
293                         packet[32:36], # Ethertype/len
294                         packet[36:], # Payload
295                     ) ) )
296                 else:
297                     # untagged
298                     ethertype = tagtype.get(packet[24:28], 'eth')
299                     return ethertype + " " + ( '-'.join( (
300                         packet[0:12], # MAC dest
301                         packet[12:24], # MAC src
302                         packet[24:28], # Ethertype/len
303                         packet[28:], # Payload
304                     ) ) )
305         else:
306             packet = stripped_packet
307     packet = packet.encode("hex")
308     if len(packet) < 48:
309         return "malformed ip " + packet
310     else:
311         return "ip " + ( '-'.join( (
312             packet[0:1], #version
313             packet[1:2], #header length
314             packet[2:4], #diffserv/ECN
315             packet[4:8], #total length
316             packet[8:12], #ident
317             packet[12:16], #flags/fragment offs
318             packet[16:18], #ttl
319             packet[18:20], #ip-proto
320             packet[20:24], #checksum
321             ipfmt(packet[24:32]), # src-ip
322             ipfmt(packet[32:40]), # dst-ip
323             packet[40:48] if (int(packet[1],16) > 5) else "", # options
324             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
325         ) ) )
326
327 def packetReady(buf, ether_mode):
328     if len(buf) < 4:
329         return False
330     elif ether_mode:
331         return True
332     else:
333         _,totallen = struct.unpack('HH',buf[:4])
334         totallen = socket.htons(totallen)
335         return len(buf) >= totallen
336
337 def pullPacket(buf, ether_mode):
338     if ether_mode:
339         return buf, ""
340     else:
341         _,totallen = struct.unpack('HH',buf[:4])
342         totallen = socket.htons(totallen)
343         return buf[:totallen], buf[totallen:]
344
345 def etherStrip(buf):
346     if len(buf) < 14:
347         return ""
348     if buf[12:14] == '\x08\x10' and buf[16:18] in '\x08\x00':
349         # tagged ethernet frame
350         return buf[18:]
351     elif buf[12:14] == '\x08\x00':
352         # untagged ethernet frame
353         return buf[14:]
354     else:
355         return ""
356
357 def etherWrap(packet):
358     return (
359         "\x00"*6*2 # bogus src and dst mac
360         +"\x08\x00" # IPv4
361         +packet # payload
362         +"\x00"*4 # bogus crc
363     )
364
365 def piStrip(buf):
366     if len(buf) < 4:
367         return buf
368     else:
369         return buf[4:]
370     
371 def piWrap(buf, ether_mode):
372     if ether_mode:
373         proto = etherProto(buf)
374     else:
375         proto = "\x08\x00"
376     return (
377         "\x00\x00" # PI: 16 bits flags
378         +proto # 16 bits proto
379         +buf
380     )
381
382 def encrypt(packet, crypter):
383     # pad
384     padding = crypter.block_size - len(packet) % crypter.block_size
385     packet += chr(padding) * padding
386     
387     # encrypt
388     return crypter.encrypt(packet)
389
390 def decrypt(packet, crypter):
391     # decrypt
392     packet = crypter.decrypt(packet)
393     
394     # un-pad
395     padding = ord(packet[-1])
396     if not (0 < padding <= crypter.block_size):
397         # wrong padding
398         raise RuntimeError, "Truncated packet"
399     packet = packet[:-padding]
400     
401     return packet
402
403 abortme = False
404 def tun_fwd(tun, remote):
405     global abortme
406     
407     # in PL mode, we cannot strip PI structs
408     # so we'll have to handle them
409     with_pi = options.mode.startswith('pl-')
410     ether_mode = tun_name.startswith('tap')
411     
412     crypto_mode = False
413     try:
414         if options.cipher_key:
415             import Crypto.Cipher.AES
416             import hashlib
417             
418             hashed_key = hashlib.sha256(options.cipher_key).digest()
419             crypter = Crypto.Cipher.AES.new(
420                 hashed_key, 
421                 Crypto.Cipher.AES.MODE_ECB)
422             crypto_mode = True
423     except:
424         import traceback
425         traceback.print_exc()
426         crypto_mode = False
427         crypter = None
428
429     if crypto_mode:
430         print >>sys.stderr, "Packets are transmitted in CIPHER"
431     else:
432         print >>sys.stderr, "Packets are transmitted in PLAINTEXT"
433     
434     # Limited frame parsing, to preserve packet boundaries.
435     # Which is needed, since /dev/net/tun is unbuffered
436     fwbuf = ""
437     bkbuf = ""
438     while not abortme:
439         wset = []
440         if packetReady(bkbuf, ether_mode):
441             wset.append(tun)
442         if packetReady(fwbuf, ether_mode):
443             wset.append(remote)
444         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
445         
446         # check for errors
447         if errs:
448             break
449         
450         # check to see if we can write
451         if remote in wrdy and packetReady(fwbuf, ether_mode):
452             packet, fwbuf = pullPacket(fwbuf, ether_mode)
453             try:
454                 if crypto_mode:
455                     enpacket = encrypt(packet, crypter)
456                 else:
457                     enpacket = packet
458                 os.write(remote.fileno(), enpacket)
459             except:
460                 if not options.udp:
461                     # in UDP mode, we ignore errors - packet loss man...
462                     raise
463             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
464         if tun in wrdy and packetReady(bkbuf, ether_mode):
465             packet, bkbuf = pullPacket(bkbuf, ether_mode)
466             formatted = formatPacket(packet, ether_mode)
467             if with_pi:
468                 packet = piWrap(packet, ether_mode)
469             os.write(tun.fileno(), packet)
470             print >>sys.stderr, '<', formatted
471         
472         # check incoming data packets
473         if tun in rdrdy:
474             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
475             if with_pi:
476                 packet = piStrip(packet)
477             fwbuf += packet
478         if remote in rdrdy:
479             try:
480                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
481                 if crypto_mode:
482                     packet = decrypt(packet, crypter)
483             except:
484                 if not options.udp:
485                     # in UDP mode, we ignore errors - packet loss man...
486                     raise
487             bkbuf += packet
488
489
490
491 nop = lambda tun_path, tun_name : (tun_path, tun_name)
492 MODEINFO = {
493     'none' : dict(alloc=nop,
494                   tunopen=tunopen, tunclose=tunclose,
495                   dealloc=nop,
496                   start=nop,
497                   stop=nop),
498     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
499                   tunopen=tunopen, tunclose=tunclose,
500                   dealloc=tuntap_dealloc,
501                   start=vif_start,
502                   stop=vif_stop),
503     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
504                   tunopen=tunopen, tunclose=tunclose,
505                   dealloc=tuntap_dealloc,
506                   start=vif_start,
507                   stop=vif_stop),
508     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
509                   tunopen=tunopen, tunclose=tunclose,
510                   dealloc=nop,
511                   start=pl_vif_start,
512                   stop=nop),
513     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
514                   tunopen=tunopen, tunclose=tunclose,
515                   dealloc=nop,
516                   start=pl_vif_start,
517                   stop=nop),
518 }
519     
520 tun_path = options.tun_path
521 tun_name = options.tun_name
522
523 modeinfo = MODEINFO[options.mode]
524
525 # be careful to roll back stuff on exceptions
526 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
527 try:
528     modeinfo['start'](tun_path, tun_name)
529     try:
530         tun = modeinfo['tunopen'](tun_path, tun_name)
531     except:
532         modeinfo['stop'](tun_path, tun_name)
533         raise
534 except:
535     modeinfo['dealloc'](tun_path, tun_name)
536     raise
537
538
539 try:
540     if options.pass_fd:
541         if options.pass_fd.startswith("base64:"):
542             options.pass_fd = base64.b64decode(
543                 options.pass_fd[len("base64:"):])
544             options.pass_fd = os.path.expandvars(options.pass_fd)
545         
546         print >>sys.stderr, "Sending FD to: %r" % (options.pass_fd,)
547         
548         # send FD to whoever wants it
549         import passfd
550         
551         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
552         try:
553             sock.connect(options.pass_fd)
554         except socket.error:
555             # wait a while, retry
556             time.sleep(1)
557             sock.connect(options.pass_fd)
558         passfd.sendfd(sock, tun.fileno(), '0')
559         
560         # just wait forever
561         def tun_fwd(tun, remote):
562             while True:
563                 time.sleep(1)
564         remote = None
565     elif options.udp:
566         # connect to remote endpoint
567         if remaining_args and not remaining_args[0].startswith('-'):
568             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
569             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
570             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
571             rsock.bind((hostaddr,options.udp))
572             rsock.connect((remaining_args[0],options.port))
573         else:
574             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
575             raise AssertionError, "Error: need a remote endpoint in UDP mode"
576         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
577     else:
578         # connect to remote endpoint
579         if remaining_args and not remaining_args[0].startswith('-'):
580             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
581             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
582             rsock.connect((remaining_args[0],options.port))
583         else:
584             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
585             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
586             lsock.bind((hostaddr,options.port))
587             lsock.listen(1)
588             rsock,raddr = lsock.accept()
589         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
590
591     print >>sys.stderr, "Connected"
592
593     tun_fwd(tun, remote)
594 finally:
595     try:
596         print >>sys.stderr, "Shutting down..."
597     except:
598         # In case sys.stderr is broken
599         pass
600     
601     # tidy shutdown in every case - swallow exceptions
602     try:
603         modeinfo['tunclose'](tun_path, tun_name, tun)
604     except:
605         pass
606         
607     try:
608         modeinfo['stop'](tun_path, tun_name)
609     except:
610         pass
611
612     try:
613         modeinfo['dealloc'](tun_path, tun_name)
614     except:
615         pass
616
617