Enhanced multicast support: generate IGMP messages for join/leave
[nepi.git] / src / nepi / util / tunchannel.py
1 import select
2 import sys
3 import os
4 import struct
5 import socket
6 import threading
7 import errno
8 import fcntl
9 import traceback
10 import functools
11 import collections
12 import ctypes
13 import time
14
15 def ipfmt(ip):
16     ipbytes = map(ord,ip.decode("hex"))
17     return '.'.join(map(str,ipbytes))
18
19 tagtype = {
20     '0806' : 'arp',
21     '0800' : 'ipv4',
22     '8870' : 'jumbo',
23     '8863' : 'PPPoE discover',
24     '8864' : 'PPPoE',
25     '86dd' : 'ipv6',
26 }
27 def etherProto(packet, len=len):
28     if len(packet) > 14:
29         if packet[12] == "\x81" and packet[13] == "\x00":
30             # tagged
31             return packet[16:18]
32         else:
33             # untagged
34             return packet[12:14]
35     # default: ip
36     return "\x08\x00"
37 def formatPacket(packet, ether_mode):
38     if ether_mode:
39         stripped_packet = etherStrip(packet)
40         if not stripped_packet:
41             packet = packet.encode("hex")
42             if len(packet) < 28:
43                 return "malformed eth " + packet.encode("hex")
44             else:
45                 if packet[24:28] == "8100":
46                     # tagged
47                     ethertype = tagtype.get(packet[32:36], 'eth')
48                     return ethertype + " " + ( '-'.join( (
49                         packet[0:12], # MAC dest
50                         packet[12:24], # MAC src
51                         packet[24:32], # VLAN tag
52                         packet[32:36], # Ethertype/len
53                         packet[36:], # Payload
54                     ) ) )
55                 else:
56                     # untagged
57                     ethertype = tagtype.get(packet[24:28], 'eth')
58                     return ethertype + " " + ( '-'.join( (
59                         packet[0:12], # MAC dest
60                         packet[12:24], # MAC src
61                         packet[24:28], # Ethertype/len
62                         packet[28:], # Payload
63                     ) ) )
64         else:
65             packet = stripped_packet
66     packet = packet.encode("hex")
67     if len(packet) < 48:
68         return "malformed ip " + packet
69     else:
70         return "ip " + ( '-'.join( (
71             packet[0:1], #version
72             packet[1:2], #header length
73             packet[2:4], #diffserv/ECN
74             packet[4:8], #total length
75             packet[8:12], #ident
76             packet[12:16], #flags/fragment offs
77             packet[16:18], #ttl
78             packet[18:20], #ip-proto
79             packet[20:24], #checksum
80             ipfmt(packet[24:32]), # src-ip
81             ipfmt(packet[32:40]), # dst-ip
82             packet[40:48] if (int(packet[1],16) > 5) else "", # options
83             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
84         ) ) )
85
86 def _packetReady(buf, ether_mode=False, len=len):
87     if not buf:
88         return False
89         
90     rv = False
91     while not rv:
92         if len(buf[0]) < 4:
93             rv = False
94         elif ether_mode:
95             rv = True
96         else:
97             _,totallen = struct.unpack('HH',buf[0][:4])
98             totallen = socket.htons(totallen)
99             rv = len(buf[0]) >= totallen
100         if not rv and len(buf) > 1:
101             nbuf = ''.join(buf)
102             buf.clear()
103             buf.append(nbuf)
104         else:
105             return rv
106     return rv
107
108 def _pullPacket(buf, ether_mode=False, len=len):
109     if ether_mode:
110         return buf.popleft()
111     else:
112         _,totallen = struct.unpack('HH',buf[0][:4])
113         totallen = socket.htons(totallen)
114         if len(buf[0]) < totallen:
115             rv = buf[0][:totallen]
116             buf[0] = buf[0][totallen:]
117         else:
118             rv = buf.popleft()
119         return rv
120
121 def etherStrip(buf):
122     if len(buf) < 14:
123         return ""
124     if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
125         # tagged ethernet frame
126         return buf[18:]
127     elif buf[12:14] == '\x08\x00':
128         # untagged ethernet frame
129         return buf[14:]
130     else:
131         return ""
132
133 def etherWrap(packet):
134     return ''.join((
135         "\x00"*6*2 # bogus src and dst mac
136         +"\x08\x00", # IPv4
137         packet, # payload
138         "\x00"*4, # bogus crc
139     ))
140
141 def piStrip(buf, len=len):
142     if len(buf) < 4:
143         return buf
144     else:
145         return buffer(buf,4)
146     
147 def piWrap(buf, ether_mode, etherProto=etherProto):
148     if ether_mode:
149         proto = etherProto(buf)
150     else:
151         proto = "\x08\x00"
152     return ''.join((
153         "\x00\x00", # PI: 16 bits flags
154         proto, # 16 bits proto
155         buf,
156     ))
157
158 _padmap = [ chr(padding) * padding for padding in xrange(127) ]
159 del padding
160
161 def encrypt(packet, crypter, len=len, padmap=_padmap):
162     # pad
163     padding = crypter.block_size - len(packet) % crypter.block_size
164     packet += padmap[padding]
165     
166     # encrypt
167     return crypter.encrypt(packet)
168
169 def decrypt(packet, crypter, ord=ord):
170     if packet:
171         # decrypt
172         packet = crypter.decrypt(packet)
173         
174         # un-pad
175         padding = ord(packet[-1])
176         if not (0 < padding <= crypter.block_size):
177             # wrong padding
178             raise RuntimeError, "Truncated packet"
179         packet = packet[:-padding]
180     
181     return packet
182
183 def nonblock(fd):
184     try:
185         fl = fcntl.fcntl(fd, fcntl.F_GETFL)
186         fl |= os.O_NONBLOCK
187         fcntl.fcntl(fd, fcntl.F_SETFL, fl)
188         return True
189     except:
190         traceback.print_exc(file=sys.stderr)
191         # Just ignore
192         return False
193
194 def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr, reconnect=None, rwrite=None, rread=None, tunqueue=1000, tunkqueue=1000,
195         cipher='AES', accept_local=None, accept_remote=None, slowlocal=True, queueclass=None, bwlimit=None,
196         len=len, max=max, min=min, OSError=OSError, select=select.select, selecterror=select.error, os=os, socket=socket,
197         retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
198     crypto_mode = False
199     crypter = None
200     try:
201         if cipher_key and cipher:
202             import Crypto.Cipher
203             import hashlib
204             __import__('Crypto.Cipher.'+cipher)
205             
206             ciphername = cipher
207             cipher = getattr(Crypto.Cipher, cipher)
208             hashed_key = hashlib.sha256(cipher_key).digest()
209             if getattr(cipher, 'key_size'):
210                 hashed_key = hashed_key[:cipher.key_size]
211             elif ciphername == 'DES3':
212                 hashed_key = hashed_key[:24]
213             crypter = cipher.new(
214                 hashed_key, 
215                 cipher.MODE_ECB)
216             crypto_mode = True
217     except:
218         traceback.print_exc(file=sys.stderr)
219         crypto_mode = False
220         crypter = None
221
222     if stderr is not None:
223         if crypto_mode:
224             print >>stderr, "Packets are transmitted in CIPHER"
225         else:
226             print >>stderr, "Packets are transmitted in PLAINTEXT"
227     
228     if hasattr(remote, 'fileno'):
229         remote_fd = remote.fileno()
230         if rwrite is None:
231             def rwrite(remote, packet, os_write=os.write):
232                 return os_write(remote_fd, packet)
233         if rread is None:
234             def rread(remote, maxlen, os_read=os.read):
235                 return os_read(remote_fd, maxlen)
236     
237     rnonblock = nonblock(remote)
238     tnonblock = nonblock(tun)
239     
240     # Pick up TUN/TAP writing method
241     if with_pi:
242         try:
243             import iovec
244             
245             # We have iovec, so we can skip PI injection
246             # and use iovec which does it natively
247             if ether_mode:
248                 twrite = iovec.ethpiwrite
249                 tread = iovec.piread2
250             else:
251                 twrite = iovec.ippiwrite
252                 tread = iovec.piread2
253         except ImportError:
254             # We have to inject PI headers pythonically
255             def twrite(fd, packet, oswrite=os.write, piWrap=piWrap, ether_mode=ether_mode):
256                 return oswrite(fd, piWrap(packet, ether_mode))
257             
258             # For reading, we strip PI headers with buffer slicing and that's it
259             def tread(fd, maxlen, osread=os.read, piStrip=piStrip):
260                 return piStrip(osread(fd, maxlen))
261     else:
262         # No need to inject PI headers
263         twrite = os.write
264         tread = os.read
265     
266     encrypt_ = encrypt
267     decrypt_ = decrypt
268     xrange_ = xrange
269
270     if accept_local is not None:
271         def tread(fd, maxlen, _tread=tread, accept=accept_local):
272             packet = _tread(fd, maxlen)
273             if accept(packet, 0):
274                 return packet
275             else:
276                 return None
277
278     if accept_remote is not None:
279         if crypto_mode:
280             def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
281                 packet = decrypt_(packet, crypter)
282                 if accept(packet, 1):
283                     return packet
284                 else:
285                     return None
286         else:
287             def rread(fd, maxlen, _rread=rread, accept=accept_remote):
288                 packet = _rread(fd, maxlen)
289                 if accept(packet, 1):
290                     return packet
291                 else:
292                     return None
293     
294     maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
295     tunhurry = max(0,maxbkbuf/2)
296     
297     if queueclass is None:
298         queueclass = collections.deque
299         maxbatch = 2000
300         maxtbatch = 50
301     else:
302         maxfwbuf = maxbkbuf = 2000000000
303         maxbatch = 50
304         maxtbatch = 30
305         tunhurry = 30
306     
307     fwbuf = queueclass()
308     bkbuf = queueclass()
309     nfwbuf = 0
310     nbkbuf = 0
311     if ether_mode or udp:
312         packetReady = bool
313         pullPacket = queueclass.popleft
314         reschedule = queueclass.appendleft
315     else:
316         packetReady = _packetReady
317         pullPacket = _pullPacket
318         reschedule = queueclass.appendleft
319     tunfd = tun.fileno()
320     os_read = os.read
321     os_write = os.write
322     
323     tget = time.time
324     maxbwfree = bwfree = 1500 * tunqueue
325     lastbwtime = tget()
326     
327     remoteok = True
328     
329     while not TERMINATE:
330         wset = []
331         if packetReady(bkbuf):
332             wset.append(tun)
333         if remoteok and packetReady(fwbuf) and (not bwlimit or bwfree > 0):
334             wset.append(remote)
335         
336         rset = []
337         if len(fwbuf) < maxfwbuf:
338             rset.append(tun)
339         if remoteok and len(bkbuf) < maxbkbuf:
340             rset.append(remote)
341         
342         if remoteok:
343             eset = (tun,remote)
344         else:
345             eset = (tun,)
346         
347         try:
348             rdrdy, wrdy, errs = select(rset,wset,eset,1)
349         except selecterror, e:
350             if e.args[0] == errno.EINTR:
351                 # just retry
352                 continue
353
354         # check for errors
355         if errs:
356             if reconnect is not None and remote in errs and tun not in errs:
357                 remote = reconnect()
358                 if hasattr(remote, 'fileno'):
359                     remote_fd = remote.fileno()
360             elif udp and remote in errs and tun not in errs:
361                 # In UDP mode, those are always transient errors
362                 # Usually, an error will imply a read-ready socket
363                 # that will raise an "Connection refused" error, so
364                 # disable read-readiness just for now, and retry
365                 # the select
366                 remoteok = False
367                 continue
368             else:
369                 break
370         else:
371             remoteok = True
372         
373         # check to see if we can write
374         #rr = wr = rt = wt = 0
375         if remote in wrdy:
376             sent = 0
377             try:
378                 try:
379                     for x in xrange(maxbatch):
380                         packet = pullPacket(fwbuf)
381
382                         if crypto_mode:
383                             packet = encrypt_(packet, crypter)
384                         
385                         sent += rwrite(remote, packet)
386                         #wr += 1
387                         
388                         if not rnonblock or not packetReady(fwbuf):
389                             break
390                 except OSError,e:
391                     # This except handles the entire While block on PURPOSE
392                     # as an optimization (setting a try/except block is expensive)
393                     # The only operation that can raise this exception is rwrite
394                     if e.errno in retrycodes:
395                         # re-schedule packet
396                         reschedule(fwbuf, packet)
397                     else:
398                         raise
399             except:
400                 if reconnect is not None:
401                     # in UDP mode, sometimes connected sockets can return a connection refused.
402                     # Give the caller a chance to reconnect
403                     remote = reconnect()
404                     if hasattr(remote, 'fileno'):
405                         remote_fd = remote.fileno()
406                 elif not udp:
407                     # in UDP mode, we ignore errors - packet loss man...
408                     raise
409                 #traceback.print_exc(file=sys.stderr)
410             
411             if bwlimit:
412                 bwfree -= sent
413         if tun in wrdy:
414             try:
415                 for x in xrange(maxtbatch):
416                     packet = pullPacket(bkbuf)
417                     twrite(tunfd, packet)
418                     #wt += 1
419                     
420                     # Do not inject packets into the TUN faster than they arrive, unless we're falling
421                     # behind. TUN devices discard packets if their queue is full (tunkqueue), but they
422                     # don't block either (they're always ready to write), so if we flood the device 
423                     # we'll have high packet loss.
424                     if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf):
425                         break
426                 else:
427                     if slowlocal:
428                         # Give some time for the kernel to process the packets
429                         time.sleep(0)
430             except OSError,e:
431                 # This except handles the entire While block on PURPOSE
432                 # as an optimization (setting a try/except block is expensive)
433                 # The only operation that can raise this exception is os_write
434                 if e.errno in retrycodes:
435                     # re-schedule packet
436                     reschedule(bkbuf, packet)
437                 else:
438                     raise
439         
440         # check incoming data packets
441         if tun in rdrdy:
442             try:
443                 for x in xrange(maxbatch):
444                     packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
445                     if not packet:
446                         continue
447                     #rt += 1
448                     fwbuf.append(packet)
449                     
450                     if not tnonblock or len(fwbuf) >= maxfwbuf:
451                         break
452             except OSError,e:
453                 # This except handles the entire While block on PURPOSE
454                 # as an optimization (setting a try/except block is expensive)
455                 # The only operation that can raise this exception is os_read
456                 if e.errno not in retrycodes:
457                     raise
458         if remote in rdrdy:
459             try:
460                 try:
461                     for x in xrange(maxbatch):
462                         packet = rread(remote,2000)
463                         #rr += 1
464                         
465                         if crypto_mode:
466                             packet = decrypt_(packet, crypter)
467                             if not packet:
468                                 continue
469                         elif not packet:
470                             if not udp and packet == "":
471                                 # Connection broken, try to reconnect (or just die)
472                                 raise RuntimeError, "Connection broken"
473                             else:
474                                 continue
475
476                         bkbuf.append(packet)
477                         
478                         if not rnonblock or len(bkbuf) >= maxbkbuf:
479                             break
480                 except OSError,e:
481                     # This except handles the entire While block on PURPOSE
482                     # as an optimization (setting a try/except block is expensive)
483                     # The only operation that can raise this exception is rread
484                     if e.errno not in retrycodes:
485                         raise
486             except Exception, e:
487                 if reconnect is not None:
488                     # in UDP mode, sometimes connected sockets can return a connection refused
489                     # on read. Give the caller a chance to reconnect
490                     remote = reconnect()
491                     if hasattr(remote, 'fileno'):
492                         remote_fd = remote.fileno()
493                 elif not udp:
494                     # in UDP mode, we ignore errors - packet loss man...
495                     raise
496                 traceback.print_exc(file=sys.stderr)
497
498         if bwlimit:
499             tnow = tget()
500             delta = tnow - lastbwtime
501             if delta > 0.001:
502                 delta = int(bwlimit * delta)
503                 if delta > 0:
504                     bwfree = min(bwfree+delta, maxbwfree)
505                     lastbwtime = tnow
506         
507         #print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
508
509
510 def udp_handshake(TERMINATE, rsock):
511     endme = False
512     def keepalive():
513         while not endme and not TERMINATE:
514             try:
515                 rsock.send('')
516             except:
517                 pass
518             time.sleep(1)
519         try:
520             rsock.send('')
521         except:
522             pass
523     keepalive_thread = threading.Thread(target=keepalive)
524     keepalive_thread.start()
525     retrydelay = 1.0
526     for i in xrange(30):
527         if TERMINATE:
528             raise OSError, "Killed"
529         try:
530             heartbeat = rsock.recv(10)
531             break
532         except:
533             time.sleep(min(30.0,retrydelay))
534             retrydelay *= 1.1
535     else:
536         heartbeat = rsock.recv(10)
537     endme = True
538     keepalive_thread.join()
539