Fix multicast forwarder when no router is present.
[nepi.git] / src / nepi / testbeds / planetlab / scripts / mcastfwd.py
1 import sys
2
3 import signal
4 import socket
5 import struct
6 import optparse
7 import threading
8 import subprocess
9 import re
10 import time
11 import collections
12 import os
13 import traceback
14 import logging
15
16 import ipaddr2
17
18 usage = "usage: %prog [options] <enabled-addresses>"
19
20 parser = optparse.OptionParser(usage=usage)
21
22 parser.add_option(
23     "-d", "--poll-delay", dest="poll_delay", metavar="SECONDS", type="float",
24     default = 1.0,
25     help = "Multicast subscription polling interval")
26 parser.add_option(
27     "-D", "--refresh-delay", dest="refresh_delay", metavar="SECONDS", type="float",
28     default = 30.0,
29     help = "Full-refresh interval - time between full IGMP reports")
30 parser.add_option(
31     "-p", "--fwd-path", dest="fwd_path", metavar="PATH", 
32     default = "/var/run/mcastfwd",
33     help = "Path of the unix socket in which the program will listen for packets")
34 parser.add_option(
35     "-r", "--router-path", dest="mrt_path", metavar="PATH", 
36     default = "/var/run/mcastrt",
37     help = "Path of the unix socket in which the program will listen for routing changes")
38 parser.add_option(
39     "-A", "--announce-only", dest="announce_only", action="store_true",
40     default = False,
41     help = "If given, only group membership announcements will be made. "
42            "Useful for non-router non-member multicast nodes.")
43 parser.add_option(
44     "-R", "--no-router", dest="no_router", action="store_true",
45     default = False,
46     help = "If given, only group membership announcements and forwarding to the default multicast egress will be made. "
47            "Useful for non-router but member multicast nodes.")
48 parser.add_option(
49     "-v", "--verbose", dest="verbose", action="store_true",
50     default = False,
51     help = "Log more verbosely")
52
53 (options, remaining_args) = parser.parse_args(sys.argv[1:])
54
55 logging.basicConfig(
56     stream=sys.stderr, 
57     level=logging.DEBUG if options.verbose else logging.WARNING)
58
59 ETH_P_ALL = 0x00000003
60 ETH_P_IP = 0x00000800
61 TUNSETIFF = 0x400454ca
62 IFF_NO_PI = 0x00001000
63 IFF_TAP = 0x00000002
64 IFF_TUN = 0x00000001
65 IFF_VNET_HDR = 0x00004000
66 TUN_PKT_STRIP = 0x00000001
67 IFHWADDRLEN = 0x00000006
68 IFNAMSIZ = 0x00000010
69 IFREQ_SZ = 0x00000028
70 FIONREAD = 0x0000541b
71
72 class IGMPThread(threading.Thread):
73     def __init__(self, vif_addr, *p, **kw):
74         super(IGMPThread, self).__init__(*p, **kw)
75         
76         vif_addr = vif_addr.strip()
77         self.vif_addr = vif_addr
78         self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
79         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
80             socket.inet_aton(self.vif_addr) )
81         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
82         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
83         self._stop = False
84         self.setDaemon(True)
85         
86         # Find tun name
87         proc = subprocess.Popen(['ip','addr','show'],
88             stdout = subprocess.PIPE,
89             stderr = subprocess.STDOUT,
90             stdin = open('/dev/null','r+b') )
91         tun_name = None
92         heading = re.compile(r"\d+:\s*([-a-zA-Z0-9_]+):.*")
93         addr = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}).*")
94         for line in proc.stdout:
95             match = heading.match(line)
96             if match:
97                 tun_name = match.group(1)
98             else:
99                 match = addr.match(line)
100                 if match and match.group(1) == vif_addr:
101                     self.tun_name = tun_name
102                     break
103         else:
104             raise RuntimeError, "Could not find iterface for", vif_addr
105     
106     def run(self):
107         devnull = open('/dev/null','r+b')
108         maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
109         cur_maddr = set()
110         lastfullrefresh = time.time()
111         while not self._stop:
112             # Get current subscriptions @ vif
113             proc = subprocess.Popen(['ip','maddr','show',self.tun_name],
114                 stdout = subprocess.PIPE,
115                 stderr = subprocess.STDOUT,
116                 stdin = devnull)
117             new_maddr = set()
118             for line in proc.stdout:
119                 match = maddr_re.match(line)
120                 if match:
121                     new_maddr.add(match.group(1))
122             proc.wait()
123             
124             # Get current subscriptions @ eth0 (default on PL),
125             # they should be considered "universal" suscriptions.
126             proc = subprocess.Popen(['ip','maddr','show', 'eth0'],
127                 stdout = subprocess.PIPE,
128                 stderr = subprocess.STDOUT,
129                 stdin = devnull)
130             new_maddr = set()
131             for line in proc.stdout:
132                 match = maddr_re.match(line)
133                 if match:
134                     new_maddr.add(match.group(1))
135             proc.wait()
136             
137             # Every now and then, send a full report
138             now = time.time()
139             report_new = new_maddr
140             if (now - lastfullrefresh) <= options.refresh_delay:
141                 report_new = report_new - cur_maddr
142             else:
143                 lastfullrefresh = now
144             
145             # Report subscriptions
146             for grp in report_new:
147                 print >>sys.stderr, "JOINING", grp
148                 igmpp = ipaddr2.ipigmp(
149                     self.vif_addr, grp, 1, 0x16, 0, grp, 
150                     noipcksum=True)
151                 try:
152                     self.igmp_socket.sendto(igmpp, 0, (grp,0))
153                 except:
154                     traceback.print_exc(file=sys.stderr)
155
156             # Notify group leave
157             for grp in cur_maddr - new_maddr:
158                 print >>sys.stderr, "LEAVING", grp
159                 igmpp = ipaddr2.ipigmp(
160                     self.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
161                     noipcksum=True)
162                 try:
163                     self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
164                 except:
165                     traceback.print_exc(file=sys.stderr)
166
167             cur_maddr = new_maddr
168             
169             time.sleep(options.poll_delay)
170     
171     def stop(self):
172         self._stop = True
173         self.join(1+5*options.poll_delay)
174
175
176 class FWDThread(threading.Thread):
177     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
178         super(FWDThread, self).__init__(*p, **kw)
179         
180         self.in_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
181         self.in_socket.bind(options.fwd_path)
182         
183         self.pending = collections.deque()
184         self.maxpending = 1000
185         self.rt_cache = rt_cache
186         self.router_socket = router_socket
187         self.vifs = vifs
188         
189         # prepare forwarding sockets 
190         self.fwd_sockets = {}
191         for fwd_target in remaining_args:
192             fwd_target = socket.inet_aton(fwd_target)
193             fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
194             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
195             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
196             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
197             self.fwd_sockets[fwd_target] = fwd_socket
198         
199         # we always forward to eth0
200         # In PL, we cannot join the multicast routers in eth0,
201         # that would bring a lot of trouble. But we can
202         # listen there for subscriptions and forward interesting
203         # packets, partially joining the mbone
204         # TODO: IGMP messages from eth0 should be selectively
205         #       replicated in all vifs to propagate external
206         #       subscriptions. It is complex though.
207         fwd_target = '\x00'*4
208         fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
209         fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
210         fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
211         fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
212         self.fwd_sockets[fwd_target] = fwd_socket
213         
214         self._stop = False
215         self.setDaemon(True)
216     
217     def run(self):
218         in_socket = self.in_socket
219         rt_cache = self.rt_cache
220         vifs = self.vifs
221         router_socket = self.router_socket
222         len_ = len
223         ord_ = ord
224         str_ = str
225         pending = self.pending
226         in_socket.settimeout(options.poll_delay)
227         buffer_ = buffer
228         enumerate_ = enumerate
229         fwd_sockets = self.fwd_sockets
230         npending = 0
231         noent = (None,None)
232         def_socket = fwd_sockets['\x00\x00\x00\x00']
233         
234         while not self._stop:
235             # Get packet
236             try:
237                 if pending and npending:
238                     packet = pending.pop()
239                     npending -= 1
240                 else:
241                     packet = in_socket.recv(2000)
242             except socket.timeout, e:
243                 if pending and not npending:
244                     npending = len_(pending)
245                 continue
246             if not packet or len_(packet) < 24:
247                 continue
248             
249             fullpacket = packet
250             parent = packet[:4]
251             packet = buffer_(packet,4)
252             
253             if packet[9] == '\x02':
254                 # IGMP packet? It's for mrouted
255                 if router_socket:
256                     router_socket.send(packet)
257             elif packet[9] == '\x00':
258                 # LOOPING packet, discard
259                 continue
260             
261             # To-Do: PIM asserts
262             
263             # Get route
264             addrinfo = packet[12:20]
265             fwd_targets, rparent = rt_cache.get(addrinfo, noent)
266             
267             if fwd_targets is not None and (rparent == '\x00\x00\x00\x00' or rparent == parent):
268                 # Forward to vifs
269                 ttl = ord_(packet[8])
270                 tgt_group = (socket.inet_ntoa(addrinfo[4:]),0)
271                 print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ttl", ttl,
272                 nfwd_targets = len_(fwd_targets)
273                 for vifi, vif in vifs.iteritems():
274                     if vifi < nfwd_targets:
275                         ttl_thresh = ord_(fwd_targets[vifi])
276                         if ttl_thresh > 0 and ttl > ttl_thresh:
277                             if vif[4] in fwd_sockets:
278                                 try:
279                                     print >>sys.stderr, socket.inet_ntoa(vif[4]),
280                                     fwd_socket = fwd_sockets[vif[4]]
281                                     fwd_socket.sendto(packet, 0, tgt_group)
282                                 except:
283                                     pass
284                 
285                 # Forward to eth0
286                 try:
287                     print >>sys.stderr, 'default',
288                     def_socket.sendto(packet, 0, tgt_group)
289                 except:
290                     pass
291                 
292                 print >>sys.stderr, "."
293             elif router_socket:
294                 # Mark pending
295                 if len_(pending) < self.maxpending:
296                     tgt_group = addrinfo[4:]
297                     print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ?"
298                     
299                     pending.append(fullpacket)
300                     
301                     # Notify mrouted by forwarding it with protocol 0
302                     router_socket.send(''.join(
303                         (packet[:9],'\x00',packet[10:]) ))
304             else:
305                 # Forward to eth0
306                 ttl = ord_(packet[8])
307                 tgt_group = (socket.inet_ntoa(addrinfo[4:]),0)
308                 
309                 try:
310                     print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ttl", ttl, 'default'
311                     def_socket.sendto(packet, 0, tgt_group)
312                 except:
313                     pass
314     
315     def stop(self):
316         self._stop = True
317         self.join(1+5*options.poll_delay)
318
319
320 class RouterThread(threading.Thread):
321     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
322         super(RouterThread, self).__init__(*p, **kw)
323         
324         self.rt_cache = rt_cache
325         self.vifs = vifs
326         self.router_socket = router_socket
327
328         self._stop = False
329         self.setDaemon(True)
330     
331     def run(self):
332         rt_cache = self.rt_cache
333         vifs = self.vifs
334         addr_vifs = {}
335         router_socket = self.router_socket
336         router_socket.settimeout(options.poll_delay)
337         len_ = len
338         buffer_ = buffer
339         
340         buf = ""
341         
342         MRT_BASE        = 200
343         MRT_ADD_VIF     = MRT_BASE+2    # Add a virtual interface               
344         MRT_DEL_VIF     = MRT_BASE+3    # Delete a virtual interface            
345         MRT_ADD_MFC     = MRT_BASE+4    # Add a multicast forwarding entry      
346         MRT_DEL_MFC = MRT_BASE+5        # Delete a multicast forwarding entry   
347         
348         def cmdhdr(cmd, unpack=struct.unpack, buffer=buffer):
349             op,dlen = unpack('II', buffer(cmd,0,8))
350             cmd = buffer(cmd,8)
351             return op,dlen,cmd
352         def vifctl(data, unpack=struct.unpack):
353             #vifi, flags,threshold,rate_limit,lcl_addr,rmt_addr = unpack('HBBI4s4s', data)
354             return unpack('HBBI4s4s', data)
355         def mfcctl(data, unpack=struct.unpack):
356             #origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = unpack('4s4sH10sIIIi', data)
357             return unpack('4s4sH32sIIIi', data)
358         
359         
360         def add_vif(cmd):
361             vifi = vifctl(cmd)
362             vifs[vifi[0]] = vifi
363             addr_vifs[vifi[4]] = vifi[0]
364             print >>sys.stderr, "Added VIF", vifi
365         def del_vif(cmd):
366             vifi = vifctl(cmd)
367             vifi = vifs[vifi[0]]
368             del addr_vifs[vifi[4]]
369             del vifs[vifi[0]]
370             print >>sys.stderr, "Removed VIF", vifi
371         def add_mfc(cmd):
372             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
373             if parent in vifs:
374                 parent_addr = vifs[parent][4]
375             else:
376                 parent_addr = '\x00\x00\x00\x00'
377             addrinfo = origin + mcastgrp
378             rt_cache[addrinfo] = (ttls, parent_addr)
379             print >>sys.stderr, "Added RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp))), map(ord,ttls)
380         def del_mfc(cmd):
381             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
382             if parent in vifs:
383                 parent_addr = vifs[parent][4]
384             else:
385                 parent_addr = '\x00\x00\x00\x00'
386             addrinfo = origin + mcastgrp
387             del rt_cache[addrinfo]
388             print >>sys.stderr, "Removed RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp)))
389         
390         commands = {
391             MRT_ADD_VIF : add_vif,
392             MRT_DEL_VIF : del_vif,
393             MRT_ADD_MFC : add_mfc,
394             MRT_DEL_MFC : del_mfc,
395         }
396
397         while not self._stop:
398             if len_(buf) < 8 or len_(buf) < (cmdhdr(buf)[1]+8):
399                 # Get cmd
400                 try:
401                     cmd = router_socket.recv(2000)
402                 except socket.timeout, e:
403                     continue
404                 if not cmd:
405                     print >>sys.stderr, "PLRT CONNECTION BROKEN"
406                     TERMINATE.append(None)
407                     break
408             
409             if buf:
410                 buf += cmd
411                 cmd = buf
412             
413             if len_(cmd) < 8:
414                 continue
415             
416             op,dlen,data = cmdhdr(cmd)
417             if len_(data) < dlen:
418                 continue
419             
420             buf = buffer_(data, dlen)
421             data = buffer_(data, 0, dlen)
422             
423             print >>sys.stderr, "COMMAND", op, "DATA", dlen
424             
425             if op in commands:
426                 try:
427                     commands[op](data)
428                 except:
429                     traceback.print_exc(file=sys.stderr)
430             else:
431                 print >>sys.stderr, "IGNORING UNKNOWN COMMAND", op
432     
433     def stop(self):
434         self._stop = True
435         self.join(1+5*options.poll_delay)
436
437
438
439 igmp_threads = []
440 for vif_addr in remaining_args:
441     igmp_threads.append(IGMPThread(vif_addr))
442
443 rt_cache = {}
444 vifs = {}
445
446 TERMINATE = []
447 TERMINATE = []
448 def _finalize(sig,frame):
449     global TERMINATE
450     TERMINATE.append(None)
451 signal.signal(signal.SIGTERM, _finalize)
452
453
454 try:
455     if not options.announce_only and not options.no_router:
456         router_socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
457         router_socket.bind(options.mrt_path)
458         router_socket.listen(0)
459         router_remote_socket, router_remote_addr = router_socket.accept()
460         router_thread = RouterThread(rt_cache, router_remote_socket, vifs)
461     else:
462         router_remote_socket = None
463         router_thread = None
464
465     if not options.announce_only:
466         fwd_thread = FWDThread(rt_cache, router_remote_socket, vifs)
467
468     for thread in igmp_threads:
469         thread.start()
470     
471     if not options.announce_only:
472         fwd_thread.start()
473     if not options.no_router and not options.announce_only:
474         router_thread.start()
475
476     while not TERMINATE:
477         time.sleep(30)
478 finally:
479     if os.path.exists(options.mrt_path):
480         try:
481             os.remove(options.mrt_path)
482         except:
483             pass
484     if os.path.exists(options.fwd_path):
485         try:
486             os.remove(options.fwd_path)    
487         except:
488             pass
489
490