Allow routing through /vsys/sliceip.
[nepi.git] / src / nepi / testbeds / planetlab / node.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 import plcapi
6 import operator
7 import rspawn
8 import time
9 import os
10 import collections
11 import cStringIO
12 import resourcealloc
13 import socket
14 import sys
15 import logging
16 import ipaddr
17 import operator
18
19 from nepi.util import server
20 from nepi.util import parallel
21
22 import application
23
24 MAX_VROUTE_ROUTES = 5
25
26 class UnresponsiveNodeError(RuntimeError):
27     pass
28
29 def _castproperty(typ, propattr):
30     def _get(self):
31         return getattr(self, propattr)
32     def _set(self, value):
33         if value is not None or (isinstance(value, basestring) and not value):
34             value = typ(value)
35         return setattr(self, propattr, value)
36     def _del(self, value):
37         return delattr(self, propattr)
38     _get.__name__ = propattr + '_get'
39     _set.__name__ = propattr + '_set'
40     _del.__name__ = propattr + '_del'
41     return property(_get, _set, _del)
42
43 class Node(object):
44     BASEFILTERS = {
45         # Map Node attribute to plcapi filter name
46         'hostname' : 'hostname',
47     }
48     
49     TAGFILTERS = {
50         # Map Node attribute to (<tag name>, <plcapi filter expression>)
51         #   There are replacements that are applied with string formatting,
52         #   so '%' has to be escaped as '%%'.
53         'architecture' : ('arch','value'),
54         'operatingSystem' : ('fcdistro','value'),
55         'pl_distro' : ('pldistro','value'),
56         'city' : ('city','value'),
57         'country' : ('country','value'),
58         'region' : ('region','value'),
59         'minReliability' : ('reliability%(timeframe)s', ']value'),
60         'maxReliability' : ('reliability%(timeframe)s', '[value'),
61         'minBandwidth' : ('bw%(timeframe)s', ']value'),
62         'maxBandwidth' : ('bw%(timeframe)s', '[value'),
63         'minLoad' : ('load%(timeframe)s', ']value'),
64         'maxLoad' : ('load%(timeframe)s', '[value'),
65         'minCpu' : ('cpu%(timeframe)s', ']value'),
66         'maxCpu' : ('cpu%(timeframe)s', '[value'),
67     }    
68     
69     DEPENDS_PIDFILE = '/tmp/nepi-depends.pid'
70     DEPENDS_LOGFILE = '/tmp/nepi-depends.log'
71     RPM_FUSION_URL = 'http://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-stable.noarch.rpm'
72     RPM_FUSION_URL_F12 = 'http://download1.rpmfusion.org/free/fedora/releases/12/Everything/x86_64/os/rpmfusion-free-release-12-1.noarch.rpm'
73     
74     minReliability = _castproperty(float, '_minReliability')
75     maxReliability = _castproperty(float, '_maxReliability')
76     minBandwidth = _castproperty(float, '_minBandwidth')
77     maxBandwidth = _castproperty(float, '_maxBandwidth')
78     minCpu = _castproperty(float, '_minCpu')
79     maxCpu = _castproperty(float, '_maxCpu')
80     minLoad = _castproperty(float, '_minLoad')
81     maxLoad = _castproperty(float, '_maxLoad')
82     
83     def __init__(self, api=None):
84         if not api:
85             api = plcapi.PLCAPI()
86         self._api = api
87         
88         # Attributes
89         self.hostname = None
90         self.architecture = None
91         self.operatingSystem = None
92         self.pl_distro = None
93         self.site = None
94         self.city = None
95         self.country = None
96         self.region = None
97         self.minReliability = None
98         self.maxReliability = None
99         self.minBandwidth = None
100         self.maxBandwidth = None
101         self.minCpu = None
102         self.maxCpu = None
103         self.minLoad = None
104         self.maxLoad = None
105         self.min_num_external_ifaces = None
106         self.max_num_external_ifaces = None
107         self.timeframe = 'm'
108         
109         # Applications and routes add requirements to connected nodes
110         self.required_packages = set()
111         self.required_vsys = set()
112         self.pythonpath = []
113         self.rpmFusion = False
114         self.env = collections.defaultdict(list)
115         
116         # Testbed-derived attributes
117         self.slicename = None
118         self.ident_path = None
119         self.server_key = None
120         self.home_path = None
121         self.enable_cleanup = False
122         
123         # Those are filled when an actual node is allocated
124         self._node_id = None
125         self._yum_dependencies = None
126         self._installed = False
127
128         # Logging
129         self._logger = logging.getLogger('nepi.testbeds.planetlab')
130     
131     def _nepi_testbed_environment_setup_get(self):
132         command = cStringIO.StringIO()
133         command.write('export PYTHONPATH=$PYTHONPATH:%s' % (
134             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
135         ))
136         command.write(' ; export PATH=$PATH:%s' % (
137             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
138         ))
139         if self.env:
140             for envkey, envvals in self.env.iteritems():
141                 for envval in envvals:
142                     command.write(' ; export %s=%s' % (envkey, envval))
143         return command.getvalue()
144     def _nepi_testbed_environment_setup_set(self, value):
145         pass
146     _nepi_testbed_environment_setup = property(
147         _nepi_testbed_environment_setup_get,
148         _nepi_testbed_environment_setup_set)
149     
150     def build_filters(self, target_filters, filter_map):
151         for attr, tag in filter_map.iteritems():
152             value = getattr(self, attr, None)
153             if value is not None:
154                 target_filters[tag] = value
155         return target_filters
156     
157     @property
158     def applicable_filters(self):
159         has = lambda att : getattr(self,att,None) is not None
160         return (
161             filter(has, self.BASEFILTERS.iterkeys())
162             + filter(has, self.TAGFILTERS.iterkeys())
163         )
164     
165     def find_candidates(self, filter_slice_id=None):
166         self._logger.info("Finding candidates for %s", self.make_filter_description())
167         
168         fields = ('node_id',)
169         replacements = {'timeframe':self.timeframe}
170         
171         # get initial candidates (no tag filters)
172         basefilters = self.build_filters({}, self.BASEFILTERS)
173         rootfilters = basefilters.copy()
174         if filter_slice_id:
175             basefilters['|slice_ids'] = (filter_slice_id,)
176         
177         # only pick healthy nodes
178         basefilters['run_level'] = 'boot'
179         basefilters['boot_state'] = 'boot'
180         basefilters['node_type'] = 'regular' # nepi can only handle regular nodes (for now)
181         basefilters['>last_contact'] = int(time.time()) - 5*3600 # allow 5h out of contact, for timezone discrepancies
182         
183         # keyword-only "pseudofilters"
184         extra = {}
185         if self.site:
186             extra['peer'] = self.site
187             
188         candidates = set(map(operator.itemgetter('node_id'), 
189             self._api.GetNodes(filters=basefilters, fields=fields, **extra)))
190         
191         # filter by tag, one tag at a time
192         applicable = self.applicable_filters
193         for tagfilter in self.TAGFILTERS.iteritems():
194             attr, (tagname, expr) = tagfilter
195             
196             # don't bother if there's no filter defined
197             if attr in applicable:
198                 tagfilter = rootfilters.copy()
199                 tagfilter['tagname'] = tagname % replacements
200                 tagfilter[expr % replacements] = getattr(self,attr)
201                 tagfilter['node_id'] = list(candidates)
202                 
203                 candidates &= set(map(operator.itemgetter('node_id'),
204                     self._api.GetNodeTags(filters=tagfilter, fields=fields)))
205         
206         # filter by vsys tags - special case since it doesn't follow
207         # the usual semantics
208         if self.required_vsys:
209             newcandidates = collections.defaultdict(set)
210             
211             vsys_tags = self._api.GetNodeTags(
212                 tagname='vsys', 
213                 node_id = list(candidates), 
214                 fields = ['node_id','value'])
215             
216             vsys_tags = map(
217                 operator.itemgetter(['node_id','value']),
218                 vsys_tags)
219             
220             required_vsys = self.required_vsys
221             for node_id, value in vsys_tags:
222                 if value in required_vsys:
223                     newcandidates[value].add(node_id)
224             
225             # take only those that have all the required vsys tags
226             newcandidates = reduce(
227                 lambda accum, new : accum & new,
228                 newcandidates.itervalues(),
229                 candidates)
230         
231         # filter by iface count
232         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
233             # fetch interfaces for all, in one go
234             filters = basefilters.copy()
235             filters['node_id'] = list(candidates)
236             ifaces = dict(map(operator.itemgetter('node_id','interface_ids'),
237                 self._api.GetNodes(filters=basefilters, fields=('node_id','interface_ids')) ))
238             
239             # filter candidates by interface count
240             if self.min_num_external_ifaces is not None and self.max_num_external_ifaces is not None:
241                 predicate = ( lambda node_id : 
242                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
243             elif self.min_num_external_ifaces is not None:
244                 predicate = ( lambda node_id : 
245                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) )
246             else:
247                 predicate = ( lambda node_id : 
248                     len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
249             
250             candidates = set(filter(predicate, candidates))
251         
252         # make sure hostnames are resolvable
253         if candidates:
254             self._logger.info("  Found %s candidates. Checking for reachability...", len(candidates))
255             
256             hostnames = dict(map(operator.itemgetter('node_id','hostname'),
257                 self._api.GetNodes(list(candidates), ['node_id','hostname'])
258             ))
259             def resolvable(node_id):
260                 try:
261                     addr = socket.gethostbyname(hostnames[node_id])
262                     return addr is not None
263                 except:
264                     return False
265             candidates = set(parallel.pfilter(resolvable, candidates,
266                 maxthreads = 16))
267
268             self._logger.info("  Found %s reachable candidates.", len(candidates))
269             
270         return candidates
271     
272     def make_filter_description(self):
273         """
274         Makes a human-readable description of filtering conditions
275         for find_candidates.
276         """
277         
278         # get initial candidates (no tag filters)
279         filters = self.build_filters({}, self.BASEFILTERS)
280         
281         # keyword-only "pseudofilters"
282         if self.site:
283             filters['peer'] = self.site
284             
285         # filter by tag, one tag at a time
286         applicable = self.applicable_filters
287         for tagfilter in self.TAGFILTERS.iteritems():
288             attr, (tagname, expr) = tagfilter
289             
290             # don't bother if there's no filter defined
291             if attr in applicable:
292                 filters[attr] = getattr(self,attr)
293         
294         # filter by vsys tags - special case since it doesn't follow
295         # the usual semantics
296         if self.required_vsys:
297             filters['vsys'] = ','.join(list(self.required_vsys))
298         
299         # filter by iface count
300         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
301             filters['num_ifaces'] = '-'.join([
302                 str(self.min_num_external_ifaces or '0'),
303                 str(self.max_num_external_ifaces or 'inf')
304             ])
305             
306         return '; '.join(map('%s: %s'.__mod__,filters.iteritems()))
307
308     def assign_node_id(self, node_id):
309         self._node_id = node_id
310         self.fetch_node_info()
311     
312     def unassign_node(self):
313         self._node_id = None
314         self.__dict__.update(self.__orig_attrs)
315     
316     def fetch_node_info(self):
317         orig_attrs = {}
318         
319         info = self._api.GetNodes(self._node_id)[0]
320         tags = dict( (t['tagname'],t['value'])
321                      for t in self._api.GetNodeTags(node_id=self._node_id, fields=('tagname','value')) )
322
323         orig_attrs['min_num_external_ifaces'] = self.min_num_external_ifaces
324         orig_attrs['max_num_external_ifaces'] = self.max_num_external_ifaces
325         self.min_num_external_ifaces = None
326         self.max_num_external_ifaces = None
327         self.timeframe = 'm'
328         
329         replacements = {'timeframe':self.timeframe}
330         for attr, tag in self.BASEFILTERS.iteritems():
331             if tag in info:
332                 value = info[tag]
333                 if hasattr(self, attr):
334                     orig_attrs[attr] = getattr(self, attr)
335                 setattr(self, attr, value)
336         for attr, (tag,_) in self.TAGFILTERS.iteritems():
337             tag = tag % replacements
338             if tag in tags:
339                 value = tags[tag]
340                 if hasattr(self, attr):
341                     orig_attrs[attr] = getattr(self, attr)
342                 setattr(self, attr, value)
343         
344         if 'peer_id' in info:
345             orig_attrs['site'] = self.site
346             self.site = self._api.peer_map[info['peer_id']]
347         
348         if 'interface_ids' in info:
349             self.min_num_external_ifaces = \
350             self.max_num_external_ifaces = len(info['interface_ids'])
351         
352         if 'ssh_rsa_key' in info:
353             orig_attrs['server_key'] = self.server_key
354             self.server_key = info['ssh_rsa_key']
355         
356         self.__orig_attrs = orig_attrs
357
358     def validate(self):
359         if self.home_path is None:
360             raise AssertionError, "Misconfigured node: missing home path"
361         if self.ident_path is None or not os.access(self.ident_path, os.R_OK):
362             raise AssertionError, "Misconfigured node: missing slice SSH key"
363         if self.slicename is None:
364             raise AssertionError, "Misconfigured node: unspecified slice"
365
366     def recover(self):
367         # Mark dependencies installed
368         self._installed = True
369         
370         # Clear load attributes, they impair re-discovery
371         self.minReliability = \
372         self.maxReliability = \
373         self.minBandwidth = \
374         self.maxBandwidth = \
375         self.minCpu = \
376         self.maxCpu = \
377         self.minLoad = \
378         self.maxLoad = None
379
380     def install_dependencies(self):
381         if self.required_packages and not self._installed:
382             # If we need rpmfusion, we must install the repo definition and the gpg keys
383             if self.rpmFusion:
384                 if self.operatingSystem == 'f12':
385                     # Fedora 12 requires a different rpmfusion package
386                     RPM_FUSION_URL = self.RPM_FUSION_URL_F12
387                 else:
388                     # This one works for f13+
389                     RPM_FUSION_URL = self.RPM_FUSION_URL
390                     
391                 rpmFusion = (
392                   '( rpm -q $(rpm -q -p %(RPM_FUSION_URL)s) || rpm -i %(RPM_FUSION_URL)s ) &&'
393                 ) % {
394                     'RPM_FUSION_URL' : RPM_FUSION_URL
395                 }
396             else:
397                 rpmFusion = ''
398             
399             if rpmFusion:
400                 (out,err),proc = server.popen_ssh_command(
401                     rpmFusion,
402                     host = self.hostname,
403                     port = None,
404                     user = self.slicename,
405                     agent = None,
406                     ident_key = self.ident_path,
407                     server_key = self.server_key
408                     )
409                 
410                 if proc.wait():
411                     raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
412             
413             # Launch p2p yum dependency installer
414             self._yum_dependencies.async_setup()
415     
416     def wait_provisioning(self, timeout = 20*60):
417         # Wait for the p2p installer
418         sleeptime = 1.0
419         totaltime = 0.0
420         while not self.is_alive():
421             time.sleep(sleeptime)
422             totaltime += sleeptime
423             sleeptime = min(30.0, sleeptime*1.5)
424             
425             if totaltime > timeout:
426                 # PlanetLab has a 15' delay on configuration propagation
427                 # If we're above that delay, the unresponsiveness is not due
428                 # to this delay.
429                 raise UnresponsiveNodeError, "Unresponsive host %s" % (self.hostname,)
430         
431         # Ensure the node is clean (no apps running that could interfere with operations)
432         if self.enable_cleanup:
433             self.do_cleanup()
434     
435     def wait_dependencies(self, pidprobe=1, probe=0.5, pidmax=10, probemax=10):
436         # Wait for the p2p installer
437         if self._yum_dependencies and not self._installed:
438             self._yum_dependencies.async_setup_wait()
439             self._installed = True
440         
441     def is_alive(self):
442         # Make sure all the paths are created where 
443         # they have to be created for deployment
444         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
445             "echo 'ALIVE'",
446             host = self.hostname,
447             port = None,
448             user = self.slicename,
449             agent = None,
450             ident_key = self.ident_path,
451             server_key = self.server_key
452             )
453         
454         if proc.wait():
455             return False
456         elif not err and out.strip() == 'ALIVE':
457             return True
458         else:
459             return False
460     
461     def destroy(self):
462         if self.enable_cleanup:
463             self.do_cleanup()
464     
465     def do_cleanup(self):
466         self._logger.info("Cleaning up %s", self.hostname)
467
468         (out,err),proc = server.popen_ssh_command(
469             # Some apps need two kills
470             "sudo -S killall -u %(slicename)s ; "
471             "sudo -S killall -u %(slicename)s ; "
472             "sudo -S killall python tcpdump ; "
473             "sudo -S kill $(ps -N T -o pid --no-heading | sort) ; "
474             "sudo -S killall -u root ; "
475             "sudo -S killall -u root " % {
476                 'slicename' : self.slicename ,
477             },
478             host = self.hostname,
479             port = None,
480             user = self.slicename,
481             agent = None,
482             ident_key = self.ident_path,
483             server_key = self.server_key,
484             tty = True, # so that ps -N -T works as advertised...
485             )
486         proc.wait()
487     
488     def prepare_dependencies(self):
489         # Configure p2p yum dependency installer
490         if self.required_packages and not self._installed:
491             self._yum_dependencies = application.YumDependency(self._api)
492             self._yum_dependencies.node = self
493             self._yum_dependencies.home_path = "nepi-yumdep"
494             self._yum_dependencies.depends = ' '.join(self.required_packages)
495
496     def routing_method(self, routes, vsys_vnet):
497         """
498         There are two methods, vroute and sliceip.
499         
500         vroute:
501             Modifies the node's routing table directly, validating that the IP
502             range lies within the network given by the slice's vsys_vnet tag.
503             This method is the most scalable for very small routing tables
504             that need not modify other routes (including the default)
505         
506         sliceip:
507             Uses policy routing and iptables filters to create per-sliver
508             routing tables. It's the most flexible way, but it doesn't scale
509             as well since only 155 routing tables can be created this way.
510         
511         This method will return the most appropriate routing method, which will
512         prefer vroute for small routing tables.
513         """
514         
515         # For now, sliceip results in kernel panics
516         # so we HAVE to use vroute
517         return 'vroute'
518         
519         # We should not make the routing table grow too big
520         if len(routes) > MAX_VROUTE_ROUTES:
521             return 'sliceip'
522         
523         vsys_vnet = ipaddr.IPNetwork(vsys_vnet)
524         for route in routes:
525             dest, prefix, nexthop, metric = route
526             dest = ipaddr.IPNetwork("%s/%d" % (dest,prefix))
527             nexthop = ipaddr.IPAddress(nexthop)
528             if dest not in vsys_vnet or nexthop not in vsys_vnet:
529                 return 'sliceip'
530         
531         return 'vroute'
532     
533     def format_route(self, route, dev, method, action):
534         dest, prefix, nexthop, metric = route
535         if method == 'vroute':
536             return (
537                 "%s %s%s gw %s %s" % (
538                     action,
539                     dest,
540                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
541                     nexthop,
542                     dev,
543                 )
544             )
545         elif method == 'sliceip':
546             return (
547                 "route %s to %s%s via %s metric %s dev %s" % (
548                     action,
549                     dest,
550                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
551                     nexthop,
552                     metric or 1,
553                     dev,
554                 )
555             )
556         else:
557             raise AssertionError, "Unknown method"
558     
559     def _annotate_routes_with_devs(self, routes, devs, method):
560         dev_routes = []
561         for route in routes:
562             for dev in devs:
563                 if dev.routes_here(route):
564                     dev_routes.append(tuple(route) + (dev.if_name,))
565                     
566                     # Stop checking
567                     break
568             else:
569                 if method == 'sliceip':
570                     dev_routes.append(tuple(route) + ('eth0',))
571                 else:
572                     raise RuntimeError, "Route %s cannot be bound to any virtual interface " \
573                         "- PL can only handle rules over virtual interfaces. Candidates are: %s" % (route,devs)
574         return dev_routes
575     
576     def configure_routes(self, routes, devs, vsys_vnet):
577         """
578         Add the specified routes to the node's routing table
579         """
580         rules = []
581         method = self.routing_method(routes, vsys_vnet)
582         tdevs = set()
583         
584         # annotate routes with devices
585         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
586         for route in dev_routes:
587             route, dev = route[:-1], route[-1]
588             
589             # Schedule rule
590             tdevs.add(dev)
591             rules.append(self.format_route(route, dev, method, 'add'))
592         
593         if method == 'sliceip':
594             rules = map('enable '.__add__, tdevs) + rules
595         
596         self._logger.info("Setting up routes for %s using %s", self.hostname, method)
597         self._logger.debug("Routes for %s:\n\t%s", self.hostname, '\n\t'.join(rules))
598         
599         self.apply_route_rules(rules, method)
600         
601         self._configured_routes = set(routes)
602         self._configured_devs = tdevs
603         self._configured_method = method
604     
605     def reconfigure_routes(self, routes, devs, vsys_vnet):
606         """
607         Updates the routes in the node's routing table to match
608         the given route list
609         """
610         method = self._configured_method
611         
612         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
613
614         current = self._configured_routes
615         current_devs = self._configured_devs
616         
617         new = set(dev_routes)
618         new_devs = set(map(operator.itemgetter(-1), dev_routes))
619         
620         deletions = current - new
621         insertions = new - current
622         
623         dev_deletions = current_devs - new_devs
624         dev_insertions = new_devs - current_devs
625         
626         # Generate rules
627         rules = []
628         
629         # Rule deletions first
630         for route in deletions:
631             route, dev = route[:-1], route[-1]
632             rules.append(self.format_route(route, dev, method, 'del'))
633         
634         if method == 'sliceip':
635             # Dev deletions now
636             rules.extend(map('disable '.__add__, dev_deletions))
637
638             # Dev insertions now
639             rules.extend(map('enable '.__add__, dev_insertions))
640
641         # Rule insertions now
642         for route in insertions:
643             route, dev = route[:-1], dev[-1]
644             rules.append(self.format_route(route, dev, method, 'add'))
645         
646         # Apply
647         self.apply_route_rules(rules, method)
648         
649         self._configured_routes = dev_routes
650         self._configured_devs = new_devs
651         
652     def apply_route_rules(self, rules, method):
653         (out,err),proc = server.popen_ssh_command(
654             "( sudo -S bash -c 'cat /vsys/%(method)s.out >&2' & ) ; sudo -S bash -c 'cat > /vsys/%(method)s.in' ; sleep 0.5" % dict(
655                 home = server.shell_escape(self.home_path),
656                 method = method),
657             host = self.hostname,
658             port = None,
659             user = self.slicename,
660             agent = None,
661             ident_key = self.ident_path,
662             server_key = self.server_key,
663             stdin = '\n'.join(rules)
664             )
665         
666         if proc.wait() or err:
667             raise RuntimeError, "Could not set routes (%s) errors: %s%s" % (rules,out,err)
668         elif out or err:
669             logger.debug("%s said: %s%s", method, out, err)
670