Attribute flags changed to bit flag system
[nepi.git] / src / nepi / core / testbed_impl.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from nepi.core import execute
5 from nepi.core.metadata import Metadata
6 from nepi.util import validation
7 from nepi.util.constants import AF_INET, AF_INET6, STATUS_UNDETERMINED
8
9 TIME_NOW = "0s"
10
11 class TestbedInstance(execute.TestbedInstance):
12     def __init__(self, testbed_id, testbed_version):
13         super(TestbedInstance, self).__init__(testbed_id, testbed_version)
14         self._started = False
15         # testbed attributes for validation
16         self._attributes = None
17         # element factories for validation
18         self._factories = dict()
19
20         # experiment construction instructions
21         self._create = dict()
22         self._create_set = dict()
23         self._connect = dict()
24         self._cross_connect = dict()
25         self._add_trace = dict()
26         self._add_address = dict()
27         self._add_route = dict()
28         self._configure = dict()
29
30         # log of set operations
31         self._set = dict()
32         # log of actions
33         self._actions = dict()
34
35         # testbed element instances
36         self._elements = dict()
37
38         self._metadata = Metadata(self._testbed_id, self._testbed_version)
39         for factory in self._metadata.build_execute_factories():
40             self._factories[factory.factory_id] = factory
41         self._attributes = self._metadata.testbed_attributes()
42
43     @property
44     def guids(self):
45         return self._create.keys()
46
47     @property
48     def elements(self):
49         return self._elements
50
51     def configure(self, name, value):
52         if not self._attributes.has_attribute(name):
53             raise RuntimeError("Invalid attribute %s for testbed" % name)
54         # Validation
55         self._attributes.set_attribute_value(name, value)
56         self._configure[name] = value
57
58     def create(self, guid, factory_id):
59         if factory_id not in self._factories:
60             raise RuntimeError("Invalid element type %s for Netns version %s" %
61                     (factory_id, self._testbed_version))
62         if guid in self._create:
63             raise RuntimeError("Cannot add elements with the same guid: %d" %
64                     guid)
65         self._create[guid] = factory_id
66
67     def create_set(self, guid, name, value):
68         if not guid in self._create:
69             raise RuntimeError("Element guid %d doesn't exist" % guid)
70         factory_id = self._create[guid]
71         factory = self._factories[factory_id]
72         if not factory.has_attribute(name):
73             raise RuntimeError("Invalid attribute %s for element type %s" %
74                     (name, factory_id))
75         factory.set_attribute_value(name, value)
76         if guid not in self._create_set:
77             self._create_set[guid] = dict()
78         self._create_set[guid][name] = value
79        
80     def connect(self, guid1, connector_type_name1, guid2, 
81             connector_type_name2):
82         factory_id1 = self._create[guid1]
83         factory_id2 = self._create[guid2]
84         count = self._get_connection_count(guid1, connector_type_name1)
85         factory1 = self._factories[factory_id1]
86         connector_type = factory1.connector_type(connector_type_name1)
87         connector_type.can_connect(self._testbed_id, factory_id2, 
88                 connector_type_name2, count)
89         if not guid1 in self._connect:
90             self._connect[guid1] = dict()
91         if not connector_type_name1 in self._connect[guid1]:
92              self._connect[guid1][connector_type_name1] = dict()
93         self._connect[guid1][connector_type_name1][guid2] = \
94                connector_type_name2
95         if not guid2 in self._connect:
96             self._connect[guid2] = dict()
97         if not connector_type_name2 in self._connect[guid2]:
98              self._connect[guid2][connector_type_name2] = dict()
99         self._connect[guid2][connector_type_name2][guid1] = \
100                 connector_type_name1
101
102     def cross_connect(self, guid, connector_type_name, cross_guid, 
103             cross_testbed_id, cross_factory_id, cross_connector_type_name):
104         factory_id = self._create[guid]
105         count = self._get_connection_count(guid, connector_type_name)
106         factory = self._factories[factory_id]
107         connector_type = factory.connector_type(connector_type_name)
108         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
109                 cross_connector_type_name, count, must_cross = True)
110         if not guid in self._connect:
111             self._cross_connect[guid] = dict()
112         if not connector_type_name in self._cross_connect[guid]:
113              self._cross_connect[guid][connector_type_name] = dict()
114         self._cross_connect[guid][connector_type_name] = \
115                 (cross_guid, cross_testbed_id, cross_factory_id, 
116                         cross_connector_type_name)
117
118     def add_trace(self, guid, trace_id):
119         if not guid in self._create:
120             raise RuntimeError("Element guid %d doesn't exist" % guid)
121         factory_id = self._create[guid]
122         factory = self._factories[factory_id]
123         if not trace_id in factory.traces:
124             raise RuntimeError("Element type '%s' has no trace '%s'" %
125                     (factory_id, trace_id))
126         if not guid in self._add_trace:
127             self._add_trace[guid] = list()
128         self._add_trace[guid].append(trace_id)
129
130     def add_adddress(self, guid, family, address, netprefix, broadcast):
131         if not guid in self._create:
132             raise RuntimeError("Element guid %d doesn't exist" % guid)
133         factory_id = self._create[guid]
134         factory = self._factories[factory_id]
135         if not factory.allow_addresses:
136             raise RuntimeError("Element type '%s' doesn't support addresses" %
137                     factory_id)
138         max_addresses = factory.get_attribute_value("MaxAddresses")
139         if guid in self._add_address:
140             count_addresses = len(self._add_address[guid])
141             if max_addresses == count_addresses:
142                 raise RuntimeError("Element guid %d of type '%s' can't accept \
143                         more addresses" % (guid, family_id))
144         else:
145             self._add_address[guid] = list()
146         self._add_address[guid].append((family, address, netprefix, broadcast))
147
148     def add_route(self, guid, destination, netprefix, nexthop):
149         if not guid in self._create:
150             raise RuntimeError("Element guid %d doesn't exist" % guid)
151         factory_id = self._create[guid]
152         factory = self._factories[factory_id]
153         if not factory.allow_routes:
154             raise RuntimeError("Element type '%s' doesn't support routes" %
155                     factory_id)
156         if not guid in self._add_route:
157             self._add_route[guid] = list()
158         self._add_route[guid].append((destination, netprefix, nexthop)) 
159
160     def do_setup(self):
161         raise NotImplementedError
162
163     def do_create(self):
164         guids = dict()
165         # order guids (elements) according to factory_id
166         for guid, factory_id in self._create.iteritems():
167             if not factory_id in guids:
168                guids[factory_id] = list()
169             guids[factory_id].append(guid)
170         # create elements following the factory_id order
171         for factory_id in self._metadata.factories_order:
172             # omit the factories that have no element to create
173             if factory_id not in guids:
174                 continue
175             factory = self._factories[factory_id]
176             for guid in guids[factory_id]:
177                 parameters = dict() if guid not in self._create_set else \
178                         self._create_set[guid]
179                 factory.create_function(self, guid, parameters)
180                 for name, value in parameters.iteritems():
181                     self.set(TIME_NOW, guid, name, value)
182
183     def do_connect(self):
184         for guid1, connections in self._connect.iteritems():
185             element1 = self._elements[guid1]
186             factory_id1 = self._create[guid1]
187             factory1 = self._factories[factory_id1]
188             for connector_type_name1, connections2 in connections.iteritems():
189                 connector_type1 = factory1.connector_type(connector_type_name1)
190                 for guid2, connector_type_name2 in connections2.iteritems():
191                     element2 = self._elements[guid2]
192                     factory_id2 = self._create[guid2]
193                     # Connections are executed in a "From -> To" direction only
194                     # This explicitly ignores the "To -> From" (mirror) 
195                     # connections of every connection pair. 
196                     code_to_connect = connector_type1.code_to_connect(
197                             self._testbed_id, factory_id2, 
198                             connector_type_name2)
199                     if code_to_connect:
200                         code_to_connect(element1, element2)
201
202     def do_configure(self):
203         raise NotImplementedError
204
205     def do_cross_connect(self):
206         for guid, cross_connections in self._cross_connect.iteritems():
207             element = self._elements[guid]
208             factory_id = self._create[guid]
209             factory = self._factories[factory_id]
210             for connector_type_name, cross_connection in \
211                     cross_connections.iteritems():
212                 connector_type = factory.connector_type(connector_type_name)
213                 (cross_testbed_id, cross_factory_id, 
214                         cross_connector_type_name) = cross_connection
215                 code_to_connect = connector_type.code_to_connect(
216                     cross_guid, cross_testbed_id, cross_factory_id, 
217                     cross_conector_type_name)
218                 if code_to_connect:
219                     code_to_connect(element, cross_guid)       
220
221     def set(self, time, guid, name, value):
222         if not guid in self._create:
223             raise RuntimeError("Element guid %d doesn't exist" % guid)
224         factory_id = self._create[guid]
225         factory = self._factories[factory_id]
226         if not factory.has_attribute(name):
227             raise RuntimeError("Invalid attribute %s for element type %s" %
228                     (name, factory_id))
229         if self._started and factory.is_attribute_design_only(name):
230             raise RuntimeError("Attribute %s can only be modified during experiment design" % name)
231         factory.set_attribute_value(name, value)
232         if guid not in self._set:
233             self._set[guid] = dict()
234         if time not in self._set[guid]:
235             self._set[guid][time] = dict()
236         self._set[guid][time][name] = value
237
238     def get(self, time, guid, name):
239         raise NotImplementedError
240
241     def start(self, time = TIME_NOW):
242         for guid, factory_id in self._create.iteritems():
243             factory = self._factories[factory_id]
244             start_function = factory.start_function
245             if start_function:
246                 traces = [] if guid not in self._add_trace else \
247                         self._add_trace[guid]
248                 parameters = dict() if guid not in self._create_set else \
249                         self._create_set[guid]
250                 start_function(self, guid, parameters, traces)
251         self._started = True
252
253     def action(self, time, guid, action):
254         raise NotImplementedError
255
256     def stop(self, time = TIME_NOW):
257         for guid, factory_id in self._create.iteritems():
258             factory = self._factories[factory_id]
259             stop_function = factory.stop_function
260             if stop_function:
261                 traces = [] if guid not in self._add_trace else \
262                         self._add_trace[guid]
263                 stop_function(self, guid, traces)
264
265     def status(self, guid):
266         for guid, factory_id in self._create.iteritems():
267             factory = self._factories[factory_id]
268             status_function = factory.status_function
269             if status_function:
270                 return status_function(self, guid)
271         return STATUS_UNDETERMINED
272
273     def trace(self, guid, trace_id):
274         raise NotImplementedError
275
276     def shutdown(self):
277         raise NotImplementedError
278
279     def get_connected(self, guid, connector_type_name, 
280             other_connector_type_name):
281         """searchs the connected elements for the specific connector_type_name 
282         pair"""
283         if guid not in self._connect:
284             return []
285         # all connections for all connectors for guid
286         all_connections = self._connect[guid]
287         if connector_type_name not in all_connections:
288             return []
289         # all connections for the specific connector
290         connections = all_connections[connector_type_name]
291         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
292                 in connections.iteritems() if \
293                 otr_connector_type_name == other_connector_type_name]
294         return specific_connections
295
296     def _get_connection_count(self, guid, connection_type_name):
297         count = 0
298         cross_count = 0
299         if guid in self._connect and connection_type_name in \
300                 self._connect[guid]:
301             count = len(self._connect[guid][connection_type_name])
302         if guid in self._cross_connect and connection_type_name in \
303                 self._cross_connect[guid]:
304             cross_count = len(self._cross_connect[guid][connection_type_name])
305         return count + cross_count
306