fixed shebangs in non executable .py files
[nepi.git] / src / nepi / util / parallel.py
1 # -*- coding: utf-8 -*-
2
3 import threading
4 import Queue
5 import traceback
6 import sys
7 import os
8
9 N_PROCS = None
10
11 THREADCACHE = []
12 THREADCACHEPID = None
13
14 class WorkerThread(threading.Thread):
15     class QUIT:
16         pass
17     class REASSIGNED:
18         pass
19     
20     def run(self):
21         while True:
22             task = self.queue.get()
23             if task is None:
24                 self.done = True
25                 self.queue.task_done()
26                 continue
27             elif task is self.QUIT:
28                 self.done = True
29                 self.queue.task_done()
30                 break
31             elif task is self.REASSIGNED:
32                 continue
33             else:
34                 self.done = False
35             
36             try:
37                 try:
38                     callable, args, kwargs = task
39                     rv = callable(*args, **kwargs)
40                     
41                     if self.rvqueue is not None:
42                         self.rvqueue.put(rv)
43                 finally:
44                     self.queue.task_done()
45             except:
46                 traceback.print_exc(file = sys.stderr)
47                 self.delayed_exceptions.append(sys.exc_info())
48     
49     def waitdone(self):
50         while not self.queue.empty() and not self.done:
51             self.queue.join()
52     
53     def attach(self, queue, rvqueue, delayed_exceptions):
54         if self.isAlive():
55             self.waitdone()
56             oldqueue = self.queue
57         self.queue = queue
58         self.rvqueue = rvqueue
59         self.delayed_exceptions = delayed_exceptions
60         if self.isAlive():
61             oldqueue.put(self.REASSIGNED)
62     
63     def detach(self):
64         if self.isAlive():
65             self.waitdone()
66             self.oldqueue = self.queue
67         self.queue = Queue.Queue()
68         self.rvqueue = None
69         self.delayed_exceptions = []
70     
71     def detach_signal(self):
72         if self.isAlive():
73             self.oldqueue.put(self.REASSIGNED)
74             del self.oldqueue
75         
76     def quit(self):
77         self.queue.put(self.QUIT)
78         self.join()
79
80 class ParallelMap(object):
81     def __init__(self, maxthreads = None, maxqueue = None, results = True):
82         global N_PROCS
83         global THREADCACHE
84         global THREADCACHEPID
85         
86         if maxthreads is None:
87             if N_PROCS is None:
88                 try:
89                     f = open("/proc/cpuinfo")
90                     try:
91                         N_PROCS = sum("processor" in l for l in f)
92                     finally:
93                         f.close()
94                 except:
95                     pass
96             maxthreads = N_PROCS
97         
98         if maxthreads is None:
99             maxthreads = 4
100         
101         self.queue = Queue.Queue(maxqueue or 0)
102
103         self.delayed_exceptions = []
104         
105         if results:
106             self.rvqueue = Queue.Queue()
107         else:
108             self.rvqueue = None
109         
110         # Check threadcache
111         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
112             del THREADCACHE[:]
113             THREADCACHEPID = os.getpid()
114     
115         self.workers = []
116         for x in xrange(maxthreads):
117             t = None
118             if THREADCACHE:
119                 try:
120                     t = THREADCACHE.pop()
121                 except:
122                     pass
123             if t is None:
124                 t = WorkerThread()
125                 t.setDaemon(True)
126             else:
127                 t.waitdone()
128             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
129             self.workers.append(t)
130     
131     def __del__(self):
132         self.destroy()
133     
134     def destroy(self):
135         # Check threadcache
136         global THREADCACHE
137         global THREADCACHEPID
138         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
139             del THREADCACHE[:]
140             THREADCACHEPID = os.getpid()
141
142         for worker in self.workers:
143             worker.waitdone()
144         for worker in self.workers:
145             worker.detach()
146         for worker in self.workers:
147             worker.detach_signal()
148         THREADCACHE.extend(self.workers)
149         del self.workers[:]
150         
151     def put(self, callable, *args, **kwargs):
152         self.queue.put((callable, args, kwargs))
153     
154     def put_nowait(self, callable, *args, **kwargs):
155         self.queue.put_nowait((callable, args, kwargs))
156
157     def start(self):
158         for thread in self.workers:
159             if not thread.isAlive():
160                 thread.start()
161     
162     def join(self):
163         for thread in self.workers:
164             # That's the sync signal
165             self.queue.put(None)
166             
167         self.queue.join()
168         for thread in self.workers:
169             thread.waitdone()
170         
171         if self.delayed_exceptions:
172             typ,val,loc = self.delayed_exceptions[0]
173             del self.delayed_exceptions[:]
174             raise typ,val,loc
175         
176         self.destroy()
177     
178     def sync(self):
179         self.queue.join()
180         if self.delayed_exceptions:
181             typ,val,loc = self.delayed_exceptions[0]
182             del self.delayed_exceptions[:]
183             raise typ,val,loc
184         
185     def __iter__(self):
186         if self.rvqueue is not None:
187             while True:
188                 try:
189                     yield self.rvqueue.get_nowait()
190                 except Queue.Empty:
191                     self.queue.join()
192                     try:
193                         yield self.rvqueue.get_nowait()
194                     except Queue.Empty:
195                         raise StopIteration
196             
197     
198 class ParallelFilter(ParallelMap):
199     class _FILTERED:
200         pass
201     
202     def __filter(self, x):
203         if self.filter_condition(x):
204             return x
205         else:
206             return self._FILTERED
207     
208     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
209         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
210         self.filter_condition = filter_condition
211
212     def put(self, what):
213         super(ParallelFilter, self).put(self.__filter, what)
214     
215     def put_nowait(self, what):
216         super(ParallelFilter, self).put_nowait(self.__filter, what)
217         
218     def __iter__(self):
219         for rv in super(ParallelFilter, self).__iter__():
220             if rv is not self._FILTERED:
221                 yield rv
222
223 class ParallelRun(ParallelMap):
224     def __run(self, x):
225         fn, args, kwargs = x
226         return fn(*args, **kwargs)
227     
228     def __init__(self, maxthreads = None, maxqueue = None):
229         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
230
231     def put(self, what, *args, **kwargs):
232         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
233     
234     def put_nowait(self, what, *args, **kwargs):
235         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
236
237
238 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
239     mapper = ParallelMap(
240         maxthreads = maxthreads,
241         maxqueue = maxqueue,
242         results = True)
243     mapper.start()
244     for elem in iterable:
245         mapper.put(elem)
246     rv = list(mapper)
247     mapper.join()
248     return rv
249
250 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
251     filtrer = ParallelFilter(
252         condition,
253         maxthreads = maxthreads,
254         maxqueue = maxqueue)
255     filtrer.start()
256     for elem in iterable:
257         filtrer.put(elem)
258     rv = list(filtrer)
259     filtrer.join()
260     return rv
261