Fixed nasty concurrency bug in EC
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
19 #
20
21 import threading
22 import Queue
23 import traceback
24 import sys
25 import os
26
27 N_PROCS = None
28
29 THREADCACHE = []
30 THREADCACHEPID = None
31
32 class WorkerThread(threading.Thread):
33     class QUIT:
34         pass
35     class REASSIGNED:
36         pass
37     
38     def run(self):
39         while True:
40             task = self.queue.get()
41             if task is None:
42                 self.done = True
43                 self.queue.task_done()
44                 continue
45             elif task is self.QUIT:
46                 self.done = True
47                 self.queue.task_done()
48                 break
49             elif task is self.REASSIGNED:
50                 continue
51             else:
52                 self.done = False
53             
54             try:
55                 try:
56                     callable, args, kwargs = task
57                     rv = callable(*args, **kwargs)
58                     
59                     if self.rvqueue is not None:
60                         self.rvqueue.put(rv)
61                 finally:
62                     self.queue.task_done()
63             except:
64                 traceback.print_exc(file = sys.stderr)
65                 self.delayed_exceptions.append(sys.exc_info())
66     
67     def waitdone(self):
68         while not self.queue.empty() and not self.done:
69             self.queue.join()
70     
71     def attach(self, queue, rvqueue, delayed_exceptions):
72         if self.isAlive():
73             self.waitdone()
74             oldqueue = self.queue
75         self.queue = queue
76         self.rvqueue = rvqueue
77         self.delayed_exceptions = delayed_exceptions
78         if self.isAlive():
79             oldqueue.put(self.REASSIGNED)
80     
81     def detach(self):
82         if self.isAlive():
83             self.waitdone()
84             self.oldqueue = self.queue
85         self.queue = Queue.Queue()
86         self.rvqueue = None
87         self.delayed_exceptions = []
88     
89     def detach_signal(self):
90         if self.isAlive():
91             self.oldqueue.put(self.REASSIGNED)
92             del self.oldqueue
93         
94     def quit(self):
95         self.queue.put(self.QUIT)
96         self.join()
97
98 class ParallelMap(object):
99     def __init__(self, maxthreads = None, maxqueue = None, results = True):
100         global N_PROCS
101         global THREADCACHE
102         global THREADCACHEPID
103         
104         if maxthreads is None:
105             if N_PROCS is None:
106                 try:
107                     f = open("/proc/cpuinfo")
108                     try:
109                         N_PROCS = sum("processor" in l for l in f)
110                     finally:
111                         f.close()
112                 except:
113                     pass
114             maxthreads = N_PROCS
115         
116         if maxthreads is None:
117             maxthreads = 4
118         
119         self.queue = Queue.Queue(maxqueue or 0)
120
121         self.delayed_exceptions = []
122         
123         if results:
124             self.rvqueue = Queue.Queue()
125         else:
126             self.rvqueue = None
127         
128         # Check threadcache
129         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
130             del THREADCACHE[:]
131             THREADCACHEPID = os.getpid()
132     
133         self.workers = []
134         for x in xrange(maxthreads):
135             t = None
136             if THREADCACHE:
137                 try:
138                     t = THREADCACHE.pop()
139                 except:
140                     pass
141             if t is None:
142                 t = WorkerThread()
143                 t.setDaemon(True)
144             else:
145                 t.waitdone()
146             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
147             self.workers.append(t)
148     
149     def __del__(self):
150         self.destroy()
151     
152     def destroy(self):
153         # Check threadcache
154         global THREADCACHE
155         global THREADCACHEPID
156         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
157             del THREADCACHE[:]
158             THREADCACHEPID = os.getpid()
159
160         for worker in self.workers:
161             worker.waitdone()
162         for worker in self.workers:
163             worker.detach()
164         for worker in self.workers:
165             worker.detach_signal()
166         THREADCACHE.extend(self.workers)
167         del self.workers[:]
168         
169     def put(self, callable, *args, **kwargs):
170         self.queue.put((callable, args, kwargs))
171     
172     def put_nowait(self, callable, *args, **kwargs):
173         self.queue.put_nowait((callable, args, kwargs))
174
175     def start(self):
176         for thread in self.workers:
177             if not thread.isAlive():
178                 thread.start()
179     
180     def join(self):
181         for thread in self.workers:
182             # That's the sync signal
183             self.queue.put(None)
184             
185         self.queue.join()
186         for thread in self.workers:
187             thread.waitdone()
188         
189         if self.delayed_exceptions:
190             typ,val,loc = self.delayed_exceptions[0]
191             del self.delayed_exceptions[:]
192             raise typ,val,loc
193         
194         self.destroy()
195     
196     def sync(self):
197         self.queue.join()
198         if self.delayed_exceptions:
199             typ,val,loc = self.delayed_exceptions[0]
200             del self.delayed_exceptions[:]
201             raise typ,val,loc
202         
203     def __iter__(self):
204         if self.rvqueue is not None:
205             while True:
206                 try:
207                     yield self.rvqueue.get_nowait()
208                 except Queue.Empty:
209                     self.queue.join()
210                     try:
211                         yield self.rvqueue.get_nowait()
212                     except Queue.Empty:
213                         raise StopIteration
214             
215     
216 class ParallelFilter(ParallelMap):
217     class _FILTERED:
218         pass
219     
220     def __filter(self, x):
221         if self.filter_condition(x):
222             return x
223         else:
224             return self._FILTERED
225     
226     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
227         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
228         self.filter_condition = filter_condition
229
230     def put(self, what):
231         super(ParallelFilter, self).put(self.__filter, what)
232     
233     def put_nowait(self, what):
234         super(ParallelFilter, self).put_nowait(self.__filter, what)
235         
236     def __iter__(self):
237         for rv in super(ParallelFilter, self).__iter__():
238             if rv is not self._FILTERED:
239                 yield rv
240
241 class ParallelRun(ParallelMap):
242     def __run(self, x):
243         fn, args, kwargs = x
244         return fn(*args, **kwargs)
245     
246     def __init__(self, maxthreads = None, maxqueue = None):
247         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
248
249     def put(self, what, *args, **kwargs):
250         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
251     
252     def put_nowait(self, what, *args, **kwargs):
253         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
254
255
256 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
257     mapper = ParallelMap(
258         maxthreads = maxthreads,
259         maxqueue = maxqueue,
260         results = True)
261     mapper.start()
262     for elem in iterable:
263         mapper.put(elem)
264     rv = list(mapper)
265     mapper.join()
266     return rv
267
268 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
269     filtrer = ParallelFilter(
270         condition,
271         maxthreads = maxthreads,
272         maxqueue = maxqueue)
273     filtrer.start()
274     for elem in iterable:
275         filtrer.put(elem)
276     rv = list(filtrer)
277     filtrer.join()
278     return rv
279