fix import
[sfa.git] / sfa / util / threadmanager.py
1 import threading
2 import traceback
3 import time
4 from Queue import Queue
5 from sfa.util.sfalogging import logger
6
7 def ThreadedMethod(callable, results, errors):
8     """
9     A function decorator that returns a running thread. The thread
10     runs the specified callable and stores the result in the specified
11     results queue
12     """
13     def wrapper(args, kwds):
14         class ThreadInstance(threading.Thread): 
15             def run(self):
16                 try:
17                     results.put(callable(*args, **kwds))
18                 except Exception, e:
19                     logger.log_exc('ThreadManager: Error in thread: ')
20                     errors.put(traceback.format_exc())
21                     
22         thread = ThreadInstance()
23         thread.start()
24         return thread
25     return wrapper
26
27  
28
29 class ThreadManager:
30     """
31     ThreadManager executes a callable in a thread and stores the result
32     in a thread safe queue. 
33     """
34     results = Queue()
35     errors = Queue()
36     threads = []
37
38     def run (self, method, *args, **kwds):
39         """
40         Execute a callable in a separate thread.    
41         """
42         method = ThreadedMethod(method, self.results, self.errors)
43         thread = method(args, kwds)
44         self.threads.append(thread)
45
46     start = run
47
48     def join(self):
49         """
50         Wait for all threads to complete  
51         """
52         for thread in self.threads:
53             thread.join()
54
55     def get_results(self, lenient=True):
56         """
57         Return a list of all the results so far. Blocks until 
58         all threads are finished. 
59         If lienent is set to false the error queue will be checked before 
60         the response is returned. If there are errors in the queue an SFA Fault will 
61         be raised.   
62         """
63         self.join()
64         results = []
65         if not lenient:
66             errors = self.get_errors()
67             if errors: 
68                 raise Exception(errors[0])
69
70         while not self.results.empty():
71             results.append(self.results.get())  
72         return results
73
74     def get_errors(self):
75         """
76         Return a list of all errors. Blocks untill all threads are finished
77         """
78         self.join()
79         errors = []
80         while not self.errors.empty():
81             errors.append(self.errors.get())
82         return errors
83
84     def get_return_value(self):
85         """
86         Get the value that should be returuned to the client. If there are errors then the
87         first error is returned. If there are no errors, then the first result is returned  
88         """
89     
90            
91 if __name__ == '__main__':
92
93     def f(name, n, sleep=1):
94         nums = []
95         for i in range(n, n+5):
96             print "%s: %s" % (name, i)
97             nums.append(i)
98             time.sleep(sleep)
99         return nums
100     def e(name, n, sleep=1):
101         nums = []
102         for i in range(n, n+3) + ['n', 'b']:
103             print "%s: 1 + %s:" % (name, i)
104             nums.append(i + 1)
105             time.sleep(sleep)
106         return nums      
107
108     threads = ThreadManager()
109     threads.run(f, "Thread1", 10, 2)
110     threads.run(f, "Thread2", -10, 1)
111     threads.run(e, "Thread3", 19, 1)
112
113     #results = threads.get_results()
114     #errors = threads.get_errors()
115     #print "Results:", results
116     #print "Errors:", errors
117     results_xlenient = threads.get_results(lenient=False)
118