Various fixes:
[nepi.git] / src / nepi / core / testbed_impl.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from nepi.core import execute
5 from nepi.core.metadata import Metadata, Parallel
6 from nepi.util import validation
7 from nepi.util.constants import TIME_NOW, \
8         ApplicationStatus as AS, \
9         TestbedStatus as TS, \
10         CONNECTION_DELAY
11 from nepi.util.parallel import ParallelRun
12
13 import collections
14 import copy
15 import logging
16
17 class TestbedController(execute.TestbedController):
18     def __init__(self, testbed_id, testbed_version):
19         super(TestbedController, self).__init__(testbed_id, testbed_version)
20         self._status = TS.STATUS_ZERO
21         # testbed attributes for validation
22         self._attributes = None
23         # element factories for validation
24         self._factories = dict()
25
26         # experiment construction instructions
27         self._create = dict()
28         self._create_set = dict()
29         self._factory_set = dict()
30         self._connect = dict()
31         self._cross_connect = dict()
32         self._add_trace = dict()
33         self._add_address = dict()
34         self._add_route = dict()
35         self._configure = dict()
36
37         # log of set operations
38         self._setlog = dict()
39         # last set operations
40         self._set = dict()
41
42         # testbed element instances
43         self._elements = dict()
44
45         self._metadata = Metadata(self._testbed_id)
46         if self._metadata.testbed_version != testbed_version:
47             raise RuntimeError("Bad testbed version on testbed %s. Asked for %s, got %s" % \
48                     (testbed_id, testbed_version, self._metadata.testbed_version))
49         for factory in self._metadata.build_factories():
50             self._factories[factory.factory_id] = factory
51         self._attributes = self._metadata.testbed_attributes()
52         self._root_directory = None
53         
54         # Logging
55         self._logger = logging.getLogger("nepi.core.testbed_impl")
56     
57     @property
58     def root_directory(self):
59         return self._root_directory
60
61     @property
62     def guids(self):
63         return self._create.keys()
64
65     @property
66     def elements(self):
67         return self._elements
68     
69     def defer_configure(self, name, value):
70         self._validate_testbed_attribute(name)
71         self._validate_testbed_value(name, value)
72         self._attributes.set_attribute_value(name, value)
73         self._configure[name] = value
74
75     def defer_create(self, guid, factory_id):
76         self._validate_factory_id(factory_id)
77         self._validate_not_guid(guid)
78         self._create[guid] = factory_id
79
80     def defer_create_set(self, guid, name, value):
81         self._validate_guid(guid)
82         self._validate_box_attribute(guid, name)
83         self._validate_box_value(guid, name, value)
84         if guid not in self._create_set:
85             self._create_set[guid] = dict()
86         self._create_set[guid][name] = value
87
88     def defer_factory_set(self, guid, name, value):
89         self._validate_guid(guid)
90         self._validate_factory_attribute(guid, name)
91         self._validate_factory_value(guid, name, value)
92         if guid not in self._factory_set:
93             self._factory_set[guid] = dict()
94         self._factory_set[guid][name] = value
95
96     def defer_connect(self, guid1, connector_type_name1, guid2, 
97             connector_type_name2):
98         self._validate_guid(guid1)
99         self._validate_guid(guid2)
100         factory1 = self._get_factory(guid1)
101         factory_id2 = self._create[guid2]
102         connector_type = factory1.connector_type(connector_type_name1)
103         connector_type.can_connect(self._testbed_id, factory_id2, 
104                 connector_type_name2, False)
105         self._validate_connection(guid1, connector_type_name1, guid2, 
106             connector_type_name2)
107
108         if not guid1 in self._connect:
109             self._connect[guid1] = dict()
110         if not connector_type_name1 in self._connect[guid1]:
111              self._connect[guid1][connector_type_name1] = dict()
112         self._connect[guid1][connector_type_name1][guid2] = \
113                connector_type_name2
114         if not guid2 in self._connect:
115             self._connect[guid2] = dict()
116         if not connector_type_name2 in self._connect[guid2]:
117              self._connect[guid2][connector_type_name2] = dict()
118         self._connect[guid2][connector_type_name2][guid1] = \
119                connector_type_name1
120
121     def defer_cross_connect(self, guid, connector_type_name, cross_guid, 
122             cross_testbed_guid, cross_testbed_id, cross_factory_id, 
123             cross_connector_type_name):
124         self._validate_guid(guid)
125         factory = self._get_factory(guid)
126         connector_type = factory.connector_type(connector_type_name)
127         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
128                 cross_connector_type_name, True)
129         self._validate_connection(guid, connector_type_name, cross_guid, 
130             cross_connector_type_name)
131
132         if not guid in self._cross_connect:
133             self._cross_connect[guid] = dict()
134         if not connector_type_name in self._cross_connect[guid]:
135              self._cross_connect[guid][connector_type_name] = dict()
136         self._cross_connect[guid][connector_type_name] = \
137                 (cross_guid, cross_testbed_guid, cross_testbed_id, 
138                 cross_factory_id, cross_connector_type_name)
139
140     def defer_add_trace(self, guid, trace_name):
141         self._validate_guid(guid)
142         self._validate_trace(guid, trace_name)
143         if not guid in self._add_trace:
144             self._add_trace[guid] = list()
145         self._add_trace[guid].append(trace_name)
146
147     def defer_add_address(self, guid, address, netprefix, broadcast):
148         self._validate_guid(guid)
149         self._validate_allow_addresses(guid)
150         if guid not in self._add_address:
151             self._add_address[guid] = list()
152         self._add_address[guid].append((address, netprefix, broadcast))
153
154     def defer_add_route(self, guid, destination, netprefix, nexthop, metric = 0):
155         self._validate_guid(guid)
156         self._validate_allow_routes(guid)
157         if not guid in self._add_route:
158             self._add_route[guid] = list()
159         self._add_route[guid].append((destination, netprefix, nexthop, metric)) 
160
161     def do_setup(self):
162         self._root_directory = self._attributes.\
163             get_attribute_value("rootDirectory")
164         self._status = TS.STATUS_SETUP
165
166     def do_create(self):
167         def set_params(self, guid):
168             parameters = self._get_parameters(guid)
169             for name, value in parameters.iteritems():
170                 self.set(guid, name, value)
171             
172         self._do_in_factory_order(
173             'create_function',
174             self._metadata.create_order,
175             postaction = set_params )
176         self._status = TS.STATUS_CREATED
177
178     def _do_connect(self, init = True):
179         unconnected = copy.deepcopy(self._connect)
180         
181         while unconnected:
182             for guid1, connections in unconnected.items():
183                 factory1 = self._get_factory(guid1)
184                 for connector_type_name1, connections2 in connections.items():
185                     connector_type1 = factory1.connector_type(connector_type_name1)
186                     for guid2, connector_type_name2 in connections2.items():
187                         factory_id2 = self._create[guid2]
188                         # Connections are executed in a "From -> To" direction only
189                         # This explicitly ignores the "To -> From" (mirror) 
190                         # connections of every connection pair.
191                         if init:
192                             connect_code = connector_type1.connect_to_init_code(
193                                     self._testbed_id, factory_id2, 
194                                     connector_type_name2,
195                                     False)
196                         else:
197                             connect_code = connector_type1.connect_to_compl_code(
198                                     self._testbed_id, factory_id2, 
199                                     connector_type_name2,
200                                     False)
201                         delay = None
202                         if connect_code:
203                             delay = connect_code(self, guid1, guid2)
204
205                         if delay is not CONNECTION_DELAY:
206                             del unconnected[guid1][connector_type_name1][guid2]
207                     if not unconnected[guid1][connector_type_name1]:
208                         del unconnected[guid1][connector_type_name1]
209                 if not unconnected[guid1]:
210                     del unconnected[guid1]
211
212     def do_connect_init(self):
213         self._do_connect()
214
215     def do_connect_compl(self):
216         self._do_connect(init = False)
217         self._status = TS.STATUS_CONNECTED
218
219     def _do_in_factory_order(self, action, order, postaction = None, poststep = None):
220         logger = self._logger
221         
222         guids = collections.defaultdict(list)
223         # order guids (elements) according to factory_id
224         for guid, factory_id in self._create.iteritems():
225             guids[factory_id].append(guid)
226         
227         # configure elements following the factory_id order
228         for factory_id in order:
229             # Create a parallel runner if we're given a Parallel() wrapper
230             runner = None
231             if isinstance(factory_id, Parallel):
232                 runner = ParallelRun(factory_id.maxthreads)
233                 factory_id = factory_id.factory
234             
235             # omit the factories that have no element to create
236             if factory_id not in guids:
237                 continue
238             
239             # configure action
240             factory = self._factories[factory_id]
241             if not getattr(factory, action):
242                 continue
243             def perform_action(guid):
244                 getattr(factory, action)(self, guid)
245                 if postaction:
246                     postaction(self, guid)
247
248             # perform the action on all elements, in parallel if so requested
249             if runner:
250                 logger.debug("Starting parallel %s", action)
251                 runner.start()
252
253             for guid in guids[factory_id]:
254                 if runner:
255                     logger.debug("Scheduling %s on %s", action, guid)
256                     runner.put(perform_action, guid)
257                 else:
258                     logger.debug("Performing %s on %s", action, guid)
259                     perform_action(guid)
260
261             # sync
262             if runner:
263                 runner.sync()
264             
265             # post hook
266             if poststep:
267                 for guid in guids[factory_id]:
268                     if runner:
269                         logger.debug("Scheduling post-%s on %s", action, guid)
270                         runner.put(poststep, self, guid)
271                     else:
272                         logger.debug("Performing post-%s on %s", action, guid)
273                         poststep(self, guid)
274
275             # sync
276             if runner:
277                 runner.join()
278                 logger.debug("Finished parallel %s", action)
279
280     @staticmethod
281     def do_poststep_preconfigure(self, guid):
282         # dummy hook for implementations interested in
283         # two-phase configuration
284         pass
285
286     def do_preconfigure(self):
287         self._do_in_factory_order(
288             'preconfigure_function',
289             self._metadata.preconfigure_order,
290             poststep = self.do_poststep_preconfigure )
291
292     @staticmethod
293     def do_poststep_configure(self, guid):
294         # dummy hook for implementations interested in
295         # two-phase configuration
296         pass
297
298     def do_configure(self):
299         self._do_in_factory_order(
300             'configure_function',
301             self._metadata.configure_order,
302             poststep = self.do_poststep_configure )
303         self._status = TS.STATUS_CONFIGURED
304
305     def do_prestart(self):
306         self._do_in_factory_order(
307             'prestart_function',
308             self._metadata.prestart_order )
309
310     def _do_cross_connect(self, cross_data, init = True):
311         for guid, cross_connections in self._cross_connect.iteritems():
312             factory = self._get_factory(guid)
313             for connector_type_name, cross_connection in \
314                     cross_connections.iteritems():
315                 connector_type = factory.connector_type(connector_type_name)
316                 (cross_guid, cross_testbed_guid, cross_testbed_id,
317                     cross_factory_id, cross_connector_type_name) = cross_connection
318                 if init:
319                     connect_code = connector_type.connect_to_init_code(
320                         cross_testbed_id, cross_factory_id, 
321                         cross_connector_type_name,
322                         True)
323                 else:
324                     connect_code = connector_type.connect_to_compl_code(
325                         cross_testbed_id, cross_factory_id, 
326                         cross_connector_type_name,
327                         True)
328                 if connect_code:
329                     elem_cross_data = cross_data[cross_testbed_guid][cross_guid]
330                     connect_code(self, guid, elem_cross_data)       
331
332     def do_cross_connect_init(self, cross_data):
333         self._do_cross_connect(cross_data)
334
335     def do_cross_connect_compl(self, cross_data):
336         self._do_cross_connect(cross_data, init = False)
337         self._status = TS.STATUS_CROSS_CONNECTED
338
339     def set(self, guid, name, value, time = TIME_NOW):
340         self._validate_guid(guid)
341         self._validate_box_attribute(guid, name)
342         self._validate_box_value(guid, name, value)
343         self._validate_modify_box_value(guid, name)
344         if guid not in self._set:
345             self._set[guid] = dict()
346             self._setlog[guid] = dict()
347         if time not in self._setlog[guid]:
348             self._setlog[guid][time] = dict()
349         self._setlog[guid][time][name] = value
350         self._set[guid][name] = value
351
352     def get(self, guid, name, time = TIME_NOW):
353         """
354         gets an attribute from box definitions if available. 
355         Throws KeyError if the GUID wasn't created
356         through the defer_create interface, and AttributeError if the
357         attribute isn't available (doesn't exist or is design-only)
358         """
359         self._validate_guid(guid)
360         self._validate_box_attribute(guid, name)
361         if guid in self._set and name in self._set[guid]:
362             return self._set[guid][name]
363         if guid in self._create_set and name in self._create_set[guid]:
364             return self._create_set[guid][name]
365         # if nothing else found, returns the factory default value
366         factory = self._get_factory(guid)
367         return factory.box_attributes.get_attribute_value(name)
368
369     def get_route(self, guid, index, attribute):
370         """
371         returns information given to defer_add_route.
372         
373         Raises AttributeError if an invalid attribute is requested
374             or if the indexed routing rule does not exist.
375         
376         Raises KeyError if the GUID has not been seen by
377             defer_add_route
378         """
379         ATTRIBUTES = ['Destination', 'NetPrefix', 'NextHop']
380         
381         if attribute not in ATTRIBUTES:
382             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
383         
384         attribute_index = ATTRIBUTES.index(attribute)
385         
386         routes = self._add_route.get(guid)
387         if not routes:
388             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
389        
390         index = int(index)
391         if not (0 <= index < len(addresses)):
392             raise AttributeError, "GUID %r at %s does not have a routing entry #%s" % (
393                 guid, self._testbed_id, index)
394         
395         return routes[index][attribute_index]
396
397     def get_address(self, guid, index, attribute='Address'):
398         """
399         returns information given to defer_add_address
400         
401         Raises AttributeError if an invalid attribute is requested
402             or if the indexed routing rule does not exist.
403         
404         Raises KeyError if the GUID has not been seen by
405             defer_add_address
406         """
407         ATTRIBUTES = ['Address', 'NetPrefix', 'Broadcast']
408         
409         if attribute not in ATTRIBUTES:
410             raise AttributeError, "Attribute %r invalid for addresses of %r" % (attribute, guid)
411         
412         attribute_index = ATTRIBUTES.index(attribute)
413         
414         addresses = self._add_address.get(guid)
415         if not addresses:
416             raise KeyError, "GUID %r not found in %s" % (guid, self._testbed_id)
417         
418         index = int(index)
419         if not (0 <= index < len(addresses)):
420             raise AttributeError, "GUID %r at %s does not have an address #%s" % (
421                 guid, self._testbed_id, index)
422         
423         return addresses[index][attribute_index]
424
425     def get_attribute_list(self, guid, filter_flags = None, exclude = False):
426         factory = self._get_factory(guid)
427         attribute_list = list()
428         return factory.box_attributes.get_attribute_list(filter_flags, exclude)
429
430     def get_factory_id(self, guid):
431         factory = self._get_factory(guid)
432         return factory.factory_id
433
434     def start(self, time = TIME_NOW):
435         self._do_in_factory_order(
436             'start_function',
437             self._metadata.start_order )
438         self._status = TS.STATUS_STARTED
439
440     #action: NotImplementedError
441
442     def stop(self, time = TIME_NOW):
443         self._do_in_factory_order(
444             'stop_function',
445             reversed(self._metadata.start_order) )
446         self._status = TS.STATUS_STOPPED
447
448     def status(self, guid = None):
449         if not guid:
450             return self._status
451         self._validate_guid(guid)
452         factory = self._get_factory(guid)
453         status_function = factory.status_function
454         if status_function:
455             return status_function(self, guid)
456         return AS.STATUS_UNDETERMINED
457
458     def trace(self, guid, trace_id, attribute='value'):
459         if attribute == 'value':
460             fd = open("%s" % self.trace_filepath(guid, trace_id), "r")
461             content = fd.read()
462             fd.close()
463         elif attribute == 'path':
464             content = self.trace_filepath(guid, trace_id)
465         else:
466             content = None
467         return content
468
469     def traces_info(self):
470         traces_info = dict()
471         host = self._attributes.get_attribute_value("deployment_host")
472         user = self._attributes.get_attribute_value("deployment_user")
473         for guid, trace_list in self._add_trace.iteritems(): 
474             traces_info[guid] = dict()
475             for trace_id in trace_list:
476                 traces_info[guid][trace_id] = dict()
477                 filepath = self.trace(guid, trace_id, attribute = "path")
478                 traces_info[guid][trace_id]["host"] = host
479                 traces_info[guid][trace_id]["user"] = user
480                 traces_info[guid][trace_id]["filepath"] = filepath
481         return traces_info
482
483     def trace_filepath(self, guid, trace_id):
484         """
485         Return a trace's file path, for TestbedController's default 
486         implementation of trace()
487         """
488         raise NotImplementedError
489
490     #shutdown: NotImplementedError
491
492     def get_connected(self, guid, connector_type_name, 
493             other_connector_type_name):
494         """searchs the connected elements for the specific connector_type_name 
495         pair"""
496         if guid not in self._connect:
497             return []
498         # all connections for all connectors for guid
499         all_connections = self._connect[guid]
500         if connector_type_name not in all_connections:
501             return []
502         # all connections for the specific connector
503         connections = all_connections[connector_type_name]
504         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
505                 in connections.iteritems() if \
506                 otr_connector_type_name == other_connector_type_name]
507         return specific_connections
508
509     def _get_connection_count(self, guid, connection_type_name):
510         count = 0
511         cross_count = 0
512         if guid in self._connect and connection_type_name in \
513                 self._connect[guid]:
514             count = len(self._connect[guid][connection_type_name])
515         if guid in self._cross_connect and connection_type_name in \
516                 self._cross_connect[guid]:
517             cross_count = len(self._cross_connect[guid][connection_type_name])
518         return count + cross_count
519
520     def _get_traces(self, guid):
521         return [] if guid not in self._add_trace else self._add_trace[guid]
522
523     def _get_parameters(self, guid):
524         return dict() if guid not in self._create_set else \
525                 self._create_set[guid]
526
527     def _get_factory(self, guid):
528         factory_id = self._create[guid]
529         return self._factories[factory_id]
530
531     def _get_factory_id(self, guid):
532         """ Returns the factory ID of the (perhaps not yet) created object """
533         return self._create.get(guid, None)
534
535     def _validate_guid(self, guid):
536         if not guid in self._create:
537             raise RuntimeError("Element guid %d doesn't exist" % guid)
538
539     def _validate_not_guid(self, guid):
540         if guid in self._create:
541             raise AttributeError("Cannot add elements with the same guid: %d" %
542                     guid)
543
544     def _validate_factory_id(self, factory_id):
545         if factory_id not in self._factories:
546             raise AttributeError("Invalid element type %s for testbed version %s" %
547                     (factory_id, self._testbed_version))
548
549     def _validate_testbed_attribute(self, name):
550         if not self._attributes.has_attribute(name):
551             raise AttributeError("Invalid testbed attribute %s for testbed" % \
552                     name)
553
554     def _validate_testbed_value(self, name, value):
555         if not self._attributes.is_attribute_value_valid(name, value):
556             raise AttributeError("Invalid value %s for testbed attribute %s" % \
557                 (value, name))
558
559     def _validate_box_attribute(self, guid, name):
560         factory = self._get_factory(guid)
561         if not factory.box_attributes.has_attribute(name):
562             raise AttributeError("Invalid attribute %s for element type %s" %
563                     (name, factory.factory_id))
564
565     def _validate_box_value(self, guid, name, value):
566         factory = self._get_factory(guid)
567         if not factory.box_attributes.is_attribute_value_valid(name, value):
568             raise AttributeError("Invalid value %s for attribute %s" % \
569                 (value, name))
570
571     def _validate_factory_attribute(self, guid, name):
572         factory = self._get_factory(guid)
573         if not factory.has_attribute(name):
574             raise AttributeError("Invalid attribute %s for element type %s" %
575                     (name, factory.factory_id))
576
577     def _validate_factory_value(self, guid, name, value):
578         factory = self._get_factory(guid)
579         if not factory.is_attribute_value_valid(name, value):
580             raise AttributeError("Invalid value %s for attribute %s" % \
581                 (value, name))
582
583     def _validate_trace(self, guid, trace_name):
584         factory = self._get_factory(guid)
585         if not trace_name in factory.traces_list:
586             raise RuntimeError("Element type '%s' has no trace '%s'" %
587                     (factory.factory_id, trace_name))
588
589     def _validate_allow_addresses(self, guid):
590         factory = self._get_factory(guid)
591         if not factory.allow_addresses:
592             raise RuntimeError("Element type '%s' doesn't support addresses" %
593                     factory.factory_id)
594         attr_name = "maxAddresses"
595         if guid in self._create_set and attr_name in self._create_set[guid]:
596             max_addresses = self._create_set[guid][attr_name]
597         else:
598             factory = self._get_factory(guid)
599             max_addresses = factory.box_attributes.get_attribute_value(attr_name)
600         if guid in self._add_address:
601             count_addresses = len(self._add_address[guid])
602             if max_addresses == count_addresses:
603                 raise RuntimeError("Element guid %d of type '%s' can't accept \
604                         more addresses" % (guid, factory.factory_id))
605
606     def _validate_allow_routes(self, guid):
607         factory = self._get_factory(guid)
608         if not factory.allow_routes:
609             raise RuntimeError("Element type '%s' doesn't support routes" %
610                     factory.factory_id)
611
612     def _validate_connection(self, guid1, connector_type_name1, guid2, 
613             connector_type_name2, cross = False):
614         # can't connect with self
615         if guid1 == guid2:
616             raise AttributeError("Can't connect guid %d to self" % \
617                 (guid1))
618         # the connection is already done, so ignore
619         connected = self.get_connected(guid1, connector_type_name1, 
620                 connector_type_name2)
621         if guid2 in connected:
622             return
623         count1 = self._get_connection_count(guid1, connector_type_name1)
624         factory1 = self._get_factory(guid1)
625         connector_type1 = factory1.connector_type(connector_type_name1)
626         if count1 == connector_type1.max:
627             raise AttributeError("Connector %s is full for guid %d" % \
628                 (connector_type_name1, guid1))
629
630     def _validate_modify_box_value(self, guid, name):
631         factory = self._get_factory(guid)
632         if self._status > TS.STATUS_STARTED and \
633                 (factory.box_attributes.is_attribute_exec_read_only(name) or \
634                 factory.box_attributes.is_attribute_exec_immutable(name)):
635             raise AttributeError("Attribute %s can only be modified during experiment design" % name)
636