added get_errors() method
authorTony Mack <tmack@paris.CS.Princeton.EDU>
Tue, 21 Sep 2010 16:37:06 +0000 (12:37 -0400)
committerTony Mack <tmack@paris.CS.Princeton.EDU>
Tue, 21 Sep 2010 16:37:06 +0000 (12:37 -0400)
sfa/util/threadmanager.py

index 3d5dd03..3ec415f 100755 (executable)
@@ -1,8 +1,9 @@
 import threading
+import traceback
 import time
 from Queue import Queue
 
-def ThreadedMethod(callable, queue):
+def ThreadedMethod(callable, results, errors):
     """
     A function decorator that returns a running thread. The thread
     runs the specified callable and stores the result in the specified
@@ -12,10 +13,10 @@ def ThreadedMethod(callable, queue):
         class ThreadInstance(threading.Thread): 
             def run(self):
                 try:
-                    queue.put(callable(*args, **kwds))
-                except:
-                    # ignore errors
-                    pass
+                    results.put(callable(*args, **kwds))
+                except Exception, e:
+                    errors.put(traceback.format_exc())
+                    
         thread = ThreadInstance()
         thread.start()
         return thread
@@ -28,30 +29,48 @@ class ThreadManager:
     ThreadManager executes a callable in a thread and stores the result
     in a thread safe queue. 
     """
-    queue = Queue()
+    results = Queue()
+    errors = Queue()
     threads = []
 
     def run (self, method, *args, **kwds):
         """
         Execute a callable in a separate thread.    
         """
-        method = ThreadedMethod(method, self.queue)
+        method = ThreadedMethod(method, self.results, self.errors)
         thread = method(args, kwds)
         self.threads.append(thread)
 
     start = run
 
+    def join(self):
+        """
+        Wait for all threads to complete  
+        """
+        for thread in self.threads:
+            thread.join()
+
     def get_results(self):
         """
         Return a list of all the results so far. Blocks until 
         all threads are finished. 
         """
-        for thread in self.threads:
-            thread.join()
+        self.join()
         results = []
-        while not self.queue.empty():
-            results.append(self.queue.get())  
+        while not self.results.empty():
+            results.append(self.results.get())  
         return results
+
+    def get_errors(self):
+        """
+        Return a list of all errors. Blocks untill all threads are finished
+        """
+        self.join()
+        errors = []
+        while not self.errors.empty():
+            errors.append(self.errors.get())
+        return errors
+    
            
 if __name__ == '__main__':
 
@@ -62,10 +81,20 @@ if __name__ == '__main__':
             nums.append(i)
             time.sleep(sleep)
         return nums
+    def e(name, n, sleep=1):
+        nums = []
+        for i in range(n, n+3) + ['n', 'b']:
+            print "%s: 1 + %s:" % (name, i)
+            nums.append(i + 1)
+            time.sleep(sleep)
+        return nums      
 
     threads = ThreadManager()
     threads.run(f, "Thread1", 10, 2)
     threads.run(f, "Thread2", -10, 1)
+    threads.run(e, "Thread3", 19, 1)
 
     results = threads.get_results()
+    errors = threads.get_errors()
     print "Results:", results
+    print "Errors:", errors