Merge with HEAD, close aly's branch.
[nepi.git] / src / nepi / testbeds / planetlab / scripts / mcastfwd.py
1 import sys
2
3 import signal
4 import socket
5 import struct
6 import optparse
7 import threading
8 import subprocess
9 import re
10 import time
11 import collections
12 import os
13 import traceback
14 import logging
15
16 import ipaddr2
17
18 usage = "usage: %prog [options] <enabled-addresses>"
19
20 parser = optparse.OptionParser(usage=usage)
21
22 parser.add_option(
23     "-d", "--poll-delay", dest="poll_delay", metavar="SECONDS", type="float",
24     default = 1.0,
25     help = "Multicast subscription polling interval")
26 parser.add_option(
27     "-D", "--refresh-delay", dest="refresh_delay", metavar="SECONDS", type="float",
28     default = 30.0,
29     help = "Full-refresh interval - time between full IGMP reports")
30 parser.add_option(
31     "-p", "--fwd-path", dest="fwd_path", metavar="PATH", 
32     default = "/var/run/mcastfwd",
33     help = "Path of the unix socket in which the program will listen for packets")
34 parser.add_option(
35     "-r", "--router-path", dest="mrt_path", metavar="PATH", 
36     default = "/var/run/mcastrt",
37     help = "Path of the unix socket in which the program will listen for routing changes")
38 parser.add_option(
39     "-A", "--announce-only", dest="announce_only", action="store_true",
40     default = False,
41     help = "If given, only group membership announcements will be made. Useful for non-router multicast nodes.")
42 parser.add_option(
43     "-v", "--verbose", dest="verbose", action="store_true",
44     default = False,
45     help = "Path of the unix socket in which the program will listen for routing changes")
46
47 (options, remaining_args) = parser.parse_args(sys.argv[1:])
48
49 logging.basicConfig(
50     stream=sys.stderr, 
51     level=logging.DEBUG if options.verbose else logging.WARNING)
52
53 ETH_P_ALL = 0x00000003
54 ETH_P_IP = 0x00000800
55 TUNSETIFF = 0x400454ca
56 IFF_NO_PI = 0x00001000
57 IFF_TAP = 0x00000002
58 IFF_TUN = 0x00000001
59 IFF_VNET_HDR = 0x00004000
60 TUN_PKT_STRIP = 0x00000001
61 IFHWADDRLEN = 0x00000006
62 IFNAMSIZ = 0x00000010
63 IFREQ_SZ = 0x00000028
64 FIONREAD = 0x0000541b
65
66 class IGMPThread(threading.Thread):
67     def __init__(self, vif_addr, *p, **kw):
68         super(IGMPThread, self).__init__(*p, **kw)
69         
70         vif_addr = vif_addr.strip()
71         self.vif_addr = vif_addr
72         self.igmp_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
73         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF,
74             socket.inet_aton(self.vif_addr) )
75         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
76         self.igmp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
77         self._stop = False
78         self.setDaemon(True)
79         
80         # Find tun name
81         proc = subprocess.Popen(['ip','addr','show'],
82             stdout = subprocess.PIPE,
83             stderr = subprocess.STDOUT,
84             stdin = open('/dev/null','r+b') )
85         tun_name = None
86         heading = re.compile(r"\d+:\s*([-a-zA-Z0-9_]+):.*")
87         addr = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}).*")
88         for line in proc.stdout:
89             match = heading.match(line)
90             if match:
91                 tun_name = match.group(1)
92             else:
93                 match = addr.match(line)
94                 if match and match.group(1) == vif_addr:
95                     self.tun_name = tun_name
96                     break
97         else:
98             raise RuntimeError, "Could not find iterface for", vif_addr
99     
100     def run(self):
101         devnull = open('/dev/null','r+b')
102         maddr_re = re.compile(r"\s*inet\s*(\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3})\s*")
103         cur_maddr = set()
104         lastfullrefresh = time.time()
105         while not self._stop:
106             # Get current subscriptions
107             proc = subprocess.Popen(['ip','maddr','show',self.tun_name],
108                 stdout = subprocess.PIPE,
109                 stderr = subprocess.STDOUT,
110                 stdin = devnull)
111             new_maddr = set()
112             for line in proc.stdout:
113                 match = maddr_re.match(line)
114                 if match:
115                     new_maddr.add(match.group(1))
116             proc.wait()
117             
118             # Every now and then, send a full report
119             now = time.time()
120             report_new = new_maddr
121             if (now - lastfullrefresh) <= options.refresh_delay:
122                 report_new = report_new - cur_maddr
123             else:
124                 lastfullrefresh = now
125             
126             # Report subscriptions
127             for grp in report_new:
128                 print >>sys.stderr, "JOINING", grp
129                 igmpp = ipaddr2.ipigmp(
130                     self.vif_addr, grp, 1, 0x16, 0, grp, 
131                     noipcksum=True)
132                 try:
133                     self.igmp_socket.sendto(igmpp, 0, (grp,0))
134                 except:
135                     traceback.print_exc(file=sys.stderr)
136
137             # Notify group leave
138             for grp in cur_maddr - new_maddr:
139                 print >>sys.stderr, "LEAVING", grp
140                 igmpp = ipaddr2.ipigmp(
141                     self.vif_addr, '224.0.0.2', 1, 0x17, 0, grp, 
142                     noipcksum=True)
143                 try:
144                     self.igmp_socket.sendto(igmpp, 0, ('224.0.0.2',0))
145                 except:
146                     traceback.print_exc(file=sys.stderr)
147
148             cur_maddr = new_maddr
149             
150             time.sleep(options.poll_delay)
151     
152     def stop(self):
153         self._stop = True
154         self.join(1+5*options.poll_delay)
155
156
157 class FWDThread(threading.Thread):
158     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
159         super(FWDThread, self).__init__(*p, **kw)
160         
161         self.in_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
162         self.in_socket.bind(options.fwd_path)
163         
164         self.pending = collections.deque()
165         self.maxpending = 1000
166         self.rt_cache = rt_cache
167         self.router_socket = router_socket
168         self.vifs = vifs
169         self.fwd_sockets = {}
170         for fwd_target in remaining_args:
171             fwd_target = socket.inet_aton(fwd_target)
172             fwd_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
173             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
174             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, fwd_target)
175             fwd_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
176             self.fwd_sockets[fwd_target] = fwd_socket
177         
178         self._stop = False
179         self.setDaemon(True)
180     
181     def run(self):
182         in_socket = self.in_socket
183         rt_cache = self.rt_cache
184         vifs = self.vifs
185         router_socket = self.router_socket
186         len_ = len
187         ord_ = ord
188         str_ = str
189         pending = self.pending
190         in_socket.settimeout(options.poll_delay)
191         buffer_ = buffer
192         enumerate_ = enumerate
193         fwd_sockets = self.fwd_sockets
194         npending = 0
195         noent = (None,None)
196         
197         while not self._stop:
198             # Get packet
199             try:
200                 if pending and npending:
201                     packet = pending.pop()
202                     npending -= 1
203                 else:
204                     packet = in_socket.recv(2000)
205             except socket.timeout, e:
206                 if pending and not npending:
207                     npending = len_(pending)
208                 continue
209             if not packet or len_(packet) < 24:
210                 continue
211             
212             fullpacket = packet
213             parent = packet[:4]
214             packet = buffer_(packet,4)
215             
216             if packet[9] == '\x02':
217                 # IGMP packet? It's for mrouted
218                 self.router_socket.send(packet)
219             elif packet[9] == '\x00':
220                 # LOOPING packet, discard
221                 continue
222             
223             # To-Do: PIM asserts
224             
225             # Get route
226             addrinfo = packet[12:20]
227             fwd_targets, rparent = rt_cache.get(addrinfo, noent)
228             
229             if fwd_targets is not None and (rparent == '\x00\x00\x00\x00' or rparent == parent):
230                 # Forward
231                 ttl = ord_(packet[8])
232                 tgt_group = (socket.inet_ntoa(addrinfo[4:]),0)
233                 print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ttl", ttl,
234                 nfwd_targets = len_(fwd_targets)
235                 for vifi, vif in vifs.iteritems():
236                     if vifi < nfwd_targets:
237                         ttl_thresh = ord_(fwd_targets[vifi])
238                         if ttl_thresh > 0 and ttl > ttl_thresh:
239                             if vif[4] in fwd_sockets:
240                                 print >>sys.stderr, socket.inet_ntoa(vif[4]),
241                                 fwd_socket = fwd_sockets[vif[4]]
242                                 fwd_socket.sendto(packet, 0, tgt_group)
243                 print >>sys.stderr, "."
244             else:
245                 # Mark pending
246                 if len_(pending) < self.maxpending:
247                     tgt_group = addrinfo[4:]
248                     print >>sys.stderr, map(socket.inet_ntoa, (parent, addrinfo[:4], addrinfo[4:])), "-> ?"
249                     
250                     pending.append(fullpacket)
251                     
252                     # Notify mrouted by forwarding it with protocol 0
253                     router_socket.send(''.join(
254                         (packet[:9],'\x00',packet[10:]) ))
255     
256     def stop(self):
257         self._stop = True
258         self.join(1+5*options.poll_delay)
259
260
261 class RouterThread(threading.Thread):
262     def __init__(self, rt_cache, router_socket, vifs, *p, **kw):
263         super(RouterThread, self).__init__(*p, **kw)
264         
265         self.rt_cache = rt_cache
266         self.vifs = vifs
267         self.router_socket = router_socket
268
269         self._stop = False
270         self.setDaemon(True)
271     
272     def run(self):
273         rt_cache = self.rt_cache
274         vifs = self.vifs
275         addr_vifs = {}
276         router_socket = self.router_socket
277         router_socket.settimeout(options.poll_delay)
278         len_ = len
279         buffer_ = buffer
280         
281         buf = ""
282         
283         MRT_BASE        = 200
284         MRT_ADD_VIF     = MRT_BASE+2    # Add a virtual interface               
285         MRT_DEL_VIF     = MRT_BASE+3    # Delete a virtual interface            
286         MRT_ADD_MFC     = MRT_BASE+4    # Add a multicast forwarding entry      
287         MRT_DEL_MFC = MRT_BASE+5        # Delete a multicast forwarding entry   
288         
289         def cmdhdr(cmd, unpack=struct.unpack, buffer=buffer):
290             op,dlen = unpack('II', buffer(cmd,0,8))
291             cmd = buffer(cmd,8)
292             return op,dlen,cmd
293         def vifctl(data, unpack=struct.unpack):
294             #vifi, flags,threshold,rate_limit,lcl_addr,rmt_addr = unpack('HBBI4s4s', data)
295             return unpack('HBBI4s4s', data)
296         def mfcctl(data, unpack=struct.unpack):
297             #origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = unpack('4s4sH10sIIIi', data)
298             return unpack('4s4sH32sIIIi', data)
299         
300         
301         def add_vif(cmd):
302             vifi = vifctl(cmd)
303             vifs[vifi[0]] = vifi
304             addr_vifs[vifi[4]] = vifi[0]
305             print >>sys.stderr, "Added VIF", vifi
306         def del_vif(cmd):
307             vifi = vifctl(cmd)
308             vifi = vifs[vifi[0]]
309             del addr_vifs[vifi[4]]
310             del vifs[vifi[0]]
311             print >>sys.stderr, "Removed VIF", vifi
312         def add_mfc(cmd):
313             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
314             if parent in vifs:
315                 parent_addr = vifs[parent][4]
316             else:
317                 parent_addr = '\x00\x00\x00\x00'
318             addrinfo = origin + mcastgrp
319             rt_cache[addrinfo] = (ttls, parent_addr)
320             print >>sys.stderr, "Added RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp))), map(ord,ttls)
321         def del_mfc(cmd):
322             origin,mcastgrp,parent,ttls,pkt_cnt,byte_cnt,wrong_if,expire = mfcctl(data)
323             if parent in vifs:
324                 parent_addr = vifs[parent][4]
325             else:
326                 parent_addr = '\x00\x00\x00\x00'
327             addrinfo = origin + mcastgrp
328             del rt_cache[addrinfo]
329             print >>sys.stderr, "Removed RT", '-'.join(map(socket.inet_ntoa,(parent_addr,origin,mcastgrp)))
330         
331         commands = {
332             MRT_ADD_VIF : add_vif,
333             MRT_DEL_VIF : del_vif,
334             MRT_ADD_MFC : add_mfc,
335             MRT_DEL_MFC : del_mfc,
336         }
337
338         while not self._stop:
339             if len_(buf) < 8 or len_(buf) < (cmdhdr(buf)[1]+8):
340                 # Get cmd
341                 try:
342                     cmd = router_socket.recv(2000)
343                 except socket.timeout, e:
344                     continue
345                 if not cmd:
346                     print >>sys.stderr, "PLRT CONNECTION BROKEN"
347                     TERMINATE.append(None)
348                     break
349             
350             if buf:
351                 buf += cmd
352                 cmd = buf
353             
354             if len_(cmd) < 8:
355                 continue
356             
357             op,dlen,data = cmdhdr(cmd)
358             if len_(data) < dlen:
359                 continue
360             
361             buf = buffer_(data, dlen)
362             data = buffer_(data, 0, dlen)
363             
364             print >>sys.stderr, "COMMAND", op, "DATA", dlen
365             
366             if op in commands:
367                 try:
368                     commands[op](data)
369                 except:
370                     traceback.print_exc(file=sys.stderr)
371             else:
372                 print >>sys.stderr, "IGNORING UNKNOWN COMMAND", op
373     
374     def stop(self):
375         self._stop = True
376         self.join(1+5*options.poll_delay)
377
378
379
380 igmp_threads = []
381 for vif_addr in remaining_args:
382     igmp_threads.append(IGMPThread(vif_addr))
383
384 rt_cache = {}
385 vifs = {}
386
387 TERMINATE = []
388 TERMINATE = []
389 def _finalize(sig,frame):
390     global TERMINATE
391     TERMINATE.append(None)
392 signal.signal(signal.SIGTERM, _finalize)
393
394
395 try:
396     if not options.announce_only:
397         router_socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
398         router_socket.bind(options.mrt_path)
399         router_socket.listen(0)
400         router_remote_socket, router_remote_addr = router_socket.accept()
401
402         fwd_thread = FWDThread(rt_cache, router_remote_socket, vifs)
403         router_thread = RouterThread(rt_cache, router_remote_socket, vifs)
404
405     for thread in igmp_threads:
406         thread.start()
407     
408     if not options.announce_only:
409         fwd_thread.start()
410         router_thread.start()
411
412     while not TERMINATE:
413         time.sleep(30)
414 finally:
415     if os.path.exists(options.mrt_path):
416         try:
417             os.remove(options.mrt_path)
418         except:
419             pass
420     if os.path.exists(options.fwd_path):
421         try:
422             os.remove(options.fwd_path)    
423         except:
424             pass
425
426