Avoid flapping subscriptions
[nepi.git] / src / nepi / testbeds / planetlab / scripts / mcastfwd.py
1 import sys
2
3 import signal
4 import socket
5 import struct
6 import optparse
7 import threading
8 import subprocess
9 import re
10 import time
11 import collections
12 import os
13 import traceback
14 import logging
15
16 import ipaddr2
17
18 usage = "usage: %prog [options] <enabled-addresses>"
19
20 parser = optparse.OptionParser(usage=usage)
21
22 parser.add_option(
23     "-d", "--poll-delay", dest="poll_delay", metavar="SECONDS", type="float",
24     default = 1.0,
25     help = "Multicast subscription polling interval")
26 parser.add_option(
27     "-D", "--refresh-delay", dest="refresh_delay", metavar="SECONDS", type="float",
28     default = 30.0,
29     help = "Full-refresh interval - time between full IGMP reports")
30 parser.add_option(
31     "-p", "--fwd-path", dest="fwd_path", metavar="PATH", 
32     default = "/var/run/mcastfwd",
33     help = "Path of the unix socket in which the program will listen for packets")
34 parser.add_option(
35     "-r", "--router-path", dest="mrt_path", metavar="PATH", 
36     default = "/var/run/mcastrt",
37     help = "Path of the unix socket in which the program will listen for routing changes")
38 parser.add_option(
39     "-A", "--announce-only", dest="announce_only", action="store_true",
40     default = False,
41     help = "If given, only group membership announcements will be made. "
42            "Useful for non-router non-member multicast nodes.")
43 parser.add_option(
44     "-R", "--no-router", dest="no_router", action="store_true",
45     default = False,
46     help = "If given, only group membership announcements and forwarding to the default multicast egress will be made. "
47            "Useful for non-router but member multicast nodes.")
48 parser.add_option(
49     "-v", "--verbose", dest="verbose", action="store_true",
50     default = False,
51     help = "Log more verbosely")
52
53 (options, remaining_args) = parser.parse_args(sys.argv[1:])
54
55 logging.basicConfig(
56     stream=sys.stderr, 
57     level=logging.DEBUG if options.verbose else logging.WARNING)
58
59 ETH_P_ALL = 0x00000003
60 ETH_P_IP = 0x00000800
61 TUNSETIFF = 0x400454ca
62 IFF_NO_PI = 0x00001000
63 IFF_TAP = 0x00000002
64 IFF_TUN = 0x00000001
65 IFF_VNET_HDR = 0x00004000
66 TUN_PKT_STRIP = 0x00000001
67 IFHWADDRLEN = 0x00000006
68 IFNAMSIZ = 0x00000010
69 IFREQ_SZ = 0x00000028
70 FIONREAD = 0x0000541b
71
72 class IGMPThread(threading.Thread):
73     def __init__(self, vif_addr, *p, **kw):
74         super(IGMPThread, self).__init__(*p, **kw)
75         
76         vif_addr = vif_addr.strip()
77         self.vif_addr = vif_addr
78         self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
79         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
80             socket.inet_aton(self.vif_addr) )
81         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
82         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
83         self._stop = False
84         self.setDaemon(True)
85         
86         # Find tun name
87         proc = subprocess.Popen(['ip','addr','show'],
88             stdout = subprocess.PIPE,
89             stderr = subprocess.STDOUT,
90             stdin = open('/dev/null','r+b') )
91         tun_name = None
92         heading = re.compile(r"\d+:\s*([-a-zA-Z0-9_]+):.*")
93         addr = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}).*")
94         for line in proc.stdout:
95             match = heading.match(line)
96             if match:
97                 tun_name = match.group(1)
98             else:
99                 match = addr.match(line)
100                 if match and match.group(1) == vif_addr:
101                     self.tun_name = tun_name
102                     break
103         else:
104             raise RuntimeError, "Could not find iterface for", vif_addr
105     
106     def run(self):
107         devnull = open('/dev/null','r+b')
108         maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
109         cur_maddr = set()
110         lastfullrefresh = time.time()
111         vif_addr_i = socket.inet_aton(self.vif_addr)
112         while not self._stop:
113             mirror_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
114             
115             for i in xrange(5):
116                 # Get current subscriptions @ vif
117                 proc = subprocess.Popen(['ip','maddr','show',self.tun_name],
118                     stdout = subprocess.PIPE,
119                     stderr = subprocess.STDOUT,
120                     stdin = devnull)
121                 new_maddr = set()
122                 for line in proc.stdout:
123                     match = maddr_re.match(line)
124                     if match:
125                         new_maddr.add(match.group(1))
126                 proc.wait()
127                 if new_maddr:
128                     break
129             
130             for i in xrange(5):
131                 # Get current subscriptions @ eth0 (default on PL),
132                 # they should be considered "universal" suscriptions.
133                 proc = subprocess.Popen(['ip','maddr','show', 'eth0'],
134                     stdout = subprocess.PIPE,
135                     stderr = subprocess.STDOUT,
136                     stdin = devnull)
137                 eth_maddr = set()
138                 for line in proc.stdout:
139                     match = maddr_re.match(line)
140                     if match:
141                         eth_maddr.add(match.group(1))
142                 proc.wait()
143                 
144                 if eth_maddr:
145                     for maddr in eth_maddr:
146                         try:
147                             mirror_socket.setsockopt(
148                                 socket.IPPROTO_IP,
149                                 socket.IP_ADD_MEMBERSHIP,
150                                 socket.inet_aton(maddr)+vif_addr_i )
151                         except:
152                             traceback.print_exc(file=sys.stderr)
153                     new_maddr.update(eth_maddr)
154                     break
155             
156             # Every now and then, send a full report
157             now = time.time()
158             report_new = new_maddr
159             if (now - lastfullrefresh) <= options.refresh_delay:
160                 report_new = report_new - cur_maddr
161             else:
162                 lastfullrefresh = now
163             
164             # Report subscriptions
165             for grp in report_new:
166                 print >>sys.stderr, "JOINING", grp
167                 igmpp = ipaddr2.ipigmp(
168                     self.vif_addr, grp, 1, 0x16, 0, grp, 
169                     noipcksum=True)
170                 try:
171                     self.igmp_socket.sendto(igmpp, 0, (grp,0))
172                 except:
173                     traceback.print_exc(file=sys.stderr)
174
175             # Notify group leave
176             for grp in cur_maddr - new_maddr:
177                 print >>sys.stderr, "LEAVING", grp
178                 igmpp = ipaddr2.ipigmp(
179                     self.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
180                     noipcksum=True)
181                 try:
182                     self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
183                 except:
184                     traceback.print_exc(file=sys.stderr)
185
186             cur_maddr = new_maddr
187             
188             time.sleep(options.poll_delay)
189     
190     def stop(self):
191         self._stop = True
192         self.join(1+5*options.poll_delay)
193
194
195 class FWDThread(threading.Thread):
196     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
197         super(FWDThread, self).__init__(*p, **kw)
198         
199         self.in_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
200         self.in_socket.bind(options.fwd_path)
201         
202         self.pending = collections.deque()
203         self.maxpending = 1000
204         self.rt_cache = rt_cache
205         self.router_socket = router_socket
206         self.vifs = vifs
207         
208         # prepare forwarding sockets 
209         self.fwd_sockets = {}
210         for fwd_target in remaining_args:
211             fwd_target = socket.inet_aton(fwd_target)
212             fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
213             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
214             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
215             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
216             self.fwd_sockets[fwd_target] = fwd_socket
217         
218         self._stop = False
219         self.setDaemon(True)
220     
221     def run(self):
222         in_socket = self.in_socket
223         rt_cache = self.rt_cache
224         vifs = self.vifs
225         router_socket = self.router_socket
226         len_ = len
227         ord_ = ord
228         str_ = str
229         pending = self.pending
230         in_socket.settimeout(options.poll_delay)
231         buffer_ = buffer
232         enumerate_ = enumerate
233         fwd_sockets = self.fwd_sockets
234         npending = 0
235         npendingpop = 0
236         noent = (None,None)
237         verbose = options.verbose
238         
239         while not self._stop:
240             # Get packet
241             try:
242                 if pending and npending:
243                     packet = pending.pop()
244                     npending -= 1
245                     npendingpop += 1
246                     if npendingpop > 10:
247                         # Don't hurry too much, 
248                         # we'll saturate the kernel's queue
249                         time.sleep(0)
250                 else:
251                     npendingpop = 0
252                     packet = in_socket.recv(2000)
253             except socket.timeout, e:
254                 if pending and not npending:
255                     npending = len_(pending)
256                 continue
257             if not packet or len_(packet) < 24:
258                 continue
259             
260             fullpacket = packet
261             parent = packet[:4]
262             packet = buffer_(packet,4)
263             
264             if packet[9] == '\x02':
265                 # IGMP packet? It's for mrouted
266                 # unless it's coming from it
267                 # NOTE: mrouted already picks it up when it's sent
268                 #       to the virtual interface. Injecting it would
269                 #       only duplicate it.
270                 #if router_socket and packet[12:16] not in fwd_sockets:
271                 #    try:
272                 #        router_socket.send(packet)
273                 #    except:
274                 #        traceback.print_exc(file=sys.stderr)
275                 continue
276             elif packet[9] == '\x00':
277                 # LOOPING packet, discard
278                 continue
279             
280             # To-Do: PIM asserts
281             
282             # Get route
283             addrinfo = packet[12:20]
284             fwd_targets, rparent = rt_cache.get(addrinfo, noent)
285             
286             if fwd_targets is not None and (rparent == '\x00\x00\x00\x00' or rparent == parent):
287                 # Forward to vifs
288                 ttl = ord_(packet[8])
289                 tgt_group = (socket.inet_ntoa(addrinfo[4:]),0)
290                 if verbose:
291                     print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ttl", ttl,
292                 nfwd_targets = len_(fwd_targets)
293                 for vifi, vif in vifs.iteritems():
294                     if vifi < nfwd_targets:
295                         ttl_thresh = ord_(fwd_targets[vifi])
296                         if ttl_thresh > 0 and ttl > ttl_thresh:
297                             if vif[4] in fwd_sockets:
298                                 try:
299                                     if verbose:
300                                         print >>sys.stderr, socket.inet_ntoa(vif[4]),
301                                     fwd_socket = fwd_sockets[vif[4]]
302                                     fwd_socket.sendto(packet, 0, tgt_group)
303                                 except Exception,e:
304                                     print >>sys.stderr, "ERROR: forwarding packet:", str(e)
305                 
306                 if verbose:
307                     print >>sys.stderr, "."
308             elif router_socket:
309                 # Mark pending
310                 if len_(pending) < self.maxpending:
311                     if verbose:
312                         tgt_group = addrinfo[4:]
313                         print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ?"
314                     
315                     pending.append(fullpacket)
316                     
317                     # Notify mrouted by forwarding it with protocol 0
318                     router_socket.send(''.join(
319                         (packet[:9],'\x00',packet[10:]) ))
320     
321     def stop(self):
322         self._stop = True
323         self.join(1+5*options.poll_delay)
324
325
326 class RouterThread(threading.Thread):
327     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
328         super(RouterThread, self).__init__(*p, **kw)
329         
330         self.rt_cache = rt_cache
331         self.vifs = vifs
332         self.router_socket = router_socket
333
334         self._stop = False
335         self.setDaemon(True)
336     
337     def run(self):
338         rt_cache = self.rt_cache
339         vifs = self.vifs
340         addr_vifs = {}
341         router_socket = self.router_socket
342         router_socket.settimeout(options.poll_delay)
343         len_ = len
344         buffer_ = buffer
345         
346         buf = ""
347         
348         MRT_BASE        = 200
349         MRT_ADD_VIF     = MRT_BASE+2    # Add a virtual interface               
350         MRT_DEL_VIF     = MRT_BASE+3    # Delete a virtual interface            
351         MRT_ADD_MFC     = MRT_BASE+4    # Add a multicast forwarding entry      
352         MRT_DEL_MFC = MRT_BASE+5        # Delete a multicast forwarding entry   
353         
354         def cmdhdr(cmd, unpack=struct.unpack, buffer=buffer):
355             op,dlen = unpack('II', buffer(cmd,0,8))
356             cmd = buffer(cmd,8)
357             return op,dlen,cmd
358         def vifctl(data, unpack=struct.unpack):
359             #vifi, flags,threshold,rate_limit,lcl_addr,rmt_addr = unpack('HBBI4s4s', data)
360             return unpack('HBBI4s4s', data)
361         def mfcctl(data, unpack=struct.unpack):
362             #origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = unpack('4s4sH10sIIIi', data)
363             return unpack('4s4sH32sIIIi', data)
364         
365         
366         def add_vif(cmd):
367             vifi = vifctl(cmd)
368             vifs[vifi[0]] = vifi
369             addr_vifs[vifi[4]] = vifi[0]
370             print >>sys.stderr, "Added VIF", vifi
371         def del_vif(cmd):
372             vifi = vifctl(cmd)
373             vifi = vifs[vifi[0]]
374             del addr_vifs[vifi[4]]
375             del vifs[vifi[0]]
376             print >>sys.stderr, "Removed VIF", vifi
377         def add_mfc(cmd):
378             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
379             if parent in vifs:
380                 parent_addr = vifs[parent][4]
381             else:
382                 parent_addr = '\x00\x00\x00\x00'
383             addrinfo = origin + mcastgrp
384             rt_cache[addrinfo] = (ttls, parent_addr)
385             print >>sys.stderr, "Added RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp))), map(ord,ttls)
386         def del_mfc(cmd):
387             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
388             if parent in vifs:
389                 parent_addr = vifs[parent][4]
390             else:
391                 parent_addr = '\x00\x00\x00\x00'
392             addrinfo = origin + mcastgrp
393             del rt_cache[addrinfo]
394             print >>sys.stderr, "Removed RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp)))
395         
396         commands = {
397             MRT_ADD_VIF : add_vif,
398             MRT_DEL_VIF : del_vif,
399             MRT_ADD_MFC : add_mfc,
400             MRT_DEL_MFC : del_mfc,
401         }
402
403         while not self._stop:
404             if len_(buf) < 8 or len_(buf) < (cmdhdr(buf)[1]+8):
405                 # Get cmd
406                 try:
407                     cmd = router_socket.recv(2000)
408                 except socket.timeout, e:
409                     continue
410                 if not cmd:
411                     print >>sys.stderr, "PLRT CONNECTION BROKEN"
412                     TERMINATE.append(None)
413                     break
414             
415             if buf:
416                 buf += cmd
417                 cmd = buf
418             
419             if len_(cmd) < 8:
420                 continue
421             
422             op,dlen,data = cmdhdr(cmd)
423             if len_(data) < dlen:
424                 continue
425             
426             buf = buffer_(data, dlen)
427             data = buffer_(data, 0, dlen)
428             
429             print >>sys.stderr, "COMMAND", op, "DATA", dlen
430             
431             if op in commands:
432                 try:
433                     commands[op](data)
434                 except:
435                     traceback.print_exc(file=sys.stderr)
436             else:
437                 print >>sys.stderr, "IGNORING UNKNOWN COMMAND", op
438     
439     def stop(self):
440         self._stop = True
441         self.join(1+5*options.poll_delay)
442
443
444
445 igmp_threads = []
446 valid_vifs = []
447 for vif_addr in remaining_args:
448     try:
449         igmp_threads.append(IGMPThread(vif_addr))
450         valid_vifs.append(vif_addr)
451     except:
452         traceback.print_exc()
453         print >>sys.stderr, "WARNING: could not listen on interface", vif_addr
454
455 remaining_args = valid_vifs
456
457 rt_cache = {}
458 vifs = {}
459
460 TERMINATE = []
461 TERMINATE = []
462 def _finalize(sig,frame):
463     global TERMINATE
464     TERMINATE.append(None)
465 signal.signal(signal.SIGTERM, _finalize)
466
467
468 try:
469     if not options.announce_only and not options.no_router:
470         router_socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
471         router_socket.bind(options.mrt_path)
472         router_socket.listen(0)
473         router_remote_socket, router_remote_addr = router_socket.accept()
474         router_thread = RouterThread(rt_cache, router_remote_socket, vifs)
475     else:
476         router_remote_socket = None
477         router_thread = None
478
479     if not options.announce_only:
480         fwd_thread = FWDThread(rt_cache, router_remote_socket, vifs)
481
482     for thread in igmp_threads:
483         thread.start()
484     
485     if not options.announce_only:
486         fwd_thread.start()
487     if not options.no_router and not options.announce_only:
488         router_thread.start()
489
490     while not TERMINATE:
491         time.sleep(30)
492 finally:
493     if os.path.exists(options.mrt_path):
494         try:
495             os.remove(options.mrt_path)
496         except:
497             pass
498     if os.path.exists(options.fwd_path):
499         try:
500             os.remove(options.fwd_path)    
501         except:
502             pass
503
504