- TAP interfaces
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import select
7
8 import struct
9 import ctypes
10 import optparse
11 import threading
12 import subprocess
13 import re
14 import functools
15
16 tun_name = 'tun0'
17 tun_path = '/dev/net/tun'
18 hostaddr = socket.gethostbyname(socket.gethostname())
19
20 usage = "usage: %prog [options] <remote-endpoint>"
21
22 parser = optparse.OptionParser(usage=usage)
23
24 parser.add_option(
25     "-i", "--iface", dest="tun_name", metavar="DEVICE",
26     default = "tun0",
27     help = "TUN/TAP interface to tap into")
28 parser.add_option(
29     "-d", "--tun-path", dest="tun_path", metavar="PATH",
30     default = "/dev/net/tun",
31     help = "TUN/TAP device file path or file descriptor number")
32 parser.add_option(
33     "-p", "--port", dest="port", metavar="PORT", type="int",
34     default = 15000,
35     help = "Peering TCP port to connect or listen to.")
36
37 parser.add_option(
38     "-m", "--mode", dest="mode", metavar="MODE",
39     default = "none",
40     help = 
41         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
42         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
43         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
44         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
45 parser.add_option(
46     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
47     default = None,
48     help = 
49         "See mode. This specifies the VIF_ADDRESS, "
50         "the IP address of the virtual interface.")
51 parser.add_option(
52     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
53     default = None,
54     help = 
55         "See mode. This specifies the VIF_MASK, "
56         "a number indicating the network type (ie: 24 for a C-class network).")
57 parser.add_option(
58     "-S", "--vif-snat", dest="vif_snat", 
59     action = "store_true",
60     default = False,
61     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
62 parser.add_option(
63     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
64     default = None,
65     help = 
66         "See mode. This specifies the remote endpoint's virtual address, "
67         "for point-to-point routing configuration. "
68         "Not supported by PlanetLab" )
69 parser.add_option(
70     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
71     default = None,
72     help = 
73         "See mode. This specifies the interface's transmission queue length. " )
74 parser.add_option(
75     "-u", "--udp", dest="udp", metavar="PORT", type="int",
76     default = None,
77     help = 
78         "Bind to the specified UDP port locally, and send UDP datagrams to the "
79         "remote endpoint, creating a tunnel through UDP rather than TCP." )
80 parser.add_option(
81     "-k", "--key", dest="cipher_key", metavar="KEY",
82     default = None,
83     help = 
84         "Specify a symmetric encryption key with which to protect packets across "
85         "the tunnel. python-crypto must be installed on the system." )
86
87 (options, remaining_args) = parser.parse_args(sys.argv[1:])
88
89
90 ETH_P_ALL = 0x00000003
91 ETH_P_IP = 0x00000800
92 TUNSETIFF = 0x400454ca
93 IFF_NO_PI = 0x00001000
94 IFF_TAP = 0x00000002
95 IFF_TUN = 0x00000001
96 IFF_VNET_HDR = 0x00004000
97 TUN_PKT_STRIP = 0x00000001
98 IFHWADDRLEN = 0x00000006
99 IFNAMSIZ = 0x00000010
100 IFREQ_SZ = 0x00000028
101 FIONREAD = 0x0000541b
102
103 def ifnam(x):
104     return x+'\x00'*(IFNAMSIZ-len(x))
105
106 def ifreq(iface, flags):
107     # ifreq contains:
108     #   char[IFNAMSIZ] : interface name
109     #   short : flags
110     #   <padding>
111     ifreq = ifnam(iface)+struct.pack("H",flags);
112     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
113     return ifreq
114
115 def tunopen(tun_path, tun_name):
116     if tun_path.isdigit():
117         # open TUN fd
118         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
119         tun = os.fdopen(int(tun_path), 'r+b', 0)
120     else:
121         # open TUN path
122         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
123         tun = open(tun_path, 'r+b', 0)
124
125         # bind file descriptor to the interface
126         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
127     
128     return tun
129
130 def tunclose(tun_path, tun_name, tun):
131     if tun_path.isdigit():
132         # close TUN fd
133         os.close(int(tun_path))
134         tun.close()
135     else:
136         # close TUN object
137         tun.close()
138
139 def tuntap_alloc(kind, tun_path, tun_name):
140     args = ["tunctl"]
141     if kind == "tun":
142         args.append("-n")
143     if tun_name:
144         args.append("-t")
145         args.append(tun_name)
146     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
147     out,err = proc.communicate()
148     if proc.wait():
149         raise RuntimeError, "Could not allocate %s device" % (kind,)
150         
151     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
152     if not match:
153         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
154     
155     tun_name = match.group("dev")
156     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
157     
158     return tun_path, tun_name
159
160 def tuntap_dealloc(tun_path, tun_name):
161     args = ["tunctl", "-d", tun_name]
162     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
163     out,err = proc.communicate()
164     if proc.wait():
165         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
166
167 def nmask_to_dot_notation(mask):
168     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
169     mask = mask[2:] # strip 0x
170     mask = mask.decode("hex") # to bytes
171     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
172     return mask
173
174 def vif_start(tun_path, tun_name):
175     args = ["ifconfig", tun_name, options.vif_addr, 
176             "netmask", nmask_to_dot_notation(options.vif_mask),
177             "-arp" ]
178     if options.vif_pointopoint:
179         args.extend(["pointopoint",options.vif_pointopoint])
180     if options.vif_txqueuelen is not None:
181         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
182     args.append("up")
183     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
184     out,err = proc.communicate()
185     if proc.wait():
186         raise RuntimeError, "Error starting virtual interface"
187     
188     if options.vif_snat:
189         # set up SNAT using iptables
190         # TODO: stop vif on error. 
191         #   Not so necessary since deallocating the tun/tap device
192         #   will forcibly stop it, but it would be tidier
193         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
194                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
195                  "-j", "SNAT",
196                  "--to-source", hostaddr, "--random" ]
197         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
198         out,err = proc.communicate()
199         if proc.wait():
200             raise RuntimeError, "Error setting up SNAT"
201
202 def vif_stop(tun_path, tun_name):
203     if options.vif_snat:
204         # set up SNAT using iptables
205         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
206                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
207                  "-j", "SNAT",
208                  "--to-source", hostaddr, "--random" ]
209         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
210         out,err = proc.communicate()
211     
212     args = ["ifconfig", tun_name, "down"]
213     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
214     out,err = proc.communicate()
215     if proc.wait():
216         print >>sys.stderr, "WARNING: error stopping virtual interface"
217     
218     
219 def pl_tuntap_alloc(kind, tun_path, tun_name):
220     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
221     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
222     kind = {"tun":IFF_TUN,
223             "tap":IFF_TAP}[kind]
224     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
225     name = c_tun_name.value
226     return str(fd), name
227
228 def pl_vif_start(tun_path, tun_name):
229     stdin = open("/vsys/vif_up.in","w")
230     stdout = open("/vsys/vif_up.out","r")
231     stdin.write(tun_name+"\n")
232     stdin.write(options.vif_addr+"\n")
233     stdin.write(str(options.vif_mask)+"\n")
234     if options.vif_snat:
235         stdin.write("snat=1\n")
236     if options.vif_txqueuelen is not None:
237         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
238     stdin.close()
239     out = stdout.read()
240     stdout.close()
241     if out.strip():
242         print >>sys.stderr, out
243
244
245 def ipfmt(ip):
246     ipbytes = map(ord,ip.decode("hex"))
247     return '.'.join(map(str,ipbytes))
248
249 tagtype = {
250     '0806' : 'arp ',
251     '0800' : 'ipv4 ',
252     '8870' : 'jumbo ',
253     '8863' : 'PPPoE discover ',
254     '8864' : 'PPPoE ',
255 }
256 def etherProto(packet):
257     packet = packet.encode("hex")
258     if len(packet) > 14:
259         if packet[12:14] == "\x81\x00":
260             # tagged
261             return packet[16:18]
262         else:
263             # untagged
264             return packet[12:14]
265     # default: ip
266     return "\x08\x00"
267 def formatPacket(packet, ether_mode):
268     if ether_mode:
269         stripped_packet = etherStrip(packet)
270         if not stripped_packet:
271             packet = packet.encode("hex")
272             if len(packet) < 28:
273                 return "malformed eth " + packet.encode("hex")
274             else:
275                 if packet[24:28] == "8100":
276                     # tagged
277                     ethertype = tagtype.get(packet[32:36], 'eth')
278                     return ethertype + " " + ( '-'.join( (
279                         packet[0:12], # MAC dest
280                         packet[12:24], # MAC src
281                         packet[24:32], # VLAN tag
282                         packet[32:36], # Ethertype/len
283                         packet[36:], # Payload
284                     ) ) )
285                 else:
286                     # untagged
287                     ethertype = tagtype.get(packet[24:28], 'eth')
288                     return ethertype + " " + ( '-'.join( (
289                         packet[0:12], # MAC dest
290                         packet[12:24], # MAC src
291                         packet[24:28], # Ethertype/len
292                         packet[28:], # Payload
293                     ) ) )
294         else:
295             packet = stripped_packet
296     packet = packet.encode("hex")
297     if len(packet) < 48:
298         return "malformed ip " + packet
299     else:
300         return "ip " + ( '-'.join( (
301             packet[0:1], #version
302             packet[1:2], #header length
303             packet[2:4], #diffserv/ECN
304             packet[4:8], #total length
305             packet[8:12], #ident
306             packet[12:16], #flags/fragment offs
307             packet[16:18], #ttl
308             packet[18:20], #ip-proto
309             packet[20:24], #checksum
310             ipfmt(packet[24:32]), # src-ip
311             ipfmt(packet[32:40]), # dst-ip
312             packet[40:48] if (int(packet[1],16) > 5) else "", # options
313             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
314         ) ) )
315
316 def packetReady(buf, ether_mode):
317     if len(buf) < 4:
318         return False
319     elif ether_mode:
320         return True
321     else:
322         _,totallen = struct.unpack('HH',buf[:4])
323         totallen = socket.htons(totallen)
324         return len(buf) >= totallen
325
326 def pullPacket(buf, ether_mode):
327     if ether_mode:
328         return buf, ""
329     else:
330         _,totallen = struct.unpack('HH',buf[:4])
331         totallen = socket.htons(totallen)
332         return buf[:totallen], buf[totallen:]
333
334 def etherStrip(buf):
335     if len(buf) < 14:
336         return ""
337     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
338         # tagged ethernet frame
339         return buf[18:-4]
340     elif buf[12:14] == '\x08\x00':
341         # untagged ethernet frame
342         return buf[14:-4]
343     else:
344         return ""
345
346 def etherWrap(packet):
347     return (
348         "\x00"*6*2 # bogus src and dst mac
349         +"\x08\x00" # IPv4
350         +packet # payload
351         +"\x00"*4 # bogus crc
352     )
353
354 def piStrip(buf):
355     if len(buf) < 4:
356         return buf
357     else:
358         return buf[4:]
359     
360 def piWrap(buf, ether_mode):
361     if ether_mode:
362         proto = etherProto(buf)
363     else:
364         proto = "\x08\x00"
365     return (
366         "\x00\x00" # PI: 16 bits flags
367         +proto # 16 bits proto
368         +buf
369     )
370
371 def encrypt(packet, crypter):
372     # pad
373     padding = crypter.block_size - len(packet) % crypter.block_size
374     packet += chr(padding) * padding
375     
376     # encrypt
377     return crypter.encrypt(packet)
378
379 def decrypt(packet, crypter):
380     # decrypt
381     packet = crypter.decrypt(packet)
382     
383     # un-pad
384     padding = ord(packet[-1])
385     if not (0 < padding <= crypter.block_size):
386         # wrong padding
387         raise RuntimeError, "Truncated packet"
388     packet = packet[:-padding]
389     
390     return packet
391
392 abortme = False
393 def tun_fwd(tun, remote):
394     global abortme
395     
396     # in PL mode, we cannot strip PI structs
397     # so we'll have to handle them
398     with_pi = options.mode.startswith('pl-')
399     ether_mode = tun_name.startswith('tap')
400     
401     crypto_mode = False
402     try:
403         if options.cipher_key:
404             import Crypto.Cipher.AES
405             import hashlib
406             
407             hashed_key = hashlib.sha256(options.cipher_key).digest()
408             crypter = Crypto.Cipher.AES.new(
409                 hashed_key, 
410                 Crypto.Cipher.AES.MODE_ECB)
411             crypto_mode = True
412     except:
413         import traceback
414         traceback.print_exc()
415         crypto_mode = False
416         crypter = None
417
418     if crypto_mode:
419         print >>sys.stderr, "Packets are transmitted in CIPHER"
420     else:
421         print >>sys.stderr, "Packets are transmitted in PLAINTEXT"
422     
423     # Limited frame parsing, to preserve packet boundaries.
424     # Which is needed, since /dev/net/tun is unbuffered
425     fwbuf = ""
426     bkbuf = ""
427     while not abortme:
428         wset = []
429         if packetReady(bkbuf, ether_mode):
430             wset.append(tun)
431         if packetReady(fwbuf, ether_mode):
432             wset.append(remote)
433         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
434         
435         # check for errors
436         if errs:
437             break
438         
439         # check to see if we can write
440         if remote in wrdy and packetReady(fwbuf, ether_mode):
441             packet, fwbuf = pullPacket(fwbuf, ether_mode)
442             try:
443                 if crypto_mode:
444                     enpacket = encrypt(packet, crypter)
445                 else:
446                     enpacket = packet
447                 os.write(remote.fileno(), enpacket)
448             except:
449                 if not options.udp:
450                     # in UDP mode, we ignore errors - packet loss man...
451                     raise
452             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
453         if tun in wrdy and packetReady(bkbuf, ether_mode):
454             packet, bkbuf = pullPacket(bkbuf, ether_mode)
455             formatted = formatPacket(packet, ether_mode)
456             if with_pi:
457                 packet = piWrap(packet, ether_mode)
458             os.write(tun.fileno(), packet)
459             print >>sys.stderr, '<', formatted
460         
461         # check incoming data packets
462         if tun in rdrdy:
463             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
464             if with_pi:
465                 packet = piStrip(packet)
466             fwbuf += packet
467         if remote in rdrdy:
468             try:
469                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
470                 if crypto_mode:
471                     packet = decrypt(packet, crypter)
472             except:
473                 if not options.udp:
474                     # in UDP mode, we ignore errors - packet loss man...
475                     raise
476             bkbuf += packet
477
478
479
480 nop = lambda tun_path, tun_name : (tun_path, tun_name)
481 MODEINFO = {
482     'none' : dict(alloc=nop,
483                   tunopen=tunopen, tunclose=tunclose,
484                   dealloc=nop,
485                   start=nop,
486                   stop=nop),
487     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
488                   tunopen=tunopen, tunclose=tunclose,
489                   dealloc=tuntap_dealloc,
490                   start=vif_start,
491                   stop=vif_stop),
492     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
493                   tunopen=tunopen, tunclose=tunclose,
494                   dealloc=tuntap_dealloc,
495                   start=vif_start,
496                   stop=vif_stop),
497     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
498                   tunopen=tunopen, tunclose=tunclose,
499                   dealloc=nop,
500                   start=pl_vif_start,
501                   stop=nop),
502     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
503                   tunopen=tunopen, tunclose=tunclose,
504                   dealloc=nop,
505                   start=pl_vif_start,
506                   stop=nop),
507 }
508     
509 tun_path = options.tun_path
510 tun_name = options.tun_name
511
512 modeinfo = MODEINFO[options.mode]
513
514 # be careful to roll back stuff on exceptions
515 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
516 try:
517     modeinfo['start'](tun_path, tun_name)
518     try:
519         tun = modeinfo['tunopen'](tun_path, tun_name)
520     except:
521         modeinfo['stop'](tun_path, tun_name)
522         raise
523 except:
524     modeinfo['dealloc'](tun_path, tun_name)
525     raise
526
527
528 try:
529     if options.udp:
530         # connect to remote endpoint
531         if remaining_args and not remaining_args[0].startswith('-'):
532             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
533             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
534             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
535             rsock.bind((hostaddr,options.udp))
536             rsock.connect((remaining_args[0],options.port))
537         else:
538             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
539             raise AssertionError, "Error: need a remote endpoint in UDP mode"
540         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
541     else:
542         # connect to remote endpoint
543         if remaining_args and not remaining_args[0].startswith('-'):
544             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
545             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
546             rsock.connect((remaining_args[0],options.port))
547         else:
548             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
549             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
550             lsock.bind((hostaddr,options.port))
551             lsock.listen(1)
552             rsock,raddr = lsock.accept()
553         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
554
555     print >>sys.stderr, "Connected"
556
557     tun_fwd(tun, remote)
558 finally:
559     try:
560         print >>sys.stderr, "Shutting down..."
561     except:
562         # In case sys.stderr is broken
563         pass
564     
565     # tidy shutdown in every case - swallow exceptions
566     try:
567         modeinfo['tunclose'](tun_path, tun_name, tun)
568     except:
569         pass
570         
571     try:
572         modeinfo['stop'](tun_path, tun_name)
573     except:
574         pass
575
576     try:
577         modeinfo['dealloc'](tun_path, tun_name)
578     except:
579         pass
580
581