should be able to upgrade from any version
[sfa.git] / sfa / server / threadmanager.py
1 import threading
2 import traceback
3 import time
4 from Queue import Queue
5 from sfa.util.sfalogging import logger
6
7 def ThreadedMethod(callable, results, errors):
8     """
9     A function decorator that returns a running thread. The thread
10     runs the specified callable and stores the result in the specified
11     results queue
12     """
13     def wrapper(args, kwds):
14         class ThreadInstance(threading.Thread): 
15             def run(self):
16                 try:
17                     results.put(callable(*args, **kwds))
18                 except Exception, e:
19                     logger.log_exc('ThreadManager: Error in thread: ')
20                     errors.put(traceback.format_exc())
21                     
22         thread = ThreadInstance()
23         thread.start()
24         return thread
25     return wrapper
26
27  
28
29 class ThreadManager:
30     """
31     ThreadManager executes a callable in a thread and stores the result
32     in a thread safe queue. 
33     """
34
35     def __init__(self):
36         self.results = Queue()
37         self.errors = Queue()
38         self.threads = []
39
40     def run (self, method, *args, **kwds):
41         """
42         Execute a callable in a separate thread.    
43         """
44         method = ThreadedMethod(method, self.results, self.errors)
45         thread = method(args, kwds)
46         self.threads.append(thread)
47
48     start = run
49
50     def join(self):
51         """
52         Wait for all threads to complete  
53         """
54         for thread in self.threads:
55             thread.join()
56
57     def get_results(self, lenient=True):
58         """
59         Return a list of all the results so far. Blocks until 
60         all threads are finished. 
61         If lienent is set to false the error queue will be checked before 
62         the response is returned. If there are errors in the queue an SFA Fault will 
63         be raised.   
64         """
65         self.join()
66         results = []
67         if not lenient:
68             errors = self.get_errors()
69             if errors: 
70                 raise Exception(errors[0])
71
72         while not self.results.empty():
73             results.append(self.results.get())  
74         return results
75
76     def get_errors(self):
77         """
78         Return a list of all errors. Blocks untill all threads are finished
79         """
80         self.join()
81         errors = []
82         while not self.errors.empty():
83             errors.append(self.errors.get())
84         return errors
85
86     def get_return_value(self):
87         """
88         Get the value that should be returuned to the client. If there are errors then the
89         first error is returned. If there are no errors, then the first result is returned  
90         """
91     
92            
93 if __name__ == '__main__':
94
95     def f(name, n, sleep=1):
96         nums = []
97         for i in range(n, n+5):
98             print "%s: %s" % (name, i)
99             nums.append(i)
100             time.sleep(sleep)
101         return nums
102     def e(name, n, sleep=1):
103         nums = []
104         for i in range(n, n+3) + ['n', 'b']:
105             print "%s: 1 + %s:" % (name, i)
106             nums.append(i + 1)
107             time.sleep(sleep)
108         return nums      
109
110     threads = ThreadManager()
111     threads.run(f, "Thread1", 10, 2)
112     threads.run(f, "Thread2", -10, 1)
113     threads.run(e, "Thread3", 19, 1)
114
115     #results = threads.get_results()
116     #errors = threads.get_errors()
117     #print "Results:", results
118     #print "Errors:", errors
119     results_xlenient = threads.get_results(lenient=False)
120