Synchronization fix for cross connections: wait for tunnels to be up before starting...
[nepi.git] / src / nepi / testbeds / planetlab / scripts / tun_connect.py
1 import sys
2
3 import socket
4 import fcntl
5 import os
6 import os.path
7 import select
8
9 import struct
10 import ctypes
11 import optparse
12 import threading
13 import subprocess
14 import re
15 import functools
16 import time
17 import base64
18
19 import time
20 print >>sys.stderr, time.time()
21
22 tun_name = 'tun0'
23 tun_path = '/dev/net/tun'
24 hostaddr = socket.gethostbyname(socket.gethostname())
25
26 usage = "usage: %prog [options] <remote-endpoint>"
27
28 parser = optparse.OptionParser(usage=usage)
29
30 parser.add_option(
31     "-i", "--iface", dest="tun_name", metavar="DEVICE",
32     default = "tun0",
33     help = "TUN/TAP interface to tap into")
34 parser.add_option(
35     "-d", "--tun-path", dest="tun_path", metavar="PATH",
36     default = "/dev/net/tun",
37     help = "TUN/TAP device file path or file descriptor number")
38 parser.add_option(
39     "-p", "--port", dest="port", metavar="PORT", type="int",
40     default = 15000,
41     help = "Peering TCP port to connect or listen to.")
42 parser.add_option(
43     "--pass-fd", dest="pass_fd", metavar="UNIX_SOCKET",
44     default = None,
45     help = "Path to a unix-domain socket to pass the TUN file descriptor to. "
46            "If given, all other connectivity options are ignored, tun_connect will "
47            "simply wait to be killed after passing the file descriptor, and it will be "
48            "the receiver's responsability to handle the tunneling.")
49
50 parser.add_option(
51     "-m", "--mode", dest="mode", metavar="MODE",
52     default = "none",
53     help = 
54         "Set mode. One of none, tun, tap, pl-tun, pl-tap. In any mode except none, a TUN/TAP will be created "
55         "by using the proper interface (tunctl for tun/tap, /vsys/fd_tuntap.control for pl-tun/pl-tap), "
56         "and it will be brought up (with ifconfig for tun/tap, with /vsys/vif_up for pl-tun/pl-tap). You have "
57         "to specify an VIF_ADDRESS and VIF_MASK in any case (except for none).")
58 parser.add_option(
59     "-A", "--vif-address", dest="vif_addr", metavar="VIF_ADDRESS",
60     default = None,
61     help = 
62         "See mode. This specifies the VIF_ADDRESS, "
63         "the IP address of the virtual interface.")
64 parser.add_option(
65     "-M", "--vif-mask", dest="vif_mask", type="int", metavar="VIF_MASK", 
66     default = None,
67     help = 
68         "See mode. This specifies the VIF_MASK, "
69         "a number indicating the network type (ie: 24 for a C-class network).")
70 parser.add_option(
71     "-S", "--vif-snat", dest="vif_snat", 
72     action = "store_true",
73     default = False,
74     help = "See mode. This specifies whether SNAT will be enabled for the virtual interface. " )
75 parser.add_option(
76     "-P", "--vif-pointopoint", dest="vif_pointopoint",  metavar="DST_ADDR",
77     default = None,
78     help = 
79         "See mode. This specifies the remote endpoint's virtual address, "
80         "for point-to-point routing configuration. "
81         "Not supported by PlanetLab" )
82 parser.add_option(
83     "-Q", "--vif-txqueuelen", dest="vif_txqueuelen", metavar="SIZE", type="int",
84     default = None,
85     help = 
86         "See mode. This specifies the interface's transmission queue length. " )
87 parser.add_option(
88     "-u", "--udp", dest="udp", metavar="PORT", type="int",
89     default = None,
90     help = 
91         "Bind to the specified UDP port locally, and send UDP datagrams to the "
92         "remote endpoint, creating a tunnel through UDP rather than TCP." )
93 parser.add_option(
94     "-k", "--key", dest="cipher_key", metavar="KEY",
95     default = None,
96     help = 
97         "Specify a symmetric encryption key with which to protect packets across "
98         "the tunnel. python-crypto must be installed on the system." )
99
100 (options, remaining_args) = parser.parse_args(sys.argv[1:])
101
102
103 ETH_P_ALL = 0x00000003
104 ETH_P_IP = 0x00000800
105 TUNSETIFF = 0x400454ca
106 IFF_NO_PI = 0x00001000
107 IFF_TAP = 0x00000002
108 IFF_TUN = 0x00000001
109 IFF_VNET_HDR = 0x00004000
110 TUN_PKT_STRIP = 0x00000001
111 IFHWADDRLEN = 0x00000006
112 IFNAMSIZ = 0x00000010
113 IFREQ_SZ = 0x00000028
114 FIONREAD = 0x0000541b
115
116 def ifnam(x):
117     return x+'\x00'*(IFNAMSIZ-len(x))
118
119 def ifreq(iface, flags):
120     # ifreq contains:
121     #   char[IFNAMSIZ] : interface name
122     #   short : flags
123     #   <padding>
124     ifreq = ifnam(iface)+struct.pack("H",flags);
125     ifreq += '\x00' * (len(ifreq)-IFREQ_SZ)
126     return ifreq
127
128 def tunopen(tun_path, tun_name):
129     if tun_path.isdigit():
130         # open TUN fd
131         print >>sys.stderr, "Using tun:", tun_name, "fd", tun_path
132         tun = os.fdopen(int(tun_path), 'r+b', 0)
133     else:
134         # open TUN path
135         print >>sys.stderr, "Using tun:", tun_name, "at", tun_path
136         tun = open(tun_path, 'r+b', 0)
137
138         # bind file descriptor to the interface
139         fcntl.ioctl(tun.fileno(), TUNSETIFF, ifreq(tun_name, IFF_NO_PI|IFF_TUN))
140     
141     return tun
142
143 def tunclose(tun_path, tun_name, tun):
144     if tun_path.isdigit():
145         # close TUN fd
146         os.close(int(tun_path))
147         tun.close()
148     else:
149         # close TUN object
150         tun.close()
151
152 def tuntap_alloc(kind, tun_path, tun_name):
153     args = ["tunctl"]
154     if kind == "tun":
155         args.append("-n")
156     if tun_name:
157         args.append("-t")
158         args.append(tun_name)
159     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
160     out,err = proc.communicate()
161     if proc.wait():
162         raise RuntimeError, "Could not allocate %s device" % (kind,)
163         
164     match = re.search(r"Set '(?P<dev>(?:tun|tap)[0-9]*)' persistent and owned by .*", out, re.I)
165     if not match:
166         raise RuntimeError, "Could not allocate %s device - tunctl said: %s" % (kind, out)
167     
168     tun_name = match.group("dev")
169     print >>sys.stderr, "Allocated %s device: %s" % (kind, tun_name)
170     
171     return tun_path, tun_name
172
173 def tuntap_dealloc(tun_path, tun_name):
174     args = ["tunctl", "-d", tun_name]
175     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
176     out,err = proc.communicate()
177     if proc.wait():
178         print >> sys.stderr, "WARNING: error deallocating %s device" % (tun_name,)
179
180 def nmask_to_dot_notation(mask):
181     mask = hex(((1 << mask) - 1) << (32 - mask)) # 24 -> 0xFFFFFF00
182     mask = mask[2:] # strip 0x
183     mask = mask.decode("hex") # to bytes
184     mask = '.'.join(map(str,map(ord,mask))) # to 255.255.255.0
185     return mask
186
187 def vif_start(tun_path, tun_name):
188     args = ["ifconfig", tun_name, options.vif_addr, 
189             "netmask", nmask_to_dot_notation(options.vif_mask),
190             "-arp" ]
191     if options.vif_pointopoint:
192         args.extend(["pointopoint",options.vif_pointopoint])
193     if options.vif_txqueuelen is not None:
194         args.extend(["txqueuelen",str(options.vif_txqueuelen)])
195     args.append("up")
196     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
197     out,err = proc.communicate()
198     if proc.wait():
199         raise RuntimeError, "Error starting virtual interface"
200     
201     if options.vif_snat:
202         # set up SNAT using iptables
203         # TODO: stop vif on error. 
204         #   Not so necessary since deallocating the tun/tap device
205         #   will forcibly stop it, but it would be tidier
206         args = [ "iptables", "-t", "nat", "-A", "POSTROUTING", 
207                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
208                  "-j", "SNAT",
209                  "--to-source", hostaddr, "--random" ]
210         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
211         out,err = proc.communicate()
212         if proc.wait():
213             raise RuntimeError, "Error setting up SNAT"
214
215 def vif_stop(tun_path, tun_name):
216     if options.vif_snat:
217         # set up SNAT using iptables
218         args = [ "iptables", "-t", "nat", "-D", "POSTROUTING", 
219                  "-s", "%s/%d" % (options.vif_addr, options.vif_mask),
220                  "-j", "SNAT",
221                  "--to-source", hostaddr, "--random" ]
222         proc = subprocess.Popen(args, stdout=subprocess.PIPE)
223         out,err = proc.communicate()
224     
225     args = ["ifconfig", tun_name, "down"]
226     proc = subprocess.Popen(args, stdout=subprocess.PIPE)
227     out,err = proc.communicate()
228     if proc.wait():
229         print >>sys.stderr, "WARNING: error stopping virtual interface"
230     
231     
232 def pl_tuntap_alloc(kind, tun_path, tun_name):
233     tunalloc_so = ctypes.cdll.LoadLibrary("./tunalloc.so")
234     c_tun_name = ctypes.c_char_p("\x00"*IFNAMSIZ) # the string will be mutated!
235     kind = {"tun":IFF_TUN,
236             "tap":IFF_TAP}[kind]
237     fd = tunalloc_so.tun_alloc(kind, c_tun_name)
238     name = c_tun_name.value
239     return str(fd), name
240
241 def pl_vif_start(tun_path, tun_name):
242     stdin = open("/vsys/vif_up.in","w")
243     stdout = open("/vsys/vif_up.out","r")
244     stdin.write(tun_name+"\n")
245     stdin.write(options.vif_addr+"\n")
246     stdin.write(str(options.vif_mask)+"\n")
247     if options.vif_snat:
248         stdin.write("snat=1\n")
249     if options.vif_txqueuelen is not None:
250         stdin.write("txqueuelen=%d\n" % (options.vif_txqueuelen,))
251     stdin.close()
252     out = stdout.read()
253     stdout.close()
254     if out.strip():
255         print >>sys.stderr, out
256
257
258 def ipfmt(ip):
259     ipbytes = map(ord,ip.decode("hex"))
260     return '.'.join(map(str,ipbytes))
261
262 tagtype = {
263     '0806' : 'arp',
264     '0800' : 'ipv4',
265     '8870' : 'jumbo',
266     '8863' : 'PPPoE discover',
267     '8864' : 'PPPoE',
268     '86dd' : 'ipv6',
269 }
270 def etherProto(packet):
271     packet = packet.encode("hex")
272     if len(packet) > 14:
273         if packet[12:14] == "\x81\x00":
274             # tagged
275             return packet[16:18]
276         else:
277             # untagged
278             return packet[12:14]
279     # default: ip
280     return "\x08\x00"
281 def formatPacket(packet, ether_mode):
282     if ether_mode:
283         stripped_packet = etherStrip(packet)
284         if not stripped_packet:
285             packet = packet.encode("hex")
286             if len(packet) < 28:
287                 return "malformed eth " + packet.encode("hex")
288             else:
289                 if packet[24:28] == "8100":
290                     # tagged
291                     ethertype = tagtype.get(packet[32:36], 'eth')
292                     return ethertype + " " + ( '-'.join( (
293                         packet[0:12], # MAC dest
294                         packet[12:24], # MAC src
295                         packet[24:32], # VLAN tag
296                         packet[32:36], # Ethertype/len
297                         packet[36:], # Payload
298                     ) ) )
299                 else:
300                     # untagged
301                     ethertype = tagtype.get(packet[24:28], 'eth')
302                     return ethertype + " " + ( '-'.join( (
303                         packet[0:12], # MAC dest
304                         packet[12:24], # MAC src
305                         packet[24:28], # Ethertype/len
306                         packet[28:], # Payload
307                     ) ) )
308         else:
309             packet = stripped_packet
310     packet = packet.encode("hex")
311     if len(packet) < 48:
312         return "malformed ip " + packet
313     else:
314         return "ip " + ( '-'.join( (
315             packet[0:1], #version
316             packet[1:2], #header length
317             packet[2:4], #diffserv/ECN
318             packet[4:8], #total length
319             packet[8:12], #ident
320             packet[12:16], #flags/fragment offs
321             packet[16:18], #ttl
322             packet[18:20], #ip-proto
323             packet[20:24], #checksum
324             ipfmt(packet[24:32]), # src-ip
325             ipfmt(packet[32:40]), # dst-ip
326             packet[40:48] if (int(packet[1],16) > 5) else "", # options
327             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
328         ) ) )
329
330 def packetReady(buf, ether_mode):
331     if len(buf) < 4:
332         return False
333     elif ether_mode:
334         return True
335     else:
336         _,totallen = struct.unpack('HH',buf[:4])
337         totallen = socket.htons(totallen)
338         return len(buf) >= totallen
339
340 def pullPacket(buf, ether_mode):
341     if ether_mode:
342         return buf, ""
343     else:
344         _,totallen = struct.unpack('HH',buf[:4])
345         totallen = socket.htons(totallen)
346         return buf[:totallen], buf[totallen:]
347
348 def etherStrip(buf):
349     if len(buf) < 14:
350         return ""
351     if buf[12:14] == '\x08\x10' and buf[16:18] in '\x08\x00':
352         # tagged ethernet frame
353         return buf[18:]
354     elif buf[12:14] == '\x08\x00':
355         # untagged ethernet frame
356         return buf[14:]
357     else:
358         return ""
359
360 def etherWrap(packet):
361     return (
362         "\x00"*6*2 # bogus src and dst mac
363         +"\x08\x00" # IPv4
364         +packet # payload
365         +"\x00"*4 # bogus crc
366     )
367
368 def piStrip(buf):
369     if len(buf) < 4:
370         return buf
371     else:
372         return buf[4:]
373     
374 def piWrap(buf, ether_mode):
375     if ether_mode:
376         proto = etherProto(buf)
377     else:
378         proto = "\x08\x00"
379     return (
380         "\x00\x00" # PI: 16 bits flags
381         +proto # 16 bits proto
382         +buf
383     )
384
385 def encrypt(packet, crypter):
386     # pad
387     padding = crypter.block_size - len(packet) % crypter.block_size
388     packet += chr(padding) * padding
389     
390     # encrypt
391     return crypter.encrypt(packet)
392
393 def decrypt(packet, crypter):
394     # decrypt
395     packet = crypter.decrypt(packet)
396     
397     # un-pad
398     padding = ord(packet[-1])
399     if not (0 < padding <= crypter.block_size):
400         # wrong padding
401         raise RuntimeError, "Truncated packet"
402     packet = packet[:-padding]
403     
404     return packet
405
406 abortme = False
407 def tun_fwd(tun, remote):
408     global abortme
409     
410     # in PL mode, we cannot strip PI structs
411     # so we'll have to handle them
412     with_pi = options.mode.startswith('pl-')
413     ether_mode = tun_name.startswith('tap')
414     
415     crypto_mode = False
416     try:
417         if options.cipher_key:
418             import Crypto.Cipher.AES
419             import hashlib
420             
421             hashed_key = hashlib.sha256(options.cipher_key).digest()
422             crypter = Crypto.Cipher.AES.new(
423                 hashed_key, 
424                 Crypto.Cipher.AES.MODE_ECB)
425             crypto_mode = True
426     except:
427         import traceback
428         traceback.print_exc()
429         crypto_mode = False
430         crypter = None
431
432     if crypto_mode:
433         print >>sys.stderr, "Packets are transmitted in CIPHER"
434     else:
435         print >>sys.stderr, "Packets are transmitted in PLAINTEXT"
436     
437     # Limited frame parsing, to preserve packet boundaries.
438     # Which is needed, since /dev/net/tun is unbuffered
439     fwbuf = ""
440     bkbuf = ""
441     while not abortme:
442         wset = []
443         if packetReady(bkbuf, ether_mode):
444             wset.append(tun)
445         if packetReady(fwbuf, ether_mode):
446             wset.append(remote)
447         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
448         
449         # check for errors
450         if errs:
451             break
452         
453         # check to see if we can write
454         if remote in wrdy and packetReady(fwbuf, ether_mode):
455             packet, fwbuf = pullPacket(fwbuf, ether_mode)
456             try:
457                 if crypto_mode:
458                     enpacket = encrypt(packet, crypter)
459                 else:
460                     enpacket = packet
461                 os.write(remote.fileno(), enpacket)
462             except:
463                 if not options.udp:
464                     # in UDP mode, we ignore errors - packet loss man...
465                     raise
466             print >>sys.stderr, '>', formatPacket(packet, ether_mode)
467         if tun in wrdy and packetReady(bkbuf, ether_mode):
468             packet, bkbuf = pullPacket(bkbuf, ether_mode)
469             formatted = formatPacket(packet, ether_mode)
470             if with_pi:
471                 packet = piWrap(packet, ether_mode)
472             os.write(tun.fileno(), packet)
473             print >>sys.stderr, '<', formatted
474         
475         # check incoming data packets
476         if tun in rdrdy:
477             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
478             if with_pi:
479                 packet = piStrip(packet)
480             fwbuf += packet
481         if remote in rdrdy:
482             try:
483                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
484                 if crypto_mode:
485                     packet = decrypt(packet, crypter)
486             except:
487                 if not options.udp:
488                     # in UDP mode, we ignore errors - packet loss man...
489                     raise
490             bkbuf += packet
491
492
493
494 nop = lambda tun_path, tun_name : (tun_path, tun_name)
495 MODEINFO = {
496     'none' : dict(alloc=nop,
497                   tunopen=tunopen, tunclose=tunclose,
498                   dealloc=nop,
499                   start=nop,
500                   stop=nop),
501     'tun'  : dict(alloc=functools.partial(tuntap_alloc, "tun"),
502                   tunopen=tunopen, tunclose=tunclose,
503                   dealloc=tuntap_dealloc,
504                   start=vif_start,
505                   stop=vif_stop),
506     'tap'  : dict(alloc=functools.partial(tuntap_alloc, "tap"),
507                   tunopen=tunopen, tunclose=tunclose,
508                   dealloc=tuntap_dealloc,
509                   start=vif_start,
510                   stop=vif_stop),
511     'pl-tun'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tun"),
512                   tunopen=tunopen, tunclose=tunclose,
513                   dealloc=nop,
514                   start=pl_vif_start,
515                   stop=nop),
516     'pl-tap'  : dict(alloc=functools.partial(pl_tuntap_alloc, "tap"),
517                   tunopen=tunopen, tunclose=tunclose,
518                   dealloc=nop,
519                   start=pl_vif_start,
520                   stop=nop),
521 }
522     
523 tun_path = options.tun_path
524 tun_name = options.tun_name
525
526 modeinfo = MODEINFO[options.mode]
527
528 # be careful to roll back stuff on exceptions
529 tun_path, tun_name = modeinfo['alloc'](tun_path, tun_name)
530 try:
531     modeinfo['start'](tun_path, tun_name)
532     try:
533         tun = modeinfo['tunopen'](tun_path, tun_name)
534     except:
535         modeinfo['stop'](tun_path, tun_name)
536         raise
537 except:
538     modeinfo['dealloc'](tun_path, tun_name)
539     raise
540
541
542 try:
543     if options.pass_fd:
544         if options.pass_fd.startswith("base64:"):
545             options.pass_fd = base64.b64decode(
546                 options.pass_fd[len("base64:"):])
547             options.pass_fd = os.path.expandvars(options.pass_fd)
548         
549         print >>sys.stderr, "Sending FD to: %r" % (options.pass_fd,)
550         
551         # send FD to whoever wants it
552         import passfd
553         
554         sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
555         try:
556             sock.connect(options.pass_fd)
557         except socket.error:
558             # wait a while, retry
559             print >>sys.stderr, "Could not connect. Retrying in a sec..."
560             time.sleep(1)
561             sock.connect(options.pass_fd)
562         passfd.sendfd(sock, tun.fileno(), '0')
563         
564         # just wait forever
565         def tun_fwd(tun, remote):
566             while True:
567                 time.sleep(1)
568         remote = None
569         
570         import time
571         print >>sys.stderr, time.time()
572         
573     elif options.udp:
574         # connect to remote endpoint
575         if remaining_args and not remaining_args[0].startswith('-'):
576             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.udp)
577             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
578             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
579             rsock.bind((hostaddr,options.udp))
580             rsock.connect((remaining_args[0],options.port))
581         else:
582             print >>sys.stderr, "Error: need a remote endpoint in UDP mode"
583             raise AssertionError, "Error: need a remote endpoint in UDP mode"
584         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
585     else:
586         # connect to remote endpoint
587         if remaining_args and not remaining_args[0].startswith('-'):
588             print >>sys.stderr, "Connecting to: %s:%d" % (remaining_args[0],options.port)
589             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
590             rsock.connect((remaining_args[0],options.port))
591         else:
592             print >>sys.stderr, "Listening at: %s:%d" % (hostaddr,options.port)
593             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
594             lsock.bind((hostaddr,options.port))
595             lsock.listen(1)
596             rsock,raddr = lsock.accept()
597         remote = os.fdopen(rsock.fileno(), 'r+b', 0)
598
599     print >>sys.stderr, "Connected"
600
601     tun_fwd(tun, remote)
602 finally:
603     try:
604         print >>sys.stderr, "Shutting down..."
605     except:
606         # In case sys.stderr is broken
607         pass
608     
609     # tidy shutdown in every case - swallow exceptions
610     try:
611         modeinfo['tunclose'](tun_path, tun_name, tun)
612     except:
613         pass
614         
615     try:
616         modeinfo['stop'](tun_path, tun_name)
617     except:
618         pass
619
620     try:
621         modeinfo['dealloc'](tun_path, tun_name)
622     except:
623         pass
624
625