c63a9549e78cc418ec65f992753f0d6dbe625106
[nepi.git] / src / nepi / testbeds / planetlab / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID, TESTBED_VERSION
5 from nepi.core import testbed_impl
6 from nepi.util.constants import TIME_NOW
7 from nepi.util.graphtools import mst
8 from nepi.util import ipaddr2
9 from nepi.util import environ
10 from nepi.util.parallel import ParallelRun
11 import sys
12 import os
13 import os.path
14 import time
15 import resourcealloc
16 import collections
17 import operator
18 import functools
19 import socket
20 import struct
21 import tempfile
22 import subprocess
23 import random
24 import shutil
25 import logging
26
27 class TempKeyError(Exception):
28     pass
29
30 class TestbedController(testbed_impl.TestbedController):
31     def __init__(self):
32         super(TestbedController, self).__init__(TESTBED_ID, TESTBED_VERSION)
33         self._home_directory = None
34         self.slicename = None
35         self._traces = dict()
36
37         import node, interfaces, application
38         self._node = node
39         self._interfaces = interfaces
40         self._app = application
41         
42         self._blacklist = set()
43         self._just_provisioned = set()
44         
45         self._load_blacklist()
46         
47         self._logger = logging.getLogger('nepi.testbeds.planetlab')
48
49     @property
50     def home_directory(self):
51         return self._home_directory
52
53     @property
54     def plapi(self):
55         if not hasattr(self, '_plapi'):
56             import plcapi
57
58             if self.authUser:
59                 self._plapi = plcapi.PLCAPI(
60                     username = self.authUser,
61                     password = self.authString,
62                     hostname = self.plcHost,
63                     urlpattern = self.plcUrl
64                     )
65             else:
66                 # anonymous access - may not be enough for much
67                 self._plapi = plcapi.PLCAPI()
68         return self._plapi
69
70     @property
71     def slice_id(self):
72         if not hasattr(self, '_slice_id'):
73             slices = self.plapi.GetSlices(self.slicename, fields=('slice_id',))
74             if slices:
75                 self._slice_id = slices[0]['slice_id']
76             else:
77                 # If it wasn't found, don't remember this failure, keep trying
78                 return None
79         return self._slice_id
80     
81     def _load_blacklist(self):
82         blpath = environ.homepath('plblacklist')
83         
84         try:
85             bl = open(blpath, "r")
86         except:
87             self._blacklist = set()
88             return
89             
90         try:
91             self._blacklist = set(
92                 map(int,
93                     map(str.strip, bl.readlines())
94                 )
95             )
96         finally:
97             bl.close()
98     
99     def _save_blacklist(self):
100         blpath = environ.homepath('plblacklist')
101         bl = open(blpath, "w")
102         try:
103             bl.writelines(
104                 map('%s\n'.__mod__, self._blacklist))
105         finally:
106             bl.close()
107     
108     def do_setup(self):
109         self._home_directory = self._attributes.\
110             get_attribute_value("homeDirectory")
111         self.slicename = self._attributes.\
112             get_attribute_value("slice")
113         self.authUser = self._attributes.\
114             get_attribute_value("authUser")
115         self.authString = self._attributes.\
116             get_attribute_value("authPass")
117         self.sliceSSHKey = self._attributes.\
118             get_attribute_value("sliceSSHKey")
119         self.sliceSSHKeyPass = None
120         self.plcHost = self._attributes.\
121             get_attribute_value("plcHost")
122         self.plcUrl = self._attributes.\
123             get_attribute_value("plcUrl")
124         self.logLevel = self._attributes.\
125             get_attribute_value("plLogLevel")
126         self.tapPortBase = self._attributes.\
127             get_attribute_value("tapPortBase")
128         
129         self._logger.setLevel(getattr(logging,self.logLevel))
130         
131         super(TestbedController, self).do_setup()
132
133     def do_post_asynclaunch(self, guid):
134         # Dependencies were launched asynchronously,
135         # so wait for them
136         dep = self._elements[guid]
137         if isinstance(dep, self._app.Dependency):
138             dep.async_setup_wait()
139     
140     # Two-phase configuration for asynchronous launch
141     do_poststep_preconfigure = staticmethod(do_post_asynclaunch)
142     do_poststep_configure = staticmethod(do_post_asynclaunch)
143
144     def do_preconfigure(self):
145         while True:
146             # Perform resource discovery if we don't have
147             # specific resources assigned yet
148             self.do_resource_discovery()
149
150             # Create PlanetLab slivers
151             self.do_provisioning()
152             
153             try:
154                 # Wait for provisioning
155                 self.do_wait_nodes()
156                 
157                 # Okkey...
158                 break
159             except self._node.UnresponsiveNodeError:
160                 # Oh... retry...
161                 pass
162         
163         # Plan application deployment
164         self.do_spanning_deployment_plan()
165
166         # Configure elements per XML data
167         super(TestbedController, self).do_preconfigure()
168
169     def do_resource_discovery(self):
170         to_provision = self._to_provision = set()
171         
172         reserved = set(self._blacklist)
173         for guid, node in self._elements.iteritems():
174             if isinstance(node, self._node.Node) and node._node_id is not None:
175                 reserved.add(node._node_id)
176         
177         # Initial algo:
178         #   look for perfectly defined nodes
179         #   (ie: those with only one candidate)
180         for guid, node in self._elements.iteritems():
181             if isinstance(node, self._node.Node) and node._node_id is None:
182                 # Try existing nodes first
183                 # If we have only one candidate, simply use it
184                 candidates = node.find_candidates(
185                     filter_slice_id = self.slice_id)
186                 candidates -= reserved
187                 if len(candidates) == 1:
188                     node_id = iter(candidates).next()
189                     node.assign_node_id(node_id)
190                     reserved.add(node_id)
191                 elif not candidates:
192                     # Try again including unassigned nodes
193                     candidates = node.find_candidates()
194                     candidates -= reserved
195                     if len(candidates) > 1:
196                         continue
197                     if len(candidates) == 1:
198                         node_id = iter(candidates).next()
199                         node.assign_node_id(node_id)
200                         to_provision.add(node_id)
201                         reserved.add(node_id)
202                     elif not candidates:
203                         raise RuntimeError, "Cannot assign resources for node %s, no candidates sith %s" % (guid,
204                             node.make_filter_description())
205         
206         # Now do the backtracking search for a suitable solution
207         # First with existing slice nodes
208         reqs = []
209         nodes = []
210         for guid, node in self._elements.iteritems():
211             if isinstance(node, self._node.Node) and node._node_id is None:
212                 # Try existing nodes first
213                 # If we have only one candidate, simply use it
214                 candidates = node.find_candidates(
215                     filter_slice_id = self.slice_id)
216                 candidates -= reserved
217                 reqs.append(candidates)
218                 nodes.append(node)
219         
220         if nodes and reqs:
221             try:
222                 solution = resourcealloc.alloc(reqs)
223             except resourcealloc.ResourceAllocationError:
224                 # Failed, try again with all nodes
225                 reqs = []
226                 for node in nodes:
227                     candidates = node.find_candidates()
228                     candidates -= reserved
229                     reqs.append(candidates)
230                 
231                 solution = resourcealloc.alloc(reqs)
232                 to_provision.update(solution)
233             
234             # Do assign nodes
235             for node, node_id in zip(nodes, solution):
236                 node.assign_node_id(node_id)
237
238     def do_provisioning(self):
239         if self._to_provision:
240             # Add new nodes to the slice
241             cur_nodes = self.plapi.GetSlices(self.slicename, ['node_ids'])[0]['node_ids']
242             new_nodes = list(set(cur_nodes) | self._to_provision)
243             self.plapi.UpdateSlice(self.slicename, nodes=new_nodes)
244
245         # cleanup
246         self._just_provisioned = self._to_provision
247         del self._to_provision
248     
249     def do_wait_nodes(self):
250         for guid, node in self._elements.iteritems():
251             if isinstance(node, self._node.Node):
252                 # Just inject configuration stuff
253                 node.home_path = "nepi-node-%s" % (guid,)
254                 node.ident_path = self.sliceSSHKey
255                 node.slicename = self.slicename
256             
257                 # Show the magic
258                 self._logger.info("PlanetLab Node %s configured at %s", guid, node.hostname)
259             
260         try:
261             for guid, node in self._elements.iteritems():
262                 if isinstance(node, self._node.Node):
263                     self._logger.info("Waiting for Node %s configured at %s", guid, node.hostname)
264                     
265                     node.wait_provisioning(
266                         (20*60 if node._node_id in self._just_provisioned else 60)
267                     )
268                     
269                     self._logger.info("READY Node %s at %s", guid, node.hostname)
270                     
271                     # Prepare dependency installer now
272                     node.prepare_dependencies()
273         except self._node.UnresponsiveNodeError:
274             # Uh... 
275             self._logger.warn("UNRESPONSIVE Node %s", node.hostname)
276             
277             # Mark all dead nodes (which are unresponsive) on the blacklist
278             # and re-raise
279             for guid, node in self._elements.iteritems():
280                 if isinstance(node, self._node.Node):
281                     if not node.is_alive():
282                         self._logger.warn("Blacklisting %s for unresponsiveness", node.hostname)
283                         self._blacklist.add(node._node_id)
284                         node.unassign_node()
285             
286             try:
287                 self._save_blacklist()
288             except:
289                 # not important...
290                 import traceback
291                 traceback.print_exc()
292             
293             raise
294     
295     def do_spanning_deployment_plan(self):
296         # Create application groups by collecting all applications
297         # based on their hash - the hash should contain everything that
298         # defines them and the platform they're built
299         
300         def dephash(app):
301             return (
302                 frozenset((app.depends or "").split(' ')),
303                 frozenset((app.sources or "").split(' ')),
304                 app.build,
305                 app.install,
306                 app.node.architecture,
307                 app.node.operatingSystem,
308                 app.node.pl_distro,
309             )
310         
311         depgroups = collections.defaultdict(list)
312         
313         for element in self._elements.itervalues():
314             if isinstance(element, self._app.Dependency):
315                 depgroups[dephash(element)].append(element)
316             elif isinstance(element, self._node.Node):
317                 deps = element._yum_dependencies
318                 if deps:
319                     depgroups[dephash(deps)].append(deps)
320         
321         # Set up spanning deployment for those applications that
322         # have been deployed in several nodes.
323         for dh, group in depgroups.iteritems():
324             if len(group) > 1:
325                 # Pick root (deterministically)
326                 root = min(group, key=lambda app:app.node.hostname)
327                 
328                 # Obtain all IPs in numeric format
329                 # (which means faster distance computations)
330                 for dep in group:
331                     dep._ip = socket.gethostbyname(dep.node.hostname)
332                     dep._ip_n = struct.unpack('!L', socket.inet_aton(dep._ip))[0]
333                 
334                 # Compute plan
335                 # NOTE: the plan is an iterator
336                 plan = mst.mst(
337                     group,
338                     lambda a,b : ipaddr2.ipdistn(a._ip_n, b._ip_n),
339                     root = root,
340                     maxbranching = 2)
341                 
342                 # Re-sign private key
343                 try:
344                     tempprk, temppuk, tmppass = self._make_temp_private_key()
345                 except TempKeyError:
346                     continue
347                 
348                 # Set up slaves
349                 plan = list(plan)
350                 for slave, master in plan:
351                     slave.set_master(master)
352                     slave.install_keys(tempprk, temppuk, tmppass)
353                     
354         # We don't need the user's passphrase anymore
355         self.sliceSSHKeyPass = None
356     
357     def _make_temp_private_key(self):
358         # Get the user's key's passphrase
359         if not self.sliceSSHKeyPass:
360             if 'SSH_ASKPASS' in os.environ:
361                 proc = subprocess.Popen(
362                     [ os.environ['SSH_ASKPASS'],
363                       "Please type the passphrase for the %s SSH identity file. "
364                       "The passphrase will be used to re-cipher the identity file with "
365                       "a random 256-bit key for automated chain deployment on the "
366                       "%s PlanetLab slice" % ( 
367                         os.path.basename(self.sliceSSHKey), 
368                         self.slicename
369                     ) ],
370                     stdin = open("/dev/null"),
371                     stdout = subprocess.PIPE,
372                     stderr = subprocess.PIPE)
373                 out,err = proc.communicate()
374                 self.sliceSSHKeyPass = out.strip()
375         
376         if not self.sliceSSHKeyPass:
377             raise TempKeyError
378         
379         # Create temporary key files
380         prk = tempfile.NamedTemporaryFile(
381             dir = self.root_directory,
382             prefix = "pl_deploy_tmpk_",
383             suffix = "")
384
385         puk = tempfile.NamedTemporaryFile(
386             dir = self.root_directory,
387             prefix = "pl_deploy_tmpk_",
388             suffix = ".pub")
389             
390         # Create secure 256-bits temporary passphrase
391         passphrase = ''.join(map(chr,[rng.randint(0,255) 
392                                       for rng in (random.SystemRandom(),)
393                                       for i in xrange(32)] )).encode("hex")
394                 
395         # Copy keys
396         oprk = open(self.sliceSSHKey, "rb")
397         opuk = open(self.sliceSSHKey+".pub", "rb")
398         shutil.copymode(oprk.name, prk.name)
399         shutil.copymode(opuk.name, puk.name)
400         shutil.copyfileobj(oprk, prk)
401         shutil.copyfileobj(opuk, puk)
402         prk.flush()
403         puk.flush()
404         oprk.close()
405         opuk.close()
406         
407         # A descriptive comment
408         comment = "%s#NEPI_INTERNAL@%s" % (self.authUser, self.slicename)
409         
410         # Recipher keys
411         proc = subprocess.Popen(
412             ["ssh-keygen", "-p",
413              "-f", prk.name,
414              "-P", self.sliceSSHKeyPass,
415              "-N", passphrase,
416              "-C", comment ],
417             stdout = subprocess.PIPE,
418             stderr = subprocess.PIPE,
419             stdin = subprocess.PIPE
420         )
421         out, err = proc.communicate()
422         
423         if err:
424             raise RuntimeError, "Problem generating keys: \n%s\n%r" % (
425                 out, err)
426         
427         prk.seek(0)
428         puk.seek(0)
429         
430         # Change comment on public key
431         puklines = puk.readlines()
432         puklines[0] = puklines[0].split(' ')
433         puklines[0][-1] = comment+'\n'
434         puklines[0] = ' '.join(puklines[0])
435         puk.seek(0)
436         puk.truncate()
437         puk.writelines(puklines)
438         del puklines
439         puk.flush()
440         
441         return prk, puk, passphrase
442     
443     def set(self, guid, name, value, time = TIME_NOW):
444         super(TestbedController, self).set(guid, name, value, time)
445         # TODO: take on account schedule time for the task
446         element = self._elements[guid]
447         if element:
448             setattr(element, name, value)
449
450             if hasattr(element, 'refresh'):
451                 # invoke attribute refresh hook
452                 element.refresh()
453
454     def get(self, guid, name, time = TIME_NOW):
455         value = super(TestbedController, self).get(guid, name, time)
456         # TODO: take on account schedule time for the task
457         factory_id = self._create[guid]
458         factory = self._factories[factory_id]
459         element = self._elements.get(guid)
460         try:
461             return getattr(element, name)
462         except (KeyError, AttributeError):
463             return value
464
465     def get_address(self, guid, index, attribute='Address'):
466         index = int(index)
467
468         # try the real stuff
469         iface = self._elements.get(guid)
470         if iface and index == 0:
471             if attribute == 'Address':
472                 return iface.address
473             elif attribute == 'NetPrefix':
474                 return iface.netprefix
475             elif attribute == 'Broadcast':
476                 return iface.broadcast
477
478         # if all else fails, query box
479         return super(TestbedController, self).get_address(guid, index, attribute)
480
481     def action(self, time, guid, action):
482         raise NotImplementedError
483
484     def shutdown(self):
485         for trace in self._traces.itervalues():
486             trace.close()
487         
488         runner = ParallelRun(16)
489         runner.start()
490         for element in self._elements.itervalues():
491             # invoke cleanup hooks
492             if hasattr(element, 'cleanup'):
493                 runner.put(element.cleanup)
494         runner.join()
495         
496         runner = ParallelRun(16)
497         runner.start()
498         for element in self._elements.itervalues():
499             # invoke destroy hooks
500             if hasattr(element, 'destroy'):
501                 runner.put(element.destroy)
502         runner.join()
503         
504         self._elements.clear()
505         self._traces.clear()
506
507     def trace(self, guid, trace_id, attribute='value'):
508         app = self._elements[guid]
509
510         if attribute == 'value':
511             path = app.sync_trace(self.home_directory, trace_id)
512             if path:
513                 fd = open(path, "r")
514                 content = fd.read()
515                 fd.close()
516             else:
517                 content = None
518         elif attribute == 'path':
519             content = app.remote_trace_path(trace_id)
520         else:
521             content = None
522         return content
523
524     def follow_trace(self, trace_id, trace):
525         self._traces[trace_id] = trace
526     
527     def _make_generic(self, parameters, kind):
528         app = kind(self.plapi)
529
530         # Note: there is 1-to-1 correspondence between attribute names
531         #   If that changes, this has to change as well
532         for attr,val in parameters.iteritems():
533             setattr(app, attr, val)
534
535         return app
536
537     def _make_node(self, parameters):
538         return self._make_generic(parameters, self._node.Node)
539
540     def _make_node_iface(self, parameters):
541         return self._make_generic(parameters, self._interfaces.NodeIface)
542
543     def _make_tun_iface(self, parameters):
544         return self._make_generic(parameters, self._interfaces.TunIface)
545
546     def _make_tap_iface(self, parameters):
547         return self._make_generic(parameters, self._interfaces.TapIface)
548
549     def _make_netpipe(self, parameters):
550         return self._make_generic(parameters, self._interfaces.NetPipe)
551
552     def _make_internet(self, parameters):
553         return self._make_generic(parameters, self._interfaces.Internet)
554
555     def _make_application(self, parameters):
556         return self._make_generic(parameters, self._app.Application)
557
558     def _make_dependency(self, parameters):
559         return self._make_generic(parameters, self._app.Dependency)
560
561     def _make_nepi_dependency(self, parameters):
562         return self._make_generic(parameters, self._app.NepiDependency)
563
564     def _make_ns3_dependency(self, parameters):
565         return self._make_generic(parameters, self._app.NS3Dependency)
566