Merge with aly's
[nepi.git] / src / nepi / util / parallel.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import threading
5 import Queue
6 import traceback
7 import sys
8 import os
9
10 N_PROCS = None
11
12 THREADCACHE = []
13 THREADCACHEPID = None
14
15 class WorkerThread(threading.Thread):
16     class QUIT:
17         pass
18     class REASSIGNED:
19         pass
20     
21     def run(self):
22         while True:
23             task = self.queue.get()
24             if task is None:
25                 self.done = True
26                 self.queue.task_done()
27                 continue
28             elif task is self.QUIT:
29                 self.done = True
30                 self.queue.task_done()
31                 break
32             elif task is self.REASSIGNED:
33                 continue
34             else:
35                 self.done = False
36             
37             try:
38                 try:
39                     callable, args, kwargs = task
40                     rv = callable(*args, **kwargs)
41                     
42                     if self.rvqueue is not None:
43                         self.rvqueue.put(rv)
44                 finally:
45                     self.queue.task_done()
46             except:
47                 traceback.print_exc(file = sys.stderr)
48                 self.delayed_exceptions.append(sys.exc_info())
49     
50     def waitdone(self):
51         while not self.queue.empty() and not self.done:
52             self.queue.join()
53     
54     def attach(self, queue, rvqueue, delayed_exceptions):
55         if self.isAlive():
56             self.waitdone()
57             oldqueue = self.queue
58         self.queue = queue
59         self.rvqueue = rvqueue
60         self.delayed_exceptions = delayed_exceptions
61         if self.isAlive():
62             oldqueue.put(self.REASSIGNED)
63     
64     def detach(self):
65         if self.isAlive():
66             self.waitdone()
67             self.oldqueue = self.queue
68         self.queue = Queue.Queue()
69         self.rvqueue = None
70         self.delayed_exceptions = []
71     
72     def detach_signal(self):
73         if self.isAlive():
74             self.oldqueue.put(self.REASSIGNED)
75             del self.oldqueue
76         
77     def quit(self):
78         self.queue.put(self.QUIT)
79         self.join()
80
81 class ParallelMap(object):
82     def __init__(self, maxthreads = None, maxqueue = None, results = True):
83         global N_PROCS
84         global THREADCACHE
85         global THREADCACHEPID
86         
87         if maxthreads is None:
88             if N_PROCS is None:
89                 try:
90                     f = open("/proc/cpuinfo")
91                     try:
92                         N_PROCS = sum("processor" in l for l in f)
93                     finally:
94                         f.close()
95                 except:
96                     pass
97             maxthreads = N_PROCS
98         
99         if maxthreads is None:
100             maxthreads = 4
101         
102         self.queue = Queue.Queue(maxqueue or 0)
103
104         self.delayed_exceptions = []
105         
106         if results:
107             self.rvqueue = Queue.Queue()
108         else:
109             self.rvqueue = None
110         
111         # Check threadcache
112         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
113             del THREADCACHE[:]
114             THREADCACHEPID = os.getpid()
115     
116         self.workers = []
117         for x in xrange(maxthreads):
118             t = None
119             if THREADCACHE:
120                 try:
121                     t = THREADCACHE.pop()
122                 except:
123                     pass
124             if t is None:
125                 t = WorkerThread()
126                 t.setDaemon(True)
127             else:
128                 t.waitdone()
129             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
130             self.workers.append(t)
131     
132     def __del__(self):
133         self.destroy()
134     
135     def destroy(self):
136         # Check threadcache
137         global THREADCACHE
138         global THREADCACHEPID
139         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
140             del THREADCACHE[:]
141             THREADCACHEPID = os.getpid()
142
143         for worker in self.workers:
144             worker.waitdone()
145         for worker in self.workers:
146             worker.detach()
147         for worker in self.workers:
148             worker.detach_signal()
149         THREADCACHE.extend(self.workers)
150         del self.workers[:]
151         
152     def put(self, callable, *args, **kwargs):
153         self.queue.put((callable, args, kwargs))
154     
155     def put_nowait(self, callable, *args, **kwargs):
156         self.queue.put_nowait((callable, args, kwargs))
157
158     def start(self):
159         for thread in self.workers:
160             if not thread.isAlive():
161                 thread.start()
162     
163     def join(self):
164         for thread in self.workers:
165             # That's the sync signal
166             self.queue.put(None)
167             
168         self.queue.join()
169         for thread in self.workers:
170             thread.waitdone()
171         
172         if self.delayed_exceptions:
173             typ,val,loc = self.delayed_exceptions[0]
174             del self.delayed_exceptions[:]
175             raise typ,val,loc
176         
177         self.destroy()
178     
179     def sync(self):
180         self.queue.join()
181         if self.delayed_exceptions:
182             typ,val,loc = self.delayed_exceptions[0]
183             del self.delayed_exceptions[:]
184             raise typ,val,loc
185         
186     def __iter__(self):
187         if self.rvqueue is not None:
188             while True:
189                 try:
190                     yield self.rvqueue.get_nowait()
191                 except Queue.Empty:
192                     self.queue.join()
193                     try:
194                         yield self.rvqueue.get_nowait()
195                     except Queue.Empty:
196                         raise StopIteration
197             
198     
199 class ParallelFilter(ParallelMap):
200     class _FILTERED:
201         pass
202     
203     def __filter(self, x):
204         if self.filter_condition(x):
205             return x
206         else:
207             return self._FILTERED
208     
209     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
210         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
211         self.filter_condition = filter_condition
212
213     def put(self, what):
214         super(ParallelFilter, self).put(self.__filter, what)
215     
216     def put_nowait(self, what):
217         super(ParallelFilter, self).put_nowait(self.__filter, what)
218         
219     def __iter__(self):
220         for rv in super(ParallelFilter, self).__iter__():
221             if rv is not self._FILTERED:
222                 yield rv
223
224 class ParallelRun(ParallelMap):
225     def __run(self, x):
226         fn, args, kwargs = x
227         return fn(*args, **kwargs)
228     
229     def __init__(self, maxthreads = None, maxqueue = None):
230         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
231
232     def put(self, what, *args, **kwargs):
233         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
234     
235     def put_nowait(self, what, *args, **kwargs):
236         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
237
238
239 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
240     mapper = ParallelMap(
241         maxthreads = maxthreads,
242         maxqueue = maxqueue,
243         results = True)
244     mapper.start()
245     for elem in iterable:
246         mapper.put(elem)
247     rv = list(mapper)
248     mapper.join()
249     return rv
250
251 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
252     filtrer = ParallelFilter(
253         condition,
254         maxthreads = maxthreads,
255         maxqueue = maxqueue)
256     filtrer.start()
257     for elem in iterable:
258         filtrer.put(elem)
259     rv = list(filtrer)
260     filtrer.join()
261     return rv
262