Working (and easy to use) multicast forwarding
[nepi.git] / src / nepi / testbeds / planetlab / node.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 import plcapi
6 import operator
7 import rspawn
8 import time
9 import os
10 import collections
11 import cStringIO
12 import resourcealloc
13 import socket
14 import sys
15 import logging
16 import ipaddr
17 import operator
18
19 from nepi.util import server
20 from nepi.util import parallel
21
22 import application
23
24 MAX_VROUTE_ROUTES = 5
25
26 class UnresponsiveNodeError(RuntimeError):
27     pass
28
29 def _castproperty(typ, propattr):
30     def _get(self):
31         return getattr(self, propattr)
32     def _set(self, value):
33         if value is not None or (isinstance(value, basestring) and not value):
34             value = typ(value)
35         return setattr(self, propattr, value)
36     def _del(self, value):
37         return delattr(self, propattr)
38     _get.__name__ = propattr + '_get'
39     _set.__name__ = propattr + '_set'
40     _del.__name__ = propattr + '_del'
41     return property(_get, _set, _del)
42
43 class Node(object):
44     BASEFILTERS = {
45         # Map Node attribute to plcapi filter name
46         'hostname' : 'hostname',
47     }
48     
49     TAGFILTERS = {
50         # Map Node attribute to (<tag name>, <plcapi filter expression>)
51         #   There are replacements that are applied with string formatting,
52         #   so '%' has to be escaped as '%%'.
53         'architecture' : ('arch','value'),
54         'operatingSystem' : ('fcdistro','value'),
55         'pl_distro' : ('pldistro','value'),
56         'city' : ('city','value'),
57         'country' : ('country','value'),
58         'region' : ('region','value'),
59         'minReliability' : ('reliability%(timeframe)s', ']value'),
60         'maxReliability' : ('reliability%(timeframe)s', '[value'),
61         'minBandwidth' : ('bw%(timeframe)s', ']value'),
62         'maxBandwidth' : ('bw%(timeframe)s', '[value'),
63         'minLoad' : ('load%(timeframe)s', ']value'),
64         'maxLoad' : ('load%(timeframe)s', '[value'),
65         'minCpu' : ('cpu%(timeframe)s', ']value'),
66         'maxCpu' : ('cpu%(timeframe)s', '[value'),
67     }    
68     
69     DEPENDS_PIDFILE = '/tmp/nepi-depends.pid'
70     DEPENDS_LOGFILE = '/tmp/nepi-depends.log'
71     RPM_FUSION_URL = 'http://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-stable.noarch.rpm'
72     RPM_FUSION_URL_F12 = 'http://download1.rpmfusion.org/free/fedora/releases/12/Everything/x86_64/os/rpmfusion-free-release-12-1.noarch.rpm'
73     
74     minReliability = _castproperty(float, '_minReliability')
75     maxReliability = _castproperty(float, '_maxReliability')
76     minBandwidth = _castproperty(float, '_minBandwidth')
77     maxBandwidth = _castproperty(float, '_maxBandwidth')
78     minCpu = _castproperty(float, '_minCpu')
79     maxCpu = _castproperty(float, '_maxCpu')
80     minLoad = _castproperty(float, '_minLoad')
81     maxLoad = _castproperty(float, '_maxLoad')
82     
83     def __init__(self, api=None):
84         if not api:
85             api = plcapi.PLCAPI()
86         self._api = api
87         
88         # Attributes
89         self.hostname = None
90         self.architecture = None
91         self.operatingSystem = None
92         self.pl_distro = None
93         self.site = None
94         self.city = None
95         self.country = None
96         self.region = None
97         self.minReliability = None
98         self.maxReliability = None
99         self.minBandwidth = None
100         self.maxBandwidth = None
101         self.minCpu = None
102         self.maxCpu = None
103         self.minLoad = None
104         self.maxLoad = None
105         self.min_num_external_ifaces = None
106         self.max_num_external_ifaces = None
107         self.timeframe = 'm'
108         
109         # Applications and routes add requirements to connected nodes
110         self.required_packages = set()
111         self.required_vsys = set()
112         self.pythonpath = []
113         self.rpmFusion = False
114         self.env = collections.defaultdict(list)
115         
116         # Some special applications - initialized when connected
117         self.multicast_forwarder = None
118         
119         # Testbed-derived attributes
120         self.slicename = None
121         self.ident_path = None
122         self.server_key = None
123         self.home_path = None
124         self.enable_cleanup = False
125         
126         # Those are filled when an actual node is allocated
127         self._node_id = None
128         self._yum_dependencies = None
129         self._installed = False
130
131         # Logging
132         self._logger = logging.getLogger('nepi.testbeds.planetlab')
133     
134     def _nepi_testbed_environment_setup_get(self):
135         command = cStringIO.StringIO()
136         command.write('export PYTHONPATH=$PYTHONPATH:%s' % (
137             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
138         ))
139         command.write(' ; export PATH=$PATH:%s' % (
140             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
141         ))
142         if self.env:
143             for envkey, envvals in self.env.iteritems():
144                 for envval in envvals:
145                     command.write(' ; export %s=%s' % (envkey, envval))
146         return command.getvalue()
147     def _nepi_testbed_environment_setup_set(self, value):
148         pass
149     _nepi_testbed_environment_setup = property(
150         _nepi_testbed_environment_setup_get,
151         _nepi_testbed_environment_setup_set)
152     
153     def build_filters(self, target_filters, filter_map):
154         for attr, tag in filter_map.iteritems():
155             value = getattr(self, attr, None)
156             if value is not None:
157                 target_filters[tag] = value
158         return target_filters
159     
160     @property
161     def applicable_filters(self):
162         has = lambda att : getattr(self,att,None) is not None
163         return (
164             filter(has, self.BASEFILTERS.iterkeys())
165             + filter(has, self.TAGFILTERS.iterkeys())
166         )
167     
168     def find_candidates(self, filter_slice_id=None):
169         self._logger.info("Finding candidates for %s", self.make_filter_description())
170         
171         fields = ('node_id',)
172         replacements = {'timeframe':self.timeframe}
173         
174         # get initial candidates (no tag filters)
175         basefilters = self.build_filters({}, self.BASEFILTERS)
176         rootfilters = basefilters.copy()
177         if filter_slice_id:
178             basefilters['|slice_ids'] = (filter_slice_id,)
179         
180         # only pick healthy nodes
181         basefilters['run_level'] = 'boot'
182         basefilters['boot_state'] = 'boot'
183         basefilters['node_type'] = 'regular' # nepi can only handle regular nodes (for now)
184         basefilters['>last_contact'] = int(time.time()) - 5*3600 # allow 5h out of contact, for timezone discrepancies
185         
186         # keyword-only "pseudofilters"
187         extra = {}
188         if self.site:
189             extra['peer'] = self.site
190             
191         candidates = set(map(operator.itemgetter('node_id'), 
192             self._api.GetNodes(filters=basefilters, fields=fields, **extra)))
193         
194         # filter by tag, one tag at a time
195         applicable = self.applicable_filters
196         for tagfilter in self.TAGFILTERS.iteritems():
197             attr, (tagname, expr) = tagfilter
198             
199             # don't bother if there's no filter defined
200             if attr in applicable:
201                 tagfilter = rootfilters.copy()
202                 tagfilter['tagname'] = tagname % replacements
203                 tagfilter[expr % replacements] = getattr(self,attr)
204                 tagfilter['node_id'] = list(candidates)
205                 
206                 candidates &= set(map(operator.itemgetter('node_id'),
207                     self._api.GetNodeTags(filters=tagfilter, fields=fields)))
208         
209         # filter by vsys tags - special case since it doesn't follow
210         # the usual semantics
211         if self.required_vsys:
212             newcandidates = collections.defaultdict(set)
213             
214             vsys_tags = self._api.GetNodeTags(
215                 tagname='vsys', 
216                 node_id = list(candidates), 
217                 fields = ['node_id','value'])
218             
219             vsys_tags = map(
220                 operator.itemgetter(['node_id','value']),
221                 vsys_tags)
222             
223             required_vsys = self.required_vsys
224             for node_id, value in vsys_tags:
225                 if value in required_vsys:
226                     newcandidates[value].add(node_id)
227             
228             # take only those that have all the required vsys tags
229             newcandidates = reduce(
230                 lambda accum, new : accum & new,
231                 newcandidates.itervalues(),
232                 candidates)
233         
234         # filter by iface count
235         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
236             # fetch interfaces for all, in one go
237             filters = basefilters.copy()
238             filters['node_id'] = list(candidates)
239             ifaces = dict(map(operator.itemgetter('node_id','interface_ids'),
240                 self._api.GetNodes(filters=basefilters, fields=('node_id','interface_ids')) ))
241             
242             # filter candidates by interface count
243             if self.min_num_external_ifaces is not None and self.max_num_external_ifaces is not None:
244                 predicate = ( lambda node_id : 
245                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
246             elif self.min_num_external_ifaces is not None:
247                 predicate = ( lambda node_id : 
248                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) )
249             else:
250                 predicate = ( lambda node_id : 
251                     len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
252             
253             candidates = set(filter(predicate, candidates))
254         
255         # make sure hostnames are resolvable
256         if candidates:
257             self._logger.info("  Found %s candidates. Checking for reachability...", len(candidates))
258             
259             hostnames = dict(map(operator.itemgetter('node_id','hostname'),
260                 self._api.GetNodes(list(candidates), ['node_id','hostname'])
261             ))
262             def resolvable(node_id):
263                 try:
264                     addr = socket.gethostbyname(hostnames[node_id])
265                     return addr is not None
266                 except:
267                     return False
268             candidates = set(parallel.pfilter(resolvable, candidates,
269                 maxthreads = 16))
270
271             self._logger.info("  Found %s reachable candidates.", len(candidates))
272             
273         return candidates
274     
275     def make_filter_description(self):
276         """
277         Makes a human-readable description of filtering conditions
278         for find_candidates.
279         """
280         
281         # get initial candidates (no tag filters)
282         filters = self.build_filters({}, self.BASEFILTERS)
283         
284         # keyword-only "pseudofilters"
285         if self.site:
286             filters['peer'] = self.site
287             
288         # filter by tag, one tag at a time
289         applicable = self.applicable_filters
290         for tagfilter in self.TAGFILTERS.iteritems():
291             attr, (tagname, expr) = tagfilter
292             
293             # don't bother if there's no filter defined
294             if attr in applicable:
295                 filters[attr] = getattr(self,attr)
296         
297         # filter by vsys tags - special case since it doesn't follow
298         # the usual semantics
299         if self.required_vsys:
300             filters['vsys'] = ','.join(list(self.required_vsys))
301         
302         # filter by iface count
303         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
304             filters['num_ifaces'] = '-'.join([
305                 str(self.min_num_external_ifaces or '0'),
306                 str(self.max_num_external_ifaces or 'inf')
307             ])
308             
309         return '; '.join(map('%s: %s'.__mod__,filters.iteritems()))
310
311     def assign_node_id(self, node_id):
312         self._node_id = node_id
313         self.fetch_node_info()
314     
315     def unassign_node(self):
316         self._node_id = None
317         self.__dict__.update(self.__orig_attrs)
318     
319     def fetch_node_info(self):
320         orig_attrs = {}
321         
322         info = self._api.GetNodes(self._node_id)[0]
323         tags = dict( (t['tagname'],t['value'])
324                      for t in self._api.GetNodeTags(node_id=self._node_id, fields=('tagname','value')) )
325
326         orig_attrs['min_num_external_ifaces'] = self.min_num_external_ifaces
327         orig_attrs['max_num_external_ifaces'] = self.max_num_external_ifaces
328         self.min_num_external_ifaces = None
329         self.max_num_external_ifaces = None
330         self.timeframe = 'm'
331         
332         replacements = {'timeframe':self.timeframe}
333         for attr, tag in self.BASEFILTERS.iteritems():
334             if tag in info:
335                 value = info[tag]
336                 if hasattr(self, attr):
337                     orig_attrs[attr] = getattr(self, attr)
338                 setattr(self, attr, value)
339         for attr, (tag,_) in self.TAGFILTERS.iteritems():
340             tag = tag % replacements
341             if tag in tags:
342                 value = tags[tag]
343                 if hasattr(self, attr):
344                     orig_attrs[attr] = getattr(self, attr)
345                 setattr(self, attr, value)
346         
347         if 'peer_id' in info:
348             orig_attrs['site'] = self.site
349             self.site = self._api.peer_map[info['peer_id']]
350         
351         if 'interface_ids' in info:
352             self.min_num_external_ifaces = \
353             self.max_num_external_ifaces = len(info['interface_ids'])
354         
355         if 'ssh_rsa_key' in info:
356             orig_attrs['server_key'] = self.server_key
357             self.server_key = info['ssh_rsa_key']
358         
359         self.__orig_attrs = orig_attrs
360
361     def validate(self):
362         if self.home_path is None:
363             raise AssertionError, "Misconfigured node: missing home path"
364         if self.ident_path is None or not os.access(self.ident_path, os.R_OK):
365             raise AssertionError, "Misconfigured node: missing slice SSH key"
366         if self.slicename is None:
367             raise AssertionError, "Misconfigured node: unspecified slice"
368
369     def recover(self):
370         # Mark dependencies installed
371         self._installed = True
372         
373         # Clear load attributes, they impair re-discovery
374         self.minReliability = \
375         self.maxReliability = \
376         self.minBandwidth = \
377         self.maxBandwidth = \
378         self.minCpu = \
379         self.maxCpu = \
380         self.minLoad = \
381         self.maxLoad = None
382
383     def install_dependencies(self):
384         if self.required_packages and not self._installed:
385             # If we need rpmfusion, we must install the repo definition and the gpg keys
386             if self.rpmFusion:
387                 if self.operatingSystem == 'f12':
388                     # Fedora 12 requires a different rpmfusion package
389                     RPM_FUSION_URL = self.RPM_FUSION_URL_F12
390                 else:
391                     # This one works for f13+
392                     RPM_FUSION_URL = self.RPM_FUSION_URL
393                     
394                 rpmFusion = (
395                   '( rpm -q $(rpm -q -p %(RPM_FUSION_URL)s) || rpm -i %(RPM_FUSION_URL)s ) &&'
396                 ) % {
397                     'RPM_FUSION_URL' : RPM_FUSION_URL
398                 }
399             else:
400                 rpmFusion = ''
401             
402             if rpmFusion:
403                 (out,err),proc = server.popen_ssh_command(
404                     rpmFusion,
405                     host = self.hostname,
406                     port = None,
407                     user = self.slicename,
408                     agent = None,
409                     ident_key = self.ident_path,
410                     server_key = self.server_key,
411                     timeout = 600,
412                     )
413                 
414                 if proc.wait():
415                     raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
416             
417             # Launch p2p yum dependency installer
418             self._yum_dependencies.async_setup()
419     
420     def wait_provisioning(self, timeout = 20*60):
421         # Wait for the p2p installer
422         sleeptime = 1.0
423         totaltime = 0.0
424         while not self.is_alive():
425             time.sleep(sleeptime)
426             totaltime += sleeptime
427             sleeptime = min(30.0, sleeptime*1.5)
428             
429             if totaltime > timeout:
430                 # PlanetLab has a 15' delay on configuration propagation
431                 # If we're above that delay, the unresponsiveness is not due
432                 # to this delay.
433                 raise UnresponsiveNodeError, "Unresponsive host %s" % (self.hostname,)
434         
435         # Ensure the node is clean (no apps running that could interfere with operations)
436         if self.enable_cleanup:
437             self.do_cleanup()
438     
439     def wait_dependencies(self, pidprobe=1, probe=0.5, pidmax=10, probemax=10):
440         # Wait for the p2p installer
441         if self._yum_dependencies and not self._installed:
442             self._yum_dependencies.async_setup_wait()
443             self._installed = True
444         
445     def is_alive(self):
446         # Make sure all the paths are created where 
447         # they have to be created for deployment
448         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
449             "echo 'ALIVE'",
450             host = self.hostname,
451             port = None,
452             user = self.slicename,
453             agent = None,
454             ident_key = self.ident_path,
455             server_key = self.server_key,
456             timeout = 60,
457             err_on_timeout = False
458             )
459         
460         if proc.wait():
461             return False
462         elif not err and out.strip() == 'ALIVE':
463             return True
464         else:
465             return False
466     
467     def destroy(self):
468         if self.enable_cleanup:
469             self.do_cleanup()
470     
471     def do_cleanup(self):
472         if self.testbed().recovering:
473             # WOW - not now
474             return
475             
476         self._logger.info("Cleaning up %s", self.hostname)
477         
478         cmds = [
479             "sudo -S killall python tcpdump || /bin/true ; "
480             "sudo -S killall python tcpdump || /bin/true ; "
481             "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ",
482             "sudo -S killall -u %(slicename)s || /bin/true ",
483             "sudo -S killall -u root || /bin/true ",
484             "sudo -S killall -u %(slicename)s || /bin/true ",
485             "sudo -S killall -u root || /bin/true ",
486         ]
487
488         for cmd in cmds:
489             (out,err),proc = server.popen_ssh_command(
490                 # Some apps need two kills
491                 cmd % {
492                     'slicename' : self.slicename ,
493                 },
494                 host = self.hostname,
495                 port = None,
496                 user = self.slicename,
497                 agent = None,
498                 ident_key = self.ident_path,
499                 server_key = self.server_key,
500                 tty = True, # so that ps -N -T works as advertised...
501                 timeout = 60,
502                 retry = 3
503                 )
504             proc.wait()
505     
506     def prepare_dependencies(self):
507         # Configure p2p yum dependency installer
508         if self.required_packages and not self._installed:
509             self._yum_dependencies = application.YumDependency(self._api)
510             self._yum_dependencies.node = self
511             self._yum_dependencies.home_path = "nepi-yumdep"
512             self._yum_dependencies.depends = ' '.join(self.required_packages)
513
514     def routing_method(self, routes, vsys_vnet):
515         """
516         There are two methods, vroute and sliceip.
517         
518         vroute:
519             Modifies the node's routing table directly, validating that the IP
520             range lies within the network given by the slice's vsys_vnet tag.
521             This method is the most scalable for very small routing tables
522             that need not modify other routes (including the default)
523         
524         sliceip:
525             Uses policy routing and iptables filters to create per-sliver
526             routing tables. It's the most flexible way, but it doesn't scale
527             as well since only 155 routing tables can be created this way.
528         
529         This method will return the most appropriate routing method, which will
530         prefer vroute for small routing tables.
531         """
532         
533         # For now, sliceip results in kernel panics
534         # so we HAVE to use vroute
535         return 'vroute'
536         
537         # We should not make the routing table grow too big
538         if len(routes) > MAX_VROUTE_ROUTES:
539             return 'sliceip'
540         
541         vsys_vnet = ipaddr.IPNetwork(vsys_vnet)
542         for route in routes:
543             dest, prefix, nexthop, metric = route
544             dest = ipaddr.IPNetwork("%s/%d" % (dest,prefix))
545             nexthop = ipaddr.IPAddress(nexthop)
546             if dest not in vsys_vnet or nexthop not in vsys_vnet:
547                 return 'sliceip'
548         
549         return 'vroute'
550     
551     def format_route(self, route, dev, method, action):
552         dest, prefix, nexthop, metric = route
553         if method == 'vroute':
554             return (
555                 "%s %s%s gw %s %s" % (
556                     action,
557                     dest,
558                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
559                     nexthop,
560                     dev,
561                 )
562             )
563         elif method == 'sliceip':
564             return (
565                 "route %s to %s%s via %s metric %s dev %s" % (
566                     action,
567                     dest,
568                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
569                     nexthop,
570                     metric or 1,
571                     dev,
572                 )
573             )
574         else:
575             raise AssertionError, "Unknown method"
576     
577     def _annotate_routes_with_devs(self, routes, devs, method):
578         dev_routes = []
579         for route in routes:
580             for dev in devs:
581                 if dev.routes_here(route):
582                     dev_routes.append(tuple(route) + (dev.if_name,))
583                     
584                     # Stop checking
585                     break
586             else:
587                 if method == 'sliceip':
588                     dev_routes.append(tuple(route) + ('eth0',))
589                 else:
590                     raise RuntimeError, "Route %s cannot be bound to any virtual interface " \
591                         "- PL can only handle rules over virtual interfaces. Candidates are: %s" % (route,devs)
592         return dev_routes
593     
594     def configure_routes(self, routes, devs, vsys_vnet):
595         """
596         Add the specified routes to the node's routing table
597         """
598         rules = []
599         method = self.routing_method(routes, vsys_vnet)
600         tdevs = set()
601         
602         # annotate routes with devices
603         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
604         for route in dev_routes:
605             route, dev = route[:-1], route[-1]
606             
607             # Schedule rule
608             tdevs.add(dev)
609             rules.append(self.format_route(route, dev, method, 'add'))
610         
611         if method == 'sliceip':
612             rules = map('enable '.__add__, tdevs) + rules
613         
614         self._logger.info("Setting up routes for %s using %s", self.hostname, method)
615         self._logger.debug("Routes for %s:\n\t%s", self.hostname, '\n\t'.join(rules))
616         
617         self.apply_route_rules(rules, method)
618         
619         self._configured_routes = set(routes)
620         self._configured_devs = tdevs
621         self._configured_method = method
622     
623     def reconfigure_routes(self, routes, devs, vsys_vnet):
624         """
625         Updates the routes in the node's routing table to match
626         the given route list
627         """
628         method = self._configured_method
629         
630         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
631
632         current = self._configured_routes
633         current_devs = self._configured_devs
634         
635         new = set(dev_routes)
636         new_devs = set(map(operator.itemgetter(-1), dev_routes))
637         
638         deletions = current - new
639         insertions = new - current
640         
641         dev_deletions = current_devs - new_devs
642         dev_insertions = new_devs - current_devs
643         
644         # Generate rules
645         rules = []
646         
647         # Rule deletions first
648         for route in deletions:
649             route, dev = route[:-1], route[-1]
650             rules.append(self.format_route(route, dev, method, 'del'))
651         
652         if method == 'sliceip':
653             # Dev deletions now
654             rules.extend(map('disable '.__add__, dev_deletions))
655
656             # Dev insertions now
657             rules.extend(map('enable '.__add__, dev_insertions))
658
659         # Rule insertions now
660         for route in insertions:
661             route, dev = route[:-1], dev[-1]
662             rules.append(self.format_route(route, dev, method, 'add'))
663         
664         # Apply
665         self.apply_route_rules(rules, method)
666         
667         self._configured_routes = dev_routes
668         self._configured_devs = new_devs
669         
670     def apply_route_rules(self, rules, method):
671         (out,err),proc = server.popen_ssh_command(
672             "( sudo -S bash -c 'cat /vsys/%(method)s.out >&2' & ) ; sudo -S bash -c 'cat > /vsys/%(method)s.in' ; sleep 0.5" % dict(
673                 home = server.shell_escape(self.home_path),
674                 method = method),
675             host = self.hostname,
676             port = None,
677             user = self.slicename,
678             agent = None,
679             ident_key = self.ident_path,
680             server_key = self.server_key,
681             stdin = '\n'.join(rules),
682             timeout = 300
683             )
684         
685         if proc.wait() or err:
686             raise RuntimeError, "Could not set routes (%s) errors: %s%s" % (rules,out,err)
687         elif out or err:
688             logger.debug("%s said: %s%s", method, out, err)
689