Fixing 'error: can't start new thread' bug in ParallelRun
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
19 #
20
21 # A.Q. TODO: BUG FIX THREADCACHE. Not needed!! remove it completely!
22
23 import threading
24 import Queue
25 import traceback
26 import sys
27 import os
28
29 N_PROCS = None
30
31 #THREADCACHE = []
32 #THREADCACHEPID = None
33
34 class WorkerThread(threading.Thread):
35     class QUIT:
36         pass
37     class REASSIGNED:
38         pass
39     
40     def run(self):
41         while True:
42             task = self.queue.get()
43             if task is None:
44                 self.done = True
45                 self.queue.task_done()
46                 continue
47             elif task is self.QUIT:
48                 self.done = True
49                 self.queue.task_done()
50                 break
51             elif task is self.REASSIGNED:
52                 continue
53             else:
54                 self.done = False
55             
56             try:
57                 try:
58                     callable, args, kwargs = task
59                     rv = callable(*args, **kwargs)
60                     
61                     if self.rvqueue is not None:
62                         self.rvqueue.put(rv)
63                 finally:
64                     self.queue.task_done()
65             except:
66                 traceback.print_exc(file = sys.stderr)
67                 self.delayed_exceptions.append(sys.exc_info())
68     
69     def waitdone(self):
70         while not self.queue.empty() and not self.done:
71             self.queue.join()
72     
73     def attach(self, queue, rvqueue, delayed_exceptions):
74         if self.isAlive():
75             self.waitdone()
76             oldqueue = self.queue
77         self.queue = queue
78         self.rvqueue = rvqueue
79         self.delayed_exceptions = delayed_exceptions
80         if self.isAlive():
81             oldqueue.put(self.REASSIGNED)
82     
83     def detach(self):
84         if self.isAlive():
85             self.waitdone()
86             self.oldqueue = self.queue
87         self.queue = Queue.Queue()
88         self.rvqueue = None
89         self.delayed_exceptions = []
90     
91     def detach_signal(self):
92         if self.isAlive():
93             self.oldqueue.put(self.REASSIGNED)
94             del self.oldqueue
95         
96     def quit(self):
97         self.queue.put(self.QUIT)
98         self.join()
99
100 class ParallelMap(object):
101     def __init__(self, maxthreads = None, maxqueue = None, results = True):
102         global N_PROCS
103         #global THREADCACHE
104         #global THREADCACHEPID
105         
106         if maxthreads is None:
107             if N_PROCS is None:
108                 try:
109                     f = open("/proc/cpuinfo")
110                     try:
111                         N_PROCS = sum("processor" in l for l in f)
112                     finally:
113                         f.close()
114                 except:
115                     pass
116             maxthreads = N_PROCS
117         
118         if maxthreads is None:
119             maxthreads = 4
120         
121         self.queue = Queue.Queue(maxqueue or 0)
122
123         self.delayed_exceptions = []
124         
125         if results:
126             self.rvqueue = Queue.Queue()
127         else:
128             self.rvqueue = None
129         
130         # Check threadcache
131         #if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
132         #    del THREADCACHE[:]
133         #    THREADCACHEPID = os.getpid()
134     
135         self.workers = []
136         for x in xrange(maxthreads):
137             t = None
138             #if THREADCACHE:
139             #    try:
140             #        t = THREADCACHE.pop()
141             #    except:
142             #        pass
143             if t is None:
144                 t = WorkerThread()
145                 t.setDaemon(True)
146             else:
147                 t.waitdone()
148             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
149             self.workers.append(t)
150     
151     def __del__(self):
152         self.destroy()
153     
154     def destroy(self):
155         # Check threadcache
156         #global THREADCACHE
157         #global THREADCACHEPID
158         #if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
159         #    del THREADCACHE[:]
160         #    THREADCACHEPID = os.getpid()
161
162         for worker in self.workers:
163             worker.waitdone()
164         for worker in self.workers:
165             worker.detach()
166         for worker in self.workers:
167             worker.detach_signal()
168         for worker in self.workers:
169             worker.quit()
170
171         # TO FIX:
172         # THREADCACHE.extend(self.workers)
173
174         del self.workers[:]
175         
176     def put(self, callable, *args, **kwargs):
177         self.queue.put((callable, args, kwargs))
178     
179     def put_nowait(self, callable, *args, **kwargs):
180         self.queue.put_nowait((callable, args, kwargs))
181
182     def start(self):
183         for thread in self.workers:
184             if not thread.isAlive():
185                 thread.start()
186     
187     def join(self):
188         for thread in self.workers:
189             # That's the sync signal
190             self.queue.put(None)
191             
192         self.queue.join()
193         for thread in self.workers:
194             thread.waitdone()
195         
196         if self.delayed_exceptions:
197             typ,val,loc = self.delayed_exceptions[0]
198             del self.delayed_exceptions[:]
199             raise typ,val,loc
200         
201         self.destroy()
202     
203     def sync(self):
204         self.queue.join()
205         if self.delayed_exceptions:
206             typ,val,loc = self.delayed_exceptions[0]
207             del self.delayed_exceptions[:]
208             raise typ,val,loc
209         
210     def __iter__(self):
211         if self.rvqueue is not None:
212             while True:
213                 try:
214                     yield self.rvqueue.get_nowait()
215                 except Queue.Empty:
216                     self.queue.join()
217                     try:
218                         yield self.rvqueue.get_nowait()
219                     except Queue.Empty:
220                         raise StopIteration
221             
222     
223 class ParallelFilter(ParallelMap):
224     class _FILTERED:
225         pass
226     
227     def __filter(self, x):
228         if self.filter_condition(x):
229             return x
230         else:
231             return self._FILTERED
232     
233     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
234         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
235         self.filter_condition = filter_condition
236
237     def put(self, what):
238         super(ParallelFilter, self).put(self.__filter, what)
239     
240     def put_nowait(self, what):
241         super(ParallelFilter, self).put_nowait(self.__filter, what)
242         
243     def __iter__(self):
244         for rv in super(ParallelFilter, self).__iter__():
245             if rv is not self._FILTERED:
246                 yield rv
247
248 class ParallelRun(ParallelMap):
249     def __run(self, x):
250         fn, args, kwargs = x
251         return fn(*args, **kwargs)
252     
253     def __init__(self, maxthreads = None, maxqueue = None):
254         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
255
256     def put(self, what, *args, **kwargs):
257         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
258     
259     def put_nowait(self, what, *args, **kwargs):
260         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
261
262
263 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
264     mapper = ParallelMap(
265         maxthreads = maxthreads,
266         maxqueue = maxqueue,
267         results = True)
268     mapper.start()
269     for elem in iterable:
270         mapper.put(elem)
271     rv = list(mapper)
272     mapper.join()
273     return rv
274
275 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
276     filtrer = ParallelFilter(
277         condition,
278         maxthreads = maxthreads,
279         maxqueue = maxqueue)
280     filtrer.start()
281     for elem in iterable:
282         filtrer.put(elem)
283     rv = list(filtrer)
284     filtrer.join()
285     return rv
286