Merge with HEAD, close aly's branch.
[nepi.git] / src / nepi / util / tunchannel.py
1 import select
2 import sys
3 import os
4 import struct
5 import socket
6 import threading
7 import traceback
8 import errno
9 import fcntl
10 import random
11 import traceback
12 import functools
13 import collections
14 import ctypes
15 import time
16
17 def ipfmt(ip):
18     ipbytes = map(ord,ip.decode("hex"))
19     return '.'.join(map(str,ipbytes))
20
21 tagtype = {
22     '0806' : 'arp',
23     '0800' : 'ipv4',
24     '8870' : 'jumbo',
25     '8863' : 'PPPoE discover',
26     '8864' : 'PPPoE',
27     '86dd' : 'ipv6',
28 }
29 def etherProto(packet, len=len):
30     if len(packet) > 14:
31         if packet[12] == "\x81" and packet[13] == "\x00":
32             # tagged
33             return packet[16:18]
34         else:
35             # untagged
36             return packet[12:14]
37     # default: ip
38     return "\x08\x00"
39 def formatPacket(packet, ether_mode):
40     if ether_mode:
41         stripped_packet = etherStrip(packet)
42         if not stripped_packet:
43             packet = packet.encode("hex")
44             if len(packet) < 28:
45                 return "malformed eth " + packet.encode("hex")
46             else:
47                 if packet[24:28] == "8100":
48                     # tagged
49                     ethertype = tagtype.get(packet[32:36], 'eth')
50                     return ethertype + " " + ( '-'.join( (
51                         packet[0:12], # MAC dest
52                         packet[12:24], # MAC src
53                         packet[24:32], # VLAN tag
54                         packet[32:36], # Ethertype/len
55                         packet[36:], # Payload
56                     ) ) )
57                 else:
58                     # untagged
59                     ethertype = tagtype.get(packet[24:28], 'eth')
60                     return ethertype + " " + ( '-'.join( (
61                         packet[0:12], # MAC dest
62                         packet[12:24], # MAC src
63                         packet[24:28], # Ethertype/len
64                         packet[28:], # Payload
65                     ) ) )
66         else:
67             packet = stripped_packet
68     packet = packet.encode("hex")
69     if len(packet) < 48:
70         return "malformed ip " + packet
71     else:
72         return "ip " + ( '-'.join( (
73             packet[0:1], #version
74             packet[1:2], #header length
75             packet[2:4], #diffserv/ECN
76             packet[4:8], #total length
77             packet[8:12], #ident
78             packet[12:16], #flags/fragment offs
79             packet[16:18], #ttl
80             packet[18:20], #ip-proto
81             packet[20:24], #checksum
82             ipfmt(packet[24:32]), # src-ip
83             ipfmt(packet[32:40]), # dst-ip
84             packet[40:48] if (int(packet[1],16) > 5) else "", # options
85             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
86         ) ) )
87
88 def _packetReady(buf, ether_mode=False, len=len):
89     if not buf:
90         return False
91         
92     rv = False
93     while not rv:
94         if len(buf[0]) < 4:
95             rv = False
96         elif ether_mode:
97             rv = True
98         else:
99             _,totallen = struct.unpack('HH',buf[0][:4])
100             totallen = socket.htons(totallen)
101             rv = len(buf[0]) >= totallen
102         if not rv and len(buf) > 1:
103             nbuf = ''.join(buf)
104             buf.clear()
105             buf.append(nbuf)
106         else:
107             return rv
108     return rv
109
110 def _pullPacket(buf, ether_mode=False, len=len):
111     if ether_mode:
112         return buf.popleft()
113     else:
114         _,totallen = struct.unpack('HH',buf[0][:4])
115         totallen = socket.htons(totallen)
116         if len(buf[0]) < totallen:
117             rv = buf[0][:totallen]
118             buf[0] = buf[0][totallen:]
119         else:
120             rv = buf.popleft()
121         return rv
122
123 def etherStrip(buf, buffer=buffer, len=len):
124     if len(buf) < 14:
125         return ""
126     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
127         # tagged ethernet frame
128         return buffer(buf, 18)
129     elif buf[12:14] == '\x08\x00':
130         # untagged ethernet frame
131         return buffer(buf, 14)
132     else:
133         return ""
134
135 def etherWrap(packet):
136     return ''.join((
137         "\x00"*6*2 # bogus src and dst mac
138         +"\x08\x00", # IPv4
139         packet, # payload
140         "\x00"*4, # bogus crc
141     ))
142
143 def piStrip(buf, len=len):
144     if len(buf) < 4:
145         return buf
146     else:
147         return buffer(buf,4)
148     
149 def piWrap(buf, ether_mode, etherProto=etherProto):
150     if ether_mode:
151         proto = etherProto(buf)
152     else:
153         proto = "\x08\x00"
154     return ''.join((
155         "\x00\x00", # PI: 16 bits flags
156         proto, # 16 bits proto
157         buf,
158     ))
159
160 _padmap = [ chr(padding) * padding for padding in xrange(127) ]
161 del padding
162
163 def encrypt(packet, crypter, len=len, padmap=_padmap):
164     # pad
165     padding = crypter.block_size - len(packet) % crypter.block_size
166     packet += padmap[padding]
167     
168     # encrypt
169     return crypter.encrypt(packet)
170
171 def decrypt(packet, crypter, ord=ord):
172     if packet:
173         # decrypt
174         packet = crypter.decrypt(packet)
175         
176         # un-pad
177         padding = ord(packet[-1])
178         if not (0 < padding <= crypter.block_size):
179             # wrong padding
180             raise RuntimeError, "Truncated packet"
181         packet = packet[:-padding]
182     
183     return packet
184
185 def nonblock(fd):
186     try:
187         fl = fcntl.fcntl(fd, fcntl.F_GETFL)
188         fl |= os.O_NONBLOCK
189         fcntl.fcntl(fd, fcntl.F_SETFL, fl)
190         return True
191     except:
192         traceback.print_exc(file=sys.stderr)
193         # Just ignore
194         return False
195
196 def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
197         cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None,
198         len=len, max=max, min=min, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
199         retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
200     crypto_mode = False
201     crypter = None
202
203     try:
204         if cipher_key and cipher:
205             import Crypto.Cipher
206             import hashlib
207             __import__('Crypto.Cipher.'+cipher)
208             
209             ciphername = cipher
210             cipher = getattr(Crypto.Cipher, cipher)
211             hashed_key = hashlib.sha256(cipher_key).digest()
212             if getattr(cipher, 'key_size'):
213                 hashed_key = hashed_key[:cipher.key_size]
214             elif ciphername == 'DES3':
215                 hashed_key = hashed_key[:24]
216             crypter = cipher.new(
217                 hashed_key, 
218                 cipher.MODE_ECB)
219             crypto_mode = True
220     except:
221         traceback.print_exc(file=sys.stderr)
222         crypto_mode = False
223         crypter = None
224
225     if stderr is not None:
226         if crypto_mode:
227             print >>stderr, "Packets are transmitted in CIPHER"
228         else:
229             print >>stderr, "Packets are transmitted in PLAINTEXT"
230     
231     if hasattr(remote, 'fileno'):
232         remote_fd = remote.fileno()
233         if rwrite is None:
234             def rwrite(remote, packet, os_write=os.write):
235                 return os_write(remote_fd, packet)
236         if rread is None:
237             def rread(remote, maxlen, os_read=os.read):
238                 return os_read(remote_fd, maxlen)
239  
240     rnonblock = nonblock(remote)
241     tnonblock = nonblock(tun)
242     
243     # Pick up TUN/TAP writing method
244     if with_pi:
245         try:
246             import iovec
247             
248             # We have iovec, so we can skip PI injection
249             # and use iovec which does it natively
250             if ether_mode:
251                 twrite = iovec.ethpiwrite
252                 tread = iovec.piread2
253             else:
254                 twrite = iovec.ippiwrite
255                 tread = iovec.piread2
256         except ImportError:
257             # We have to inject PI headers pythonically
258             def twrite(fd, packet, oswrite=os.write, piWrap=piWrap, ether_mode=ether_mode):
259                 return oswrite(fd, piWrap(packet, ether_mode))
260             
261             # For reading, we strip PI headers with buffer slicing and that's it
262             def tread(fd, maxlen, osread=os.read, piStrip=piStrip):
263                 return piStrip(osread(fd, maxlen))
264     else:
265         # No need to inject PI headers
266         twrite = os.write
267         tread = os.read
268     
269     encrypt_ = encrypt
270     decrypt_ = decrypt
271     xrange_ = xrange
272
273     if accept_local is not None:
274         def tread(fd, maxlen, _tread=tread, accept=accept_local):
275             packet = _tread(fd, maxlen)
276             if accept(packet, 0):
277                 return packet
278             else:
279                 return None
280
281     if accept_remote is not None:
282         if crypto_mode:
283             def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
284                 packet = decrypt_(packet, crypter)
285                 if accept(packet, 1):
286                     return packet
287                 else:
288                     return None
289         else:
290             def rread(fd, maxlen, _rread=rread, accept=accept_remote):
291                 packet = _rread(fd, maxlen)
292                 if accept(packet, 1):
293                     return packet
294                 else:
295                     return None
296     
297     maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
298     tunhurry = max(0,maxbkbuf/2)
299     
300     if queueclass is None:
301         queueclass = collections.deque
302         maxbatch = 2000
303         maxtbatch = 50
304     else:
305         maxfwbuf = maxbkbuf = 2000000000
306         maxbatch = 50
307         maxtbatch = 30
308         tunhurry = 30
309     
310     fwbuf = queueclass()
311     bkbuf = queueclass()
312     nfwbuf = 0
313     nbkbuf = 0
314     if ether_mode or udp:
315         packetReady = bool
316         pullPacket = queueclass.popleft
317         reschedule = queueclass.appendleft
318     else:
319         packetReady = _packetReady
320         pullPacket = _pullPacket
321         reschedule = queueclass.appendleft
322     tunfd = tun.fileno()
323     os_read = os.read
324     os_write = os.write
325     
326     tget = time.time
327     maxbwfree = bwfree = 1500 * tunqueue
328     lastbwtime = tget()
329     
330     remoteok = True
331     
332     
333     while not TERMINATE:
334         wset = []
335         if packetReady(bkbuf):
336             wset.append(tun)
337         if remoteok and packetReady(fwbuf) and (not bwlimit or bwfree > 0):
338             wset.append(remote)
339         
340         rset = []
341         if len(fwbuf) < maxfwbuf:
342             rset.append(tun)
343         if remoteok and len(bkbuf) < maxbkbuf:
344             rset.append(remote)
345         
346         if remoteok:
347             eset = (tun,remote)
348         else:
349             eset = (tun,)
350         
351         try:
352             rdrdy, wrdy, errs = select(rset,wset,eset,1)
353         except selecterror, e:
354             if e.args[0] == errno.EINTR:
355                 # just retry
356                 continue
357             else:
358                 traceback.print_exc(file=sys.stderr)
359                 raise
360
361         # check for errors
362         if errs:
363             if reconnect is not None and remote in errs and tun not in errs:
364                 remote = reconnect()
365                 if hasattr(remote, 'fileno'):
366                     remote_fd = remote.fileno()
367             elif udp and remote in errs and tun not in errs:
368                 # In UDP mode, those are always transient errors
369                 # Usually, an error will imply a read-ready socket
370                 # that will raise an "Connection refused" error, so
371                 # disable read-readiness just for now, and retry
372                 # the select
373                 remoteok = False
374                 continue
375             else:
376                 break
377         else:
378             remoteok = True
379         
380         # check to see if we can write
381         #rr = wr = rt = wt = 0
382         if remote in wrdy:
383             sent = 0
384             try:
385                 try:
386                     for x in xrange(maxbatch):
387                         packet = pullPacket(fwbuf)
388
389                         if crypto_mode:
390                             packet = encrypt_(packet, crypter)
391                         
392                         sent += rwrite(remote, packet)
393                         #wr += 1
394                         
395                         if not rnonblock or not packetReady(fwbuf):
396                             break
397                 except OSError,e:
398                     # This except handles the entire While block on PURPOSE
399                     # as an optimization (setting a try/except block is expensive)
400                     # The only operation that can raise this exception is rwrite
401                     if e.errno in retrycodes:
402                         # re-schedule packet
403                         reschedule(fwbuf, packet)
404                     else:
405                         raise
406             except:
407                 if reconnect is not None:
408                     # in UDP mode, sometimes connected sockets can return a connection refused.
409                     # Give the caller a chance to reconnect
410                     remote = reconnect()
411                     if hasattr(remote, 'fileno'):
412                         remote_fd = remote.fileno()
413                 elif not udp:
414                     # in UDP mode, we ignore errors - packet loss man...
415                     raise
416                 #traceback.print_exc(file=sys.stderr)
417             
418             if bwlimit:
419                 bwfree -= sent
420         if tun in wrdy:
421             try:
422                 for x in xrange(maxtbatch):
423                     packet = pullPacket(bkbuf)
424                     twrite(tunfd, packet)
425                     #wt += 1
426                     
427                     # Do not inject packets into the TUN faster than they arrive, unless we're falling
428                     # behind. TUN devices discard packets if their queue is full (tunkqueue), but they
429                     # don't block either (they're always ready to write), so if we flood the device 
430                     # we'll have high packet loss.
431                     if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf):
432                         break
433                 else:
434                     if slowlocal:
435                         # Give some time for the kernel to process the packets
436                         time.sleep(0)
437             except OSError,e:
438                 # This except handles the entire While block on PURPOSE
439                 # as an optimization (setting a try/except block is expensive)
440                 # The only operation that can raise this exception is os_write
441                 if e.errno in retrycodes:
442                     # re-schedule packet
443                     reschedule(bkbuf, packet)
444                 else:
445                     raise
446         
447         # check incoming data packets
448         if tun in rdrdy:
449             try:
450                 for x in xrange(maxbatch):
451                     packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
452                     if not packet:
453                         continue
454                     #rt += 1
455                     fwbuf.append(packet)
456                     
457                     if not tnonblock or len(fwbuf) >= maxfwbuf:
458                         break
459             except OSError,e:
460                 # This except handles the entire While block on PURPOSE
461                 # as an optimization (setting a try/except block is expensive)
462                 # The only operation that can raise this exception is os_read
463                 if e.errno not in retrycodes:
464                     raise
465         if remote in rdrdy:
466             try:
467                 try:
468                     for x in xrange(maxbatch):
469                         packet = rread(remote,2000)
470                         
471                         #rr += 1
472                         
473                         if crypto_mode:
474                             packet = decrypt_(packet, crypter)
475                             if not packet:
476                                 continue
477                         elif not packet:
478                             if not udp and packet == "":
479                                 # Connection broken, try to reconnect (or just die)
480                                 raise RuntimeError, "Connection broken"
481                             else:
482                                 continue
483
484                         bkbuf.append(packet)
485                         
486                         if not rnonblock or len(bkbuf) >= maxbkbuf:
487                             break
488                 except OSError,e:
489                     # This except handles the entire While block on PURPOSE
490                     # as an optimization (setting a try/except block is expensive)
491                     # The only operation that can raise this exception is rread
492                     if e.errno not in retrycodes:
493                         raise
494             except Exception, e:
495                 if reconnect is not None:
496                     # in UDP mode, sometimes connected sockets can return a connection refused
497                     # on read. Give the caller a chance to reconnect
498                     remote = reconnect()
499                     if hasattr(remote, 'fileno'):
500                         remote_fd = remote.fileno()
501                 elif not udp:
502                     # in UDP mode, we ignore errors - packet loss man...
503                     raise
504                 traceback.print_exc(file=sys.stderr)
505
506         if bwlimit:
507             tnow = tget()
508             delta = tnow - lastbwtime
509             if delta > 0.001:
510                 delta = int(bwlimit * delta)
511                 if delta > 0:
512                     bwfree = min(bwfree+delta, maxbwfree)
513                     lastbwtime = tnow
514         
515         #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
516
517 def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port):
518     rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
519     retrydelay = 1.0
520     for i in xrange(30):
521         # TERMINATE is a array. An item can be added to TERMINATE, from
522         # outside this function to force termination of the loop
523         if TERMINATE:
524             raise OSError, "Killed"
525         try:
526             rsock.bind((local_addr, local_port))
527             break
528         except socket.error:
529             # wait a while, retry
530             print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
531             time.sleep(min(30.0,retrydelay))
532             retrydelay *= 1.1
533     else:
534         rsock.bind((local_addr, local_port))
535     print >>sys.stderr, "Listening UDP at: %s:%d" % (local_addr, local_port)
536     print >>sys.stderr, "Connecting UDP to: %s:%d" % (peer_addr, peer_port)
537     rsock.connect((peer_addr, peer_port))
538     return rsock
539
540 def udp_handshake(TERMINATE, rsock):
541     endme = False
542     def keepalive():
543         while not endme and not TERMINATE:
544             try:
545                 rsock.send('')
546             except:
547                 pass
548             time.sleep(1)
549         try:
550             rsock.send('')
551         except:
552             pass
553     keepalive_thread = threading.Thread(target=keepalive)
554     keepalive_thread.start()
555     retrydelay = 1.0
556     for i in xrange(30):
557         if TERMINATE:
558             raise OSError, "Killed"
559         try:
560             heartbeat = rsock.recv(10)
561             break
562         except:
563             time.sleep(min(30.0,retrydelay))
564             retrydelay *= 1.1
565     else:
566         heartbeat = rsock.recv(10)
567     endme = True
568     keepalive_thread.join()
569
570 def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
571     rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr,
572             peer_port)
573     udp_handshake(TERMINATE, rsock)
574     return rsock 
575
576 def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
577     sock = None
578     retrydelay = 1.0
579     # The peer has a firewall that prevents a response to the connect, we 
580     # will be forever blocked in the connect, so we put a reasonable timeout.
581     rsock.settimeout(10) 
582     # We wait for 
583     for i in xrange(30):
584         if stop:
585             break
586         if TERMINATE:
587             raise OSError, "Killed"
588         try:
589             rsock.connect((peer_addr, peer_port))
590             sock = rsock
591             break
592         except socket.error:
593             # wait a while, retry
594             print >>sys.stderr, "%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),)
595             time.sleep(min(30.0,retrydelay))
596             retrydelay *= 1.1
597     else:
598         rsock.connect((peer_addr, peer_port))
599         sock = rsock
600     if sock:
601         print >>sys.stderr, "tcp_connect: TCP sock connected to remote %s:%s" % (peer_addr, peer_port)
602         sock.settimeout(0) 
603     return sock
604
605 def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
606     sock = None
607     retrydelay = 1.0
608     # We try to bind to the local virtual interface. 
609     # It might not exist yet so we wait in a loop.
610     for i in xrange(30):
611         if stop:
612             break
613         if TERMINATE:
614             raise OSError, "Killed"
615         try:
616             lsock.bind((local_addr, local_port))
617             break
618         except socket.error:
619             # wait a while, retry
620             print >>sys.stderr, "%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),)
621             time.sleep(min(30.0,retrydelay))
622             retrydelay *= 1.1
623     else:
624         lsock.bind((local_addr, local_port))
625
626     print >>sys.stderr, "tcp_listen: TCP sock listening in local sock %s:%s" % (local_addr, local_port)
627     # Now we wait until the other side connects. 
628     # The other side might not be ready yet, so we also wait in a loop for timeouts.
629     timeout = 1
630     lsock.listen(1)
631     for i in xrange(30):
632         if TERMINATE:
633             raise OSError, "Killed"
634         rlist, wlist, xlist = select.select([lsock], [], [], timeout)
635         if stop:
636             break
637         if lsock in rlist:
638             sock,raddr = lsock.accept()
639             print >>sys.stderr, "tcp_listen: TCP connection accepted in local sock %s:%s" % (local_addr, local_port)
640             break
641         timeout += 5
642     return sock
643
644 def tcp_handshake(rsock, listen, hand):
645     # we are going to use a barrier algorithm to decide wich side listen.
646     # each side will "roll a dice" and send the resulting value to the other 
647     # side. 
648     win = False
649     rsock.settimeout(10)
650     try:
651         rsock.send(hand)
652         peer_hand = rsock.recv(1)
653         print >>sys.stderr, "tcp_handshake: hand %s, peer_hand %s" % (hand, peer_hand)
654         if hand < peer_hand:
655             if listen:
656                 win = True
657         elif hand > peer_hand:
658             if not listen:
659                 win = True
660     except socket.timeout:
661         pass
662     rsock.settimeout(0)
663     return win
664
665 def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
666     def listen(stop, hand, lsock, lresult):
667         win = False
668         rsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
669         if rsock:
670             win = tcp_handshake(rsock, True, hand)
671             stop.append(True)
672         lresult.append((win, rsock))
673
674     def connect(stop, hand, rsock, rresult):
675         win = False
676         rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
677         if rsock:
678             win = tcp_handshake(rsock, False, hand)
679             stop.append(True)
680         rresult.append((win, rsock))
681   
682     end = False
683     sock = None
684     for i in xrange(0, 50):
685         if end:
686             break
687         if TERMINATE:
688             raise OSError, "Killed"
689         hand = str(random.randint(1, 6))
690         stop = []
691         lresult = []
692         rresult = []
693         lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
694         rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
695         listen_thread = threading.Thread(target=listen, args=(stop, hand, lsock, lresult))
696         connect_thread = threading.Thread(target=connect, args=(stop, hand, rsock, rresult))
697         connect_thread.start()
698         listen_thread.start()
699         connect_thread.join()
700         listen_thread.join()
701         (lwin, lrsock) = lresult[0]
702         (rwin, rrsock) = rresult[0]
703         if not lrsock or not rrsock:
704             if not lrsock:
705                 sock = rrsock
706             if not rrsock:
707                 sock = lrsock
708             end = True
709         # both socket are connected
710         else:
711            if lwin:
712                 sock = lrsock
713                 end = True
714            elif rwin: 
715                 sock = rrsock
716                 end = True
717
718     if not sock:
719         raise OSError, "Error: tcp_establish could not establish connection."
720     return sock
721
722