Design + Execution first netns prototype working
[nepi.git] / src / nepi / testbeds / netns / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 from nepi.core import execute
6 from nepi.core.attributes import Attribute
7 from nepi.core.metadata import Metadata
8 from nepi.util import validation
9 from nepi.util.constants import AF_INET, AF_INET6
10
11 class TestbedConfiguration(execute.TestbedConfiguration):
12     def __init__(self):
13         super(TestbedConfiguration, self).__init__()
14         self.add_attribute("EnableDebug", "Enable netns debug output", 
15                 Attribute.BOOL, False, None, None, False, validation.is_bool)
16
17 class TestbedInstance(execute.TestbedInstance):
18     def __init__(self, testbed_version, configuration):
19         self._configuration = configuration
20         self._testbed_id = TESTBED_ID
21         self._testbed_version = testbed_version
22         self._factories = dict()
23         self._elements = dict()
24         self._create = dict()
25         self._set = dict()
26         self._connect = dict()
27         self._cross_connect = dict()
28         self._add_trace = dict()
29         self._add_address = dict()
30         self._add_route = dict()        
31
32         self._metadata = Metadata(self._testbed_id, self._testbed_version)
33         for factory in self._metadata.build_execute_factories():
34             self._factories[factory.factory_id] = factory
35
36         self._netns = self._load_netns_module(configuration)
37
38     @property
39     def elements(self):
40         return self._elements
41
42     @property
43     def netns(self):
44         return self._netns
45
46     def create(self, guid, factory_id):
47         if factory_id not in self._factories:
48             raise RuntimeError("Invalid element type %s for Netns version %s" %
49                     (factory_id, self._testbed_version))
50         if guid in self._create:
51             raise RuntimeError("Cannot add elements with the same guid: %d" %
52                     guid)
53         self._create[guid] = factory_id
54
55     def create_set(self, guid, name, value):
56         if not guid in self._create:
57             raise RuntimeError("Element guid %d doesn't exist" % guid)
58         factory_id = self._create[guid]
59         factory = self._factories[factory_id]
60         if not factory.has_attribute(name):
61             raise RuntimeError("Invalid attribute %s for element type %s" %
62                     (name, factory_id))
63         factory.set_attribute_value(name, value)
64         if guid not in self._set:
65             self._set[guid] = dict()
66         self._set[guid][name] = value
67        
68     def connect(self, guid1, connector_type_name1, guid2, 
69             connector_type_name2):
70         factory_id1 = self._create[guid1]
71         factory_id2 = self._create[guid2]
72         count = self._get_connection_count(guid1, connector_type_name1)
73         factory1 = self._factories[factory_id1]
74         connector_type = factory1.connector_type(connector_type_name1)
75         connector_type.can_connect(self._testbed_id, factory_id2, 
76                 connector_type_name2, count)
77         if not guid1 in self._connect:
78             self._connect[guid1] = dict()
79         if not connector_type_name1 in self._connect[guid1]:
80              self._connect[guid1][connector_type_name1] = dict()
81         self._connect[guid1][connector_type_name1][guid2] = \
82                connector_type_name2
83         if not guid2 in self._connect:
84             self._connect[guid2] = dict()
85         if not connector_type_name2 in self._connect[guid2]:
86              self._connect[guid2][connector_type_name2] = dict()
87         self._connect[guid2][connector_type_name2][guid1] = \
88                 connector_type_name1
89
90     def cross_connect(self, guid, connector_type_name, cross_guid, 
91             cross_testbed_id, cross_factory_id, cross_connector_type_name):
92         factory_id = self._create[guid]
93         count = self._get_connection_count(guid, connector_type_name)
94         factory = self._factories[factory_id]
95         connector_type = factory.connector_type(connector_type_name)
96         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
97                 cross_connector_type_name, count, must_cross = True)
98         if not guid in self._connect:
99             self._cross_connect[guid] = dict()
100         if not connector_type_name in self._cross_connect[guid]:
101              self._cross_connect[guid][connector_type_name] = dict()
102         self._cross_connect[guid][connector_type_name] = \
103                 (cross_guid, cross_testbed_id, cross_factory_id, 
104                         cross_connector_type_name)
105
106     def add_trace(self, guid, trace_id):
107         if not guid in self._create:
108             raise RuntimeError("Element guid %d doesn't exist" % guid)
109         factory_id = self._create[guid]
110         factory = self._factories[factory_id]
111         if not trace_id in factory_traces:
112             raise RuntimeError("Element type %s doesn't support trace %s" %
113                     (factory_id, trace_id))
114         if not guid in self._add_trace:
115             self._add_trace[guid] = list()
116         self._add_trace[guid].append(trace_id)
117
118     def add_adddress(self, guid, family, address, netprefix, broadcast):
119         if not guid in self._create:
120             raise RuntimeError("Element guid %d doesn't exist" % guid)
121         factory_id = self._create[guid]
122         factory = self._factories[factory_id]
123         if not factory.allow_addresses:
124             raise RuntimeError("Element type %s doesn't support addresses" %
125                     factory_id)
126         max_addresses = factory.get_attribute_value("MaxAddresses")
127         if guid in self._add_address:
128             count_addresses = len(self._add_address[guid])
129             if max_addresses == count_addresses:
130                 raise RuntimeError("Element guid %d of type %s can't accept \
131                         more addresses" % (guid, family_id))
132         else:
133             self._add_address[guid] = list()
134         self._add_address[guid].append((family, address, netprefix, broadcast))
135
136     def add_route(self, guid, family, destination, netprefix, nexthop, 
137             interface):
138         if not guid in self._create:
139             raise RuntimeError("Element guid %d doesn't exist" % guid)
140         factory_id = self._create[guid]
141         factory = self._factories[factory_id]
142         if not factory.allow_routes:
143             raise RuntimeError("Element type %s doesn't support routes" %
144                     factory_id)
145         if not guid in self._add_route:
146             self._add_route[guid] = list()
147         self._add_route[guid].append((family, destination, netprefix, nexthop,
148             interface)) 
149
150     def do_create(self):
151         guids = dict()
152         for guid, factory_id in self._create.iteritems():
153             if not factory_id in guids:
154                guids[factory_id] = list()
155             guids[factory_id].append(guid)
156         for factory_id in self._metadata.factories_order:
157             if factory_id not in guids:
158                 continue
159             factory = self._factories[factory_id]
160             for guid in guids[factory_id]:
161                 parameters = dict() if guid not in self._set else \
162                         self._set[guid]
163                 factory.create_function(self, guid, parameters)
164                 element = self._elements[guid]
165                 if element:
166                     for name, value in parameters.iteritems():
167                         setattr(element, name, value)
168
169     def do_connect(self):
170         for guid1, connections in self._connect.iteritems():
171             element1 = self._elements[guid1]
172             factory_id1 = self._create[guid1]
173             factory1 = self._factories[factory_id1]
174             for connector_type_name1, connections2 in connections.iteritems():
175                 connector_type1 = factory1.connector_type(connector_type_name1)
176                 for guid2, connector_type_name2 in connections2.iteritems():
177                     element2 = self._elements[guid2]
178                     factory_id2 = self._create[guid2]
179                     # Connections are executed in a "From -> To" direction only
180                     # This explicitly ignores the "To -> From" (mirror) 
181                     # connections of every connection pair. 
182                     code_to_connect = connector_type1.code_to_connect(
183                             self._testbed_id, factory_id2, 
184                             connector_type_name2)
185                     if code_to_connect:
186                         code_to_connect(element1, element2)
187
188     def do_configure(self):
189         # TODO: add traces!
190         # add addressess
191         for guid, addresses in self._add_address.iteritems():
192             element = self._elements[guid]
193             for address in addresses:
194                 (family, address, netprefix, broadcast) = address
195                 if family == AF_INET:
196                     element.add_v4_address(address, netprefix)
197         # add routes
198         for guid, routes in self._add_route.iteritems():
199             element = self._elements[guid]
200             for route in routes:
201                 # TODO: family and interface not used!!!!!
202                 (family, destination, netprefix, nexthop, interfaces) = routes
203                 element.add_route(prefix = destination, prefix_len = netprefix,
204                         nexthop = nexthop)
205
206     def do_cross_connect(self):
207         for guid, cross_connections in self._cross_connect.iteritems():
208             element = self._elements[guid]
209             factory_id = self._create[guid]
210             factory = self._factories[factory_id]
211             for connector_type_name, cross_connection in \
212                     cross_connections.iteritems():
213                 connector_type = factory.connector_type(connector_type_name)
214                 (cross_testbed_id, cross_factory_id, 
215                         cross_connector_type_name) = cross_connection
216                 code_to_connect = connector_type.code_to_connect(
217                     cross_guid, cross_testbed_id, cross_factory_id, 
218                     cross_conector_type_name)
219                 if code_to_connect:
220                     code_to_connect(element, cross_guid)       
221
222     def set(self, time, guid, name, value):
223         # TODO: take on account schedule time for the task 
224         element = self._elements[guid]
225         setattr(element, name, value)
226
227     def get(self, time, guid, name):
228         # TODO: take on account schedule time for the task 
229         element = self._elements[guid]
230         return getattr(element, name)
231
232     def start(self, time = "0s"):
233         for guid, factory_id in self._create.iteritems():
234             factory = self._factories[factory_id]
235             start_function = factory.start_function
236             if start_function:
237                 traces = [] if guid not in self._add_trace else \
238                         self._add_trace[guid]
239                 parameters = dict() if guid not in self._set else \
240                         self._set[guid]
241                 start_function(self, guid, parameters, traces)
242
243     def action(self, time, guid, action):
244         raise NotImplementedError
245
246     def stop(self, time = "0s"):
247         for guid, factory_id in self._create.iteritems():
248             factory = self._factories[factory_id]
249             stop_function = factory.stop_function
250             if stop_function:
251                 traces = [] if guid not in self._add_trace else \
252                         self._add_trace[guid]
253                 stop_function(self, guid, traces)
254
255     def status(self, guid):
256           for guid, factory_id in self._create.iteritems():
257             factory = self._factories[factory_id]
258             status_function = factory.status_function
259             if status_function:
260                 result = status_function(self, guid)
261
262     def trace(self, guid, trace_id):
263         raise NotImplementedError
264
265     def shutdown(self):
266         for element in self._elements.values():
267             element.destroy()
268
269     def get_connected(self, guid, connector_type_name, 
270             other_connector_type_name):
271         """searchs the connected elements for the specific connector_type_name 
272         pair"""
273         if guid not in self._connect:
274             return []
275         # all connections for all connectors for guid
276         all_connections = self._connect[guid]
277         if connector_type_name not in all_connections:
278             return []
279         # all connections for the specific connector
280         connections = all_connections[connector_type_name]
281         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
282                 in connections.iteritems() if \
283                 otr_connector_type_name == other_connector_type_name]
284         return specific_connections
285
286     def _load_netns_module(self, configuration):
287         # TODO: Do something with the configuration!!!
288         import sys
289         __import__("netns")
290         netns_mod = sys.modules["netns"]
291         # enable debug
292         enable_debug = configuration.get_attribute_value("EnableDebug")
293         if enable_debug:
294             netns_mod.environ.set_log_level(netns_mod.environ.LOG_DEBUG)
295         return netns_mod
296
297     def _get_connection_count(self, guid, connection_type_name):
298         count = 0
299         cross_count = 0
300         if guid in self._connect and connection_type_name in \
301                 self._connect[guid]:
302             count = len(self._connect[guid][connection_type_name])
303         if guid in self._cross_connect and connection_type_name in \
304                 self._cross_connect[guid]:
305             cross_count = len(self._cross_connect[guid][connection_type_name])
306         return count + cross_count
307
308