Use repr for showing invalid values, helps detecting type mismatches
[nepi.git] / src / nepi / core / testbed_impl.py
index 336da50..fd947c0 100644 (file)
@@ -2,15 +2,17 @@
 # -*- coding: utf-8 -*-
 
 from nepi.core import execute
-from nepi.core.metadata import Metadata
+from nepi.core.metadata import Metadata, Parallel
 from nepi.util import validation
 from nepi.util.constants import TIME_NOW, \
         ApplicationStatus as AS, \
         TestbedStatus as TS, \
         CONNECTION_DELAY
+from nepi.util.parallel import ParallelRun
 
 import collections
 import copy
+import logging
 
 class TestbedController(execute.TestbedController):
     def __init__(self, testbed_id, testbed_version):
@@ -40,11 +42,17 @@ class TestbedController(execute.TestbedController):
         # testbed element instances
         self._elements = dict()
 
-        self._metadata = Metadata(self._testbed_id, self._testbed_version)
+        self._metadata = Metadata(self._testbed_id)
+        if self._metadata.testbed_version != testbed_version:
+            raise RuntimeError("Bad testbed version on testbed %s. Asked for %s, got %s" % \
+                    (testbed_id, testbed_version, self._metadata.testbed_version))
         for factory in self._metadata.build_factories():
             self._factories[factory.factory_id] = factory
         self._attributes = self._metadata.testbed_attributes()
         self._root_directory = None
+        
+        # Logging
+        self._logger = logging.getLogger("nepi.core.testbed_impl")
     
     @property
     def root_directory(self):
@@ -209,25 +217,68 @@ class TestbedController(execute.TestbedController):
         self._status = TS.STATUS_CONNECTED
 
     def _do_in_factory_order(self, action, order, postaction = None, poststep = None):
+        logger = self._logger
+        
         guids = collections.defaultdict(list)
         # order guids (elements) according to factory_id
         for guid, factory_id in self._create.iteritems():
             guids[factory_id].append(guid)
+        
         # configure elements following the factory_id order
         for factory_id in order:
+            # Create a parallel runner if we're given a Parallel() wrapper
+            runner = None
+            if isinstance(factory_id, Parallel):
+                runner = ParallelRun(factory_id.maxthreads)
+                factory_id = factory_id.factory
+            
             # omit the factories that have no element to create
             if factory_id not in guids:
                 continue
+            
+            # configure action
             factory = self._factories[factory_id]
-            if not getattr(factory, action):
+            if isinstance(action, basestring) and not getattr(factory, action):
                 continue
-            for guid in guids[factory_id]:
-                getattr(factory, action)(self, guid)
+            def perform_action(guid):
+                if isinstance(action, basestring):
+                    getattr(factory, action)(self, guid)
+                else:
+                    action(self, guid)
                 if postaction:
                     postaction(self, guid)
+
+            # perform the action on all elements, in parallel if so requested
+            if runner:
+                logger.debug("Starting parallel %s", action)
+                runner.start()
+
+            for guid in guids[factory_id]:
+                if runner:
+                    logger.debug("Scheduling %s on %s", action, guid)
+                    runner.put(perform_action, guid)
+                else:
+                    logger.debug("Performing %s on %s", action, guid)
+                    perform_action(guid)
+
+            # sync
+            if runner:
+                runner.sync()
+            
+            # post hook
             if poststep:
                 for guid in guids[factory_id]:
-                    poststep(self, guid)
+                    if runner:
+                        logger.debug("Scheduling post-%s on %s", action, guid)
+                        runner.put(poststep, self, guid)
+                    else:
+                        logger.debug("Performing post-%s on %s", action, guid)
+                        poststep(self, guid)
+
+            # sync
+            if runner:
+                runner.join()
+                logger.debug("Finished parallel %s", action)
 
     @staticmethod
     def do_poststep_preconfigure(self, guid):
@@ -505,7 +556,7 @@ class TestbedController(execute.TestbedController):
 
     def _validate_testbed_value(self, name, value):
         if not self._attributes.is_attribute_value_valid(name, value):
-            raise AttributeError("Invalid value %s for testbed attribute %s" % \
+            raise AttributeError("Invalid value %r for testbed attribute %s" % \
                 (value, name))
 
     def _validate_box_attribute(self, guid, name):
@@ -517,7 +568,7 @@ class TestbedController(execute.TestbedController):
     def _validate_box_value(self, guid, name, value):
         factory = self._get_factory(guid)
         if not factory.box_attributes.is_attribute_value_valid(name, value):
-            raise AttributeError("Invalid value %s for attribute %s" % \
+            raise AttributeError("Invalid value %r for attribute %s" % \
                 (value, name))
 
     def _validate_factory_attribute(self, guid, name):
@@ -529,7 +580,7 @@ class TestbedController(execute.TestbedController):
     def _validate_factory_value(self, guid, name, value):
         factory = self._get_factory(guid)
         if not factory.is_attribute_value_valid(name, value):
-            raise AttributeError("Invalid value %s for attribute %s" % \
+            raise AttributeError("Invalid value %r for attribute %s" % \
                 (value, name))
 
     def _validate_trace(self, guid, trace_name):