Adding authors and correcting licence information
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import threading
20 import Queue
21 import traceback
22 import sys
23 import os
24
25 N_PROCS = None
26
27 THREADCACHE = []
28 THREADCACHEPID = None
29
30 class WorkerThread(threading.Thread):
31     class QUIT:
32         pass
33     class REASSIGNED:
34         pass
35     
36     def run(self):
37         while True:
38             task = self.queue.get()
39             if task is None:
40                 self.done = True
41                 self.queue.task_done()
42                 continue
43             elif task is self.QUIT:
44                 self.done = True
45                 self.queue.task_done()
46                 break
47             elif task is self.REASSIGNED:
48                 continue
49             else:
50                 self.done = False
51             
52             try:
53                 try:
54                     callable, args, kwargs = task
55                     rv = callable(*args, **kwargs)
56                     
57                     if self.rvqueue is not None:
58                         self.rvqueue.put(rv)
59                 finally:
60                     self.queue.task_done()
61             except:
62                 traceback.print_exc(file = sys.stderr)
63                 self.delayed_exceptions.append(sys.exc_info())
64     
65     def waitdone(self):
66         while not self.queue.empty() and not self.done:
67             self.queue.join()
68     
69     def attach(self, queue, rvqueue, delayed_exceptions):
70         if self.isAlive():
71             self.waitdone()
72             oldqueue = self.queue
73         self.queue = queue
74         self.rvqueue = rvqueue
75         self.delayed_exceptions = delayed_exceptions
76         if self.isAlive():
77             oldqueue.put(self.REASSIGNED)
78     
79     def detach(self):
80         if self.isAlive():
81             self.waitdone()
82             self.oldqueue = self.queue
83         self.queue = Queue.Queue()
84         self.rvqueue = None
85         self.delayed_exceptions = []
86     
87     def detach_signal(self):
88         if self.isAlive():
89             self.oldqueue.put(self.REASSIGNED)
90             del self.oldqueue
91         
92     def quit(self):
93         self.queue.put(self.QUIT)
94         self.join()
95
96 class ParallelMap(object):
97     def __init__(self, maxthreads = None, maxqueue = None, results = True):
98         global N_PROCS
99         global THREADCACHE
100         global THREADCACHEPID
101         
102         if maxthreads is None:
103             if N_PROCS is None:
104                 try:
105                     f = open("/proc/cpuinfo")
106                     try:
107                         N_PROCS = sum("processor" in l for l in f)
108                     finally:
109                         f.close()
110                 except:
111                     pass
112             maxthreads = N_PROCS
113         
114         if maxthreads is None:
115             maxthreads = 4
116         
117         self.queue = Queue.Queue(maxqueue or 0)
118
119         self.delayed_exceptions = []
120         
121         if results:
122             self.rvqueue = Queue.Queue()
123         else:
124             self.rvqueue = None
125         
126         # Check threadcache
127         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
128             del THREADCACHE[:]
129             THREADCACHEPID = os.getpid()
130     
131         self.workers = []
132         for x in xrange(maxthreads):
133             t = None
134             if THREADCACHE:
135                 try:
136                     t = THREADCACHE.pop()
137                 except:
138                     pass
139             if t is None:
140                 t = WorkerThread()
141                 t.setDaemon(True)
142             else:
143                 t.waitdone()
144             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
145             self.workers.append(t)
146     
147     def __del__(self):
148         self.destroy()
149     
150     def destroy(self):
151         # Check threadcache
152         global THREADCACHE
153         global THREADCACHEPID
154         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
155             del THREADCACHE[:]
156             THREADCACHEPID = os.getpid()
157
158         for worker in self.workers:
159             worker.waitdone()
160         for worker in self.workers:
161             worker.detach()
162         for worker in self.workers:
163             worker.detach_signal()
164         THREADCACHE.extend(self.workers)
165         del self.workers[:]
166         
167     def put(self, callable, *args, **kwargs):
168         self.queue.put((callable, args, kwargs))
169     
170     def put_nowait(self, callable, *args, **kwargs):
171         self.queue.put_nowait((callable, args, kwargs))
172
173     def start(self):
174         for thread in self.workers:
175             if not thread.isAlive():
176                 thread.start()
177     
178     def join(self):
179         for thread in self.workers:
180             # That's the sync signal
181             self.queue.put(None)
182             
183         self.queue.join()
184         for thread in self.workers:
185             thread.waitdone()
186         
187         if self.delayed_exceptions:
188             typ,val,loc = self.delayed_exceptions[0]
189             del self.delayed_exceptions[:]
190             raise typ,val,loc
191         
192         self.destroy()
193     
194     def sync(self):
195         self.queue.join()
196         if self.delayed_exceptions:
197             typ,val,loc = self.delayed_exceptions[0]
198             del self.delayed_exceptions[:]
199             raise typ,val,loc
200         
201     def __iter__(self):
202         if self.rvqueue is not None:
203             while True:
204                 try:
205                     yield self.rvqueue.get_nowait()
206                 except Queue.Empty:
207                     self.queue.join()
208                     try:
209                         yield self.rvqueue.get_nowait()
210                     except Queue.Empty:
211                         raise StopIteration
212             
213     
214 class ParallelFilter(ParallelMap):
215     class _FILTERED:
216         pass
217     
218     def __filter(self, x):
219         if self.filter_condition(x):
220             return x
221         else:
222             return self._FILTERED
223     
224     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
225         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
226         self.filter_condition = filter_condition
227
228     def put(self, what):
229         super(ParallelFilter, self).put(self.__filter, what)
230     
231     def put_nowait(self, what):
232         super(ParallelFilter, self).put_nowait(self.__filter, what)
233         
234     def __iter__(self):
235         for rv in super(ParallelFilter, self).__iter__():
236             if rv is not self._FILTERED:
237                 yield rv
238
239 class ParallelRun(ParallelMap):
240     def __run(self, x):
241         fn, args, kwargs = x
242         return fn(*args, **kwargs)
243     
244     def __init__(self, maxthreads = None, maxqueue = None):
245         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
246
247     def put(self, what, *args, **kwargs):
248         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
249     
250     def put_nowait(self, what, *args, **kwargs):
251         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
252
253
254 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
255     mapper = ParallelMap(
256         maxthreads = maxthreads,
257         maxqueue = maxqueue,
258         results = True)
259     mapper.start()
260     for elem in iterable:
261         mapper.put(elem)
262     rv = list(mapper)
263     mapper.join()
264     return rv
265
266 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
267     filtrer = ParallelFilter(
268         condition,
269         maxthreads = maxthreads,
270         maxqueue = maxqueue)
271     filtrer.start()
272     for elem in iterable:
273         filtrer.put(elem)
274     rv = list(filtrer)
275     filtrer.join()
276     return rv
277