4d7d6e9691bb6d2bf123189ad6f13d8805fe3e01
[nepi.git] / src / nepi / testbeds / ns3 / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID
5 from nepi.core import testbed_impl
6 from nepi.util.constants import AF_INET, AF_INET6
7 import os
8
9 class TestbedInstance(testbed_impl.TestbedInstance):
10     def __init__(self, testbed_version):
11         super(TestbedInstance, self).__init__(TESTBED_ID, testbed_version)
12         self._ns3 = None
13         self._home_directory = None
14         self._traces = dict()
15
16     @property
17     def home_directory(self):
18         return self._home_directory
19
20     @property
21     def ns3(self):
22         return self._ns3
23
24     def do_setup(self):
25         self._home_directory = self._attributes.\
26             get_attribute_value("homeDirectory")
27         self._ns3 = self._load_ns3_module()
28
29     def do_configure(self):
30         # TODO: add traces!
31         # configure addressess
32         for guid, addresses in self._add_address.iteritems():
33             element = self._elements[guid]
34             for address in addresses:
35                 (family, address, netprefix, broadcast) = address
36                 if family == AF_INET:
37                     element.add_v4_address(address, netprefix)
38         # configure routes
39         for guid, routes in self._add_route.iteritems():
40             element = self._elements[guid]
41             for route in routes:
42                 (destination, netprefix, nexthop) = route
43                 element.add_route(prefix = destination, prefix_len = netprefix,
44                         nexthop = nexthop)
45
46     def set(self, time, guid, name, value):
47         super(TestbedInstance, self).set(time, guid, name, value)
48         factory_id = self._crerate[guid]
49         element = self._elements[guid]
50         self._set(element, factory_id, name, value)
51
52     def get(self, time, guid, name):
53         raise NotImplementedError
54         # TODO: take on account schedule time for the task
55         #element = self._elements[guid]
56         #return getattr(element, name)
57
58     def action(self, time, guid, action):
59         raise NotImplementedError
60
61     def trace(self, guid, trace_id):
62         fd = open("%s" % self.trace_filename(guid, trace_id), "r")
63         content = fd.read()
64         fd.close()
65         return content
66
67     def shutdown(self):
68         for element in self._elements.values():
69             element.destroy()
70
71     def trace_filename(self, guid, trace_id):
72         # TODO: Need to be defined inside a home!!!! with and experiment id_code
73         filename = self._trace_filenames[guid][trace_id]
74         return os.path.join(self.home_directory, filename)
75
76     def follow_trace(self, guid, trace_id, filename):
77         if guid not in self._traces:
78             self._traces[guid] = dict()
79         self._traces[guid][trace_id] = filename
80
81     def _set(self, element, factory_id, name, value):
82         TypeId = self.ns3.TypeId()
83         typeId = TypeId.LookupByName(factory_id)
84         index = None
85         attr_count = typeId.GetAttributeN()
86         for idx in range(attr_count):
87             if name == typeId.GetAttributeName(idx)
88                 index = idx
89                 break
90         checker = typeid.GetAttributeChecker(index)
91         ns3_value = attribute_checker.Create()
92         value = str(value)
93         if isinstance(value, bool):
94             value = value.lower()
95         ns3_value.DeserializeFromString(value, checker)
96         element.Set(name, ns3_value)
97
98     def _load_ns3_module(self):
99         import ctypes
100         import imp
101
102         bindings = self._attributes.get_attribute_value("ns3Bindings")
103         libfile = self._attributes.get_attribute_value("ns3Library")
104         simu_impl_type = self._attributes.get_attribute_value(
105                 "SimulatorImplementationType")
106         checksum = self._attributes.get_attribute_value("ChecksumEnabled")
107
108         if libfile:
109             ctypes.CDLL(libfile, ctypes.RTLD_GLOBAL)
110
111         path = [ os.path.dirname(__file__) ] + sys.path
112         if bindings:
113             path = [ bindings ] + path
114
115         module = imp.find_module ('ns3', path)
116         mod = imp.load_module ('ns3', *module)
117     
118         if simu_impl_type:
119             value = mod.StringValue(simu_impl_type)
120             mod.GlobalValue.Bind ("SimulatorImplementationType", value)
121         if checksum:
122             value = mod.BooleanValue(checksum)
123             mod.GlobalValue.Bind ("ChecksumEnabled", value)
124         return mod
125