TUN/TAP filters, initial version, with tests.
[nepi.git] / src / nepi / util / tunchannel_impl.py
1 import os
2 import sys
3 import random
4 import threading
5 import socket
6 import select
7 import weakref
8 import time
9
10 from tunchannel import tun_fwd
11
12 class TunChannel(object):
13     """
14     Helper box class that implements most of the required boilerplate
15     for tunnelling cross connections.
16     
17     The class implements a threaded forwarder that runs in the
18     testbed controller process. It takes several parameters that
19     can be given by directly setting attributes:
20     
21         tun_port/addr/proto/cipher: information about the local endpoint.
22             The addresses here should be externally-reachable,
23             since when listening or when using the UDP protocol,
24             connections to this address/port will be attempted
25             by remote endpoitns.
26         
27         peer_port/addr/proto/cipher: information about the remote endpoint.
28             Usually, you set these when the cross connection 
29             initializer/completion functions are invoked (both).
30         
31         tun_key: the agreed upon encryption key.
32         
33         listen: if set to True (and in TCP mode), it marks a
34             listening endpoint. Be certain that any TCP connection
35             is made between a listening and a non-listening
36             endpoint, or it won't work.
37         
38         with_pi: set if the incoming packet stream (see tun_socket)
39             contains PI headers - if so, they will be stripped.
40         
41         ethernet_mode: set if the incoming packet stream is
42             composed of ethernet frames (as opposed of IP packets).
43         
44         udp: set to use UDP datagrams instead of TCP connections.
45         
46         tun_socket: a socket or file object that can be read
47             from and written to. Packets will be read when available,
48             remote packets will be forwarded as writes.
49             A socket should be of type SOCK_SEQPACKET (or SOCK_DGRAM
50             if not possible), a file object should preserve packet
51             boundaries (ie, a pipe or TUN/TAP device file descriptor).
52         
53         trace_target: a file object where trace output will be sent.
54             It cannot be changed after launch.
55             By default, it's sys.stderr
56     """
57     
58     def __init__(self):
59         # Some operational attributes
60         self.listen = False
61         self.ethernet_mode = True
62         self.with_pi = False
63         
64         # These get initialized when the channel is configured
65         # They're part of the TUN standard attribute set
66         self.tun_port = None
67         self.tun_addr = None
68         self.tun_cipher = None
69         
70         # These get initialized when the channel is connected to its peer
71         self.peer_proto = None
72         self.peer_addr = None
73         self.peer_port = None
74         self.peer_cipher = None
75         
76         # These get initialized when the channel is connected to its iface
77         self.tun_socket = None
78
79         # same as peer proto, but for execute-time standard attribute lookups
80         self.tun_proto = None 
81         
82         # some state
83         self.prepared = False
84         self._terminate = [] # terminate signaller
85         self._exc = [] # exception store, to relay exceptions from the forwarder thread
86         self._connected = threading.Event()
87         self._forwarder_thread = None
88         
89         # trace to stderr
90         self.stderr = sys.stderr
91         
92         # Generate an initial random cryptographic key to use for tunnelling
93         # Upon connection, both endpoints will agree on a common one based on
94         # this one.
95         self.tun_key = ( ''.join(map(chr, [ 
96                     r.getrandbits(8) 
97                     for i in xrange(32) 
98                     for r in (random.SystemRandom(),) ])
99                 ).encode("base64").strip() )        
100         
101
102     def __str__(self):
103         return "%s<%s %s:%s %s %s:%s %s>" % (
104             self.__class__.__name__,
105             self.tun_proto, 
106             self.tun_addr, self.tun_port,
107             self.peer_proto, 
108             self.peer_addr, self.peer_port,
109             self.tun_cipher,
110         )
111
112     def Prepare(self):
113         if self.tun_proto:
114             udp = self.tun_proto == "udp"
115             if not udp and self.listen and not self._forwarder_thread:
116                 if self.listen or (self.peer_addr and self.peer_port and self.peer_proto):
117                     self._launch()
118     
119     def Setup(self):
120         if self.tun_proto:
121             if not self._forwarder_thread:
122                 self._launch()
123     
124     def Cleanup(self):
125         if self._forwarder_thread:
126             self.Kill()
127
128     def Wait(self):
129         if self._forwarder_thread:
130             self._connected.wait()
131             for exc in self._exc:
132                 # Relay exception
133                 eTyp, eVal, eLoc = exc
134                 raise eTyp, eVal, eLoc
135
136     def Kill(self):    
137         if self._forwarder_thread:
138             if not self._terminate:
139                 self._terminate.append(None)
140             self._forwarder_thread.join()
141
142     def _launch(self):
143         # Launch forwarder thread with a weak reference
144         # to self, so that we don't create any strong cycles
145         # and automatic refcounting works as expected
146         self._forwarder_thread = threading.Thread(
147             target = self._forwarder,
148             args = (weakref.ref(self),) )
149         self._forwarder_thread.start()
150
151     @staticmethod
152     def _forwarder(weak_self):
153         try:
154             weak_self().__forwarder(weak_self)
155         except:
156             self = weak_self()
157             
158             # store exception and wake up anyone waiting
159             self._exc.append(sys.exc_info())
160             self._connected.set()
161     
162     @staticmethod
163     def __forwarder(weak_self):
164         # grab strong reference
165         self = weak_self()
166         if not self:
167             return
168         
169         peer_port = self.peer_port
170         peer_addr = self.peer_addr
171         peer_proto= self.peer_proto
172         peer_cipher=self.peer_cipher
173
174         local_port = self.tun_port
175         local_addr = self.tun_addr
176         local_proto = self.tun_proto
177         local_cipher= self.tun_cipher
178         
179         stderr = self.stderr
180         ether_mode = self.ethernet_mode
181         with_pi = self.with_pi
182         
183         if local_proto != peer_proto:
184             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (local_proto, peer_proto)
185
186         if local_cipher != peer_cipher:
187             raise RuntimeError, "Peering cipher mismatch: %s != %s" % (local_cipher, peer_cipher)
188         
189         udp = local_proto == 'udp'
190         listen = self.listen
191
192         if (udp or not listen) and (not peer_port or not peer_addr):
193             raise RuntimeError, "Misconfigured peer for: %s" % (self,)
194
195         if (udp or listen) and (not local_port or not local_addr):
196             raise RuntimeError, "Misconfigured TUN: %s" % (self,)
197         
198         TERMINATE = self._terminate
199         cipher_key = self.tun_key
200         tun = self.tun_socket
201         
202         if not tun:
203             raise RuntimeError, "Unconnected TUN channel %s" % (self,)
204         
205         if udp:
206             # listen on udp port
207             rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
208             for i in xrange(30):
209                 try:
210                     rsock.bind((local_addr,local_port))
211                     break
212                 except socket.error:
213                     # wait a while, retry
214                     time.sleep(1)
215             else:
216                 rsock.bind((local_addr,local_port))
217             rsock.connect((peer_addr,peer_port))
218             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
219         elif listen:
220             # accept tcp connections
221             lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
222             for i in xrange(30):
223                 try:
224                     lsock.bind((local_addr,local_port))
225                     break
226                 except socket.error:
227                     # wait a while, retry
228                     time.sleep(1)
229             else:
230                 lsock.bind((local_addr,local_port))
231             lsock.listen(1)
232             rsock,raddr = lsock.accept()
233             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
234         else:
235             # connect to tcp server
236             rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
237             for i in xrange(30):
238                 try:
239                     rsock.connect((peer_addr,peer_port))
240                     break
241                 except socket.error:
242                     # wait a while, retry
243                     time.sleep(1)
244             else:
245                 rsock.connect((peer_addr,peer_port))
246             remote = os.fdopen(rsock.fileno(), 'r+b', 0)
247         
248         # notify that we're ready
249         self._connected.set()
250         
251         # drop strong reference
252         del self
253         
254         tun_fwd(tun, remote,
255             with_pi = with_pi, 
256             ether_mode = ether_mode, 
257             cipher_key = cipher_key, 
258             udp = udp, 
259             TERMINATE = TERMINATE,
260             stderr = stderr,
261             cipher = local_cipher
262         )
263         
264         tun.close()
265         remote.close()
266
267
268 def create_tunchannel(testbed_instance, guid, devnull = []):
269     """
270     TunChannel factory for metadata.
271     By default, silences traceing.
272     
273     You can override the created element's attributes if you will.
274     """
275     if not devnull:
276         # just so it's not open if not needed
277         devnull.append(open("/dev/null","w"))
278     element = TunChannel()
279     element.stderr = devnull[0] # silence tracing
280     testbed_instance._elements[guid] = element
281
282 def preconfigure_tunchannel(testbed_instance, guid):
283     """
284     TunChannel preconfiguration.
285     
286     It initiates the forwarder thread for listening tcp channels.
287     
288     Takes the public address from the operating system, so it should be adequate
289     for most situations when the TunChannel forwarder thread runs in the same
290     process as the testbed controller.
291     """
292     element = testbed_instance._elements[guid]
293     
294     # Find external interface, if any
295     public_addr = os.popen(
296         "/sbin/ifconfig "
297         "| grep $(ip route | grep default | awk '{print $3}' "
298                 "| awk -F. '{print $1\"[.]\"$2}') "
299         "| head -1 | awk '{print $2}' "
300         "| awk -F : '{print $2}'").read().rstrip()
301     element.tun_addr = public_addr
302
303     # Set standard TUN attributes
304     if not element.tun_port and element.tun_addr:
305         element.tun_port = 15000 + int(guid)
306
307     # First-phase setup
308     if element.peer_proto:
309         # cross tun
310         if not element.tun_addr or not element.tun_port:
311             listening = True
312         elif not element.peer_addr or not element.peer_port:
313             listening = True
314         else:
315             # both have addresses...
316             # ...the one with the lesser address listens
317             listening = element.tun_addr < element.peer_addr
318         element.listen = listening
319         element.Prepare()
320
321 def postconfigure_tunchannel(testbed_instance, guid):
322     """
323     TunChannel preconfiguration.
324     
325     Initiates the forwarder thread for connecting tcp channels or 
326     udp channels in general.
327     
328     Should be adequate for most implementations.
329     """
330     element = testbed_instance._elements[guid]
331     
332     # Second-phase setup
333     element.Setup()
334
335
336 def crossconnect_tunchannel_peer_init(proto, testbed_instance, tun_guid, peer_data,
337         preconfigure_tunchannel = preconfigure_tunchannel):
338     """
339     Cross-connection initialization.
340     Should be adequate for most implementations.
341     
342     For use in metadata, bind the first "proto" argument with the connector type. Eg:
343     
344         conn_init = functools.partial(crossconnect_tunchannel_peer_init, "tcp")
345     
346     If you don't use the stock preconfigure function, specify your own as a keyword argument.
347     """
348     tun = testbed_instance._elements[tun_guid]
349     tun.peer_addr = peer_data.get("tun_addr")
350     tun.peer_proto = peer_data.get("tun_proto") or proto
351     tun.peer_port = peer_data.get("tun_port")
352     tun.peer_cipher = peer_data.get("tun_cipher")
353     tun.tun_key = min(tun.tun_key, peer_data.get("tun_key"))
354     tun.tun_proto = proto
355     
356     preconfigure_tunchannel(testbed_instance, tun_guid)
357
358 def crossconnect_tunchannel_peer_compl(proto, testbed_instance, tun_guid, peer_data,
359         postconfigure_tunchannel = postconfigure_tunchannel):
360     """
361     Cross-connection completion.
362     Should be adequeate for most implementations.
363     
364     For use in metadata, bind the first "proto" argument with the connector type. Eg:
365     
366         conn_init = functools.partial(crossconnect_tunchannel_peer_compl, "tcp")
367     
368     If you don't use the stock postconfigure function, specify your own as a keyword argument.
369     """
370     # refresh (refreshable) attributes for second-phase
371     tun = testbed_instance._elements[tun_guid]
372     tun.peer_addr = peer_data.get("tun_addr")
373     tun.peer_proto = peer_data.get("tun_proto") or proto
374     tun.peer_port = peer_data.get("tun_port")
375     tun.peer_cipher = peer_data.get("tun_cipher")
376     
377     postconfigure_tunchannel(testbed_instance, tun_guid)
378
379     
380
381 def wait_tunchannel(testbed_instance, guid):
382     """
383     Wait for the channel forwarder to be up and running.
384     
385     Useful as a pre-start function to assure proper startup synchronization,
386     be certain to start TunChannels before applications that might require them.
387     """
388     element = testbed_instance.elements[guid]
389     element.Wait()
390