ade1cccdc174fbfd26d8c1ab264600773fcf055f
[nepi.git] / src / nepi / core / testbed_impl.py
1 # -*- coding: utf-8 -*-
2
3 from nepi.core import execute
4 from nepi.core.metadata import Metadata, Parallel
5 from nepi.util import validation
6 from nepi.util.constants import TIME_NOW, \
7         ApplicationStatus as AS, \
8         TestbedStatus as TS, \
9         CONNECTION_DELAY
10 from nepi.util.parallel import ParallelRun
11
12 import collections
13 import copy
14 import logging
15
16 class TestbedController(execute.TestbedController):
17     def __init__(self, testbed_id, testbed_version):
18         super(TestbedController, self).__init__(testbed_id, testbed_version)
19         self._status = TS.STATUS_ZERO
20         # testbed attributes for validation
21         self._attributes = None
22         # element factories for validation
23         self._factories = dict()
24
25         # experiment construction instructions
26         self._create = dict()
27         self._create_set = dict()
28         self._factory_set = dict()
29         self._connect = dict()
30         self._cross_connect = dict()
31         self._add_trace = dict()
32         self._add_address = dict()
33         self._add_route = dict()
34         self._configure = dict()
35
36         # log of set operations
37         self._setlog = dict()
38         # last set operations
39         self._set = dict()
40
41         # testbed element instances
42         self._elements = dict()
43
44         self._metadata = Metadata(self._testbed_id)
45         if self._metadata.testbed_version != testbed_version:
46             raise RuntimeError("Bad testbed version on testbed %s. Asked for %s, got %s" % \
47                     (testbed_id, testbed_version, self._metadata.testbed_version))
48         for factory in self._metadata.build_factories():
49             self._factories[factory.factory_id] = factory
50         self._attributes = self._metadata.testbed_attributes()
51         self._root_directory = None
52         
53         # Logging
54         self._logger = logging.getLogger("nepi.core.testbed_impl")
55     
56     @property
57     def root_directory(self):
58         return self._root_directory
59
60     @property
61     def guids(self):
62         return self._create.keys()
63
64     @property
65     def elements(self):
66         return self._elements
67     
68     def defer_configure(self, name, value):
69         self._validate_testbed_attribute(name)
70         self._validate_testbed_value(name, value)
71         self._attributes.set_attribute_value(name, value)
72         self._configure[name] = value
73
74     def defer_create(self, guid, factory_id):
75         self._validate_factory_id(factory_id)
76         self._validate_not_guid(guid)
77         self._create[guid] = factory_id
78
79     def defer_create_set(self, guid, name, value):
80         self._validate_guid(guid)
81         self._validate_box_attribute(guid, name)
82         self._validate_box_value(guid, name, value)
83         if guid not in self._create_set:
84             self._create_set[guid] = dict()
85         self._create_set[guid][name] = value
86
87     def defer_factory_set(self, guid, name, value):
88         self._validate_guid(guid)
89         self._validate_factory_attribute(guid, name)
90         self._validate_factory_value(guid, name, value)
91         if guid not in self._factory_set:
92             self._factory_set[guid] = dict()
93         self._factory_set[guid][name] = value
94
95     def defer_connect(self, guid1, connector_type_name1, guid2, 
96             connector_type_name2):
97         self._validate_guid(guid1)
98         self._validate_guid(guid2)
99         factory1 = self._get_factory(guid1)
100         factory_id2 = self._create[guid2]
101         connector_type = factory1.connector_type(connector_type_name1)
102         connector_type.can_connect(self._testbed_id, factory_id2, 
103                 connector_type_name2, False)
104         self._validate_connection(guid1, connector_type_name1, guid2, 
105             connector_type_name2)
106
107         if not guid1 in self._connect:
108             self._connect[guid1] = dict()
109         if not connector_type_name1 in self._connect[guid1]:
110              self._connect[guid1][connector_type_name1] = dict()
111         self._connect[guid1][connector_type_name1][guid2] = \
112                connector_type_name2
113         if not guid2 in self._connect:
114             self._connect[guid2] = dict()
115         if not connector_type_name2 in self._connect[guid2]:
116              self._connect[guid2][connector_type_name2] = dict()
117         self._connect[guid2][connector_type_name2][guid1] = \
118                connector_type_name1
119
120     def defer_cross_connect(self, guid, connector_type_name, cross_guid, 
121             cross_testbed_guid, cross_testbed_id, cross_factory_id, 
122             cross_connector_type_name):
123         self._validate_guid(guid)
124         factory = self._get_factory(guid)
125         connector_type = factory.connector_type(connector_type_name)
126         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
127                 cross_connector_type_name, True)
128         self._validate_connection(guid, connector_type_name, cross_guid, 
129             cross_connector_type_name)
130
131         if not guid in self._cross_connect:
132             self._cross_connect[guid] = dict()
133         if not connector_type_name in self._cross_connect[guid]:
134              self._cross_connect[guid][connector_type_name] = dict()
135         self._cross_connect[guid][connector_type_name] = \
136                 (cross_guid, cross_testbed_guid, cross_testbed_id, 
137                 cross_factory_id, cross_connector_type_name)
138
139     def defer_add_trace(self, guid, trace_name):
140         self._validate_guid(guid)
141         self._validate_trace(guid, trace_name)
142         if not guid in self._add_trace:
143             self._add_trace[guid] = list()
144         self._add_trace[guid].append(trace_name)
145
146     def defer_add_address(self, guid, address, netprefix, broadcast):
147         self._validate_guid(guid)
148         self._validate_allow_addresses(guid)
149         if guid not in self._add_address:
150             self._add_address[guid] = list()
151         self._add_address[guid].append((address, netprefix, broadcast))
152
153     def defer_add_route(self, guid, destination, netprefix, nexthop, metric = 0):
154         self._validate_guid(guid)
155         self._validate_allow_routes(guid)
156         if not guid in self._add_route:
157             self._add_route[guid] = list()
158         self._add_route[guid].append((destination, netprefix, nexthop, metric)) 
159
160     def do_setup(self):
161         self._root_directory = self._attributes.\
162             get_attribute_value("rootDirectory")
163         self._status = TS.STATUS_SETUP
164
165     def do_create(self):
166         def set_params(self, guid):
167             parameters = self._get_parameters(guid)
168             for name, value in parameters.iteritems():
169                 self.set(guid, name, value)
170             
171         self._do_in_factory_order(
172             'create_function',
173             self._metadata.create_order,
174             postaction = set_params )
175         self._status = TS.STATUS_CREATED
176
177     def _do_connect(self, init = True):
178         unconnected = copy.deepcopy(self._connect)
179         
180         while unconnected:
181             for guid1, connections in unconnected.items():
182                 factory1 = self._get_factory(guid1)
183                 for connector_type_name1, connections2 in connections.items():
184                     connector_type1 = factory1.connector_type(connector_type_name1)
185                     for guid2, connector_type_name2 in connections2.items():
186                         factory_id2 = self._create[guid2]
187                         # Connections are executed in a "From -> To" direction only
188                         # This explicitly ignores the "To -> From" (mirror) 
189                         # connections of every connection pair.
190                         if init:
191                             connect_code = connector_type1.connect_to_init_code(
192                                     self._testbed_id, factory_id2, 
193                                     connector_type_name2,
194                                     False)
195                         else:
196                             connect_code = connector_type1.connect_to_compl_code(
197                                     self._testbed_id, factory_id2, 
198                                     connector_type_name2,
199                                     False)
200                         delay = None
201                         if connect_code:
202                             delay = connect_code(self, guid1, guid2)
203
204                         if delay is not CONNECTION_DELAY:
205                             del unconnected[guid1][connector_type_name1][guid2]
206                     if not unconnected[guid1][connector_type_name1]:
207                         del unconnected[guid1][connector_type_name1]
208                 if not unconnected[guid1]:
209                     del unconnected[guid1]
210
211     def do_connect_init(self):
212         self._do_connect()
213
214     def do_connect_compl(self):
215         self._do_connect(init = False)
216         self._status = TS.STATUS_CONNECTED
217
218     def _do_in_factory_order(self, action, order, postaction = None, poststep = None):
219         logger = self._logger
220         
221         guids = collections.defaultdict(list)
222         # order guids (elements) according to factory_id
223         for guid, factory_id in self._create.iteritems():
224             guids[factory_id].append(guid)
225         
226         # configure elements following the factory_id order
227         for factory_id in order:
228             # Create a parallel runner if we're given a Parallel() wrapper
229             runner = None
230             if isinstance(factory_id, Parallel):
231                 runner = ParallelRun(factory_id.maxthreads)
232                 factory_id = factory_id.factory
233             
234             # omit the factories that have no element to create
235             if factory_id not in guids:
236                 continue
237             
238             # configure action
239             factory = self._factories[factory_id]
240             if isinstance(action, basestring) and not getattr(factory, action):
241                 continue
242             def perform_action(guid):
243                 if isinstance(action, basestring):
244                     getattr(factory, action)(self, guid)
245                 else:
246                     action(self, guid)
247                 if postaction:
248                     postaction(self, guid)
249
250             # perform the action on all elements, in parallel if so requested
251             if runner:
252                 logger.debug("TestbedController: Starting parallel %s", action)
253                 runner.start()
254
255             for guid in guids[factory_id]:
256                 if runner:
257                     logger.debug("TestbedController: Scheduling %s on %s", action, guid)
258                     runner.put(perform_action, guid)
259                 else:
260                     logger.debug("TestbedController: Performing %s on %s", action, guid)
261                     perform_action(guid)
262
263             # sync
264             if runner:
265                 runner.sync()
266             
267             # post hook
268             if poststep:
269                 for guid in guids[factory_id]:
270                     if runner:
271                         logger.debug("TestbedController: Scheduling post-%s on %s", action, guid)
272                         runner.put(poststep, self, guid)
273                     else:
274                         logger.debug("TestbedController: Performing post-%s on %s", action, guid)
275                         poststep(self, guid)
276
277             # sync
278             if runner:
279                 runner.join()
280                 logger.debug("TestbedController: Finished parallel %s", action)
281
282     @staticmethod
283     def do_poststep_preconfigure(self, guid):
284         # dummy hook for implementations interested in
285         # two-phase configuration
286         pass
287
288     def do_preconfigure(self):
289         self._do_in_factory_order(
290             'preconfigure_function',
291             self._metadata.preconfigure_order,
292             poststep = self.do_poststep_preconfigure )
293
294     @staticmethod
295     def do_poststep_configure(self, guid):
296         # dummy hook for implementations interested in
297         # two-phase configuration
298         pass
299
300     def do_configure(self):
301         self._do_in_factory_order(
302             'configure_function',
303             self._metadata.configure_order,
304             poststep = self.do_poststep_configure )
305         self._status = TS.STATUS_CONFIGURED
306
307     def do_prestart(self):
308         self._do_in_factory_order(
309             'prestart_function',
310             self._metadata.prestart_order )
311
312     def _do_cross_connect(self, cross_data, init = True):
313         for guid, cross_connections in self._cross_connect.iteritems():
314             factory = self._get_factory(guid)
315             for connector_type_name, cross_connection in \
316                     cross_connections.iteritems():
317                 connector_type = factory.connector_type(connector_type_name)
318                 (cross_guid, cross_testbed_guid, cross_testbed_id,
319                     cross_factory_id, cross_connector_type_name) = cross_connection
320                 if init:
321                     connect_code = connector_type.connect_to_init_code(
322                         cross_testbed_id, cross_factory_id, 
323                         cross_connector_type_name,
324                         True)
325                 else:
326                     connect_code = connector_type.connect_to_compl_code(
327                         cross_testbed_id, cross_factory_id, 
328                         cross_connector_type_name,
329                         True)
330                 if connect_code:
331                     if hasattr(connect_code, "func"):
332                         func_name = connect_code.func.__name__
333                     elif hasattr(connect_code, "__name__"):
334                         func_name = connect_code.__name__
335                     else:
336                         func_name = repr(connect_code)
337                     self._logger.debug("Cross-connect - guid: %d, connect_code: %s " % (
338                         guid, func_name))
339                     elem_cross_data = cross_data[cross_testbed_guid][cross_guid]
340                     connect_code(self, guid, elem_cross_data)       
341
342     def do_cross_connect_init(self, cross_data):
343         self._do_cross_connect(cross_data)
344
345     def do_cross_connect_compl(self, cross_data):
346         self._do_cross_connect(cross_data, init = False)
347         self._status = TS.STATUS_CROSS_CONNECTED
348
349     def set(self, guid, name, value, time = TIME_NOW):
350         self._validate_guid(guid)
351         self._validate_box_attribute(guid, name)
352         self._validate_box_value(guid, name, value)
353         self._validate_modify_box_value(guid, name)
354         if guid not in self._set:
355             self._set[guid] = dict()
356             self._setlog[guid] = dict()
357         if time not in self._setlog[guid]:
358             self._setlog[guid][time] = dict()
359         self._setlog[guid][time][name] = value
360         self._set[guid][name] = value
361
362     def get(self, guid, name, time = TIME_NOW):
363         """
364         gets an attribute from box definitions if available. 
365         Throws KeyError if the GUID wasn't created
366         through the defer_create interface, and AttributeError if the
367         attribute isn't available (doesn't exist or is design-only)
368         """
369         self._validate_guid(guid)
370         self._validate_box_attribute(guid, name)
371         if guid in self._set and name in self._set[guid]:
372             return self._set[guid][name]
373         if guid in self._create_set and name in self._create_set[guid]:
374             return self._create_set[guid][name]
375         # if nothing else found, returns the factory default value
376         factory = self._get_factory(guid)
377         return factory.box_attributes.get_attribute_value(name)
378
379     def get_route(self, guid, index, attribute):
380         """
381         returns information given to defer_add_route.
382         
383         Raises AttributeError if an invalid attribute is requested
384             or if the indexed routing rule does not exist.
385         
386         Raises KeyError if the GUID has not been seen by
387             defer_add_route
388         """
389         ATTRIBUTES = ['Destination', 'NetPrefix', 'NextHop']
390         
391         if attribute not in ATTRIBUTES:
392             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
393         
394         attribute_index = ATTRIBUTES.index(attribute)
395         
396         routes = self._add_route.get(guid)
397         if not routes:
398             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
399        
400         index = int(index)
401         if not (0 <= index < len(addresses)):
402             raise AttributeError, "GUID %r at %s does not have a routing entry #%s" % (
403                 guid, self._testbed_id, index)
404         
405         return routes[index][attribute_index]
406
407     def get_address(self, guid, index, attribute='Address'):
408         """
409         returns information given to defer_add_address
410         
411         Raises AttributeError if an invalid attribute is requested
412             or if the indexed routing rule does not exist.
413         
414         Raises KeyError if the GUID has not been seen by
415             defer_add_address
416         """
417         ATTRIBUTES = ['Address', 'NetPrefix', 'Broadcast']
418         
419         if attribute not in ATTRIBUTES:
420             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
421         
422         attribute_index = ATTRIBUTES.index(attribute)
423         
424         addresses = self._add_address.get(guid)
425         if not addresses:
426             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
427         
428         index = int(index)
429         if not (0 <= index < len(addresses)):
430             raise AttributeError, "GUID %r at %s does not have an address #%s" % (
431                 guid, self._testbed_id, index)
432         
433         return addresses[index][attribute_index]
434
435     def get_attribute_list(self, guid, filter_flags = None, exclude = False):
436         factory = self._get_factory(guid)
437         attribute_list = list()
438         return factory.box_attributes.get_attribute_list(filter_flags, exclude)
439
440     def get_factory_id(self, guid):
441         factory = self._get_factory(guid)
442         return factory.factory_id
443
444     def start(self, time = TIME_NOW):
445         self._do_in_factory_order(
446             'start_function',
447             self._metadata.start_order )
448         self._status = TS.STATUS_STARTED
449
450     #action: NotImplementedError
451
452     def stop(self, time = TIME_NOW):
453         self._do_in_factory_order(
454             'stop_function',
455             reversed(self._metadata.start_order) )
456         self._status = TS.STATUS_STOPPED
457
458     def status(self, guid = None):
459         if not guid:
460             return self._status
461         self._validate_guid(guid)
462         factory = self._get_factory(guid)
463         status_function = factory.status_function
464         if status_function:
465             return status_function(self, guid)
466         return AS.STATUS_UNDETERMINED
467     
468     def testbed_status(self):
469         return self._status
470
471     def trace(self, guid, trace_id, attribute='value'):
472         if attribute == 'value':
473             fd = open("%s" % self.trace_filepath(guid, trace_id), "r")
474             content = fd.read()
475             fd.close()
476         elif attribute == 'path':
477             content = self.trace_filepath(guid, trace_id)
478         elif attribute == 'filename':
479             content = self.trace_filename(guid, trace_id)
480         else:
481             content = None
482         return content
483
484     def traces_info(self):
485         traces_info = dict()
486         host = self._attributes.get_attribute_value("deployment_host")
487         user = self._attributes.get_attribute_value("deployment_user")
488         for guid, trace_list in self._add_trace.iteritems(): 
489             traces_info[guid] = dict()
490             for trace_id in trace_list:
491                 traces_info[guid][trace_id] = dict()
492                 filepath = self.trace(guid, trace_id, attribute = "path")
493                 traces_info[guid][trace_id]["host"] = host
494                 traces_info[guid][trace_id]["user"] = user
495                 traces_info[guid][trace_id]["filepath"] = filepath
496         return traces_info
497
498     def trace_filepath(self, guid, trace_id):
499         """
500         Return a trace's file path, for TestbedController's default 
501         implementation of trace()
502         """
503         raise NotImplementedError
504
505     def trace_filename(self, guid, trace_id):
506         """
507         Return a trace's file name, for TestbedController's default 
508         implementation of trace()
509         """
510         raise NotImplementedError
511
512     #shutdown: NotImplementedError
513
514     def get_connected(self, guid, connector_type_name, 
515             other_connector_type_name):
516         """searchs the connected elements for the specific connector_type_name 
517         pair"""
518         if guid not in self._connect:
519             return []
520         # all connections for all connectors for guid
521         all_connections = self._connect[guid]
522         if connector_type_name not in all_connections:
523             return []
524         # all connections for the specific connector
525         connections = all_connections[connector_type_name]
526         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
527                 in connections.iteritems() if \
528                 otr_connector_type_name == other_connector_type_name]
529         return specific_connections
530
531     def _get_connection_count(self, guid, connection_type_name):
532         count = 0
533         cross_count = 0
534         if guid in self._connect and connection_type_name in \
535                 self._connect[guid]:
536             count = len(self._connect[guid][connection_type_name])
537         if guid in self._cross_connect and connection_type_name in \
538                 self._cross_connect[guid]:
539             cross_count = len(self._cross_connect[guid][connection_type_name])
540         return count + cross_count
541
542     def _get_traces(self, guid):
543         return [] if guid not in self._add_trace else self._add_trace[guid]
544
545     def _get_parameters(self, guid):
546         return dict() if guid not in self._create_set else \
547                 self._create_set[guid]
548
549     def _get_factory(self, guid):
550         factory_id = self._create[guid]
551         return self._factories[factory_id]
552
553     def _get_factory_id(self, guid):
554         """ Returns the factory ID of the (perhaps not yet) created object """
555         return self._create.get(guid, None)
556
557     def _validate_guid(self, guid):
558         if not guid in self._create:
559             raise RuntimeError("Element guid %d doesn't exist" % guid)
560
561     def _validate_not_guid(self, guid):
562         if guid in self._create:
563             raise AttributeError("Cannot add elements with the same guid: %d" %
564                     guid)
565
566     def _validate_factory_id(self, factory_id):
567         if factory_id not in self._factories:
568             raise AttributeError("Invalid element type %s for testbed version %s" %
569                     (factory_id, self._testbed_version))
570
571     def _validate_testbed_attribute(self, name):
572         if not self._attributes.has_attribute(name):
573             raise AttributeError("Invalid testbed attribute %s for testbed" % \
574                     name)
575
576     def _validate_testbed_value(self, name, value):
577         if not self._attributes.is_attribute_value_valid(name, value):
578             raise AttributeError("Invalid value %r for testbed attribute %s" % \
579                 (value, name))
580
581     def _validate_box_attribute(self, guid, name):
582         factory = self._get_factory(guid)
583         if not factory.box_attributes.has_attribute(name):
584             raise AttributeError("Invalid attribute %s for element type %s" %
585                     (name, factory.factory_id))
586
587     def _validate_box_value(self, guid, name, value):
588         factory = self._get_factory(guid)
589         if not factory.box_attributes.is_attribute_value_valid(name, value):
590             raise AttributeError("Invalid value %r for attribute %s" % \
591                 (value, name))
592
593     def _validate_factory_attribute(self, guid, name):
594         factory = self._get_factory(guid)
595         if not factory.has_attribute(name):
596             raise AttributeError("Invalid attribute %s for element type %s" %
597                     (name, factory.factory_id))
598
599     def _validate_factory_value(self, guid, name, value):
600         factory = self._get_factory(guid)
601         if not factory.is_attribute_value_valid(name, value):
602             raise AttributeError("Invalid value %r for attribute %s" % \
603                 (value, name))
604
605     def _validate_trace(self, guid, trace_name):
606         factory = self._get_factory(guid)
607         if not trace_name in factory.traces_list:
608             raise RuntimeError("Element type '%s' has no trace '%s'" %
609                     (factory.factory_id, trace_name))
610
611     def _validate_allow_addresses(self, guid):
612         factory = self._get_factory(guid)
613         if not factory.allow_addresses:
614             raise RuntimeError("Element type '%s' doesn't support addresses" %
615                     factory.factory_id)
616         attr_name = "maxAddresses"
617         if guid in self._create_set and attr_name in self._create_set[guid]:
618             max_addresses = self._create_set[guid][attr_name]
619         else:
620             factory = self._get_factory(guid)
621             max_addresses = factory.box_attributes.get_attribute_value(attr_name)
622         if guid in self._add_address:
623             count_addresses = len(self._add_address[guid])
624             if max_addresses == count_addresses:
625                 raise RuntimeError("Element guid %d of type '%s' can't accept \
626                         more addresses" % (guid, factory.factory_id))
627
628     def _validate_allow_routes(self, guid):
629         factory = self._get_factory(guid)
630         if not factory.allow_routes:
631             raise RuntimeError("Element type '%s' doesn't support routes" %
632                     factory.factory_id)
633
634     def _validate_connection(self, guid1, connector_type_name1, guid2, 
635             connector_type_name2, cross = False):
636         # can't connect with self
637         if guid1 == guid2:
638             raise AttributeError("Can't connect guid %d to self" % \
639                 (guid1))
640         # the connection is already done, so ignore
641         connected = self.get_connected(guid1, connector_type_name1, 
642                 connector_type_name2)
643         if guid2 in connected:
644             return
645         count1 = self._get_connection_count(guid1, connector_type_name1)
646         factory1 = self._get_factory(guid1)
647         connector_type1 = factory1.connector_type(connector_type_name1)
648         if count1 == connector_type1.max:
649             raise AttributeError("Connector %s is full for guid %d" % \
650                 (connector_type_name1, guid1))
651
652     def _validate_modify_box_value(self, guid, name):
653         factory = self._get_factory(guid)
654         if self._status > TS.STATUS_STARTED and \
655                 (factory.box_attributes.is_attribute_exec_read_only(name) or \
656                 factory.box_attributes.is_attribute_exec_immutable(name)):
657             raise AttributeError("Attribute %s can only be modified during experiment design" % name)
658