FIX support for tap and pl-tap modes.
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import select
7
8 import struct
9 import ctypes
10 import optparse
11 import threading
12 import subprocess
13 import re
14 import functools
15
16 tun_name = 'tun0'
17 tun_path = '/dev/net/tun'
18 hostaddr = socket.gethostbyname(socket.gethostname())
19
20 usage = "usage: %prog [options] <remote-endpoint>"
21
22 parser = optparse.OptionParser(usage=usage)
23
24 parser.add_option(
25     "-i", "--iface", dest="tun_name", metavar="DEVICE",
26     default = "tun0",
27     help = "TUN/TAP interface to tap into")
28 parser.add_option(
29     "-d", "--tun-path", dest="tun_path", metavar="PATH",
30     default = "/dev/net/tun",
31     help = "TUN/TAP device file path or file descriptor number")
32 parser.add_option(
33     "-p", "--port", dest="port", metavar="PORT", type="int",
34     default = 15000,
35     help = "Peering TCP port to connect or listen to.")
36
37 parser.add_option(
38     "-m", "--mode", dest="mode", metavar="MODE",
39     default = "none",
40     help = 
41         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
42         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
43         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
44         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
45 parser.add_option(
46     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
47     default = None,
48     help = 
49         "See mode. This specifies the VIF_ADDRESS, "
50         "the IP address of the virtual interface.")
51 parser.add_option(
52     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
53     default = None,
54     help = 
55         "See mode. This specifies the VIF_MASK, "
56         "a number indicating the network type (ie: 24 for a C-class network).")
57 parser.add_option(
58     "-S", "--vif-snat", dest="vif_snat", 
59     action = "store_true",
60     default = False,
61     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
62 parser.add_option(
63     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
64     default = None,
65     help = 
66         "See mode. This specifies the remote endpoint's virtual address, "
67         "for point-to-point routing configuration. "
68         "Not supported by PlanetLab" )
69 parser.add_option(
70     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
71     default = None,
72     help = 
73         "See mode. This specifies the interface's transmission queue length. " )
74
75 (options, remaining_args) = parser.parse_args(sys.argv[1:])
76
77
78 ETH_P_ALL = 0x00000003
79 ETH_P_IP = 0x00000800
80 TUNSETIFF = 0x400454ca
81 IFF_NO_PI = 0x00001000
82 IFF_TAP = 0x00000002
83 IFF_TUN = 0x00000001
84 IFF_VNET_HDR = 0x00004000
85 TUN_PKT_STRIP = 0x00000001
86 IFHWADDRLEN = 0x00000006
87 IFNAMSIZ = 0x00000010
88 IFREQ_SZ = 0x00000028
89 FIONREAD = 0x0000541b
90
91 def ifnam(x):
92     return x+'\x00'*(IFNAMSIZ-len(x))
93
94 def ifreq(iface, flags):
95     # ifreq contains:
96     #   char[IFNAMSIZ] : interface name
97     #   short : flags
98     #   <padding>
99     ifreq = ifnam(iface)+struct.pack("H",flags);
100     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
101     return ifreq
102
103 def tunopen(tun_path, tun_name):
104     if tun_path.isdigit():
105         # open TUN fd
106         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
107         tun = os.fdopen(int(tun_path), 'r+b', 0)
108     else:
109         # open TUN path
110         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
111         tun = open(tun_path, 'r+b', 0)
112
113         # bind file descriptor to the interface
114         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
115     
116     return tun
117
118 def tunclose(tun_path, tun_name, tun):
119     if tun_path.isdigit():
120         # close TUN fd
121         os.close(int(tun_path))
122         tun.close()
123     else:
124         # close TUN object
125         tun.close()
126
127 def tuntap_alloc(kind, tun_path, tun_name):
128     args = ["tunctl"]
129     if kind == "tun":
130         args.append("-n")
131     if tun_name:
132         args.append("-t")
133         args.append(tun_name)
134     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
135     out,err = proc.communicate()
136     if proc.wait():
137         raise RuntimeError, "Could not allocate %s device" % (kind,)
138         
139     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
140     if not match:
141         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
142     
143     tun_name = match.group("dev")
144     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
145     
146     return tun_path, tun_name
147
148 def tuntap_dealloc(tun_path, tun_name):
149     args = ["tunctl", "-d", tun_name]
150     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
151     out,err = proc.communicate()
152     if proc.wait():
153         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
154
155 def nmask_to_dot_notation(mask):
156     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
157     mask = mask[2:] # strip 0x
158     mask = mask.decode("hex") # to bytes
159     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
160     return mask
161
162 def vif_start(tun_path, tun_name):
163     args = ["ifconfig", tun_name, options.vif_addr, 
164             "netmask", nmask_to_dot_notation(options.vif_mask),
165             "-arp" ]
166     if options.vif_pointopoint:
167         args.extend(["pointopoint",options.vif_pointopoint])
168     if options.vif_txqueuelen is not None:
169         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
170     args.append("up")
171     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
172     out,err = proc.communicate()
173     if proc.wait():
174         raise RuntimeError, "Error starting virtual interface"
175     
176     if options.vif_snat:
177         # set up SNAT using iptables
178         # TODO: stop vif on error. 
179         #   Not so necessary since deallocating the tun/tap device
180         #   will forcibly stop it, but it would be tidier
181         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
182                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
183                  "-j", "SNAT",
184                  "--to-source", hostaddr, "--random" ]
185         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
186         out,err = proc.communicate()
187         if proc.wait():
188             raise RuntimeError, "Error setting up SNAT"
189
190 def vif_stop(tun_path, tun_name):
191     if options.vif_snat:
192         # set up SNAT using iptables
193         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
194                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
195                  "-j", "SNAT",
196                  "--to-source", hostaddr, "--random" ]
197         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
198         out,err = proc.communicate()
199     
200     args = ["ifconfig", tun_name, "down"]
201     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
202     out,err = proc.communicate()
203     if proc.wait():
204         print >>sys.stderr, "WARNING: error stopping virtual interface"
205     
206     
207 def pl_tuntap_alloc(kind, tun_path, tun_name):
208     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
209     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
210     kind = {"tun":IFF_TUN,
211             "tap":IFF_TAP}[kind]
212     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
213     name = c_tun_name.value
214     return str(fd), name
215
216 def pl_vif_start(tun_path, tun_name):
217     stdin = open("/vsys/vif_up.in","w")
218     stdout = open("/vsys/vif_up.out","r")
219     stdin.write(tun_name+"\n")
220     stdin.write(options.vif_addr+"\n")
221     stdin.write(str(options.vif_mask)+"\n")
222     if options.vif_snat:
223         stdin.write("snat=1\n")
224     if options.vif_txqueuelen is not None:
225         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
226     stdin.close()
227     out = stdout.read()
228     stdout.close()
229     if out.strip():
230         print >>sys.stderr, out
231
232
233 def ipfmt(ip):
234     ipbytes = map(ord,ip.decode("hex"))
235     return '.'.join(map(str,ipbytes))
236
237 tagtype = {
238     '0806' : 'arp ',
239     '0800' : 'ipv4 ',
240     '8870' : 'jumbo ',
241     '8863' : 'PPPoE discover ',
242     '8864' : 'PPPoE ',
243 }
244 def etherProto(packet):
245     packet = packet.encode("hex")
246     if len(packet) > 14:
247         if packet[12:14] == "\x81\x00":
248             # tagged
249             return packet[16:18]
250         else:
251             # untagged
252             return packet[12:14]
253     # default: ip
254     return "\x08\x00"
255 def formatPacket(packet, ether_mode):
256     if ether_mode:
257         stripped_packet = etherStrip(packet)
258         if not stripped_packet:
259             packet = packet.encode("hex")
260             if len(packet) < 28:
261                 return "malformed eth " + packet.encode("hex")
262             else:
263                 if packet[24:28] == "8100":
264                     # tagged
265                     ethertype = tagtype.get(packet[32:36], 'eth')
266                     return ethertype + " " + ( '-'.join( (
267                         packet[0:12], # MAC dest
268                         packet[12:24], # MAC src
269                         packet[24:32], # VLAN tag
270                         packet[32:36], # Ethertype/len
271                         packet[36:], # Payload
272                     ) ) )
273                 else:
274                     # untagged
275                     ethertype = tagtype.get(packet[24:28], 'eth')
276                     return ethertype + " " + ( '-'.join( (
277                         packet[0:12], # MAC dest
278                         packet[12:24], # MAC src
279                         packet[24:28], # Ethertype/len
280                         packet[28:], # Payload
281                     ) ) )
282         else:
283             packet = stripped_packet
284     packet = packet.encode("hex")
285     if len(packet) < 48:
286         return "malformed ip " + packet
287     else:
288         return "ip " + ( '-'.join( (
289             packet[0:1], #version
290             packet[1:2], #header length
291             packet[2:4], #diffserv/ECN
292             packet[4:8], #total length
293             packet[8:12], #ident
294             packet[12:16], #flags/fragment offs
295             packet[16:18], #ttl
296             packet[18:20], #ip-proto
297             packet[20:24], #checksum
298             ipfmt(packet[24:32]), # src-ip
299             ipfmt(packet[32:40]), # dst-ip
300             packet[40:48] if (int(packet[1],16) > 5) else "", # options
301             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
302         ) ) )
303
304 def packetReady(buf, ether_mode):
305     if len(buf) < 4:
306         return False
307     elif ether_mode:
308         return True
309     else:
310         _,totallen = struct.unpack('HH',buf[:4])
311         totallen = socket.htons(totallen)
312         return len(buf) >= totallen
313
314 def pullPacket(buf, ether_mode):
315     if ether_mode:
316         return buf, ""
317     else:
318         _,totallen = struct.unpack('HH',buf[:4])
319         totallen = socket.htons(totallen)
320         return buf[:totallen], buf[totallen:]
321
322 def etherStrip(buf):
323     if len(buf) < 14:
324         return ""
325     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
326         # tagged ethernet frame
327         return buf[18:-4]
328     elif buf[12:14] == '\x08\x00':
329         # untagged ethernet frame
330         return buf[14:-4]
331     else:
332         return ""
333
334 def etherWrap(packet):
335     return (
336         "\x00"*6*2 # bogus src and dst mac
337         +"\x08\x00" # IPv4
338         +packet # payload
339         +"\x00"*4 # bogus crc
340     )
341
342 def piStrip(buf):
343     if len(buf) < 4:
344         return buf
345     else:
346         return buf[4:]
347     
348 def piWrap(buf, ether_mode):
349     if ether_mode:
350         proto = etherProto(buf)
351     else:
352         proto = "\x08\x00"
353     return (
354         "\x00\x00" # PI: 16 bits flags
355         +proto # 16 bits proto
356         +buf
357     )
358
359 abortme = False
360 def tun_fwd(tun, remote):
361     global abortme
362     
363     # in PL mode, we cannot strip PI structs
364     # so we'll have to handle them
365     with_pi = options.mode.startswith('pl-')
366     ether_mode = tun_name.startswith('tap')
367     
368     # Limited frame parsing, to preserve packet boundaries.
369     # Which is needed, since /dev/net/tun is unbuffered
370     fwbuf = ""
371     bkbuf = ""
372     while not abortme:
373         wset = []
374         if packetReady(bkbuf, ether_mode):
375             wset.append(tun)
376         if packetReady(fwbuf, ether_mode):
377             wset.append(remote)
378         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
379         
380         # check for errors
381         if errs:
382             break
383         
384         # check to see if we can write
385         if remote in wrdy and packetReady(fwbuf, ether_mode):
386             packet, fwbuf = pullPacket(fwbuf, ether_mode)
387             os.write(remote.fileno(), packet)
388             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
389         if tun in wrdy and packetReady(bkbuf, ether_mode):
390             packet, bkbuf = pullPacket(bkbuf, ether_mode)
391             formatted = formatPacket(packet, ether_mode)
392             if with_pi:
393                 packet = piWrap(packet, ether_mode)
394             os.write(tun.fileno(), packet)
395             print >>sys.stderr, '<', formatted
396         
397         # check incoming data packets
398         if tun in rdrdy:
399             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
400             if with_pi:
401                 packet = piStrip(packet)
402             fwbuf += packet
403         if remote in rdrdy:
404             packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
405             bkbuf += packet
406
407
408
409 nop = lambda tun_path, tun_name : (tun_path, tun_name)
410 MODEINFO = {
411     'none' : dict(alloc=nop,
412                   tunopen=tunopen, tunclose=tunclose,
413                   dealloc=nop,
414                   start=nop,
415                   stop=nop),
416     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
417                   tunopen=tunopen, tunclose=tunclose,
418                   dealloc=tuntap_dealloc,
419                   start=vif_start,
420                   stop=vif_stop),
421     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
422                   tunopen=tunopen, tunclose=tunclose,
423                   dealloc=tuntap_dealloc,
424                   start=vif_start,
425                   stop=vif_stop),
426     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
427                   tunopen=tunopen, tunclose=tunclose,
428                   dealloc=nop,
429                   start=pl_vif_start,
430                   stop=nop),
431     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
432                   tunopen=tunopen, tunclose=tunclose,
433                   dealloc=nop,
434                   start=pl_vif_start,
435                   stop=nop),
436 }
437     
438 tun_path = options.tun_path
439 tun_name = options.tun_name
440
441 modeinfo = MODEINFO[options.mode]
442
443 # be careful to roll back stuff on exceptions
444 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
445 try:
446     modeinfo['start'](tun_path, tun_name)
447     try:
448         tun = modeinfo['tunopen'](tun_path, tun_name)
449     except:
450         modeinfo['stop'](tun_path, tun_name)
451         raise
452 except:
453     modeinfo['dealloc'](tun_path, tun_name)
454     raise
455
456
457 try:
458     # connect to remote endpoint
459     if remaining_args and not remaining_args[0].startswith('-'):
460         print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
461         rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
462         rsock.connect((remaining_args[0],options.port))
463     else:
464         print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
465         lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
466         lsock.bind((hostaddr,options.port))
467         lsock.listen(1)
468         rsock,raddr = lsock.accept()
469     remote = os.fdopen(rsock.fileno(), 'r+b', 0)
470
471     print >>sys.stderr, "Connected"
472
473     tun_fwd(tun, remote)
474 finally:
475     try:
476         print >>sys.stderr, "Shutting down..."
477     except:
478         # In case sys.stderr is broken
479         pass
480     
481     # tidy shutdown in every case - swallow exceptions
482     try:
483         modeinfo['tunclose'](tun_path, tun_name, tun)
484     except:
485         pass
486         
487     try:
488         modeinfo['stop'](tun_path, tun_name)
489     except:
490         pass
491
492     try:
493         modeinfo['dealloc'](tun_path, tun_name)
494     except:
495         pass
496
497