added get_errors() method
[sfa.git] / sfa / util / threadmanager.py
1 import threading
2 import traceback
3 import time
4 from Queue import Queue
5
6 def ThreadedMethod(callable, results, errors):
7     """
8     A function decorator that returns a running thread. The thread
9     runs the specified callable and stores the result in the specified
10     results queue
11     """
12     def wrapper(args, kwds):
13         class ThreadInstance(threading.Thread): 
14             def run(self):
15                 try:
16                     results.put(callable(*args, **kwds))
17                 except Exception, e:
18                     errors.put(traceback.format_exc())
19                     
20         thread = ThreadInstance()
21         thread.start()
22         return thread
23     return wrapper
24
25  
26
27 class ThreadManager:
28     """
29     ThreadManager executes a callable in a thread and stores the result
30     in a thread safe queue. 
31     """
32     results = Queue()
33     errors = Queue()
34     threads = []
35
36     def run (self, method, *args, **kwds):
37         """
38         Execute a callable in a separate thread.    
39         """
40         method = ThreadedMethod(method, self.results, self.errors)
41         thread = method(args, kwds)
42         self.threads.append(thread)
43
44     start = run
45
46     def join(self):
47         """
48         Wait for all threads to complete  
49         """
50         for thread in self.threads:
51             thread.join()
52
53     def get_results(self):
54         """
55         Return a list of all the results so far. Blocks until 
56         all threads are finished. 
57         """
58         self.join()
59         results = []
60         while not self.results.empty():
61             results.append(self.results.get())  
62         return results
63
64     def get_errors(self):
65         """
66         Return a list of all errors. Blocks untill all threads are finished
67         """
68         self.join()
69         errors = []
70         while not self.errors.empty():
71             errors.append(self.errors.get())
72         return errors
73     
74            
75 if __name__ == '__main__':
76
77     def f(name, n, sleep=1):
78         nums = []
79         for i in range(n, n+5):
80             print "%s: %s" % (name, i)
81             nums.append(i)
82             time.sleep(sleep)
83         return nums
84     def e(name, n, sleep=1):
85         nums = []
86         for i in range(n, n+3) + ['n', 'b']:
87             print "%s: 1 + %s:" % (name, i)
88             nums.append(i + 1)
89             time.sleep(sleep)
90         return nums      
91
92     threads = ThreadManager()
93     threads.run(f, "Thread1", 10, 2)
94     threads.run(f, "Thread2", -10, 1)
95     threads.run(e, "Thread3", 19, 1)
96
97     results = threads.get_results()
98     errors = threads.get_errors()
99     print "Results:", results
100     print "Errors:", errors