Fixing ec unit tests
[nepi.git] / src / neco / execution / ec.py
index 429f5ca..0b15c3a 100644 (file)
@@ -2,15 +2,16 @@ import logging
 import os
 import sys
 import time
+import threading
 
 from neco.util import guid
 from neco.util.timefuncs import strfnow, strfdiff, strfvalid 
 from neco.execution.resource import ResourceFactory
-from neco.execution.scheduler import HeapScheduler, Task
+from neco.execution.scheduler import HeapScheduler, Task, TaskStatus
 from neco.util.parallel import ParallelRun
 
 class ExperimentController(object):
-    def __init__(self, root_dir = "/tmp", loglevel = 'error'):
+    def __init__(self, root_dir = "/tmp", loglevel = 'error'): 
         super(ExperimentController, self).__init__()
         # root directory to store files
         self._root_dir = root_dir
@@ -24,6 +25,9 @@ class ExperimentController(object):
         # Scheduler
         self._scheduler = HeapScheduler()
 
+        # Tasks
+        self._tasks = dict()
+
         # Event processing thread
         self._stop = False
         self._cond = threading.Condition()
@@ -34,7 +38,10 @@ class ExperimentController(object):
         self._logger = logging.getLogger("neco.execution.ec")
         self._logger.setLevel(getattr(logging, loglevel.upper()))
 
-    def resource(self, guid):
+    def get_task(self, tid):
+        return self._tasks.get(tid)
+
+    def get_resource(self, guid):
         return self._resources.get(guid)
 
     @property
@@ -54,26 +61,26 @@ class ExperimentController(object):
         return guid
 
     def get_attributes(self, guid):
-        rm = self._resources[guid]
+        rm = self.get_resource(guid)
         return rm.get_attributes()
 
     def get_filters(self, guid):
-        rm = self._resources[guid]
+        rm = self.get_resource(guid)
         return rm.get_filters()
 
     def register_connection(self, guid1, guid2):
-        rm1 = self._resources[guid1]
-        rm2 = self._resources[guid2]
+        rm1 = self.get_resource(guid1)
+        rm2 = self.get_resource(guid2)
 
         rm1.connect(guid2)
         rm2.connect(guid1)
 
     def discover_resource(self, guid, filters):
-        rm = self._resources[guid]
+        rm = self.get_resource(guid)
         return rm.discover(filters)
 
     def provision_resource(self, guid, filters):
-        rm = self._resources[guid]
+        rm = self.get_resource(guid)
         return rm.provision(filters)
 
     def register_start(self, group1, time, after_status, group2):
@@ -84,7 +91,7 @@ class ExperimentController(object):
 
         for guid1 in group1:
             for guid2 in group2:
-                rm = self._resources(guid1)
+                rm = self.get_resource(guid)
                 rm.start_after(time, after_status, guid2)
 
     def register_stop(self, group1, time, after_status, group2):
@@ -95,7 +102,7 @@ class ExperimentController(object):
 
         for guid1 in group1:
             for guid2 in group2:
-                rm = self._resources(guid1)
+                rm = self.get_resource(guid)
                 rm.stop_after(time, after_status, guid2)
 
     def register_set(self, name, value, group1, time, after_status, group2):
@@ -106,23 +113,23 @@ class ExperimentController(object):
 
         for guid1 in group1:
             for guid2 in group2:
-                rm = self._resources(guid1)
+                rm = self.get_resource(guid)
                 rm.set_after(name, value, time, after_status, guid2)
 
     def get(self, guid, name):
-        rm = self._resources(guid)
+        rm = self.get_resource(guid)
         return rm.get(name)
 
     def set(self, guid, name, value):
-        rm = self._resources(guid)
+        rm = self.get_resource(guid)
         return rm.set(name, value)
 
     def status(self, guid):
-        rm = self._resources(guid)
+        rm = self.get_resource(guid)
         return rm.status()
 
     def stop(self, guid):
-        rm = self._resources(guid)
+        rm = self.get_resource(guid)
         return rm.stop()
 
     def deploy(self, group = None, start_when_all_ready = True):
@@ -131,7 +138,7 @@ class ExperimentController(object):
 
         threads = []
         for guid in group:
-            rm = self._resources(guid1)
+            rm = self.get_resource(guid)
 
             kwargs = {'target': rm.deploy}
             if start_when_all_ready:
@@ -152,7 +159,7 @@ class ExperimentController(object):
 
         threads = []
         for guid in group:
-            rm = self._resources(guid1)
+            rm = self.get_resource(guid)
             thread = threading.Thread(target=rm.release)
             threads.append(thread)
             thread.start()
@@ -161,10 +168,16 @@ class ExperimentController(object):
             thread.join()
 
     def shutdown(self):
-        self._stop = False
         self.release()
+        
+        self._stop = True
+        self._cond.acquire()
+        self._cond.notify()
+        self._cond.release()
+        if self._thread.is_alive():
+           self._thread.join()
 
-    def schedule(self, date, callback):
+    def schedule(self, date, callback, track = False):
         """
             date    string containing execution time for the task.
                     It can be expressed as an absolute time, using
@@ -174,17 +187,24 @@ class ExperimentController(object):
             callback    code to be executed for the task. Must be a
                         Python function, and receives args and kwargs
                         as arguments.
+
+            track   if set to True, the task will be retrivable with
+                    the get_task() method
         """
         timestamp = strfvalid(date)
         
         task = Task(timestamp, callback)
         task = self._scheduler.schedule(task)
 
+        if track:
+            self._tasks[task.id] = task
+  
         # Notify condition to wake up the processing thread
         self._cond.acquire()
         self._cond.notify()
         self._cond.release()
-        return task
+
+        return task.id
      
     def _process(self):
         runner = ParallelRun(maxthreads = 50)
@@ -217,9 +237,23 @@ class ExperimentController(object):
                         self._cond.release()
                     else:
                         # Process tasks in parallel
-                        runner.put(task.callback)
+                        runner.put(self._execute, task)
         except:  
             import traceback
             err = traceback.format_exc()
             self._logger.error("Error while processing tasks in the EC: %s" % err)
+
+    def _execute(self, task):
+        # Invoke callback
+        task.status = TaskStatus.DONE
+
+        try:
+            task.result = task.callback()
+        except:
+            import traceback
+            err = traceback.format_exc()
+            self._logger.error("Error while executing event: %s" % err)
+
+            task.result = err
+            task.status = TaskStatus.ERROR
+