Refactor tunnelling code to put VPN channel protocol stuff in nepi.util (so that...
[nepi.git] / src / nepi / util / tunchannel.py
1 import select
2 import sys
3 import os
4 import struct
5 import socket
6
7 def ipfmt(ip):
8     ipbytes = map(ord,ip.decode("hex"))
9     return '.'.join(map(str,ipbytes))
10
11 tagtype = {
12     '0806' : 'arp',
13     '0800' : 'ipv4',
14     '8870' : 'jumbo',
15     '8863' : 'PPPoE discover',
16     '8864' : 'PPPoE',
17     '86dd' : 'ipv6',
18 }
19 def etherProto(packet):
20     packet = packet.encode("hex")
21     if len(packet) > 14:
22         if packet[12:14] == "\x81\x00":
23             # tagged
24             return packet[16:18]
25         else:
26             # untagged
27             return packet[12:14]
28     # default: ip
29     return "\x08\x00"
30 def formatPacket(packet, ether_mode):
31     if ether_mode:
32         stripped_packet = etherStrip(packet)
33         if not stripped_packet:
34             packet = packet.encode("hex")
35             if len(packet) < 28:
36                 return "malformed eth " + packet.encode("hex")
37             else:
38                 if packet[24:28] == "8100":
39                     # tagged
40                     ethertype = tagtype.get(packet[32:36], 'eth')
41                     return ethertype + " " + ( '-'.join( (
42                         packet[0:12], # MAC dest
43                         packet[12:24], # MAC src
44                         packet[24:32], # VLAN tag
45                         packet[32:36], # Ethertype/len
46                         packet[36:], # Payload
47                     ) ) )
48                 else:
49                     # untagged
50                     ethertype = tagtype.get(packet[24:28], 'eth')
51                     return ethertype + " " + ( '-'.join( (
52                         packet[0:12], # MAC dest
53                         packet[12:24], # MAC src
54                         packet[24:28], # Ethertype/len
55                         packet[28:], # Payload
56                     ) ) )
57         else:
58             packet = stripped_packet
59     packet = packet.encode("hex")
60     if len(packet) < 48:
61         return "malformed ip " + packet
62     else:
63         return "ip " + ( '-'.join( (
64             packet[0:1], #version
65             packet[1:2], #header length
66             packet[2:4], #diffserv/ECN
67             packet[4:8], #total length
68             packet[8:12], #ident
69             packet[12:16], #flags/fragment offs
70             packet[16:18], #ttl
71             packet[18:20], #ip-proto
72             packet[20:24], #checksum
73             ipfmt(packet[24:32]), # src-ip
74             ipfmt(packet[32:40]), # dst-ip
75             packet[40:48] if (int(packet[1],16) > 5) else "", # options
76             packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
77         ) ) )
78
79 def packetReady(buf, ether_mode):
80     if len(buf) < 4:
81         return False
82     elif ether_mode:
83         return True
84     else:
85         _,totallen = struct.unpack('HH',buf[:4])
86         totallen = socket.htons(totallen)
87         return len(buf) >= totallen
88
89 def pullPacket(buf, ether_mode):
90     if ether_mode:
91         return buf, ""
92     else:
93         _,totallen = struct.unpack('HH',buf[:4])
94         totallen = socket.htons(totallen)
95         return buf[:totallen], buf[totallen:]
96
97 def etherStrip(buf):
98     if len(buf) < 14:
99         return ""
100     if buf[12:14] == '\x08\x10' and buf[16:18] in '\x08\x00':
101         # tagged ethernet frame
102         return buf[18:]
103     elif buf[12:14] == '\x08\x00':
104         # untagged ethernet frame
105         return buf[14:]
106     else:
107         return ""
108
109 def etherWrap(packet):
110     return (
111         "\x00"*6*2 # bogus src and dst mac
112         +"\x08\x00" # IPv4
113         +packet # payload
114         +"\x00"*4 # bogus crc
115     )
116
117 def piStrip(buf):
118     if len(buf) < 4:
119         return buf
120     else:
121         return buf[4:]
122     
123 def piWrap(buf, ether_mode):
124     if ether_mode:
125         proto = etherProto(buf)
126     else:
127         proto = "\x08\x00"
128     return (
129         "\x00\x00" # PI: 16 bits flags
130         +proto # 16 bits proto
131         +buf
132     )
133
134 def encrypt(packet, crypter):
135     # pad
136     padding = crypter.block_size - len(packet) % crypter.block_size
137     packet += chr(padding) * padding
138     
139     # encrypt
140     return crypter.encrypt(packet)
141
142 def decrypt(packet, crypter):
143     # decrypt
144     packet = crypter.decrypt(packet)
145     
146     # un-pad
147     padding = ord(packet[-1])
148     if not (0 < padding <= crypter.block_size):
149         # wrong padding
150         raise RuntimeError, "Truncated packet"
151     packet = packet[:-padding]
152     
153     return packet
154
155
156 def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, stderr=sys.stderr):
157     crypto_mode = False
158     try:
159         if cipher_key:
160             import Crypto.Cipher.AES
161             import hashlib
162             
163             hashed_key = hashlib.sha256(cipher_key).digest()
164             crypter = Crypto.Cipher.AES.new(
165                 hashed_key, 
166                 Crypto.Cipher.AES.MODE_ECB)
167             crypto_mode = True
168     except:
169         import traceback
170         traceback.print_exc()
171         crypto_mode = False
172         crypter = None
173
174     if crypto_mode:
175         print >>stderr, "Packets are transmitted in CIPHER"
176     else:
177         print >>stderr, "Packets are transmitted in PLAINTEXT"
178     
179     # Limited frame parsing, to preserve packet boundaries.
180     # Which is needed, since /dev/net/tun is unbuffered
181     fwbuf = ""
182     bkbuf = ""
183     while not TERMINATE:
184         wset = []
185         if packetReady(bkbuf, ether_mode):
186             wset.append(tun)
187         if packetReady(fwbuf, ether_mode):
188             wset.append(remote)
189         rdrdy, wrdy, errs = select.select((tun,remote),wset,(tun,remote),1)
190         
191         # check for errors
192         if errs:
193             break
194         
195         # check to see if we can write
196         if remote in wrdy and packetReady(fwbuf, ether_mode):
197             packet, fwbuf = pullPacket(fwbuf, ether_mode)
198             try:
199                 if crypto_mode:
200                     enpacket = encrypt(packet, crypter)
201                 else:
202                     enpacket = packet
203                 os.write(remote.fileno(), enpacket)
204             except:
205                 if not udp:
206                     # in UDP mode, we ignore errors - packet loss man...
207                     raise
208             print >>stderr, '>', formatPacket(packet, ether_mode)
209         if tun in wrdy and packetReady(bkbuf, ether_mode):
210             packet, bkbuf = pullPacket(bkbuf, ether_mode)
211             formatted = formatPacket(packet, ether_mode)
212             if with_pi:
213                 packet = piWrap(packet, ether_mode)
214             os.write(tun.fileno(), packet)
215             print >>stderr, '<', formatted
216         
217         # check incoming data packets
218         if tun in rdrdy:
219             packet = os.read(tun.fileno(),2000) # tun.read blocks until it gets 2k!
220             if with_pi:
221                 packet = piStrip(packet)
222             fwbuf += packet
223         if remote in rdrdy:
224             try:
225                 packet = os.read(remote.fileno(),2000) # remote.read blocks until it gets 2k!
226                 if crypto_mode:
227                     packet = decrypt(packet, crypter)
228             except:
229                 if not udp:
230                     # in UDP mode, we ignore errors - packet loss man...
231                     raise
232             bkbuf += packet
233