SSH timeout. It tends to... hang. Whatevah...
[nepi.git] / src / nepi / testbeds / planetlab / node.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 import plcapi
6 import operator
7 import rspawn
8 import time
9 import os
10 import collections
11 import cStringIO
12 import resourcealloc
13 import socket
14 import sys
15 import logging
16 import ipaddr
17 import operator
18
19 from nepi.util import server
20 from nepi.util import parallel
21
22 import application
23
24 MAX_VROUTE_ROUTES = 5
25
26 class UnresponsiveNodeError(RuntimeError):
27     pass
28
29 def _castproperty(typ, propattr):
30     def _get(self):
31         return getattr(self, propattr)
32     def _set(self, value):
33         if value is not None or (isinstance(value, basestring) and not value):
34             value = typ(value)
35         return setattr(self, propattr, value)
36     def _del(self, value):
37         return delattr(self, propattr)
38     _get.__name__ = propattr + '_get'
39     _set.__name__ = propattr + '_set'
40     _del.__name__ = propattr + '_del'
41     return property(_get, _set, _del)
42
43 class Node(object):
44     BASEFILTERS = {
45         # Map Node attribute to plcapi filter name
46         'hostname' : 'hostname',
47     }
48     
49     TAGFILTERS = {
50         # Map Node attribute to (<tag name>, <plcapi filter expression>)
51         #   There are replacements that are applied with string formatting,
52         #   so '%' has to be escaped as '%%'.
53         'architecture' : ('arch','value'),
54         'operatingSystem' : ('fcdistro','value'),
55         'pl_distro' : ('pldistro','value'),
56         'city' : ('city','value'),
57         'country' : ('country','value'),
58         'region' : ('region','value'),
59         'minReliability' : ('reliability%(timeframe)s', ']value'),
60         'maxReliability' : ('reliability%(timeframe)s', '[value'),
61         'minBandwidth' : ('bw%(timeframe)s', ']value'),
62         'maxBandwidth' : ('bw%(timeframe)s', '[value'),
63         'minLoad' : ('load%(timeframe)s', ']value'),
64         'maxLoad' : ('load%(timeframe)s', '[value'),
65         'minCpu' : ('cpu%(timeframe)s', ']value'),
66         'maxCpu' : ('cpu%(timeframe)s', '[value'),
67     }    
68     
69     DEPENDS_PIDFILE = '/tmp/nepi-depends.pid'
70     DEPENDS_LOGFILE = '/tmp/nepi-depends.log'
71     RPM_FUSION_URL = 'http://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-stable.noarch.rpm'
72     RPM_FUSION_URL_F12 = 'http://download1.rpmfusion.org/free/fedora/releases/12/Everything/x86_64/os/rpmfusion-free-release-12-1.noarch.rpm'
73     
74     minReliability = _castproperty(float, '_minReliability')
75     maxReliability = _castproperty(float, '_maxReliability')
76     minBandwidth = _castproperty(float, '_minBandwidth')
77     maxBandwidth = _castproperty(float, '_maxBandwidth')
78     minCpu = _castproperty(float, '_minCpu')
79     maxCpu = _castproperty(float, '_maxCpu')
80     minLoad = _castproperty(float, '_minLoad')
81     maxLoad = _castproperty(float, '_maxLoad')
82     
83     def __init__(self, api=None):
84         if not api:
85             api = plcapi.PLCAPI()
86         self._api = api
87         
88         # Attributes
89         self.hostname = None
90         self.architecture = None
91         self.operatingSystem = None
92         self.pl_distro = None
93         self.site = None
94         self.city = None
95         self.country = None
96         self.region = None
97         self.minReliability = None
98         self.maxReliability = None
99         self.minBandwidth = None
100         self.maxBandwidth = None
101         self.minCpu = None
102         self.maxCpu = None
103         self.minLoad = None
104         self.maxLoad = None
105         self.min_num_external_ifaces = None
106         self.max_num_external_ifaces = None
107         self.timeframe = 'm'
108         
109         # Applications and routes add requirements to connected nodes
110         self.required_packages = set()
111         self.required_vsys = set()
112         self.pythonpath = []
113         self.rpmFusion = False
114         self.env = collections.defaultdict(list)
115         
116         # Testbed-derived attributes
117         self.slicename = None
118         self.ident_path = None
119         self.server_key = None
120         self.home_path = None
121         self.enable_cleanup = False
122         
123         # Those are filled when an actual node is allocated
124         self._node_id = None
125         self._yum_dependencies = None
126         self._installed = False
127
128         # Logging
129         self._logger = logging.getLogger('nepi.testbeds.planetlab')
130     
131     def _nepi_testbed_environment_setup_get(self):
132         command = cStringIO.StringIO()
133         command.write('export PYTHONPATH=$PYTHONPATH:%s' % (
134             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
135         ))
136         command.write(' ; export PATH=$PATH:%s' % (
137             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
138         ))
139         if self.env:
140             for envkey, envvals in self.env.iteritems():
141                 for envval in envvals:
142                     command.write(' ; export %s=%s' % (envkey, envval))
143         return command.getvalue()
144     def _nepi_testbed_environment_setup_set(self, value):
145         pass
146     _nepi_testbed_environment_setup = property(
147         _nepi_testbed_environment_setup_get,
148         _nepi_testbed_environment_setup_set)
149     
150     def build_filters(self, target_filters, filter_map):
151         for attr, tag in filter_map.iteritems():
152             value = getattr(self, attr, None)
153             if value is not None:
154                 target_filters[tag] = value
155         return target_filters
156     
157     @property
158     def applicable_filters(self):
159         has = lambda att : getattr(self,att,None) is not None
160         return (
161             filter(has, self.BASEFILTERS.iterkeys())
162             + filter(has, self.TAGFILTERS.iterkeys())
163         )
164     
165     def find_candidates(self, filter_slice_id=None):
166         self._logger.info("Finding candidates for %s", self.make_filter_description())
167         
168         fields = ('node_id',)
169         replacements = {'timeframe':self.timeframe}
170         
171         # get initial candidates (no tag filters)
172         basefilters = self.build_filters({}, self.BASEFILTERS)
173         rootfilters = basefilters.copy()
174         if filter_slice_id:
175             basefilters['|slice_ids'] = (filter_slice_id,)
176         
177         # only pick healthy nodes
178         basefilters['run_level'] = 'boot'
179         basefilters['boot_state'] = 'boot'
180         basefilters['node_type'] = 'regular' # nepi can only handle regular nodes (for now)
181         basefilters['>last_contact'] = int(time.time()) - 5*3600 # allow 5h out of contact, for timezone discrepancies
182         
183         # keyword-only "pseudofilters"
184         extra = {}
185         if self.site:
186             extra['peer'] = self.site
187             
188         candidates = set(map(operator.itemgetter('node_id'), 
189             self._api.GetNodes(filters=basefilters, fields=fields, **extra)))
190         
191         # filter by tag, one tag at a time
192         applicable = self.applicable_filters
193         for tagfilter in self.TAGFILTERS.iteritems():
194             attr, (tagname, expr) = tagfilter
195             
196             # don't bother if there's no filter defined
197             if attr in applicable:
198                 tagfilter = rootfilters.copy()
199                 tagfilter['tagname'] = tagname % replacements
200                 tagfilter[expr % replacements] = getattr(self,attr)
201                 tagfilter['node_id'] = list(candidates)
202                 
203                 candidates &= set(map(operator.itemgetter('node_id'),
204                     self._api.GetNodeTags(filters=tagfilter, fields=fields)))
205         
206         # filter by vsys tags - special case since it doesn't follow
207         # the usual semantics
208         if self.required_vsys:
209             newcandidates = collections.defaultdict(set)
210             
211             vsys_tags = self._api.GetNodeTags(
212                 tagname='vsys', 
213                 node_id = list(candidates), 
214                 fields = ['node_id','value'])
215             
216             vsys_tags = map(
217                 operator.itemgetter(['node_id','value']),
218                 vsys_tags)
219             
220             required_vsys = self.required_vsys
221             for node_id, value in vsys_tags:
222                 if value in required_vsys:
223                     newcandidates[value].add(node_id)
224             
225             # take only those that have all the required vsys tags
226             newcandidates = reduce(
227                 lambda accum, new : accum & new,
228                 newcandidates.itervalues(),
229                 candidates)
230         
231         # filter by iface count
232         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
233             # fetch interfaces for all, in one go
234             filters = basefilters.copy()
235             filters['node_id'] = list(candidates)
236             ifaces = dict(map(operator.itemgetter('node_id','interface_ids'),
237                 self._api.GetNodes(filters=basefilters, fields=('node_id','interface_ids')) ))
238             
239             # filter candidates by interface count
240             if self.min_num_external_ifaces is not None and self.max_num_external_ifaces is not None:
241                 predicate = ( lambda node_id : 
242                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
243             elif self.min_num_external_ifaces is not None:
244                 predicate = ( lambda node_id : 
245                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) )
246             else:
247                 predicate = ( lambda node_id : 
248                     len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
249             
250             candidates = set(filter(predicate, candidates))
251         
252         # make sure hostnames are resolvable
253         if candidates:
254             self._logger.info("  Found %s candidates. Checking for reachability...", len(candidates))
255             
256             hostnames = dict(map(operator.itemgetter('node_id','hostname'),
257                 self._api.GetNodes(list(candidates), ['node_id','hostname'])
258             ))
259             def resolvable(node_id):
260                 try:
261                     addr = socket.gethostbyname(hostnames[node_id])
262                     return addr is not None
263                 except:
264                     return False
265             candidates = set(parallel.pfilter(resolvable, candidates,
266                 maxthreads = 16))
267
268             self._logger.info("  Found %s reachable candidates.", len(candidates))
269             
270         return candidates
271     
272     def make_filter_description(self):
273         """
274         Makes a human-readable description of filtering conditions
275         for find_candidates.
276         """
277         
278         # get initial candidates (no tag filters)
279         filters = self.build_filters({}, self.BASEFILTERS)
280         
281         # keyword-only "pseudofilters"
282         if self.site:
283             filters['peer'] = self.site
284             
285         # filter by tag, one tag at a time
286         applicable = self.applicable_filters
287         for tagfilter in self.TAGFILTERS.iteritems():
288             attr, (tagname, expr) = tagfilter
289             
290             # don't bother if there's no filter defined
291             if attr in applicable:
292                 filters[attr] = getattr(self,attr)
293         
294         # filter by vsys tags - special case since it doesn't follow
295         # the usual semantics
296         if self.required_vsys:
297             filters['vsys'] = ','.join(list(self.required_vsys))
298         
299         # filter by iface count
300         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
301             filters['num_ifaces'] = '-'.join([
302                 str(self.min_num_external_ifaces or '0'),
303                 str(self.max_num_external_ifaces or 'inf')
304             ])
305             
306         return '; '.join(map('%s: %s'.__mod__,filters.iteritems()))
307
308     def assign_node_id(self, node_id):
309         self._node_id = node_id
310         self.fetch_node_info()
311     
312     def unassign_node(self):
313         self._node_id = None
314         self.__dict__.update(self.__orig_attrs)
315     
316     def fetch_node_info(self):
317         orig_attrs = {}
318         
319         info = self._api.GetNodes(self._node_id)[0]
320         tags = dict( (t['tagname'],t['value'])
321                      for t in self._api.GetNodeTags(node_id=self._node_id, fields=('tagname','value')) )
322
323         orig_attrs['min_num_external_ifaces'] = self.min_num_external_ifaces
324         orig_attrs['max_num_external_ifaces'] = self.max_num_external_ifaces
325         self.min_num_external_ifaces = None
326         self.max_num_external_ifaces = None
327         self.timeframe = 'm'
328         
329         replacements = {'timeframe':self.timeframe}
330         for attr, tag in self.BASEFILTERS.iteritems():
331             if tag in info:
332                 value = info[tag]
333                 if hasattr(self, attr):
334                     orig_attrs[attr] = getattr(self, attr)
335                 setattr(self, attr, value)
336         for attr, (tag,_) in self.TAGFILTERS.iteritems():
337             tag = tag % replacements
338             if tag in tags:
339                 value = tags[tag]
340                 if hasattr(self, attr):
341                     orig_attrs[attr] = getattr(self, attr)
342                 setattr(self, attr, value)
343         
344         if 'peer_id' in info:
345             orig_attrs['site'] = self.site
346             self.site = self._api.peer_map[info['peer_id']]
347         
348         if 'interface_ids' in info:
349             self.min_num_external_ifaces = \
350             self.max_num_external_ifaces = len(info['interface_ids'])
351         
352         if 'ssh_rsa_key' in info:
353             orig_attrs['server_key'] = self.server_key
354             self.server_key = info['ssh_rsa_key']
355         
356         self.__orig_attrs = orig_attrs
357
358     def validate(self):
359         if self.home_path is None:
360             raise AssertionError, "Misconfigured node: missing home path"
361         if self.ident_path is None or not os.access(self.ident_path, os.R_OK):
362             raise AssertionError, "Misconfigured node: missing slice SSH key"
363         if self.slicename is None:
364             raise AssertionError, "Misconfigured node: unspecified slice"
365
366     def recover(self):
367         # Mark dependencies installed
368         self._installed = True
369         
370         # Clear load attributes, they impair re-discovery
371         self.minReliability = \
372         self.maxReliability = \
373         self.minBandwidth = \
374         self.maxBandwidth = \
375         self.minCpu = \
376         self.maxCpu = \
377         self.minLoad = \
378         self.maxLoad = None
379
380     def install_dependencies(self):
381         if self.required_packages and not self._installed:
382             # If we need rpmfusion, we must install the repo definition and the gpg keys
383             if self.rpmFusion:
384                 if self.operatingSystem == 'f12':
385                     # Fedora 12 requires a different rpmfusion package
386                     RPM_FUSION_URL = self.RPM_FUSION_URL_F12
387                 else:
388                     # This one works for f13+
389                     RPM_FUSION_URL = self.RPM_FUSION_URL
390                     
391                 rpmFusion = (
392                   '( rpm -q $(rpm -q -p %(RPM_FUSION_URL)s) || rpm -i %(RPM_FUSION_URL)s ) &&'
393                 ) % {
394                     'RPM_FUSION_URL' : RPM_FUSION_URL
395                 }
396             else:
397                 rpmFusion = ''
398             
399             if rpmFusion:
400                 (out,err),proc = server.popen_ssh_command(
401                     rpmFusion,
402                     host = self.hostname,
403                     port = None,
404                     user = self.slicename,
405                     agent = None,
406                     ident_key = self.ident_path,
407                     server_key = self.server_key,
408                     timeout = 600,
409                     )
410                 
411                 if proc.wait():
412                     raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
413             
414             # Launch p2p yum dependency installer
415             self._yum_dependencies.async_setup()
416     
417     def wait_provisioning(self, timeout = 20*60):
418         # Wait for the p2p installer
419         sleeptime = 1.0
420         totaltime = 0.0
421         while not self.is_alive():
422             time.sleep(sleeptime)
423             totaltime += sleeptime
424             sleeptime = min(30.0, sleeptime*1.5)
425             
426             if totaltime > timeout:
427                 # PlanetLab has a 15' delay on configuration propagation
428                 # If we're above that delay, the unresponsiveness is not due
429                 # to this delay.
430                 raise UnresponsiveNodeError, "Unresponsive host %s" % (self.hostname,)
431         
432         # Ensure the node is clean (no apps running that could interfere with operations)
433         if self.enable_cleanup:
434             self.do_cleanup()
435     
436     def wait_dependencies(self, pidprobe=1, probe=0.5, pidmax=10, probemax=10):
437         # Wait for the p2p installer
438         if self._yum_dependencies and not self._installed:
439             self._yum_dependencies.async_setup_wait()
440             self._installed = True
441         
442     def is_alive(self):
443         # Make sure all the paths are created where 
444         # they have to be created for deployment
445         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
446             "echo 'ALIVE'",
447             host = self.hostname,
448             port = None,
449             user = self.slicename,
450             agent = None,
451             ident_key = self.ident_path,
452             server_key = self.server_key,
453             timeout = 60,
454             err_on_timeout = False
455             )
456         
457         if proc.wait():
458             return False
459         elif not err and out.strip() == 'ALIVE':
460             return True
461         else:
462             return False
463     
464     def destroy(self):
465         if self.enable_cleanup:
466             self.do_cleanup()
467     
468     def do_cleanup(self):
469         if self.testbed().recovering:
470             # WOW - not now
471             return
472             
473         self._logger.info("Cleaning up %s", self.hostname)
474         
475         cmds = [
476             "sudo -S killall python tcpdump || /bin/true ; "
477             "sudo -S killall python tcpdump || /bin/true ; "
478             "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ",
479             "sudo -S killall -u %(slicename)s || /bin/true ",
480             "sudo -S killall -u root || /bin/true ",
481             "sudo -S killall -u %(slicename)s || /bin/true ",
482             "sudo -S killall -u root || /bin/true ",
483         ]
484
485         for cmd in cmds:
486             (out,err),proc = server.popen_ssh_command(
487                 # Some apps need two kills
488                 cmd % {
489                     'slicename' : self.slicename ,
490                 },
491                 host = self.hostname,
492                 port = None,
493                 user = self.slicename,
494                 agent = None,
495                 ident_key = self.ident_path,
496                 server_key = self.server_key,
497                 tty = True, # so that ps -N -T works as advertised...
498                 timeout = 60,
499                 retry = 3
500                 )
501             proc.wait()
502     
503     def prepare_dependencies(self):
504         # Configure p2p yum dependency installer
505         if self.required_packages and not self._installed:
506             self._yum_dependencies = application.YumDependency(self._api)
507             self._yum_dependencies.node = self
508             self._yum_dependencies.home_path = "nepi-yumdep"
509             self._yum_dependencies.depends = ' '.join(self.required_packages)
510
511     def routing_method(self, routes, vsys_vnet):
512         """
513         There are two methods, vroute and sliceip.
514         
515         vroute:
516             Modifies the node's routing table directly, validating that the IP
517             range lies within the network given by the slice's vsys_vnet tag.
518             This method is the most scalable for very small routing tables
519             that need not modify other routes (including the default)
520         
521         sliceip:
522             Uses policy routing and iptables filters to create per-sliver
523             routing tables. It's the most flexible way, but it doesn't scale
524             as well since only 155 routing tables can be created this way.
525         
526         This method will return the most appropriate routing method, which will
527         prefer vroute for small routing tables.
528         """
529         
530         # For now, sliceip results in kernel panics
531         # so we HAVE to use vroute
532         return 'vroute'
533         
534         # We should not make the routing table grow too big
535         if len(routes) > MAX_VROUTE_ROUTES:
536             return 'sliceip'
537         
538         vsys_vnet = ipaddr.IPNetwork(vsys_vnet)
539         for route in routes:
540             dest, prefix, nexthop, metric = route
541             dest = ipaddr.IPNetwork("%s/%d" % (dest,prefix))
542             nexthop = ipaddr.IPAddress(nexthop)
543             if dest not in vsys_vnet or nexthop not in vsys_vnet:
544                 return 'sliceip'
545         
546         return 'vroute'
547     
548     def format_route(self, route, dev, method, action):
549         dest, prefix, nexthop, metric = route
550         if method == 'vroute':
551             return (
552                 "%s %s%s gw %s %s" % (
553                     action,
554                     dest,
555                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
556                     nexthop,
557                     dev,
558                 )
559             )
560         elif method == 'sliceip':
561             return (
562                 "route %s to %s%s via %s metric %s dev %s" % (
563                     action,
564                     dest,
565                     (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
566                     nexthop,
567                     metric or 1,
568                     dev,
569                 )
570             )
571         else:
572             raise AssertionError, "Unknown method"
573     
574     def _annotate_routes_with_devs(self, routes, devs, method):
575         dev_routes = []
576         for route in routes:
577             for dev in devs:
578                 if dev.routes_here(route):
579                     dev_routes.append(tuple(route) + (dev.if_name,))
580                     
581                     # Stop checking
582                     break
583             else:
584                 if method == 'sliceip':
585                     dev_routes.append(tuple(route) + ('eth0',))
586                 else:
587                     raise RuntimeError, "Route %s cannot be bound to any virtual interface " \
588                         "- PL can only handle rules over virtual interfaces. Candidates are: %s" % (route,devs)
589         return dev_routes
590     
591     def configure_routes(self, routes, devs, vsys_vnet):
592         """
593         Add the specified routes to the node's routing table
594         """
595         rules = []
596         method = self.routing_method(routes, vsys_vnet)
597         tdevs = set()
598         
599         # annotate routes with devices
600         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
601         for route in dev_routes:
602             route, dev = route[:-1], route[-1]
603             
604             # Schedule rule
605             tdevs.add(dev)
606             rules.append(self.format_route(route, dev, method, 'add'))
607         
608         if method == 'sliceip':
609             rules = map('enable '.__add__, tdevs) + rules
610         
611         self._logger.info("Setting up routes for %s using %s", self.hostname, method)
612         self._logger.debug("Routes for %s:\n\t%s", self.hostname, '\n\t'.join(rules))
613         
614         self.apply_route_rules(rules, method)
615         
616         self._configured_routes = set(routes)
617         self._configured_devs = tdevs
618         self._configured_method = method
619     
620     def reconfigure_routes(self, routes, devs, vsys_vnet):
621         """
622         Updates the routes in the node's routing table to match
623         the given route list
624         """
625         method = self._configured_method
626         
627         dev_routes = self._annotate_routes_with_devs(routes, devs, method)
628
629         current = self._configured_routes
630         current_devs = self._configured_devs
631         
632         new = set(dev_routes)
633         new_devs = set(map(operator.itemgetter(-1), dev_routes))
634         
635         deletions = current - new
636         insertions = new - current
637         
638         dev_deletions = current_devs - new_devs
639         dev_insertions = new_devs - current_devs
640         
641         # Generate rules
642         rules = []
643         
644         # Rule deletions first
645         for route in deletions:
646             route, dev = route[:-1], route[-1]
647             rules.append(self.format_route(route, dev, method, 'del'))
648         
649         if method == 'sliceip':
650             # Dev deletions now
651             rules.extend(map('disable '.__add__, dev_deletions))
652
653             # Dev insertions now
654             rules.extend(map('enable '.__add__, dev_insertions))
655
656         # Rule insertions now
657         for route in insertions:
658             route, dev = route[:-1], dev[-1]
659             rules.append(self.format_route(route, dev, method, 'add'))
660         
661         # Apply
662         self.apply_route_rules(rules, method)
663         
664         self._configured_routes = dev_routes
665         self._configured_devs = new_devs
666         
667     def apply_route_rules(self, rules, method):
668         (out,err),proc = server.popen_ssh_command(
669             "( sudo -S bash -c 'cat /vsys/%(method)s.out >&2' & ) ; sudo -S bash -c 'cat > /vsys/%(method)s.in' ; sleep 0.5" % dict(
670                 home = server.shell_escape(self.home_path),
671                 method = method),
672             host = self.hostname,
673             port = None,
674             user = self.slicename,
675             agent = None,
676             ident_key = self.ident_path,
677             server_key = self.server_key,
678             stdin = '\n'.join(rules),
679             timeout = 300
680             )
681         
682         if proc.wait() or err:
683             raise RuntimeError, "Could not set routes (%s) errors: %s%s" % (rules,out,err)
684         elif out or err:
685             logger.debug("%s said: %s%s", method, out, err)
686