Merge neco to nepi-3.0
[nepi.git] / src / nepi / util / parallel.py
1 import threading
2 import Queue
3 import traceback
4 import sys
5 import os
6
7 N_PROCS = None
8
9 THREADCACHE = []
10 THREADCACHEPID = None
11
12 class WorkerThread(threading.Thread):
13     class QUIT:
14         pass
15     class REASSIGNED:
16         pass
17     
18     def run(self):
19         while True:
20             task = self.queue.get()
21             if task is None:
22                 self.done = True
23                 self.queue.task_done()
24                 continue
25             elif task is self.QUIT:
26                 self.done = True
27                 self.queue.task_done()
28                 break
29             elif task is self.REASSIGNED:
30                 continue
31             else:
32                 self.done = False
33             
34             try:
35                 try:
36                     callable, args, kwargs = task
37                     rv = callable(*args, **kwargs)
38                     
39                     if self.rvqueue is not None:
40                         self.rvqueue.put(rv)
41                 finally:
42                     self.queue.task_done()
43             except:
44                 traceback.print_exc(file = sys.stderr)
45                 self.delayed_exceptions.append(sys.exc_info())
46     
47     def waitdone(self):
48         while not self.queue.empty() and not self.done:
49             self.queue.join()
50     
51     def attach(self, queue, rvqueue, delayed_exceptions):
52         if self.isAlive():
53             self.waitdone()
54             oldqueue = self.queue
55         self.queue = queue
56         self.rvqueue = rvqueue
57         self.delayed_exceptions = delayed_exceptions
58         if self.isAlive():
59             oldqueue.put(self.REASSIGNED)
60     
61     def detach(self):
62         if self.isAlive():
63             self.waitdone()
64             self.oldqueue = self.queue
65         self.queue = Queue.Queue()
66         self.rvqueue = None
67         self.delayed_exceptions = []
68     
69     def detach_signal(self):
70         if self.isAlive():
71             self.oldqueue.put(self.REASSIGNED)
72             del self.oldqueue
73         
74     def quit(self):
75         self.queue.put(self.QUIT)
76         self.join()
77
78 class ParallelMap(object):
79     def __init__(self, maxthreads = None, maxqueue = None, results = True):
80         global N_PROCS
81         global THREADCACHE
82         global THREADCACHEPID
83         
84         if maxthreads is None:
85             if N_PROCS is None:
86                 try:
87                     f = open("/proc/cpuinfo")
88                     try:
89                         N_PROCS = sum("processor" in l for l in f)
90                     finally:
91                         f.close()
92                 except:
93                     pass
94             maxthreads = N_PROCS
95         
96         if maxthreads is None:
97             maxthreads = 4
98         
99         self.queue = Queue.Queue(maxqueue or 0)
100
101         self.delayed_exceptions = []
102         
103         if results:
104             self.rvqueue = Queue.Queue()
105         else:
106             self.rvqueue = None
107         
108         # Check threadcache
109         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
110             del THREADCACHE[:]
111             THREADCACHEPID = os.getpid()
112     
113         self.workers = []
114         for x in xrange(maxthreads):
115             t = None
116             if THREADCACHE:
117                 try:
118                     t = THREADCACHE.pop()
119                 except:
120                     pass
121             if t is None:
122                 t = WorkerThread()
123                 t.setDaemon(True)
124             else:
125                 t.waitdone()
126             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
127             self.workers.append(t)
128     
129     def __del__(self):
130         self.destroy()
131     
132     def destroy(self):
133         # Check threadcache
134         global THREADCACHE
135         global THREADCACHEPID
136         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
137             del THREADCACHE[:]
138             THREADCACHEPID = os.getpid()
139
140         for worker in self.workers:
141             worker.waitdone()
142         for worker in self.workers:
143             worker.detach()
144         for worker in self.workers:
145             worker.detach_signal()
146         THREADCACHE.extend(self.workers)
147         del self.workers[:]
148         
149     def put(self, callable, *args, **kwargs):
150         self.queue.put((callable, args, kwargs))
151     
152     def put_nowait(self, callable, *args, **kwargs):
153         self.queue.put_nowait((callable, args, kwargs))
154
155     def start(self):
156         for thread in self.workers:
157             if not thread.isAlive():
158                 thread.start()
159     
160     def join(self):
161         for thread in self.workers:
162             # That's the sync signal
163             self.queue.put(None)
164             
165         self.queue.join()
166         for thread in self.workers:
167             thread.waitdone()
168         
169         if self.delayed_exceptions:
170             typ,val,loc = self.delayed_exceptions[0]
171             del self.delayed_exceptions[:]
172             raise typ,val,loc
173         
174         self.destroy()
175     
176     def sync(self):
177         self.queue.join()
178         if self.delayed_exceptions:
179             typ,val,loc = self.delayed_exceptions[0]
180             del self.delayed_exceptions[:]
181             raise typ,val,loc
182         
183     def __iter__(self):
184         if self.rvqueue is not None:
185             while True:
186                 try:
187                     yield self.rvqueue.get_nowait()
188                 except Queue.Empty:
189                     self.queue.join()
190                     try:
191                         yield self.rvqueue.get_nowait()
192                     except Queue.Empty:
193                         raise StopIteration
194             
195     
196 class ParallelFilter(ParallelMap):
197     class _FILTERED:
198         pass
199     
200     def __filter(self, x):
201         if self.filter_condition(x):
202             return x
203         else:
204             return self._FILTERED
205     
206     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
207         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
208         self.filter_condition = filter_condition
209
210     def put(self, what):
211         super(ParallelFilter, self).put(self.__filter, what)
212     
213     def put_nowait(self, what):
214         super(ParallelFilter, self).put_nowait(self.__filter, what)
215         
216     def __iter__(self):
217         for rv in super(ParallelFilter, self).__iter__():
218             if rv is not self._FILTERED:
219                 yield rv
220
221 class ParallelRun(ParallelMap):
222     def __run(self, x):
223         fn, args, kwargs = x
224         return fn(*args, **kwargs)
225     
226     def __init__(self, maxthreads = None, maxqueue = None):
227         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
228
229     def put(self, what, *args, **kwargs):
230         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
231     
232     def put_nowait(self, what, *args, **kwargs):
233         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
234
235
236 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
237     mapper = ParallelMap(
238         maxthreads = maxthreads,
239         maxqueue = maxqueue,
240         results = True)
241     mapper.start()
242     for elem in iterable:
243         mapper.put(elem)
244     rv = list(mapper)
245     mapper.join()
246     return rv
247
248 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
249     filtrer = ParallelFilter(
250         condition,
251         maxthreads = maxthreads,
252         maxqueue = maxqueue)
253     filtrer.start()
254     for elem in iterable:
255         filtrer.put(elem)
256     rv = list(filtrer)
257     filtrer.join()
258     return rv
259