465a557e4980b111f582d3d11a562438f6e45759
[nepi.git] / src / nepi / core / testbed_impl.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from nepi.core import execute
5 from nepi.core.metadata import Metadata, Parallel
6 from nepi.util import validation
7 from nepi.util.constants import TIME_NOW, \
8         ApplicationStatus as AS, \
9         TestbedStatus as TS, \
10         CONNECTION_DELAY
11 from nepi.util.parallel import ParallelRun
12
13 import collections
14 import copy
15 import logging
16
17 class TestbedController(execute.TestbedController):
18     def __init__(self, testbed_id, testbed_version):
19         super(TestbedController, self).__init__(testbed_id, testbed_version)
20         self._status = TS.STATUS_ZERO
21         # testbed attributes for validation
22         self._attributes = None
23         # element factories for validation
24         self._factories = dict()
25
26         # experiment construction instructions
27         self._create = dict()
28         self._create_set = dict()
29         self._factory_set = dict()
30         self._connect = dict()
31         self._cross_connect = dict()
32         self._add_trace = dict()
33         self._add_address = dict()
34         self._add_route = dict()
35         self._configure = dict()
36
37         # log of set operations
38         self._setlog = dict()
39         # last set operations
40         self._set = dict()
41
42         # testbed element instances
43         self._elements = dict()
44
45         self._metadata = Metadata(self._testbed_id)
46         if self._metadata.testbed_version != testbed_version:
47             raise RuntimeError("Bad testbed version on testbed %s. Asked for %s, got %s" % \
48                     (testbed_id, testbed_version, self._metadata.testbed_version))
49         for factory in self._metadata.build_factories():
50             self._factories[factory.factory_id] = factory
51         self._attributes = self._metadata.testbed_attributes()
52         self._root_directory = None
53         
54         # Logging
55         self._logger = logging.getLogger("nepi.core.testbed_impl")
56     
57     @property
58     def root_directory(self):
59         return self._root_directory
60
61     @property
62     def guids(self):
63         return self._create.keys()
64
65     @property
66     def elements(self):
67         return self._elements
68     
69     def defer_configure(self, name, value):
70         self._validate_testbed_attribute(name)
71         self._validate_testbed_value(name, value)
72         self._attributes.set_attribute_value(name, value)
73         self._configure[name] = value
74
75     def defer_create(self, guid, factory_id):
76         self._validate_factory_id(factory_id)
77         self._validate_not_guid(guid)
78         self._create[guid] = factory_id
79
80     def defer_create_set(self, guid, name, value):
81         self._validate_guid(guid)
82         self._validate_box_attribute(guid, name)
83         self._validate_box_value(guid, name, value)
84         if guid not in self._create_set:
85             self._create_set[guid] = dict()
86         self._create_set[guid][name] = value
87
88     def defer_factory_set(self, guid, name, value):
89         self._validate_guid(guid)
90         self._validate_factory_attribute(guid, name)
91         self._validate_factory_value(guid, name, value)
92         if guid not in self._factory_set:
93             self._factory_set[guid] = dict()
94         self._factory_set[guid][name] = value
95
96     def defer_connect(self, guid1, connector_type_name1, guid2, 
97             connector_type_name2):
98         self._validate_guid(guid1)
99         self._validate_guid(guid2)
100         factory1 = self._get_factory(guid1)
101         factory_id2 = self._create[guid2]
102         connector_type = factory1.connector_type(connector_type_name1)
103         connector_type.can_connect(self._testbed_id, factory_id2, 
104                 connector_type_name2, False)
105         self._validate_connection(guid1, connector_type_name1, guid2, 
106             connector_type_name2)
107
108         if not guid1 in self._connect:
109             self._connect[guid1] = dict()
110         if not connector_type_name1 in self._connect[guid1]:
111              self._connect[guid1][connector_type_name1] = dict()
112         self._connect[guid1][connector_type_name1][guid2] = \
113                connector_type_name2
114         if not guid2 in self._connect:
115             self._connect[guid2] = dict()
116         if not connector_type_name2 in self._connect[guid2]:
117              self._connect[guid2][connector_type_name2] = dict()
118         self._connect[guid2][connector_type_name2][guid1] = \
119                connector_type_name1
120
121     def defer_cross_connect(self, guid, connector_type_name, cross_guid, 
122             cross_testbed_guid, cross_testbed_id, cross_factory_id, 
123             cross_connector_type_name):
124         self._validate_guid(guid)
125         factory = self._get_factory(guid)
126         connector_type = factory.connector_type(connector_type_name)
127         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
128                 cross_connector_type_name, True)
129         self._validate_connection(guid, connector_type_name, cross_guid, 
130             cross_connector_type_name)
131
132         if not guid in self._cross_connect:
133             self._cross_connect[guid] = dict()
134         if not connector_type_name in self._cross_connect[guid]:
135              self._cross_connect[guid][connector_type_name] = dict()
136         self._cross_connect[guid][connector_type_name] = \
137                 (cross_guid, cross_testbed_guid, cross_testbed_id, 
138                 cross_factory_id, cross_connector_type_name)
139
140     def defer_add_trace(self, guid, trace_name):
141         self._validate_guid(guid)
142         self._validate_trace(guid, trace_name)
143         if not guid in self._add_trace:
144             self._add_trace[guid] = list()
145         self._add_trace[guid].append(trace_name)
146
147     def defer_add_address(self, guid, address, netprefix, broadcast):
148         self._validate_guid(guid)
149         self._validate_allow_addresses(guid)
150         if guid not in self._add_address:
151             self._add_address[guid] = list()
152         self._add_address[guid].append((address, netprefix, broadcast))
153
154     def defer_add_route(self, guid, destination, netprefix, nexthop, metric = 0):
155         self._validate_guid(guid)
156         self._validate_allow_routes(guid)
157         if not guid in self._add_route:
158             self._add_route[guid] = list()
159         self._add_route[guid].append((destination, netprefix, nexthop, metric)) 
160
161     def do_setup(self):
162         self._root_directory = self._attributes.\
163             get_attribute_value("rootDirectory")
164         self._status = TS.STATUS_SETUP
165
166     def do_create(self):
167         def set_params(self, guid):
168             parameters = self._get_parameters(guid)
169             for name, value in parameters.iteritems():
170                 self.set(guid, name, value)
171             
172         self._do_in_factory_order(
173             'create_function',
174             self._metadata.create_order,
175             postaction = set_params )
176         self._status = TS.STATUS_CREATED
177
178     def _do_connect(self, init = True):
179         unconnected = copy.deepcopy(self._connect)
180         
181         while unconnected:
182             for guid1, connections in unconnected.items():
183                 factory1 = self._get_factory(guid1)
184                 for connector_type_name1, connections2 in connections.items():
185                     connector_type1 = factory1.connector_type(connector_type_name1)
186                     for guid2, connector_type_name2 in connections2.items():
187                         factory_id2 = self._create[guid2]
188                         # Connections are executed in a "From -> To" direction only
189                         # This explicitly ignores the "To -> From" (mirror) 
190                         # connections of every connection pair.
191                         if init:
192                             connect_code = connector_type1.connect_to_init_code(
193                                     self._testbed_id, factory_id2, 
194                                     connector_type_name2,
195                                     False)
196                         else:
197                             connect_code = connector_type1.connect_to_compl_code(
198                                     self._testbed_id, factory_id2, 
199                                     connector_type_name2,
200                                     False)
201                         delay = None
202                         if connect_code:
203                             delay = connect_code(self, guid1, guid2)
204
205                         if delay is not CONNECTION_DELAY:
206                             del unconnected[guid1][connector_type_name1][guid2]
207                     if not unconnected[guid1][connector_type_name1]:
208                         del unconnected[guid1][connector_type_name1]
209                 if not unconnected[guid1]:
210                     del unconnected[guid1]
211
212     def do_connect_init(self):
213         self._do_connect()
214
215     def do_connect_compl(self):
216         self._do_connect(init = False)
217         self._status = TS.STATUS_CONNECTED
218
219     def _do_in_factory_order(self, action, order, postaction = None, poststep = None):
220         logger = self._logger
221         
222         guids = collections.defaultdict(list)
223         # order guids (elements) according to factory_id
224         for guid, factory_id in self._create.iteritems():
225             guids[factory_id].append(guid)
226         
227         # configure elements following the factory_id order
228         for factory_id in order:
229             # Create a parallel runner if we're given a Parallel() wrapper
230             runner = None
231             if isinstance(factory_id, Parallel):
232                 runner = ParallelRun(factory_id.maxthreads)
233                 factory_id = factory_id.factory
234             
235             # omit the factories that have no element to create
236             if factory_id not in guids:
237                 continue
238             
239             # configure action
240             factory = self._factories[factory_id]
241             if not getattr(factory, action):
242                 continue
243             def perform_action(guid):
244                 getattr(factory, action)(self, guid)
245                 if postaction:
246                     postaction(self, guid)
247
248             # perform the action on all elements, in parallel if so requested
249             if runner:
250                 logger.debug("Starting parallel %s", action)
251                 runner.start()
252
253             for guid in guids[factory_id]:
254                 if runner:
255                     logger.debug("Scheduling %s on %s", action, guid)
256                     runner.put(perform_action, guid)
257                 else:
258                     logger.debug("Performing %s on %s", action, guid)
259                     perform_action(guid)
260             
261             # post hook
262             if poststep:
263                 for guid in guids[factory_id]:
264                     if runner:
265                         logger.debug("Scheduling post-%s on %s", action, guid)
266                         runner.put(poststep, self, guid)
267                     else:
268                         logger.debug("Performing post-%s on %s", action, guid)
269                         poststep(self, guid)
270
271             # sync
272             if runner:
273                 runner.join()
274                 logger.debug("Finished parallel %s", action)
275
276     @staticmethod
277     def do_poststep_preconfigure(self, guid):
278         # dummy hook for implementations interested in
279         # two-phase configuration
280         pass
281
282     def do_preconfigure(self):
283         self._do_in_factory_order(
284             'preconfigure_function',
285             self._metadata.preconfigure_order,
286             poststep = self.do_poststep_preconfigure )
287
288     @staticmethod
289     def do_poststep_configure(self, guid):
290         # dummy hook for implementations interested in
291         # two-phase configuration
292         pass
293
294     def do_configure(self):
295         self._do_in_factory_order(
296             'configure_function',
297             self._metadata.configure_order,
298             poststep = self.do_poststep_configure )
299         self._status = TS.STATUS_CONFIGURED
300
301     def do_prestart(self):
302         self._do_in_factory_order(
303             'prestart_function',
304             self._metadata.prestart_order )
305
306     def _do_cross_connect(self, cross_data, init = True):
307         for guid, cross_connections in self._cross_connect.iteritems():
308             factory = self._get_factory(guid)
309             for connector_type_name, cross_connection in \
310                     cross_connections.iteritems():
311                 connector_type = factory.connector_type(connector_type_name)
312                 (cross_guid, cross_testbed_guid, cross_testbed_id,
313                     cross_factory_id, cross_connector_type_name) = cross_connection
314                 if init:
315                     connect_code = connector_type.connect_to_init_code(
316                         cross_testbed_id, cross_factory_id, 
317                         cross_connector_type_name,
318                         True)
319                 else:
320                     connect_code = connector_type.connect_to_compl_code(
321                         cross_testbed_id, cross_factory_id, 
322                         cross_connector_type_name,
323                         True)
324                 if connect_code:
325                     elem_cross_data = cross_data[cross_testbed_guid][cross_guid]
326                     connect_code(self, guid, elem_cross_data)       
327
328     def do_cross_connect_init(self, cross_data):
329         self._do_cross_connect(cross_data)
330
331     def do_cross_connect_compl(self, cross_data):
332         self._do_cross_connect(cross_data, init = False)
333         self._status = TS.STATUS_CROSS_CONNECTED
334
335     def set(self, guid, name, value, time = TIME_NOW):
336         self._validate_guid(guid)
337         self._validate_box_attribute(guid, name)
338         self._validate_box_value(guid, name, value)
339         self._validate_modify_box_value(guid, name)
340         if guid not in self._set:
341             self._set[guid] = dict()
342             self._setlog[guid] = dict()
343         if time not in self._setlog[guid]:
344             self._setlog[guid][time] = dict()
345         self._setlog[guid][time][name] = value
346         self._set[guid][name] = value
347
348     def get(self, guid, name, time = TIME_NOW):
349         """
350         gets an attribute from box definitions if available. 
351         Throws KeyError if the GUID wasn't created
352         through the defer_create interface, and AttributeError if the
353         attribute isn't available (doesn't exist or is design-only)
354         """
355         self._validate_guid(guid)
356         self._validate_box_attribute(guid, name)
357         if guid in self._set and name in self._set[guid]:
358             return self._set[guid][name]
359         if guid in self._create_set and name in self._create_set[guid]:
360             return self._create_set[guid][name]
361         # if nothing else found, returns the factory default value
362         factory = self._get_factory(guid)
363         return factory.box_attributes.get_attribute_value(name)
364
365     def get_route(self, guid, index, attribute):
366         """
367         returns information given to defer_add_route.
368         
369         Raises AttributeError if an invalid attribute is requested
370             or if the indexed routing rule does not exist.
371         
372         Raises KeyError if the GUID has not been seen by
373             defer_add_route
374         """
375         ATTRIBUTES = ['Destination', 'NetPrefix', 'NextHop']
376         
377         if attribute not in ATTRIBUTES:
378             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
379         
380         attribute_index = ATTRIBUTES.index(attribute)
381         
382         routes = self._add_route.get(guid)
383         if not routes:
384             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
385        
386         index = int(index)
387         if not (0 <= index < len(addresses)):
388             raise AttributeError, "GUID %r at %s does not have a routing entry #%s" % (
389                 guid, self._testbed_id, index)
390         
391         return routes[index][attribute_index]
392
393     def get_address(self, guid, index, attribute='Address'):
394         """
395         returns information given to defer_add_address
396         
397         Raises AttributeError if an invalid attribute is requested
398             or if the indexed routing rule does not exist.
399         
400         Raises KeyError if the GUID has not been seen by
401             defer_add_address
402         """
403         ATTRIBUTES = ['Address', 'NetPrefix', 'Broadcast']
404         
405         if attribute not in ATTRIBUTES:
406             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
407         
408         attribute_index = ATTRIBUTES.index(attribute)
409         
410         addresses = self._add_address.get(guid)
411         if not addresses:
412             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
413         
414         index = int(index)
415         if not (0 <= index < len(addresses)):
416             raise AttributeError, "GUID %r at %s does not have an address #%s" % (
417                 guid, self._testbed_id, index)
418         
419         return addresses[index][attribute_index]
420
421     def get_attribute_list(self, guid, filter_flags = None, exclude = False):
422         factory = self._get_factory(guid)
423         attribute_list = list()
424         return factory.box_attributes.get_attribute_list(filter_flags, exclude)
425
426     def get_factory_id(self, guid):
427         factory = self._get_factory(guid)
428         return factory.factory_id
429
430     def start(self, time = TIME_NOW):
431         self._do_in_factory_order(
432             'start_function',
433             self._metadata.start_order )
434         self._status = TS.STATUS_STARTED
435
436     #action: NotImplementedError
437
438     def stop(self, time = TIME_NOW):
439         self._do_in_factory_order(
440             'stop_function',
441             reversed(self._metadata.start_order) )
442         self._status = TS.STATUS_STOPPED
443
444     def status(self, guid = None):
445         if not guid:
446             return self._status
447         self._validate_guid(guid)
448         factory = self._get_factory(guid)
449         status_function = factory.status_function
450         if status_function:
451             return status_function(self, guid)
452         return AS.STATUS_UNDETERMINED
453
454     def trace(self, guid, trace_id, attribute='value'):
455         if attribute == 'value':
456             fd = open("%s" % self.trace_filepath(guid, trace_id), "r")
457             content = fd.read()
458             fd.close()
459         elif attribute == 'path':
460             content = self.trace_filepath(guid, trace_id)
461         else:
462             content = None
463         return content
464
465     def traces_info(self):
466         traces_info = dict()
467         host = self._attributes.get_attribute_value("deployment_host")
468         user = self._attributes.get_attribute_value("deployment_user")
469         for guid, trace_list in self._add_trace.iteritems(): 
470             traces_info[guid] = dict()
471             for trace_id in trace_list:
472                 traces_info[guid][trace_id] = dict()
473                 filepath = self.trace(guid, trace_id, attribute = "path")
474                 traces_info[guid][trace_id]["host"] = host
475                 traces_info[guid][trace_id]["user"] = user
476                 traces_info[guid][trace_id]["filepath"] = filepath
477         return traces_info
478
479     def trace_filepath(self, guid, trace_id):
480         """
481         Return a trace's file path, for TestbedController's default 
482         implementation of trace()
483         """
484         raise NotImplementedError
485
486     #shutdown: NotImplementedError
487
488     def get_connected(self, guid, connector_type_name, 
489             other_connector_type_name):
490         """searchs the connected elements for the specific connector_type_name 
491         pair"""
492         if guid not in self._connect:
493             return []
494         # all connections for all connectors for guid
495         all_connections = self._connect[guid]
496         if connector_type_name not in all_connections:
497             return []
498         # all connections for the specific connector
499         connections = all_connections[connector_type_name]
500         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
501                 in connections.iteritems() if \
502                 otr_connector_type_name == other_connector_type_name]
503         return specific_connections
504
505     def _get_connection_count(self, guid, connection_type_name):
506         count = 0
507         cross_count = 0
508         if guid in self._connect and connection_type_name in \
509                 self._connect[guid]:
510             count = len(self._connect[guid][connection_type_name])
511         if guid in self._cross_connect and connection_type_name in \
512                 self._cross_connect[guid]:
513             cross_count = len(self._cross_connect[guid][connection_type_name])
514         return count + cross_count
515
516     def _get_traces(self, guid):
517         return [] if guid not in self._add_trace else self._add_trace[guid]
518
519     def _get_parameters(self, guid):
520         return dict() if guid not in self._create_set else \
521                 self._create_set[guid]
522
523     def _get_factory(self, guid):
524         factory_id = self._create[guid]
525         return self._factories[factory_id]
526
527     def _get_factory_id(self, guid):
528         """ Returns the factory ID of the (perhaps not yet) created object """
529         return self._create.get(guid, None)
530
531     def _validate_guid(self, guid):
532         if not guid in self._create:
533             raise RuntimeError("Element guid %d doesn't exist" % guid)
534
535     def _validate_not_guid(self, guid):
536         if guid in self._create:
537             raise AttributeError("Cannot add elements with the same guid: %d" %
538                     guid)
539
540     def _validate_factory_id(self, factory_id):
541         if factory_id not in self._factories:
542             raise AttributeError("Invalid element type %s for testbed version %s" %
543                     (factory_id, self._testbed_version))
544
545     def _validate_testbed_attribute(self, name):
546         if not self._attributes.has_attribute(name):
547             raise AttributeError("Invalid testbed attribute %s for testbed" % \
548                     name)
549
550     def _validate_testbed_value(self, name, value):
551         if not self._attributes.is_attribute_value_valid(name, value):
552             raise AttributeError("Invalid value %s for testbed attribute %s" % \
553                 (value, name))
554
555     def _validate_box_attribute(self, guid, name):
556         factory = self._get_factory(guid)
557         if not factory.box_attributes.has_attribute(name):
558             raise AttributeError("Invalid attribute %s for element type %s" %
559                     (name, factory.factory_id))
560
561     def _validate_box_value(self, guid, name, value):
562         factory = self._get_factory(guid)
563         if not factory.box_attributes.is_attribute_value_valid(name, value):
564             raise AttributeError("Invalid value %s for attribute %s" % \
565                 (value, name))
566
567     def _validate_factory_attribute(self, guid, name):
568         factory = self._get_factory(guid)
569         if not factory.has_attribute(name):
570             raise AttributeError("Invalid attribute %s for element type %s" %
571                     (name, factory.factory_id))
572
573     def _validate_factory_value(self, guid, name, value):
574         factory = self._get_factory(guid)
575         if not factory.is_attribute_value_valid(name, value):
576             raise AttributeError("Invalid value %s for attribute %s" % \
577                 (value, name))
578
579     def _validate_trace(self, guid, trace_name):
580         factory = self._get_factory(guid)
581         if not trace_name in factory.traces_list:
582             raise RuntimeError("Element type '%s' has no trace '%s'" %
583                     (factory.factory_id, trace_name))
584
585     def _validate_allow_addresses(self, guid):
586         factory = self._get_factory(guid)
587         if not factory.allow_addresses:
588             raise RuntimeError("Element type '%s' doesn't support addresses" %
589                     factory.factory_id)
590         attr_name = "maxAddresses"
591         if guid in self._create_set and attr_name in self._create_set[guid]:
592             max_addresses = self._create_set[guid][attr_name]
593         else:
594             factory = self._get_factory(guid)
595             max_addresses = factory.box_attributes.get_attribute_value(attr_name)
596         if guid in self._add_address:
597             count_addresses = len(self._add_address[guid])
598             if max_addresses == count_addresses:
599                 raise RuntimeError("Element guid %d of type '%s' can't accept \
600                         more addresses" % (guid, factory.factory_id))
601
602     def _validate_allow_routes(self, guid):
603         factory = self._get_factory(guid)
604         if not factory.allow_routes:
605             raise RuntimeError("Element type '%s' doesn't support routes" %
606                     factory.factory_id)
607
608     def _validate_connection(self, guid1, connector_type_name1, guid2, 
609             connector_type_name2, cross = False):
610         # can't connect with self
611         if guid1 == guid2:
612             raise AttributeError("Can't connect guid %d to self" % \
613                 (guid1))
614         # the connection is already done, so ignore
615         connected = self.get_connected(guid1, connector_type_name1, 
616                 connector_type_name2)
617         if guid2 in connected:
618             return
619         count1 = self._get_connection_count(guid1, connector_type_name1)
620         factory1 = self._get_factory(guid1)
621         connector_type1 = factory1.connector_type(connector_type_name1)
622         if count1 == connector_type1.max:
623             raise AttributeError("Connector %s is full for guid %d" % \
624                 (connector_type_name1, guid1))
625
626     def _validate_modify_box_value(self, guid, name):
627         factory = self._get_factory(guid)
628         if self._status > TS.STATUS_STARTED and \
629                 (factory.box_attributes.is_attribute_exec_read_only(name) or \
630                 factory.box_attributes.is_attribute_exec_immutable(name)):
631             raise AttributeError("Attribute %s can only be modified during experiment design" % name)
632