test support added
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Fri, 4 Mar 2011 17:27:45 +0000 (18:27 +0100)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Fri, 4 Mar 2011 17:27:45 +0000 (18:27 +0100)
16 files changed:
Makefile [new file with mode: 0644]
examples/design1.py [deleted file]
examples/execution1.py [deleted file]
setup.cfg [new file with mode: 0644]
setup.py [new file with mode: 0755]
src/nepi/core/design.py
src/nepi/core/execute.py
src/nepi/core/execute_impl.py [new file with mode: 0644]
src/nepi/testbeds/netns/execute.py
src/nepi/testbeds/netns/metadata_v01.py
src/nepi/util/environ.py [new file with mode: 0644]
src/nepi/util/parser/_xml.py
src/nepi/util/parser/base.py
test/testbeds/netns/design.py [new file with mode: 0755]
test/testbeds/netns/execute.py [new file with mode: 0755]
test/util/test_util.py [new file with mode: 0644]

diff --git a/Makefile b/Makefile
new file mode 100644 (file)
index 0000000..12570b2
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,64 @@
+SRCDIR      = $(CURDIR)/src
+TESTDIR     = $(CURDIR)/test
+TESTLIB     = $(TESTDIR)/lib
+BUILDDIR    = $(CURDIR)/build
+DISTDIR     = $(CURDIR)/dist
+
+# stupid distutils, it's broken in so many ways
+SUBBUILDDIR = $(shell python -c 'import distutils.util, sys; \
+             print "lib.%s-%s" % (distutils.util.get_platform(), \
+             sys.version[0:3])')
+PYTHON25 := $(shell python -c 'import sys; v = sys.version_info; \
+    print (1 if v[0] <= 2 and v[1] <= 5 else 0)')
+
+ifeq ($(PYTHON25),0)
+BUILDDIR := $(BUILDDIR)/$(SUBBUILDDIR)
+else
+BUILDDIR := $(BUILDDIR)/lib
+endif
+
+#PYPATH = $(BUILDDIR):$(TESTLIB):$(PYTHONPATH)
+PYPATH = "../nepi2/src:../nepi2/test/util:../netns/src"
+COVERAGE = $(or $(shell which coverage), $(shell which python-coverage), \
+          coverage)
+
+all:
+       ./setup.py build
+
+install: all
+       ./setup.py install
+
+test: all
+       retval=0; \
+              for i in `find "$(TESTDIR)" -perm -u+x -type f`; do \
+              echo $$i; \
+              TESTLIBPATH="$(TESTLIB)" PYTHONPATH="$(PYPATH)" $$i || retval=$$?; \
+              done; exit $$retval
+
+coverage: all
+       rm -f .coverage
+       for i in `find "$(TESTDIR)" -perm -u+x -type f`; do \
+               set -e; \
+               TESTLIBPATH="$(TESTLIB)" PYTHONPATH="$(PYPATH)" $(COVERAGE) -x $$i; \
+               done
+       $(COVERAGE) -c
+       $(COVERAGE) -r -m `find "$(BUILDDIR)" -name \\*.py -type f`
+       rm -f .coverage
+
+clean:
+       ./setup.py clean
+       rm -f `find -name \*.pyc` .coverage *.pcap
+
+distclean: clean
+       rm -rf "$(DISTDIR)"
+
+MANIFEST:
+       find . -path ./.hg\* -prune -o -path ./build -prune -o \
+               -name \*.pyc -prune -o -name \*.swp -prune -o \
+               -name MANIFEST -prune -o -type f -print | \
+               sed 's#^\./##' | sort > MANIFEST
+
+dist: MANIFEST
+       ./setup.py sdist
+
+.PHONY: all clean distclean dist test coverage install MANIFEST
diff --git a/examples/design1.py b/examples/design1.py
deleted file mode 100644 (file)
index 77373bd..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from nepi.core.design import ExperimentDescription, FactoriesProvider
-
-exp_desc = ExperimentDescription()
-testbed_version = "01"
-testbed_id = "netns"
-netns_provider = FactoriesProvider(testbed_id, testbed_version)
-netns_desc = exp_desc.add_testbed_description(netns_provider)
-
-node1 = netns_desc.create("Node")
-node2 = netns_desc.create("Node")
-iface1 = netns_desc.create("NodeInterface")
-iface1.set_attribute_value("up", True)
-node1.connector("devs").connect(iface1.connector("node"))
-ip1 = iface1.add_address()
-ip1.set_attribute_value("Address", "10.0.0.1")
-iface2 = netns_desc.create("NodeInterface")
-iface2.set_attribute_value("up", True)
-node2.connector("devs").connect(iface2.connector("node"))
-ip2 = iface2.add_address()
-ip2.set_attribute_value("Address", "10.0.0.2")
-switch = netns_desc.create("Switch")
-switch.set_attribute_value("up", True)
-iface1.connector("switch").connect(switch.connector("devs"))
-iface2.connector("switch").connect(switch.connector("devs"))
-app = netns_desc.create("Application")
-app.set_attribute_value("command", "ping -qc10 10.0.0.2")
-app.connector("node").connect(node1.connector("apps"))
-
-xml = exp_desc.to_xml()
-exp_desc2 = ExperimentDescription()
-exp_desc2.from_xml(xml)
-xml2 = exp_desc2.to_xml()
-assert xml == xml2
-
-
diff --git a/examples/execution1.py b/examples/execution1.py
deleted file mode 100644 (file)
index b1c1e55..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from nepi.core.design import AF_INET
-from nepi.testbeds import netns
-
-user = "alina"
-testbed_version = "01"
-config = netns.TestbedConfiguration()
-instance = netns.TestbedInstance(testbed_version, config)
-
-instance.create(2, "Node")
-instance.create(3, "Node")
-instance.create(4, "NodeInterface")
-instance.create_set(4, "up", True)
-instance.connect(2, "devs", 4, "node")
-instance.add_adddress(4, AF_INET, "10.0.0.1", 24, None)
-instance.create(5, "NodeInterface")
-instance.create_set(5, "up", True)
-instance.connect(3, "devs", 5, "node")
-instance.add_adddress(5, AF_INET, "10.0.0.2", 24, None)
-instance.create(6, "Switch")
-instance.create_set(6, "up", True)
-instance.connect(4, "switch", 6, "devs")
-instance.connect(5, "switch", 6, "devs")
-instance.create(7, "Application")
-instance.create_set(7, "command", "ping -qc10 10.0.0.2")
-instance.create_set(7, "user", user)
-instance.connect(7, "node", 2, "apps")
-
-instance.do_create()
-instance.do_connect()
-instance.do_configure()
-instance.start()
-import time
-time.sleep(5)
-instance.stop()
-instance.shutdown()
-
diff --git a/setup.cfg b/setup.cfg
new file mode 100644 (file)
index 0000000..e1b8322
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[clean]
+all = 1
diff --git a/setup.py b/setup.py
new file mode 100755 (executable)
index 0000000..38d64e0
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8
+from distutils.core import setup, Extension, Command
+
+setup(
+        name        = "nepi",
+        version     = "0.1",
+        description = "High-level abstraction for running network experiments",
+#        long_description = longdesc,
+        author      = "Alina Quereilhac and Martín Ferrari",
+        url         = "http://yans.pl.sophia.inria.fr/code/hgwebdir.cgi/nepi/",
+        license     = "GPLv2",
+        platforms   = "Linux",
+        packages    = [
+            "nepi",
+            "nepi.testbeds",
+            "nepi.testbeds.netns",
+            "nepi.core",
+            "nepi.util" ],
+        package_dir = {"": "src"}
+    )
index 52d3ddf..edd2e03 100644 (file)
@@ -206,10 +206,6 @@ class Route(AttributesMap):
                 help = "Address for the next hop", 
                 type = Attribute.STRING,
                 validation_function = address_validation)
-        self.add_attribute(name = "Interface",
-                help = "Local interface address", 
-                type = Attribute.STRING,
-                validation_function = address_validation)
 
 class Box(AttributesMap):
     def __init__(self, guid, factory, testbed_guid, container = None):
index aa8b8af..a657c53 100644 (file)
@@ -1,7 +1,8 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from nepi.core.attributes import AttributesMap
+from nepi.core.attributes import Attribute, AttributesMap
+from nepi.util import validation
 import sys
 
 class ConnectorType(object):
@@ -141,11 +142,25 @@ class Factory(AttributesMap):
         self._traces.append(trace_id)
 
 class TestbedConfiguration(AttributesMap):
-    pass
+    def __init__(self):
+        super(TestbedConfiguration, self).__init__()
+        self.add_attribute("HomeDirectory", 
+                "Path to the local directory where traces and other files \
+                        will be stored",
+                Attribute.STRING, False, None, None, "", 
+                validation.is_string)
 
 class TestbedInstance(object):
-    def __init__(self, testbed_version, configuration):
-        pass
+    def __init__(self, testbed_id, testbed_version, configuration):
+        self._testbed_id = testbed_id
+        self._testbed_version = testbed_version
+        self._configuration = configuration
+        self._home_directory = configuration.get_attribute_value(
+                "HomeDirectory")
+
+    @property
+    def home_directory(self):
+        return self._home_directory
 
     def create(self, guid, factory_id):
         """Instructs creation of element """
@@ -177,8 +192,7 @@ class TestbedInstance(object):
     def add_adddress(self, guid, family, address, netprefix, broadcast): 
         raise NotImplementedError
 
-    def add_route(self, guid, family, destination, netprefix, nexthop, 
-            interface):
+    def add_route(self, guid, destination, netprefix, nexthop):
         raise NotImplementedError
 
     def do_configure(self):
diff --git a/src/nepi/core/execute_impl.py b/src/nepi/core/execute_impl.py
new file mode 100644 (file)
index 0000000..61e0175
--- /dev/null
@@ -0,0 +1,261 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from nepi.core import execute
+from nepi.core.attributes import Attribute
+from nepi.core.metadata import Metadata
+from nepi.util import validation
+from nepi.util.constants import AF_INET, AF_INET6
+
+TIME_NOW = "0s"
+
+class TestbedInstance(execute.TestbedInstance):
+    def __init__(self, testbed_id, testbed_version, configuration):
+        super(TestbedInstance, self).__init__(testbed_id, testbed_version,
+                configuration)
+        self._factories = dict()
+        self._elements = dict()
+        self._create = dict()
+        self._set = dict()
+        self._connect = dict()
+        self._cross_connect = dict()
+        self._add_trace = dict()
+        self._add_address = dict()
+        self._add_route = dict()        
+
+        self._metadata = Metadata(self._testbed_id, self._testbed_version)
+        for factory in self._metadata.build_execute_factories():
+            self._factories[factory.factory_id] = factory
+
+    @property
+    def elements(self):
+        return self._elements
+
+    def create(self, guid, factory_id):
+        if factory_id not in self._factories:
+            raise RuntimeError("Invalid element type %s for Netns version %s" %
+                    (factory_id, self._testbed_version))
+        if guid in self._create:
+            raise RuntimeError("Cannot add elements with the same guid: %d" %
+                    guid)
+        self._create[guid] = factory_id
+
+    def create_set(self, guid, name, value):
+        if not guid in self._create:
+            raise RuntimeError("Element guid %d doesn't exist" % guid)
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if not factory.has_attribute(name):
+            raise RuntimeError("Invalid attribute %s for element type %s" %
+                    (name, factory_id))
+        factory.set_attribute_value(name, value)
+        if guid not in self._set:
+            self._set[guid] = dict()
+        self._set[guid][name] = value
+       
+    def connect(self, guid1, connector_type_name1, guid2, 
+            connector_type_name2):
+        factory_id1 = self._create[guid1]
+        factory_id2 = self._create[guid2]
+        count = self._get_connection_count(guid1, connector_type_name1)
+        factory1 = self._factories[factory_id1]
+        connector_type = factory1.connector_type(connector_type_name1)
+        connector_type.can_connect(self._testbed_id, factory_id2, 
+                connector_type_name2, count)
+        if not guid1 in self._connect:
+            self._connect[guid1] = dict()
+        if not connector_type_name1 in self._connect[guid1]:
+             self._connect[guid1][connector_type_name1] = dict()
+        self._connect[guid1][connector_type_name1][guid2] = \
+               connector_type_name2
+        if not guid2 in self._connect:
+            self._connect[guid2] = dict()
+        if not connector_type_name2 in self._connect[guid2]:
+             self._connect[guid2][connector_type_name2] = dict()
+        self._connect[guid2][connector_type_name2][guid1] = \
+                connector_type_name1
+
+    def cross_connect(self, guid, connector_type_name, cross_guid, 
+            cross_testbed_id, cross_factory_id, cross_connector_type_name):
+        factory_id = self._create[guid]
+        count = self._get_connection_count(guid, connector_type_name)
+        factory = self._factories[factory_id]
+        connector_type = factory.connector_type(connector_type_name)
+        connector_type.can_connect(cross_testbed_id, cross_factory_id, 
+                cross_connector_type_name, count, must_cross = True)
+        if not guid in self._connect:
+            self._cross_connect[guid] = dict()
+        if not connector_type_name in self._cross_connect[guid]:
+             self._cross_connect[guid][connector_type_name] = dict()
+        self._cross_connect[guid][connector_type_name] = \
+                (cross_guid, cross_testbed_id, cross_factory_id, 
+                        cross_connector_type_name)
+
+    def add_trace(self, guid, trace_id):
+        if not guid in self._create:
+            raise RuntimeError("Element guid %d doesn't exist" % guid)
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if not trace_id in factory.traces:
+            raise RuntimeError("Element type '%s' has no trace '%s'" %
+                    (factory_id, trace_id))
+        if not guid in self._add_trace:
+            self._add_trace[guid] = list()
+        self._add_trace[guid].append(trace_id)
+
+    def add_adddress(self, guid, family, address, netprefix, broadcast):
+        if not guid in self._create:
+            raise RuntimeError("Element guid %d doesn't exist" % guid)
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if not factory.allow_addresses:
+            raise RuntimeError("Element type '%s' doesn't support addresses" %
+                    factory_id)
+        max_addresses = factory.get_attribute_value("MaxAddresses")
+        if guid in self._add_address:
+            count_addresses = len(self._add_address[guid])
+            if max_addresses == count_addresses:
+                raise RuntimeError("Element guid %d of type '%s' can't accept \
+                        more addresses" % (guid, family_id))
+        else:
+            self._add_address[guid] = list()
+        self._add_address[guid].append((family, address, netprefix, broadcast))
+
+    def add_route(self, guid, destination, netprefix, nexthop):
+        if not guid in self._create:
+            raise RuntimeError("Element guid %d doesn't exist" % guid)
+        factory_id = self._create[guid]
+        factory = self._factories[factory_id]
+        if not factory.allow_routes:
+            raise RuntimeError("Element type '%s' doesn't support routes" %
+                    factory_id)
+        if not guid in self._add_route:
+            self._add_route[guid] = list()
+        self._add_route[guid].append((destination, netprefix, nexthop)) 
+
+    def do_create(self):
+        guids = dict()
+        for guid, factory_id in self._create.iteritems():
+            if not factory_id in guids:
+               guids[factory_id] = list()
+            guids[factory_id].append(guid)
+        for factory_id in self._metadata.factories_order:
+            if factory_id not in guids:
+                continue
+            factory = self._factories[factory_id]
+            for guid in guids[factory_id]:
+                parameters = dict() if guid not in self._set else \
+                        self._set[guid]
+                factory.create_function(self, guid, parameters)
+                for name, value in parameters.iteritems():
+                    self.set(TIME_NOW, guid, name, value)
+
+    def do_connect(self):
+        for guid1, connections in self._connect.iteritems():
+            element1 = self._elements[guid1]
+            factory_id1 = self._create[guid1]
+            factory1 = self._factories[factory_id1]
+            for connector_type_name1, connections2 in connections.iteritems():
+                connector_type1 = factory1.connector_type(connector_type_name1)
+                for guid2, connector_type_name2 in connections2.iteritems():
+                    element2 = self._elements[guid2]
+                    factory_id2 = self._create[guid2]
+                    # Connections are executed in a "From -> To" direction only
+                    # This explicitly ignores the "To -> From" (mirror) 
+                    # connections of every connection pair. 
+                    code_to_connect = connector_type1.code_to_connect(
+                            self._testbed_id, factory_id2, 
+                            connector_type_name2)
+                    if code_to_connect:
+                        code_to_connect(element1, element2)
+
+    def do_configure(self):
+        raise NotImplementedError
+
+    def do_cross_connect(self):
+        for guid, cross_connections in self._cross_connect.iteritems():
+            element = self._elements[guid]
+            factory_id = self._create[guid]
+            factory = self._factories[factory_id]
+            for connector_type_name, cross_connection in \
+                    cross_connections.iteritems():
+                connector_type = factory.connector_type(connector_type_name)
+                (cross_testbed_id, cross_factory_id, 
+                        cross_connector_type_name) = cross_connection
+                code_to_connect = connector_type.code_to_connect(
+                    cross_guid, cross_testbed_id, cross_factory_id, 
+                    cross_conector_type_name)
+                if code_to_connect:
+                    code_to_connect(element, cross_guid)       
+
+    def set(self, time, guid, name, value):
+        raise NotImplementedError
+
+    def get(self, time, guid, name):
+        raise NotImplementedError
+
+    def start(self, time = TIME_NOW):
+        for guid, factory_id in self._create.iteritems():
+            factory = self._factories[factory_id]
+            start_function = factory.start_function
+            if start_function:
+                traces = [] if guid not in self._add_trace else \
+                        self._add_trace[guid]
+                parameters = dict() if guid not in self._set else \
+                        self._set[guid]
+                start_function(self, guid, parameters, traces)
+
+    def action(self, time, guid, action):
+        raise NotImplementedError
+
+    def stop(self, time = TIME_NOW):
+        for guid, factory_id in self._create.iteritems():
+            factory = self._factories[factory_id]
+            stop_function = factory.stop_function
+            if stop_function:
+                traces = [] if guid not in self._add_trace else \
+                        self._add_trace[guid]
+                stop_function(self, guid, traces)
+
+    def status(self, guid):
+          for guid, factory_id in self._create.iteritems():
+            factory = self._factories[factory_id]
+            status_function = factory.status_function
+            if status_function:
+                result = status_function(self, guid)
+
+    def trace(self, guid, trace_id):
+        raise NotImplementedError
+
+    def shutdown(self):
+        raise NotImplementedError
+
+    def get_connected(self, guid, connector_type_name, 
+            other_connector_type_name):
+        """searchs the connected elements for the specific connector_type_name 
+        pair"""
+        if guid not in self._connect:
+            return []
+        # all connections for all connectors for guid
+        all_connections = self._connect[guid]
+        if connector_type_name not in all_connections:
+            return []
+        # all connections for the specific connector
+        connections = all_connections[connector_type_name]
+        specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
+                in connections.iteritems() if \
+                otr_connector_type_name == other_connector_type_name]
+        return specific_connections
+
+    def _get_connection_count(self, guid, connection_type_name):
+        count = 0
+        cross_count = 0
+        if guid in self._connect and connection_type_name in \
+                self._connect[guid]:
+            count = len(self._connect[guid][connection_type_name])
+        if guid in self._cross_connect and connection_type_name in \
+                self._cross_connect[guid]:
+            cross_count = len(self._cross_connect[guid][connection_type_name])
+        return count + cross_count
+
+
index 11c622f..ba1f28c 100644 (file)
@@ -3,10 +3,12 @@
 
 from constants import TESTBED_ID
 from nepi.core import execute
+from nepi.core import execute_impl
 from nepi.core.attributes import Attribute
 from nepi.core.metadata import Metadata
 from nepi.util import validation
 from nepi.util.constants import AF_INET, AF_INET6
+import os
 
 class TestbedConfiguration(execute.TestbedConfiguration):
     def __init__(self):
@@ -14,274 +16,66 @@ class TestbedConfiguration(execute.TestbedConfiguration):
         self.add_attribute("EnableDebug", "Enable netns debug output", 
                 Attribute.BOOL, False, None, None, False, validation.is_bool)
 
-class TestbedInstance(execute.TestbedInstance):
+class TestbedInstance(execute_impl.TestbedInstance):
     def __init__(self, testbed_version, configuration):
-        self._configuration = configuration
-        self._testbed_id = TESTBED_ID
-        self._testbed_version = testbed_version
-        self._factories = dict()
-        self._elements = dict()
-        self._create = dict()
-        self._set = dict()
-        self._connect = dict()
-        self._cross_connect = dict()
-        self._add_trace = dict()
-        self._add_address = dict()
-        self._add_route = dict()        
-
-        self._metadata = Metadata(self._testbed_id, self._testbed_version)
-        for factory in self._metadata.build_execute_factories():
-            self._factories[factory.factory_id] = factory
-
+        super(TestbedInstance, self).__init__(TESTBED_ID, testbed_version, 
+                configuration)
         self._netns = self._load_netns_module(configuration)
-
-    @property
-    def elements(self):
-        return self._elements
+        self._traces = dict()
 
     @property
     def netns(self):
         return self._netns
 
-    def create(self, guid, factory_id):
-        if factory_id not in self._factories:
-            raise RuntimeError("Invalid element type %s for Netns version %s" %
-                    (factory_id, self._testbed_version))
-        if guid in self._create:
-            raise RuntimeError("Cannot add elements with the same guid: %d" %
-                    guid)
-        self._create[guid] = factory_id
-
-    def create_set(self, guid, name, value):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.has_attribute(name):
-            raise RuntimeError("Invalid attribute %s for element type %s" %
-                    (name, factory_id))
-        factory.set_attribute_value(name, value)
-        if guid not in self._set:
-            self._set[guid] = dict()
-        self._set[guid][name] = value
-       
-    def connect(self, guid1, connector_type_name1, guid2, 
-            connector_type_name2):
-        factory_id1 = self._create[guid1]
-        factory_id2 = self._create[guid2]
-        count = self._get_connection_count(guid1, connector_type_name1)
-        factory1 = self._factories[factory_id1]
-        connector_type = factory1.connector_type(connector_type_name1)
-        connector_type.can_connect(self._testbed_id, factory_id2, 
-                connector_type_name2, count)
-        if not guid1 in self._connect:
-            self._connect[guid1] = dict()
-        if not connector_type_name1 in self._connect[guid1]:
-             self._connect[guid1][connector_type_name1] = dict()
-        self._connect[guid1][connector_type_name1][guid2] = \
-               connector_type_name2
-        if not guid2 in self._connect:
-            self._connect[guid2] = dict()
-        if not connector_type_name2 in self._connect[guid2]:
-             self._connect[guid2][connector_type_name2] = dict()
-        self._connect[guid2][connector_type_name2][guid1] = \
-                connector_type_name1
-
-    def cross_connect(self, guid, connector_type_name, cross_guid, 
-            cross_testbed_id, cross_factory_id, cross_connector_type_name):
-        factory_id = self._create[guid]
-        count = self._get_connection_count(guid, connector_type_name)
-        factory = self._factories[factory_id]
-        connector_type = factory.connector_type(connector_type_name)
-        connector_type.can_connect(cross_testbed_id, cross_factory_id, 
-                cross_connector_type_name, count, must_cross = True)
-        if not guid in self._connect:
-            self._cross_connect[guid] = dict()
-        if not connector_type_name in self._cross_connect[guid]:
-             self._cross_connect[guid][connector_type_name] = dict()
-        self._cross_connect[guid][connector_type_name] = \
-                (cross_guid, cross_testbed_id, cross_factory_id, 
-                        cross_connector_type_name)
-
-    def add_trace(self, guid, trace_id):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not trace_id in factory_traces:
-            raise RuntimeError("Element type %s doesn't support trace %s" %
-                    (factory_id, trace_id))
-        if not guid in self._add_trace:
-            self._add_trace[guid] = list()
-        self._add_trace[guid].append(trace_id)
-
-    def add_adddress(self, guid, family, address, netprefix, broadcast):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.allow_addresses:
-            raise RuntimeError("Element type %s doesn't support addresses" %
-                    factory_id)
-        max_addresses = factory.get_attribute_value("MaxAddresses")
-        if guid in self._add_address:
-            count_addresses = len(self._add_address[guid])
-            if max_addresses == count_addresses:
-                raise RuntimeError("Element guid %d of type %s can't accept \
-                        more addresses" % (guid, family_id))
-        else:
-            self._add_address[guid] = list()
-        self._add_address[guid].append((family, address, netprefix, broadcast))
-
-    def add_route(self, guid, family, destination, netprefix, nexthop, 
-            interface):
-        if not guid in self._create:
-            raise RuntimeError("Element guid %d doesn't exist" % guid)
-        factory_id = self._create[guid]
-        factory = self._factories[factory_id]
-        if not factory.allow_routes:
-            raise RuntimeError("Element type %s doesn't support routes" %
-                    factory_id)
-        if not guid in self._add_route:
-            self._add_route[guid] = list()
-        self._add_route[guid].append((family, destination, netprefix, nexthop,
-            interface)) 
-
-    def do_create(self):
-        guids = dict()
-        for guid, factory_id in self._create.iteritems():
-            if not factory_id in guids:
-               guids[factory_id] = list()
-            guids[factory_id].append(guid)
-        for factory_id in self._metadata.factories_order:
-            if factory_id not in guids:
-                continue
-            factory = self._factories[factory_id]
-            for guid in guids[factory_id]:
-                parameters = dict() if guid not in self._set else \
-                        self._set[guid]
-                factory.create_function(self, guid, parameters)
-                element = self._elements[guid]
-                if element:
-                    for name, value in parameters.iteritems():
-                        setattr(element, name, value)
-
-    def do_connect(self):
-        for guid1, connections in self._connect.iteritems():
-            element1 = self._elements[guid1]
-            factory_id1 = self._create[guid1]
-            factory1 = self._factories[factory_id1]
-            for connector_type_name1, connections2 in connections.iteritems():
-                connector_type1 = factory1.connector_type(connector_type_name1)
-                for guid2, connector_type_name2 in connections2.iteritems():
-                    element2 = self._elements[guid2]
-                    factory_id2 = self._create[guid2]
-                    # Connections are executed in a "From -> To" direction only
-                    # This explicitly ignores the "To -> From" (mirror) 
-                    # connections of every connection pair. 
-                    code_to_connect = connector_type1.code_to_connect(
-                            self._testbed_id, factory_id2, 
-                            connector_type_name2)
-                    if code_to_connect:
-                        code_to_connect(element1, element2)
-
     def do_configure(self):
         # TODO: add traces!
-        # add addressess
+        # configure addressess
         for guid, addresses in self._add_address.iteritems():
             element = self._elements[guid]
             for address in addresses:
                 (family, address, netprefix, broadcast) = address
                 if family == AF_INET:
                     element.add_v4_address(address, netprefix)
-        # add routes
+        # configure routes
         for guid, routes in self._add_route.iteritems():
             element = self._elements[guid]
             for route in routes:
-                # TODO: family and interface not used!!!!!
-                (family, destination, netprefix, nexthop, interfaces) = routes
+                (destination, netprefix, nexthop) = route
                 element.add_route(prefix = destination, prefix_len = netprefix,
                         nexthop = nexthop)
 
-    def do_cross_connect(self):
-        for guid, cross_connections in self._cross_connect.iteritems():
-            element = self._elements[guid]
-            factory_id = self._create[guid]
-            factory = self._factories[factory_id]
-            for connector_type_name, cross_connection in \
-                    cross_connections.iteritems():
-                connector_type = factory.connector_type(connector_type_name)
-                (cross_testbed_id, cross_factory_id, 
-                        cross_connector_type_name) = cross_connection
-                code_to_connect = connector_type.code_to_connect(
-                    cross_guid, cross_testbed_id, cross_factory_id, 
-                    cross_conector_type_name)
-                if code_to_connect:
-                    code_to_connect(element, cross_guid)       
-
     def set(self, time, guid, name, value):
         # TODO: take on account schedule time for the task 
         element = self._elements[guid]
-        setattr(element, name, value)
+        if element:
+            setattr(element, name, value)
 
     def get(self, time, guid, name):
-        # TODO: take on account schedule time for the task 
+        # TODO: take on account schedule time for the task
         element = self._elements[guid]
         return getattr(element, name)
 
-    def start(self, time = "0s"):
-        for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            start_function = factory.start_function
-            if start_function:
-                traces = [] if guid not in self._add_trace else \
-                        self._add_trace[guid]
-                parameters = dict() if guid not in self._set else \
-                        self._set[guid]
-                start_function(self, guid, parameters, traces)
-
     def action(self, time, guid, action):
         raise NotImplementedError
 
-    def stop(self, time = "0s"):
-        for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            stop_function = factory.stop_function
-            if stop_function:
-                traces = [] if guid not in self._add_trace else \
-                        self._add_trace[guid]
-                stop_function(self, guid, traces)
-
-    def status(self, guid):
-          for guid, factory_id in self._create.iteritems():
-            factory = self._factories[factory_id]
-            status_function = factory.status_function
-            if status_function:
-                result = status_function(self, guid)
-
     def trace(self, guid, trace_id):
-        raise NotImplementedError
+        f = open(self.trace_filename(guid, trace_id), "r")
+        content = f.read()
+        f.close()
+        return content
 
     def shutdown(self):
+        for trace in self._traces.values():
+            trace.close()
         for element in self._elements.values():
             element.destroy()
 
-    def get_connected(self, guid, connector_type_name, 
-            other_connector_type_name):
-        """searchs the connected elements for the specific connector_type_name 
-        pair"""
-        if guid not in self._connect:
-            return []
-        # all connections for all connectors for guid
-        all_connections = self._connect[guid]
-        if connector_type_name not in all_connections:
-            return []
-        # all connections for the specific connector
-        connections = all_connections[connector_type_name]
-        specific_connections = [otr_guid for otr_guid, otr_connector_type_name \
-                in connections.iteritems() if \
-                otr_connector_type_name == other_connector_type_name]
-        return specific_connections
+    def trace_filename(self, guid, trace_id):
+        # TODO: Need to be defined inside a home!!!! with and experiment id_code
+        return os.path.join(self.home_directory, "%d_%s" % (guid, trace_id))
+
+    def follow_trace(self, trace_id, trace):
+        self._traces[trace_id] = trace
 
     def _load_netns_module(self, configuration):
         # TODO: Do something with the configuration!!!
@@ -294,15 +88,3 @@ class TestbedInstance(execute.TestbedInstance):
             netns_mod.environ.set_log_level(netns_mod.environ.LOG_DEBUG)
         return netns_mod
 
-    def _get_connection_count(self, guid, connection_type_name):
-        count = 0
-        cross_count = 0
-        if guid in self._connect and connection_type_name in \
-                self._connect[guid]:
-            count = len(self._connect[guid][connection_type_name])
-        if guid in self._cross_connect and connection_type_name in \
-                self._cross_connect[guid]:
-            cross_count = len(self._cross_connect[guid][connection_type_name])
-        return count + cross_count
-
-
index e371277..e5c27e9 100644 (file)
@@ -8,7 +8,7 @@ from nepi.util import validation
 from nepi.util.constants import AF_INET
 
 NODE = "Node"
-P2PIFACE = "P2PInterface"
+P2PIFACE = "P2PNodeInterface"
 TAPIFACE = "TapNodeInterface"
 NODEIFACE = "NodeInterface"
 SWITCH = "Switch"
@@ -99,11 +99,13 @@ def start_application(testbed_instance, guid, parameters, traces):
     command = parameters["command"]
     stdout = stderr = None
     if "stdout" in traces:
-        filename = "%d_%s" % (guid, "stdout")
+        filename = testbed_instance.trace_filename(guid, "stdout")
         stdout = open(filename, "wb")
+        testbed_instance.follow_trace("stdout", stdout)
     if "stderr" in traces:
-        filename = "%d_%s" % (guid, "stderr")
+        filename = testbed_instance.trace_filename(guid, "stderr")
         stderr = open(filename, "wb")
+        testbed_instance.follow_trace("stderr", stderr)
 
     node_guid = testbed_instance.get_connected(guid, "node", "apps")
     if len(node_guid) == 0:
@@ -357,11 +359,11 @@ attributes = dict({
 
 traces = dict({
     "stdout": dict({
-                "name": "StdoutTrace",
+                "name": "stdout",
                 "help": "Standard output stream"
               }),
     "stderr": dict({
-                "name": "StderrTrace",
+                "name": "stderr",
                 "help": "Application standard error",
         }) 
     })
diff --git a/src/nepi/util/environ.py b/src/nepi/util/environ.py
new file mode 100644 (file)
index 0000000..ced1239
--- /dev/null
@@ -0,0 +1,59 @@
+# vim:ts=4:sw=4:et:ai:sts=4
+
+import os, subprocess
+
+__all__ =  ["python", "ssh_path"]
+__all__ += ["rsh", "tcpdump_path", "sshd_path"]
+__all__ += ["execute", "backticks"]
+
+def find_bin(name, extra_path = None):
+    search = []
+    if "PATH" in os.environ:
+        search += os.environ["PATH"].split(":")
+    for pref in ("/", "/usr/", "/usr/local/"):
+        for d in ("bin", "sbin"):
+            search.append(pref + d)
+    if extra_path:
+        search += extra_path
+
+    for d in search:
+            try:
+                os.stat(d + "/" + name)
+                return d + "/" + name
+            except OSError, e:
+                if e.errno != os.errno.ENOENT:
+                    raise
+    return None
+
+def find_bin_or_die(name, extra_path = None):
+    r = find_bin(name)
+    if not r:
+        raise RuntimeError(("Cannot find `%s' command, impossible to " +
+                "continue.") % name)
+    return r
+
+ssh_path = find_bin_or_die("ssh")
+python_path = find_bin_or_die("python")
+
+# Optional tools
+rsh_path = find_bin("rsh")
+tcpdump_path = find_bin("tcpdump")
+sshd_path = find_bin("sshd")
+
+def execute(cmd):
+    # FIXME: create a global debug variable
+    #print "[pid %d]" % os.getpid(), " ".join(cmd)
+    null = open("/dev/null", "r+")
+    p = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
+    out, err = p.communicate()
+    if p.returncode != 0:
+        raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
+
+def backticks(cmd):
+    p = subprocess.Popen(cmd, stdout = subprocess.PIPE,
+            stderr = subprocess.PIPE)
+    out, err = p.communicate()
+    if p.returncode != 0:
+        raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
+    return out
+
index 925f392..291f891 100644 (file)
@@ -22,7 +22,6 @@ class XmlExperimentParser(ExperimentParser):
                 self.box_data_to_xml(doc, elements_tags, guid, data)
         doc.appendChild(exp_tag)
         xml = doc.toprettyxml(indent="    ", encoding="UTF-8")
-        print xml
         return xml
 
     def testbed_data_to_xml(self, doc, parent_tag, guid, data):
@@ -115,7 +114,7 @@ class XmlExperimentParser(ExperimentParser):
 
     def routes_data_to_xml(self, doc, parent_tag, guid, data):
         routes_tag = doc.createElement("routes") 
-        for (family, destination, netprefix, nexthop, interface) \
+        for (family, destination, netprefix, nexthop) \
                 in data.get_route_data(guid):
             route_tag = doc.createElement("route") 
             routes_tag.appendChild(route_tag)
@@ -123,7 +122,6 @@ class XmlExperimentParser(ExperimentParser):
             route_tag.setAttribute("Destination", str(destination))
             route_tag.setAttribute("NetPrefix", str(netprefix))
             route_tag.setAttribute("NextHop", str(nexthop))
-            route_tag.setAttribute("Interface", str(interface))
         if routes_tag.hasChildNodes():
             parent_tag.appendChild(routes_tag)
 
@@ -269,9 +267,8 @@ class XmlExperimentParser(ExperimentParser):
                 destination = str(route_tag.getAttribute("Destination"))
                 netprefix = int(route_tag.getAttribute("NetPrefix"))
                 nexthop = str(route_tag.getAttribute("NextHop"))
-                interface = str(route_tag.getAttribute("Interface"))
                 data.add_route_data(guid, family, destination, netprefix, 
-                        nexthop, interface)
+                        nexthop)
 
     def connections_data_from_xml(self, tag, guid, data):
         connections_tag_list = tag.getElementsByTagName("connections")
index 1cb17ed..64a7751 100644 (file)
@@ -85,8 +85,7 @@ class ExperimentData(object):
             address_data["Broadcast"] = broadcast
         addresses_data.append(address_data)
 
-    def add_route_data(self, guid, family, destination, netprefix, nexthop, 
-            interface):
+    def add_route_data(self, guid, family, destination, netprefix, nexthop): 
         data = self.data[guid]
         if not "routes" in data:
             data["routes"] = list()
@@ -95,8 +94,7 @@ class ExperimentData(object):
             "Family": family, 
             "Destination": destination,
             "NetPrefix": netprefix, 
-            "NextHop": nexthop, 
-            "Interface": Interface
+            "NextHop": nexthop 
             })
         routes_data.append(route_data)
 
@@ -175,9 +173,8 @@ class ExperimentData(object):
         return [(data["Family"],
                  data["Destination"],
                  data["NetPrefix"],
-                 data["NextHop"],
-                 data["Interface"]) \
-                 for data in routes_data]
+                 data["NextHop"]) \
+                         for data in routes_data]
 
 class ExperimentParser(object):
     def to_data(self, experiment_description):
@@ -253,9 +250,7 @@ class ExperimentParser(object):
              destination = route.get_attribute_value("Destination")
              netprefix = route.get_attribute_value("NetPrefix")
              nexthop = route.get_attribute_value("NextHop")
-             interface = route.get_attribute_value("Interface")
-             data.add_route_data(guid, family, destination, netprefix, nexthop, 
-                    interface)
+             data.add_route_data(guid, family, destination, netprefix, nexthop)
 
     def from_data(self, experiment_description, data):
         box_guids = list()
@@ -327,13 +322,12 @@ class ExperimentParser(object):
                 addr.set_attribute_value("Broadcast", broadcast)
 
     def routes_from_data(self, box, data):
-         for (family, destination, netprefix, nexthop, interface) \
-                in data.get_route_data(box.guid):
+         for (family, destination, netprefix, nexthop) \
+                 in data.get_route_data(box.guid):
             addr = box.add_route(family)
             addr.set_attribute_value("Destination", destination)
             addr.set_attribute_value("NetPrefix", netprefix)
             addr.set_attribute_value("NextHop", nexthop)
-            addr.set_attribute_value("Interface", interface)
 
     def connections_from_data(self, experiment_description, guids, data):
         for guid in guids:
diff --git a/test/testbeds/netns/design.py b/test/testbeds/netns/design.py
new file mode 100755 (executable)
index 0000000..36f2453
--- /dev/null
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from nepi.core.design import ExperimentDescription, FactoriesProvider
+from nepi.core.design import AF_INET
+import os
+import shutil
+import test_util
+import unittest
+import uuid
+
+class NetnsDesignTestCase(unittest.TestCase):
+    @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
+    def test_design_if(self):
+        exp_desc = ExperimentDescription()
+        testbed_version = "01"
+        testbed_id = "netns"
+        netns_provider = FactoriesProvider(testbed_id, testbed_version)
+        netns_desc = exp_desc.add_testbed_description(netns_provider)
+        node1 = netns_desc.create("Node")
+        node2 = netns_desc.create("Node")
+        iface1 = netns_desc.create("NodeInterface")
+        iface1.set_attribute_value("up", True)
+        node1.connector("devs").connect(iface1.connector("node"))
+        ip1 = iface1.add_address()
+        ip1.set_attribute_value("Address", "10.0.0.1")
+        iface2 = netns_desc.create("NodeInterface")
+        iface2.set_attribute_value("up", True)
+        node2.connector("devs").connect(iface2.connector("node"))
+        ip2 = iface2.add_address()
+        ip2.set_attribute_value("Address", "10.0.0.2")
+        switch = netns_desc.create("Switch")
+        switch.set_attribute_value("up", True)
+        iface1.connector("switch").connect(switch.connector("devs"))
+        iface2.connector("switch").connect(switch.connector("devs"))
+        app = netns_desc.create("Application")
+        app.set_attribute_value("command", "ping -qc10 10.0.0.2")
+        app.connector("node").connect(node1.connector("apps"))
+        xml = exp_desc.to_xml()
+        exp_desc2 = ExperimentDescription()
+        exp_desc2.from_xml(xml)
+        xml2 = exp_desc2.to_xml()
+        self.assertTrue(xml == xml2)
+        
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/testbeds/netns/execute.py b/test/testbeds/netns/execute.py
new file mode 100755 (executable)
index 0000000..208457a
--- /dev/null
@@ -0,0 +1,164 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import getpass
+from nepi.core.design import AF_INET
+from nepi.testbeds import netns
+import os
+import shutil
+import test_util
+import time
+import unittest
+import uuid
+
+class NetnsExecuteTestCase(unittest.TestCase):
+    def setUp(self):
+        self._home_dir = os.path.join(os.getenv("HOME"), ".nepi", 
+                str(uuid.uuid1()))
+        os.makedirs(self._home_dir)
+
+    @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
+    def test_run_ping_if(self):
+        user = getpass.getuser()
+        testbed_version = "01"
+        config = netns.TestbedConfiguration()
+        config.set_attribute_value("HomeDirectory", self._home_dir)
+        instance = netns.TestbedInstance(testbed_version, config)
+        instance.create(2, "Node")
+        instance.create(3, "Node")
+        instance.create(4, "NodeInterface")
+        instance.create_set(4, "up", True)
+        instance.connect(2, "devs", 4, "node")
+        instance.add_adddress(4, AF_INET, "10.0.0.1", 24, None)
+        instance.create(5, "NodeInterface")
+        instance.create_set(5, "up", True)
+        instance.connect(3, "devs", 5, "node")
+        instance.add_adddress(5, AF_INET, "10.0.0.2", 24, None)
+        instance.create(6, "Switch")
+        instance.create_set(6, "up", True)
+        instance.connect(4, "switch", 6, "devs")
+        instance.connect(5, "switch", 6, "devs")
+        instance.create(7, "Application")
+        instance.create_set(7, "command", "ping -qc1 10.0.0.2")
+        instance.create_set(7, "user", user)
+        instance.add_trace(7, "stdout")
+        instance.connect(7, "node", 2, "apps")
+
+        instance.do_create()
+        instance.do_connect()
+        instance.do_configure()
+        instance.start()
+        time.sleep(2)
+        ping_result = instance.trace(7, "stdout")
+        comp_result = """PING 10.0.0.2 (10.0.0.2) 56(84) bytes of data.
+
+--- 10.0.0.2 ping statistics ---
+1 packets transmitted, 1 received, 0% packet loss, time 0ms
+"""
+        self.assertTrue(ping_result.startswith(comp_result))
+        instance.stop()
+        instance.shutdown()
+
+    @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
+    def test_run_ping_p2pif(self):
+        user = getpass.getuser()
+        testbed_version = "01"
+        config = netns.TestbedConfiguration()
+        config.set_attribute_value("HomeDirectory", self._home_dir)
+        instance = netns.TestbedInstance(testbed_version, config)
+        instance.create(2, "Node")
+        instance.create(3, "Node")
+        instance.create(4, "P2PNodeInterface")
+        instance.create_set(4, "up", True)
+        instance.connect(2, "devs", 4, "node")
+        instance.add_adddress(4, AF_INET, "10.0.0.1", 24, None)
+        instance.create(5, "P2PNodeInterface")
+        instance.create_set(5, "up", True)
+        instance.connect(3, "devs", 5, "node")
+        instance.add_adddress(5, AF_INET, "10.0.0.2", 24, None)
+        instance.connect(4, "p2p", 5, "p2p")
+        instance.create(6, "Application")
+        instance.create_set(6, "command", "ping -qc1 10.0.0.2")
+        instance.create_set(6, "user", user)
+        instance.add_trace(6, "stdout")
+        instance.connect(6, "node", 2, "apps")
+
+        instance.do_create()
+        instance.do_connect()
+        instance.do_configure()
+        instance.start()
+        time.sleep(2)
+        ping_result = instance.trace(6, "stdout")
+        comp_result = """PING 10.0.0.2 (10.0.0.2) 56(84) bytes of data.
+
+--- 10.0.0.2 ping statistics ---
+1 packets transmitted, 1 received, 0% packet loss, time 0ms
+"""
+        self.assertTrue(ping_result.startswith(comp_result))
+        instance.stop()
+        instance.shutdown()
+
+    @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
+    def test_run_ping_routing(self):
+        user = getpass.getuser()
+        testbed_version = "01"
+        config = netns.TestbedConfiguration()
+        config.set_attribute_value("HomeDirectory", self._home_dir)
+        instance = netns.TestbedInstance(testbed_version, config)
+        instance.create(2, "Node")
+        instance.create(3, "Node")
+        instance.create(4, "Node")
+        instance.create(5, "NodeInterface")
+        instance.create_set(5, "up", True)
+        instance.connect(2, "devs", 5, "node")
+        instance.add_adddress(5, AF_INET, "10.0.0.1", 24, None)
+        instance.create(6, "NodeInterface")
+        instance.create_set(6, "up", True)
+        instance.connect(3, "devs", 6, "node")
+        instance.add_adddress(6, AF_INET, "10.0.0.2", 24, None)
+        instance.create(7, "NodeInterface")
+        instance.create_set(7, "up", True)
+        instance.connect(3, "devs", 7, "node")
+        instance.add_adddress(7, AF_INET, "10.0.1.1", 24, None)
+        instance.create(8, "NodeInterface")
+        instance.create_set(8, "up", True)
+        instance.connect(4, "devs", 8, "node")
+        instance.add_adddress(8, AF_INET, "10.0.1.2", 24, None)
+        instance.create(9, "Switch")
+        instance.create_set(9, "up", True)
+        instance.connect(5, "switch", 9, "devs")
+        instance.connect(6, "switch", 9, "devs")
+        instance.create(10, "Switch")
+        instance.create_set(10, "up", True)
+        instance.connect(7, "switch", 10, "devs")
+        instance.connect(8, "switch", 10, "devs")
+        instance.create(11, "Application")
+        instance.create_set(11, "command", "ping -qc1 10.0.1.2")
+        instance.create_set(11, "user", user)
+        instance.add_trace(11, "stdout")
+        instance.connect(11, "node", 2, "apps")
+
+        instance.add_route(2, "10.0.1.0", 24, "10.0.0.2")
+        instance.add_route(4, "10.0.0.0", 24, "10.0.1.1")
+
+        instance.do_create()
+        instance.do_connect()
+        instance.do_configure()
+        instance.start()
+        time.sleep(2)
+        ping_result = instance.trace(11, "stdout")
+        comp_result = """PING 10.0.0.2 (10.0.0.2) 56(84) bytes of data.
+
+--- 10.0.0.2 ping statistics ---
+1 packets transmitted, 1 received, 0% packet loss, time 0ms
+"""
+        self.assertTrue(ping_result.startswith(comp_result))
+        instance.stop()
+        instance.shutdown()
+        
+    def tearDown(self):
+        shutil.rmtree(self._home_dir)
+
+if __name__ == '__main__':
+    unittest.main()
+
diff --git a/test/util/test_util.py b/test/util/test_util.py
new file mode 100644 (file)
index 0000000..0181c49
--- /dev/null
@@ -0,0 +1,104 @@
+#!/usr/bin/env python
+# vim:ts=4:sw=4:et:ai:sts=4
+
+import sys
+import nepi.util.environ
+
+# Unittest from Python 2.6 doesn't have these decorators
+def _bannerwrap(f, text):
+    name = f.__name__
+    def banner(*args, **kwargs):
+        sys.stderr.write("*** WARNING: Skipping test %s: `%s'\n" %
+                (name, text))
+        return None
+    return banner
+def skip(text):
+    return lambda f: _bannerwrap(f, text)
+def skipUnless(cond, text):
+    return (lambda f: _bannerwrap(f, text)) if not cond else lambda f: f
+def skipIf(cond, text):
+    return (lambda f: _bannerwrap(f, text)) if cond else lambda f: f
+
+# SSH stuff
+
+import os, os.path, re, signal, shutil, socket, subprocess, tempfile
+def gen_ssh_keypair(filename):
+    ssh_keygen = nepi.util.environ.find_bin_or_die("ssh-keygen")
+    args = [ssh_keygen, '-q', '-N', '', '-f', filename]
+    assert subprocess.Popen(args).wait() == 0
+    return filename, "%s.pub" % filename
+
+def add_key_to_agent(filename):
+    ssh_add = nepi.util.environ.find_bin_or_die("ssh-add")
+    args = [ssh_add, filename]
+    null = file("/dev/null", "w")
+    assert subprocess.Popen(args, stderr = null).wait() == 0
+    null.close()
+
+def get_free_port():
+    s = socket.socket()
+    s.bind(("127.0.0.1", 0))
+    port = s.getsockname()[1]
+    return port
+
+_SSH_CONF = """ListenAddress 127.0.0.1:%d
+Protocol 2
+HostKey %s
+UsePrivilegeSeparation no
+PubkeyAuthentication yes
+PasswordAuthentication no
+AuthorizedKeysFile %s
+UsePAM no
+AllowAgentForwarding yes
+PermitRootLogin yes
+StrictModes no
+PermitUserEnvironment yes
+"""
+
+def gen_sshd_config(filename, port, server_key, auth_keys):
+    conf = open(filename, "w")
+    text = _SSH_CONF % (port, server_key, auth_keys)
+    conf.write(text)
+    conf.close()
+    return filename
+
+def gen_auth_keys(pubkey, output, environ):
+    #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup
+    opts = []
+    for k, v in environ.items():
+        opts.append('environment="%s=%s"' % (k, v))
+
+    lines = file(pubkey).readlines()
+    pubkey = lines[0].split()[0:2]
+    out = file(output, "w")
+    out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
+    out.close()
+    return output
+
+def start_ssh_agent():
+    ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
+    proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE)
+    (out, foo) = proc.communicate()
+    assert proc.returncode == 0
+    d = {}
+    for l in out.split("\n"):
+        match = re.search("^(\w+)=([^ ;]+);.*", l)
+        if not match:
+            continue
+        k, v = match.groups()
+        os.environ[k] = v
+        d[k] = v
+    return d
+
+def stop_ssh_agent(data):
+    # No need to gather the pid, ssh-agent knows how to kill itself; after we
+    # had set up the environment
+    ssh_agent = nepi.util.environ.find_bin_or_die("ssh-agent")
+    null = file("/dev/null", "w")
+    proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
+    null.close()
+    assert proc.wait() == 0
+    for k in data:
+        del os.environ[k]
+
+