tun_connect fix: forgot to pass queueclass to tun_fwd (oops)
[nepi.git] / src / nepi / testbeds / netns / execute.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from constants import TESTBED_ID, TESTBED_VERSION
5 from nepi.core import testbed_impl
6 from nepi.util.constants import TIME_NOW
7 import os
8 import fcntl
9 import threading
10
11 class TestbedController(testbed_impl.TestbedController):
12     from nepi.util.tunchannel_impl import TunChannel
13     
14     class HostLock(object):
15         # This class is used as a lock to prevent concurrency issues with more
16         # than one instance of netns running in the same machine. Both in 
17         # different processes or different threads.
18         taken = False
19         processcond = threading.Condition()
20         
21         def __init__(self, lockfile):
22             processcond = self.__class__.processcond
23             
24             processcond.acquire()
25             try:
26                 # It's not reentrant
27                 while self.__class__.taken:
28                     processcond.wait()
29                 self.__class__.taken = True
30             finally:
31                 processcond.release()
32             
33             self.lockfile = lockfile
34             fcntl.flock(self.lockfile, fcntl.LOCK_EX)
35         
36         def __del__(self):
37             processcond = self.__class__.processcond
38             
39             processcond.acquire()
40             try:
41                 assert self.__class__.taken, "HostLock unlocked without being locked!"
42
43                 fcntl.flock(self.lockfile, fcntl.LOCK_UN)
44                 
45                 # It's not reentrant
46                 self.__class__.taken = False
47                 processcond.notify()
48             finally:
49                 processcond.release()
50     
51     def __init__(self):
52         super(TestbedController, self).__init__(TESTBED_ID, TESTBED_VERSION)
53         self._netns = None
54         self._home_directory = None
55         self._traces = dict()
56         self._netns_lock = open("/tmp/nepi-netns-lock","a")
57     
58     def _lock(self):
59         return self.HostLock(self._netns_lock)
60
61     @property
62     def home_directory(self):
63         return self._home_directory
64
65     @property
66     def netns(self):
67         return self._netns
68
69     def do_setup(self):
70         self._home_directory = self._attributes.\
71             get_attribute_value("homeDirectory")
72         # create home...
73         home = os.path.normpath(self.home_directory)
74         if not os.path.exists(home):
75             os.makedirs(home, 0755)
76
77         self._netns = self._load_netns_module()
78         super(TestbedController, self).do_setup()
79     
80     def do_create(self):
81         lock = self._lock()
82         super(TestbedController, self).do_create()    
83
84     def set(self, guid, name, value, time = TIME_NOW):
85         super(TestbedController, self).set(guid, name, value, time)
86         # TODO: take on account schedule time for the task 
87         factory_id = self._create[guid]
88         factory = self._factories[factory_id]
89         if factory.box_attributes.is_attribute_metadata(name):
90             return
91         element = self._elements.get(guid)
92         if element:
93             setattr(element, name, value)
94
95     def get(self, guid, name, time = TIME_NOW):
96         value = super(TestbedController, self).get(guid, name, time)
97         # TODO: take on account schedule time for the task
98         factory_id = self._create[guid]
99         factory = self._factories[factory_id]
100         if factory.box_attributes.is_attribute_metadata(name):
101             return value
102         element = self._elements.get(guid)
103         try:
104             return getattr(element, name)
105         except (KeyError, AttributeError):
106             return value
107
108     def action(self, time, guid, action):
109         raise NotImplementedError
110
111     def shutdown(self):
112         for guid, traces in self._traces.iteritems():
113             for trace_id, (trace, filename) in traces.iteritems():
114                 if hasattr(trace, "close"):
115                     trace.close()
116         for guid, element in self._elements.iteritems():
117             if isinstance(element, self.TunChannel):
118                 element.Cleanup()
119             else:
120                 factory_id = self._create[guid]
121                 if factory_id == "Node":
122                     element.destroy()
123         self._elements.clear()
124
125     def trace_filepath(self, guid, trace_id, filename = None):
126         if not filename:
127             (trace, filename) = self._traces[guid][trace_id]
128         return os.path.join(self.home_directory, filename)
129
130     def follow_trace(self, guid, trace_id, trace, filename):
131         if not guid in self._traces:
132             self._traces[guid] = dict()
133         self._traces[guid][trace_id] = (trace, filename)
134
135     def _load_netns_module(self):
136         # TODO: Do something with the configuration!!!
137         import sys
138         __import__("netns")
139         netns_mod = sys.modules["netns"]
140         # enable debug
141         enable_debug = self._attributes.get_attribute_value("enableDebug")
142         if enable_debug:
143             netns_mod.environ.set_log_level(netns_mod.environ.LOG_DEBUG)
144         return netns_mod
145