Working (and easy to use) multicast forwarding
[nepi.git] / src / nepi / testbeds / planetlab / scripts / mcastfwd.py
1 import sys
2
3 import signal
4 import socket
5 import struct
6 import optparse
7 import threading
8 import subprocess
9 import re
10 import time
11 import collections
12 import os
13 import traceback
14
15 import ipaddr2
16
17 usage = "usage: %prog [options] <enabled-addresses>"
18
19 parser = optparse.OptionParser(usage=usage)
20
21 parser.add_option(
22     "-d", "--poll-delay", dest="poll_delay", metavar="SECONDS", type="float",
23     default = 1.0,
24     help = "Multicast subscription polling interval")
25 parser.add_option(
26     "-D", "--refresh-delay", dest="refresh_delay", metavar="SECONDS", type="float",
27     default = 30.0,
28     help = "Full-refresh interval - time between full IGMP reports")
29 parser.add_option(
30     "-p", "--fwd-path", dest="fwd_path", metavar="PATH", 
31     default = "/var/run/mcastfwd",
32     help = "Path of the unix socket in which the program will listen for packets")
33 parser.add_option(
34     "-r", "--router-path", dest="mrt_path", metavar="PATH", 
35     default = "/var/run/mcastrt",
36     help = "Path of the unix socket in which the program will listen for routing changes")
37 parser.add_option(
38     "-A", "--announce-only", dest="announce_only", action="store_true",
39     default = False,
40     help = "If given, only group membership announcements will be made. Useful for non-router multicast nodes.")
41
42 (options, remaining_args) = parser.parse_args(sys.argv[1:])
43
44 ETH_P_ALL = 0x00000003
45 ETH_P_IP = 0x00000800
46 TUNSETIFF = 0x400454ca
47 IFF_NO_PI = 0x00001000
48 IFF_TAP = 0x00000002
49 IFF_TUN = 0x00000001
50 IFF_VNET_HDR = 0x00004000
51 TUN_PKT_STRIP = 0x00000001
52 IFHWADDRLEN = 0x00000006
53 IFNAMSIZ = 0x00000010
54 IFREQ_SZ = 0x00000028
55 FIONREAD = 0x0000541b
56
57 class IGMPThread(threading.Thread):
58     def __init__(self, vif_addr, *p, **kw):
59         super(IGMPThread, self).__init__(*p, **kw)
60         
61         vif_addr = vif_addr.strip()
62         self.vif_addr = vif_addr
63         self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
64         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
65             socket.inet_aton(self.vif_addr) )
66         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
67         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
68         self._stop = False
69         self.setDaemon(True)
70         
71         # Find tun name
72         proc = subprocess.Popen(['ip','addr','show'],
73             stdout = subprocess.PIPE,
74             stderr = subprocess.STDOUT,
75             stdin = open('/dev/null','r+b') )
76         tun_name = None
77         heading = re.compile(r"\d+:\s*(\w+):.*")
78         addr = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}).*")
79         for line in proc.stdout:
80             match = heading.match(line)
81             if match:
82                 tun_name = match.group(1)
83             else:
84                 match = addr.match(line)
85                 if match and match.group(1) == vif_addr:
86                     self.tun_name = tun_name
87                     break
88         else:
89             raise RuntimeError, "Could not find iterface for", vif_addr
90     
91     def run(self):
92         devnull = open('/dev/null','r+b')
93         maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
94         cur_maddr = set()
95         lastfullrefresh = time.time()
96         while not self._stop:
97             # Get current subscriptions
98             proc = subprocess.Popen(['ip','maddr','show',self.tun_name],
99                 stdout = subprocess.PIPE,
100                 stderr = subprocess.STDOUT,
101                 stdin = devnull)
102             new_maddr = set()
103             for line in proc.stdout:
104                 match = maddr_re.match(line)
105                 if match:
106                     new_maddr.add(match.group(1))
107             proc.wait()
108             
109             # Every now and then, send a full report
110             now = time.time()
111             report_new = new_maddr
112             if (now - lastfullrefresh) <= options.refresh_delay:
113                 report_new = report_new - cur_maddr
114             else:
115                 lastfullrefresh = now
116             
117             # Report subscriptions
118             for grp in report_new:
119                 print >>sys.stderr, "JOINING", grp
120                 igmpp = ipaddr2.ipigmp(
121                     self.vif_addr, '224.0.0.2', 1, 0x16, 0, grp, 
122                     noipcksum=True)
123                 try:
124                     self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
125                 except:
126                     traceback.print_exc(file=sys.stderr)
127
128             # Notify group leave
129             for grp in cur_maddr - new_maddr:
130                 print >>sys.stderr, "LEAVING", grp
131                 igmpp = ipaddr2.ipigmp(
132                     self.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
133                     noipcksum=True)
134                 try:
135                     self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
136                 except:
137                     traceback.print_exc(file=sys.stderr)
138
139             cur_maddr = new_maddr
140             
141             time.sleep(options.poll_delay)
142     
143     def stop(self):
144         self._stop = True
145         self.join(1+5*options.poll_delay)
146
147
148 class FWDThread(threading.Thread):
149     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
150         super(FWDThread, self).__init__(*p, **kw)
151         
152         self.in_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
153         self.in_socket.bind(options.fwd_path)
154         
155         self.pending = collections.deque()
156         self.maxpending = 1000
157         self.rt_cache = rt_cache
158         self.router_socket = router_socket
159         self.vifs = vifs
160         self.fwd_sockets = {}
161         for fwd_target in remaining_args:
162             fwd_target = socket.inet_aton(fwd_target)
163             fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
164             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
165             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
166             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
167             self.fwd_sockets[fwd_target] = fwd_socket
168         
169         self._stop = False
170         self.setDaemon(True)
171     
172     def run(self):
173         in_socket = self.in_socket
174         rt_cache = self.rt_cache
175         vifs = self.vifs
176         router_socket = self.router_socket
177         len_ = len
178         ord_ = ord
179         str_ = str
180         pending = self.pending
181         in_socket.settimeout(options.poll_delay)
182         buffer_ = buffer
183         enumerate_ = enumerate
184         fwd_sockets = self.fwd_sockets
185         npending = 0
186         
187         while not self._stop:
188             # Get packet
189             try:
190                 if pending and npending:
191                     packet = pending.pop()
192                     npending -= 1
193                 else:
194                     packet = in_socket.recv(2000)
195             except socket.timeout, e:
196                 if pending and not npending:
197                     npending = len_(pending)
198                 continue
199             if not packet or len_(packet) < 24:
200                 continue
201             
202             fullpacket = packet
203             parent = buffer_(packet,0,4)
204             packet = buffer_(packet,4)
205             
206             if packet[9] == '\x02':
207                 # IGMP packet? It's for mrouted
208                 self.router_socket.send(packet)
209             elif packet[9] == '\x00':
210                 # LOOPING packet, discard
211                 continue
212             
213             # To-Do: PIM asserts
214             
215             # Get route
216             addrinfo = buffer_(packet,12,8)
217             fwd_targets = rt_cache.get(parent+addrinfo)
218             if fwd_targets is None:
219                 fwd_targets = rt_cache.get('\x00\x00\x00\x00'+str_(addrinfo))
220             
221             if fwd_targets is not None:
222                 # Forward
223                 ttl = ord_(packet[8])
224                 tgt_group = (addrinfo[4:],0)
225                 print >>sys.stderr, socket.inet_ntoa(tgt_group[0]), "->",
226                 for vifi, ttl in enumerate_(fwd_targets):
227                     ttl_thresh = ord_(ttl)
228                     if ttl_thresh > 0 and ttl > ttl_thresh and vifi in vifs:
229                         vifi = vifs[vifi]
230                         if vifi[4] in fwd_sockets:
231                             print >>sys.stderr, socket.inet_ntoa(vifi[4]),
232                             fwd_socket = fwd_sockets[vifi[4]]
233                             fwd_socket.sendto(packet, 0, tgt_group)
234                 print >>sys.stderr, "."
235             else:
236                 # Mark pending
237                 if len_(pending) < self.maxpending:
238                     tgt_group = addrinfo[4:]
239                     print >>sys.stderr, socket.inet_ntoa(tgt_group), "-> ?"
240                     
241                     pending.append(fullpacket)
242                     
243                     # Notify mrouted by forwarding it with protocol 0
244                     router_socket.send(''.join(
245                         (packet[:9],'\x00',packet[10:]) ))
246     
247     def stop(self):
248         self._stop = True
249         self.join(1+5*options.poll_delay)
250
251
252 class RouterThread(threading.Thread):
253     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
254         super(RouterThread, self).__init__(*p, **kw)
255         
256         self.rt_cache = rt_cache
257         self.vifs = vifs
258         self.router_socket = router_socket
259
260         self._stop = False
261         self.setDaemon(True)
262     
263     def run(self):
264         rt_cache = self.rt_cache
265         vifs = self.vifs
266         addr_vifs = {}
267         router_socket = self.router_socket
268         router_socket.settimeout(options.poll_delay)
269         len_ = len
270         buffer_ = buffer
271         
272         buf = ""
273         
274         MRT_BASE        = 200
275         MRT_ADD_VIF     = MRT_BASE+2    # Add a virtual interface               
276         MRT_DEL_VIF     = MRT_BASE+3    # Delete a virtual interface            
277         MRT_ADD_MFC     = MRT_BASE+4    # Add a multicast forwarding entry      
278         MRT_DEL_MFC = MRT_BASE+5        # Delete a multicast forwarding entry   
279         
280         def cmdhdr(cmd, unpack=struct.unpack, buffer=buffer):
281             op,dlen = unpack('II', buffer(cmd,0,8))
282             cmd = buffer(cmd,8)
283             return op,dlen,cmd
284         def vifctl(data, unpack=struct.unpack):
285             #vifi, flags,threshold,rate_limit,lcl_addr,rmt_addr = unpack('HBBI4s4s', data)
286             return unpack('HBBI4s4s', data)
287         def mfcctl(data, unpack=struct.unpack):
288             #origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = unpack('4s4sH10sIIIi', data)
289             return unpack('4s4sH32sIIIi', data)
290         
291         
292         def add_vif(cmd):
293             vifi = vifctl(cmd)
294             vifs[vifi[0]] = vifi
295             addr_vifs[vifi[4]] = vifi[0]
296             print >>sys.stderr, "Added VIF", vifi
297         def del_vif(cmd):
298             vifi = vifctl(cmd)
299             vifi = vifs[vifi[0]]
300             del addr_vifs[vifi[4]]
301             del vifs[vifi[0]]
302             print >>sys.stderr, "Removed VIF", vifi
303         def add_mfc(cmd):
304             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
305             if parent in vifs:
306                 parent_addr = vifs[parent][4]
307             else:
308                 parent_addr = '\x00\x00\x00\x00'
309             addrinfo = ''.join((parent_addr,origin,mcastgrp))
310             rt_cache[addrinfo] = ttls
311             print >>sys.stderr, "Added RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp)))
312         def del_mfc(cmd):
313             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
314             if parent in vifs:
315                 parent_addr = vifs[parent][4]
316             else:
317                 parent_addr = '\x00\x00\x00\x00'
318             addrinfo = ''.join((parent_addr,origin,mcastgrp))
319             del rt_cache[addrinfo]
320             print >>sys.stderr, "Removed RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp)))
321         
322         commands = {
323             MRT_ADD_VIF : add_vif,
324             MRT_DEL_VIF : del_vif,
325             MRT_ADD_MFC : add_mfc,
326             MRT_DEL_MFC : del_mfc,
327         }
328
329         while not self._stop:
330             if len_(buf) < 8 or len_(buf) < (cmdhdr(buf)[1]+8):
331                 # Get cmd
332                 try:
333                     cmd = router_socket.recv(2000)
334                 except socket.timeout, e:
335                     continue
336                 if not cmd:
337                     print >>sys.stderr, "PLRT CONNECTION BROKEN"
338                     TERMINATE.append(None)
339                     break
340             
341             if buf:
342                 buf += cmd
343                 cmd = buf
344             
345             if len_(cmd) < 8:
346                 continue
347             
348             op,dlen,data = cmdhdr(cmd)
349             if len_(data) < dlen:
350                 continue
351             
352             buf = buffer_(data, dlen)
353             data = buffer_(data, 0, dlen)
354             
355             print >>sys.stderr, "COMMAND", op, "DATA", dlen
356             
357             if op in commands:
358                 try:
359                     commands[op](data)
360                 except:
361                     traceback.print_exc(file=sys.stderr)
362             else:
363                 print >>sys.stderr, "IGNORING UNKNOWN COMMAND", op
364     
365     def stop(self):
366         self._stop = True
367         self.join(1+5*options.poll_delay)
368
369
370
371 igmp_threads = []
372 for vif_addr in remaining_args:
373     igmp_threads.append(IGMPThread(vif_addr))
374
375 rt_cache = {}
376 vifs = {}
377
378 TERMINATE = []
379 TERMINATE = []
380 def _finalize(sig,frame):
381     global TERMINATE
382     TERMINATE.append(None)
383 signal.signal(signal.SIGTERM, _finalize)
384
385
386 try:
387     if not options.announce_only:
388         router_socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
389         router_socket.bind(options.mrt_path)
390         router_socket.listen(0)
391         router_remote_socket, router_remote_addr = router_socket.accept()
392
393         fwd_thread = FWDThread(rt_cache, router_remote_socket, vifs)
394         router_thread = RouterThread(rt_cache, router_remote_socket, vifs)
395
396     for thread in igmp_threads:
397         thread.start()
398     
399     if not options.announce_only:
400         fwd_thread.start()
401         router_thread.start()
402
403     while not TERMINATE:
404         time.sleep(30)
405 finally:
406     if os.path.exists(options.mrt_path):
407         try:
408             os.remove(options.mrt_path)
409         except:
410             pass
411     if os.path.exists(options.fwd_path):
412         try:
413             os.remove(options.fwd_path)    
414         except:
415             pass
416
417