Fix routing: only consider gateway addresses for routes_here (ie: the GW must belong...
[nepi.git] / src / nepi / testbeds / planetlab / node.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 import plcapi
6 import operator
7 import rspawn
8 import time
9 import os
10 import collections
11 import cStringIO
12
13 from nepi.util import server
14
15 class Node(object):
16     BASEFILTERS = {
17         # Map Node attribute to plcapi filter name
18         'hostname' : 'hostname',
19     }
20     
21     TAGFILTERS = {
22         # Map Node attribute to (<tag name>, <plcapi filter expression>)
23         #   There are replacements that are applied with string formatting,
24         #   so '%' has to be escaped as '%%'.
25         'architecture' : ('arch','value'),
26         'operating_system' : ('fcdistro','value'),
27         'pl_distro' : ('pldistro','value'),
28         'min_reliability' : ('reliability%(timeframe)s', ']value'),
29         'max_reliability' : ('reliability%(timeframe)s', '[value'),
30         'min_bandwidth' : ('bw%(timeframe)s', ']value'),
31         'max_bandwidth' : ('bw%(timeframe)s', '[value'),
32     }    
33     
34     DEPENDS_PIDFILE = '/tmp/nepi-depends.pid'
35     DEPENDS_LOGFILE = '/tmp/nepi-depends.log'
36     
37     def __init__(self, api=None):
38         if not api:
39             api = plcapi.PLCAPI()
40         self._api = api
41         
42         # Attributes
43         self.hostname = None
44         self.architecture = None
45         self.operating_system = None
46         self.pl_distro = None
47         self.site = None
48         self.emulation = None
49         self.min_reliability = None
50         self.max_reliability = None
51         self.min_bandwidth = None
52         self.max_bandwidth = None
53         self.min_num_external_ifaces = None
54         self.max_num_external_ifaces = None
55         self.timeframe = 'm'
56         
57         # Applications and routes add requirements to connected nodes
58         self.required_packages = set()
59         self.required_vsys = set()
60         self.pythonpath = []
61         self.env = collections.defaultdict(list)
62         
63         # Testbed-derived attributes
64         self.slicename = None
65         self.ident_path = None
66         self.server_key = None
67         self.home_path = None
68         
69         # Those are filled when an actual node is allocated
70         self._node_id = None
71     
72     @property
73     def _nepi_testbed_environment_setup(self):
74         command = cStringIO.StringIO()
75         command.write('export PYTHONPATH=$PYTHONPATH:%s' % (
76             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
77         ))
78         command.write(' ; export PATH=$PATH:%s' % (
79             ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
80         ))
81         if self.env:
82             for envkey, envvals in self.env.iteritems():
83                 for envval in envvals:
84                     command.write(' ; export %s=%s' % (envkey, envval))
85         return command.getvalue()
86     
87     def build_filters(self, target_filters, filter_map):
88         for attr, tag in filter_map.iteritems():
89             value = getattr(self, attr, None)
90             if value is not None:
91                 target_filters[tag] = value
92         return target_filters
93     
94     @property
95     def applicable_filters(self):
96         has = lambda att : getattr(self,att,None) is not None
97         return (
98             filter(has, self.BASEFILTERS.iterkeys())
99             + filter(has, self.TAGFILTERS.iterkeys())
100         )
101     
102     def find_candidates(self, filter_slice_id=None):
103         fields = ('node_id',)
104         replacements = {'timeframe':self.timeframe}
105         
106         # get initial candidates (no tag filters)
107         basefilters = self.build_filters({}, self.BASEFILTERS)
108         if filter_slice_id:
109             basefilters['|slice_ids'] = (filter_slice_id,)
110         
111         # keyword-only "pseudofilters"
112         extra = {}
113         if self.site:
114             extra['peer'] = self.site
115             
116         candidates = set(map(operator.itemgetter('node_id'), 
117             self._api.GetNodes(filters=basefilters, fields=fields, **extra)))
118         
119         # filter by tag, one tag at a time
120         applicable = self.applicable_filters
121         for tagfilter in self.TAGFILTERS.iteritems():
122             attr, (tagname, expr) = tagfilter
123             
124             # don't bother if there's no filter defined
125             if attr in applicable:
126                 tagfilter = basefilters.copy()
127                 tagfilter['tagname'] = tagname % replacements
128                 tagfilter[expr % replacements] = getattr(self,attr)
129                 tagfilter['node_id'] = list(candidates)
130                 
131                 candidates &= set(map(operator.itemgetter('node_id'),
132                     self._api.GetNodeTags(filters=tagfilter, fields=fields)))
133         
134         # filter by vsys tags - special case since it doesn't follow
135         # the usual semantics
136         if self.required_vsys:
137             newcandidates = collections.defaultdict(set)
138             
139             vsys_tags = self._api.GetNodeTags(
140                 tagname='vsys', 
141                 node_id = list(candidates), 
142                 fields = ['node_id','value'])
143             
144             vsys_tags = map(
145                 operator.itemgetter(['node_id','value']),
146                 vsys_tags)
147             
148             required_vsys = self.required_vsys
149             for node_id, value in vsys_tags:
150                 if value in required_vsys:
151                     newcandidates[value].add(node_id)
152             
153             # take only those that have all the required vsys tags
154             newcandidates = reduce(
155                 lambda accum, new : accum & new,
156                 newcandidates.itervalues(),
157                 candidates)
158         
159         # filter by iface count
160         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
161             # fetch interfaces for all, in one go
162             filters = basefilters.copy()
163             filters['node_id'] = list(candidates)
164             ifaces = dict(map(operator.itemgetter('node_id','interface_ids'),
165                 self._api.GetNodes(filters=basefilters, fields=('node_id','interface_ids')) ))
166             
167             # filter candidates by interface count
168             if self.min_num_external_ifaces is not None and self.max_num_external_ifaces is not None:
169                 predicate = ( lambda node_id : 
170                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
171             elif self.min_num_external_ifaces is not None:
172                 predicate = ( lambda node_id : 
173                     self.min_num_external_ifaces <= len(ifaces.get(node_id,())) )
174             else:
175                 predicate = ( lambda node_id : 
176                     len(ifaces.get(node_id,())) <= self.max_num_external_ifaces )
177             
178             candidates = set(filter(predicate, candidates))
179             
180         return candidates
181
182     def assign_node_id(self, node_id):
183         self._node_id = node_id
184         self.fetch_node_info()
185     
186     def fetch_node_info(self):
187         info = self._api.GetNodes(self._node_id)[0]
188         tags = dict( (t['tagname'],t['value'])
189                      for t in self._api.GetNodeTags(node_id=self._node_id, fields=('tagname','value')) )
190
191         self.min_num_external_ifaces = None
192         self.max_num_external_ifaces = None
193         self.timeframe = 'm'
194         
195         replacements = {'timeframe':self.timeframe}
196         for attr, tag in self.BASEFILTERS.iteritems():
197             if tag in info:
198                 value = info[tag]
199                 setattr(self, attr, value)
200         for attr, (tag,_) in self.TAGFILTERS.iteritems():
201             tag = tag % replacements
202             if tag in tags:
203                 value = tags[tag]
204                 setattr(self, attr, value)
205         
206         if 'peer_id' in info:
207             self.site = self._api.peer_map[info['peer_id']]
208         
209         if 'interface_ids' in info:
210             self.min_num_external_ifaces = \
211             self.max_num_external_ifaces = len(info['interface_ids'])
212         
213         if 'ssh_rsa_key' in info:
214             self.server_key = info['ssh_rsa_key']
215
216     def validate(self):
217         if self.home_path is None:
218             raise AssertionError, "Misconfigured node: missing home path"
219         if self.ident_path is None or not os.access(self.ident_path, os.R_OK):
220             raise AssertionError, "Misconfigured node: missing slice SSH key"
221         if self.slicename is None:
222             raise AssertionError, "Misconfigured node: unspecified slice"
223
224     def install_dependencies(self):
225         if self.required_packages:
226             # TODO: make dependant on the experiment somehow...
227             pidfile = self.DEPENDS_PIDFILE
228             logfile = self.DEPENDS_LOGFILE
229             
230             # Start process in a "daemonized" way, using nohup and heavy
231             # stdin/out redirection to avoid connection issues
232             (out,err),proc = rspawn.remote_spawn(
233                 "yum -y install %(packages)s" % {
234                     'packages' : ' '.join(self.required_packages),
235                 },
236                 pidfile = pidfile,
237                 stdout = logfile,
238                 stderr = rspawn.STDOUT,
239                 
240                 host = self.hostname,
241                 port = None,
242                 user = self.slicename,
243                 agent = None,
244                 ident_key = self.ident_path,
245                 server_key = self.server_key,
246                 sudo = True
247                 )
248             
249             if proc.wait():
250                 raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
251     
252     def wait_dependencies(self, pidprobe=1, probe=0.5, pidmax=10, probemax=10):
253         if self.required_packages:
254             pidfile = self.DEPENDS_PIDFILE
255             
256             # get PID
257             pid = ppid = None
258             for probenum in xrange(pidmax):
259                 pidtuple = rspawn.remote_check_pid(
260                     pidfile = pidfile,
261                     host = self.hostname,
262                     port = None,
263                     user = self.slicename,
264                     agent = None,
265                     ident_key = self.ident_path,
266                     server_key = self.server_key
267                     )
268                 if pidtuple:
269                     pid, ppid = pidtuple
270                     break
271                 else:
272                     time.sleep(pidprobe)
273             else:
274                 raise RuntimeError, "Failed to obtain pidfile for dependency installer"
275         
276             # wait for it to finish
277             while rspawn.RUNNING is rspawn.remote_status(
278                     pid, ppid,
279                     host = self.hostname,
280                     port = None,
281                     user = self.slicename,
282                     agent = None,
283                     ident_key = self.ident_path,
284                     server_key = self.server_key
285                     ):
286                 time.sleep(probe)
287                 probe = min(probemax, 1.5*probe)
288         
289     def is_alive(self):
290         # Make sure all the paths are created where 
291         # they have to be created for deployment
292         (out,err),proc = server.popen_ssh_command(
293             "echo 'ALIVE'",
294             host = self.hostname,
295             port = None,
296             user = self.slicename,
297             agent = None,
298             ident_key = self.ident_path,
299             server_key = self.server_key
300             )
301         
302         if proc.wait():
303             return False
304         elif not err and out.strip() == 'ALIVE':
305             return True
306         else:
307             return False
308     
309
310     def configure_routes(self, routes, devs):
311         """
312         Add the specified routes to the node's routing table
313         """
314         rules = []
315         
316         for route in routes:
317             for dev in devs:
318                 if dev.routes_here(route):
319                     # Schedule rule
320                     dest, prefix, nexthop = route
321                     rules.append(
322                         "add %s%s gw %s %s" % (
323                             dest,
324                             (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
325                             nexthop,
326                             dev.if_name,
327                         )
328                     )
329                     
330                     # Stop checking
331                     break
332             else:
333                 raise RuntimeError, "Route %s cannot be bound to any virtual interface " \
334                     "- PL can only handle rules over virtual interfaces. Candidates are: %s" % (route,devs)
335         
336         (out,err),proc = server.popen_ssh_command(
337             "( sudo -S bash -c 'cat /vsys/vroute.out >&2' & ) ; sudo -S bash -c 'cat > /vsys/vroute.in' ; sleep 0.1" % dict(
338                 home = server.shell_escape(self.home_path)),
339             host = self.hostname,
340             port = None,
341             user = self.slicename,
342             agent = None,
343             ident_key = self.ident_path,
344             server_key = self.server_key,
345             stdin = '\n'.join(rules)
346             )
347         
348         if proc.wait() or err:
349             raise RuntimeError, "Could not set routes (%s) errors: %s%s" % (rules,out,err)
350         
351         
352