refactored remove_slivers()
[sfa.git] / sfa / util / threadmanager.py
old mode 100755 (executable)
new mode 100644 (file)
index 3ec415f..b47b818
@@ -2,6 +2,7 @@ import threading
 import traceback
 import time
 from Queue import Queue
+from sfa.util.sfalogging import logger
 
 def ThreadedMethod(callable, results, errors):
     """
@@ -15,6 +16,7 @@ def ThreadedMethod(callable, results, errors):
                 try:
                     results.put(callable(*args, **kwds))
                 except Exception, e:
+                    logger.log_exc('ThreadManager: Error in thread: ')
                     errors.put(traceback.format_exc())
                     
         thread = ThreadInstance()
@@ -29,9 +31,11 @@ class ThreadManager:
     ThreadManager executes a callable in a thread and stores the result
     in a thread safe queue. 
     """
-    results = Queue()
-    errors = Queue()
-    threads = []
+
+    def __init__(self):
+        self.results = Queue()
+        self.errors = Queue()
+        self.threads = []
 
     def run (self, method, *args, **kwds):
         """
@@ -50,13 +54,21 @@ class ThreadManager:
         for thread in self.threads:
             thread.join()
 
-    def get_results(self):
+    def get_results(self, lenient=True):
         """
         Return a list of all the results so far. Blocks until 
         all threads are finished. 
+        If lienent is set to false the error queue will be checked before 
+        the response is returned. If there are errors in the queue an SFA Fault will 
+        be raised.   
         """
         self.join()
         results = []
+        if not lenient:
+            errors = self.get_errors()
+            if errors: 
+                raise Exception(errors[0])
+
         while not self.results.empty():
             results.append(self.results.get())  
         return results
@@ -70,6 +82,12 @@ class ThreadManager:
         while not self.errors.empty():
             errors.append(self.errors.get())
         return errors
+
+    def get_return_value(self):
+        """
+        Get the value that should be returuned to the client. If there are errors then the
+        first error is returned. If there are no errors, then the first result is returned  
+        """
     
            
 if __name__ == '__main__':
@@ -94,7 +112,9 @@ if __name__ == '__main__':
     threads.run(f, "Thread2", -10, 1)
     threads.run(e, "Thread3", 19, 1)
 
-    results = threads.get_results()
-    errors = threads.get_errors()
-    print "Results:", results
-    print "Errors:", errors
+    #results = threads.get_results()
+    #errors = threads.get_errors()
+    #print "Results:", results
+    #print "Errors:", errors
+    results_xlenient = threads.get_results(lenient=False)
+