test support added
[nepi.git] / src / nepi / core / execute_impl.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from nepi.core import execute
5 from nepi.core.attributes import Attribute
6 from nepi.core.metadata import Metadata
7 from nepi.util import validation
8 from nepi.util.constants import AF_INET, AF_INET6
9
10 TIME_NOW = "0s"
11
12 class TestbedInstance(execute.TestbedInstance):
13     def __init__(self, testbed_id, testbed_version, configuration):
14         super(TestbedInstance, self).__init__(testbed_id, testbed_version,
15                 configuration)
16         self._factories = dict()
17         self._elements = dict()
18         self._create = dict()
19         self._set = dict()
20         self._connect = dict()
21         self._cross_connect = dict()
22         self._add_trace = dict()
23         self._add_address = dict()
24         self._add_route = dict()        
25
26         self._metadata = Metadata(self._testbed_id, self._testbed_version)
27         for factory in self._metadata.build_execute_factories():
28             self._factories[factory.factory_id] = factory
29
30     @property
31     def elements(self):
32         return self._elements
33
34     def create(self, guid, factory_id):
35         if factory_id not in self._factories:
36             raise RuntimeError("Invalid element type %s for Netns version %s" %
37                     (factory_id, self._testbed_version))
38         if guid in self._create:
39             raise RuntimeError("Cannot add elements with the same guid: %d" %
40                     guid)
41         self._create[guid] = factory_id
42
43     def create_set(self, guid, name, value):
44         if not guid in self._create:
45             raise RuntimeError("Element guid %d doesn't exist" % guid)
46         factory_id = self._create[guid]
47         factory = self._factories[factory_id]
48         if not factory.has_attribute(name):
49             raise RuntimeError("Invalid attribute %s for element type %s" %
50                     (name, factory_id))
51         factory.set_attribute_value(name, value)
52         if guid not in self._set:
53             self._set[guid] = dict()
54         self._set[guid][name] = value
55        
56     def connect(self, guid1, connector_type_name1, guid2, 
57             connector_type_name2):
58         factory_id1 = self._create[guid1]
59         factory_id2 = self._create[guid2]
60         count = self._get_connection_count(guid1, connector_type_name1)
61         factory1 = self._factories[factory_id1]
62         connector_type = factory1.connector_type(connector_type_name1)
63         connector_type.can_connect(self._testbed_id, factory_id2, 
64                 connector_type_name2, count)
65         if not guid1 in self._connect:
66             self._connect[guid1] = dict()
67         if not connector_type_name1 in self._connect[guid1]:
68              self._connect[guid1][connector_type_name1] = dict()
69         self._connect[guid1][connector_type_name1][guid2] = \
70                connector_type_name2
71         if not guid2 in self._connect:
72             self._connect[guid2] = dict()
73         if not connector_type_name2 in self._connect[guid2]:
74              self._connect[guid2][connector_type_name2] = dict()
75         self._connect[guid2][connector_type_name2][guid1] = \
76                 connector_type_name1
77
78     def cross_connect(self, guid, connector_type_name, cross_guid, 
79             cross_testbed_id, cross_factory_id, cross_connector_type_name):
80         factory_id = self._create[guid]
81         count = self._get_connection_count(guid, connector_type_name)
82         factory = self._factories[factory_id]
83         connector_type = factory.connector_type(connector_type_name)
84         connector_type.can_connect(cross_testbed_id, cross_factory_id, 
85                 cross_connector_type_name, count, must_cross = True)
86         if not guid in self._connect:
87             self._cross_connect[guid] = dict()
88         if not connector_type_name in self._cross_connect[guid]:
89              self._cross_connect[guid][connector_type_name] = dict()
90         self._cross_connect[guid][connector_type_name] = \
91                 (cross_guid, cross_testbed_id, cross_factory_id, 
92                         cross_connector_type_name)
93
94     def add_trace(self, guid, trace_id):
95         if not guid in self._create:
96             raise RuntimeError("Element guid %d doesn't exist" % guid)
97         factory_id = self._create[guid]
98         factory = self._factories[factory_id]
99         if not trace_id in factory.traces:
100             raise RuntimeError("Element type '%s' has no trace '%s'" %
101                     (factory_id, trace_id))
102         if not guid in self._add_trace:
103             self._add_trace[guid] = list()
104         self._add_trace[guid].append(trace_id)
105
106     def add_adddress(self, guid, family, address, netprefix, broadcast):
107         if not guid in self._create:
108             raise RuntimeError("Element guid %d doesn't exist" % guid)
109         factory_id = self._create[guid]
110         factory = self._factories[factory_id]
111         if not factory.allow_addresses:
112             raise RuntimeError("Element type '%s' doesn't support addresses" %
113                     factory_id)
114         max_addresses = factory.get_attribute_value("MaxAddresses")
115         if guid in self._add_address:
116             count_addresses = len(self._add_address[guid])
117             if max_addresses == count_addresses:
118                 raise RuntimeError("Element guid %d of type '%s' can't accept \
119                         more addresses" % (guid, family_id))
120         else:
121             self._add_address[guid] = list()
122         self._add_address[guid].append((family, address, netprefix, broadcast))
123
124     def add_route(self, guid, destination, netprefix, nexthop):
125         if not guid in self._create:
126             raise RuntimeError("Element guid %d doesn't exist" % guid)
127         factory_id = self._create[guid]
128         factory = self._factories[factory_id]
129         if not factory.allow_routes:
130             raise RuntimeError("Element type '%s' doesn't support routes" %
131                     factory_id)
132         if not guid in self._add_route:
133             self._add_route[guid] = list()
134         self._add_route[guid].append((destination, netprefix, nexthop)) 
135
136     def do_create(self):
137         guids = dict()
138         for guid, factory_id in self._create.iteritems():
139             if not factory_id in guids:
140                guids[factory_id] = list()
141             guids[factory_id].append(guid)
142         for factory_id in self._metadata.factories_order:
143             if factory_id not in guids:
144                 continue
145             factory = self._factories[factory_id]
146             for guid in guids[factory_id]:
147                 parameters = dict() if guid not in self._set else \
148                         self._set[guid]
149                 factory.create_function(self, guid, parameters)
150                 for name, value in parameters.iteritems():
151                     self.set(TIME_NOW, guid, name, value)
152
153     def do_connect(self):
154         for guid1, connections in self._connect.iteritems():
155             element1 = self._elements[guid1]
156             factory_id1 = self._create[guid1]
157             factory1 = self._factories[factory_id1]
158             for connector_type_name1, connections2 in connections.iteritems():
159                 connector_type1 = factory1.connector_type(connector_type_name1)
160                 for guid2, connector_type_name2 in connections2.iteritems():
161                     element2 = self._elements[guid2]
162                     factory_id2 = self._create[guid2]
163                     # Connections are executed in a "From -> To" direction only
164                     # This explicitly ignores the "To -> From" (mirror) 
165                     # connections of every connection pair. 
166                     code_to_connect = connector_type1.code_to_connect(
167                             self._testbed_id, factory_id2, 
168                             connector_type_name2)
169                     if code_to_connect:
170                         code_to_connect(element1, element2)
171
172     def do_configure(self):
173         raise NotImplementedError
174
175     def do_cross_connect(self):
176         for guid, cross_connections in self._cross_connect.iteritems():
177             element = self._elements[guid]
178             factory_id = self._create[guid]
179             factory = self._factories[factory_id]
180             for connector_type_name, cross_connection in \
181                     cross_connections.iteritems():
182                 connector_type = factory.connector_type(connector_type_name)
183                 (cross_testbed_id, cross_factory_id, 
184                         cross_connector_type_name) = cross_connection
185                 code_to_connect = connector_type.code_to_connect(
186                     cross_guid, cross_testbed_id, cross_factory_id, 
187                     cross_conector_type_name)
188                 if code_to_connect:
189                     code_to_connect(element, cross_guid)       
190
191     def set(self, time, guid, name, value):
192         raise NotImplementedError
193
194     def get(self, time, guid, name):
195         raise NotImplementedError
196
197     def start(self, time = TIME_NOW):
198         for guid, factory_id in self._create.iteritems():
199             factory = self._factories[factory_id]
200             start_function = factory.start_function
201             if start_function:
202                 traces = [] if guid not in self._add_trace else \
203                         self._add_trace[guid]
204                 parameters = dict() if guid not in self._set else \
205                         self._set[guid]
206                 start_function(self, guid, parameters, traces)
207
208     def action(self, time, guid, action):
209         raise NotImplementedError
210
211     def stop(self, time = TIME_NOW):
212         for guid, factory_id in self._create.iteritems():
213             factory = self._factories[factory_id]
214             stop_function = factory.stop_function
215             if stop_function:
216                 traces = [] if guid not in self._add_trace else \
217                         self._add_trace[guid]
218                 stop_function(self, guid, traces)
219
220     def status(self, guid):
221           for guid, factory_id in self._create.iteritems():
222             factory = self._factories[factory_id]
223             status_function = factory.status_function
224             if status_function:
225                 result = status_function(self, guid)
226
227     def trace(self, guid, trace_id):
228         raise NotImplementedError
229
230     def shutdown(self):
231         raise NotImplementedError
232
233     def get_connected(self, guid, connector_type_name, 
234             other_connector_type_name):
235         """searchs the connected elements for the specific connector_type_name 
236         pair"""
237         if guid not in self._connect:
238             return []
239         # all connections for all connectors for guid
240         all_connections = self._connect[guid]
241         if connector_type_name not in all_connections:
242             return []
243         # all connections for the specific connector
244         connections = all_connections[connector_type_name]
245         specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
246                 in connections.iteritems() if \
247                 otr_connector_type_name == other_connector_type_name]
248         return specific_connections
249
250     def _get_connection_count(self, guid, connection_type_name):
251         count = 0
252         cross_count = 0
253         if guid in self._connect and connection_type_name in \
254                 self._connect[guid]:
255             count = len(self._connect[guid][connection_type_name])
256         if guid in self._cross_connect and connection_type_name in \
257                 self._cross_connect[guid]:
258             cross_count = len(self._cross_connect[guid][connection_type_name])
259         return count + cross_count
260
261