Added LICENSE
[nepi.git] / src / nepi / util / parallel.py
1 """
2     NEPI, a framework to manage network experiments
3     Copyright (C) 2013 INRIA
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 """
19
20 import threading
21 import Queue
22 import traceback
23 import sys
24 import os
25
26 N_PROCS = None
27
28 THREADCACHE = []
29 THREADCACHEPID = None
30
31 class WorkerThread(threading.Thread):
32     class QUIT:
33         pass
34     class REASSIGNED:
35         pass
36     
37     def run(self):
38         while True:
39             task = self.queue.get()
40             if task is None:
41                 self.done = True
42                 self.queue.task_done()
43                 continue
44             elif task is self.QUIT:
45                 self.done = True
46                 self.queue.task_done()
47                 break
48             elif task is self.REASSIGNED:
49                 continue
50             else:
51                 self.done = False
52             
53             try:
54                 try:
55                     callable, args, kwargs = task
56                     rv = callable(*args, **kwargs)
57                     
58                     if self.rvqueue is not None:
59                         self.rvqueue.put(rv)
60                 finally:
61                     self.queue.task_done()
62             except:
63                 traceback.print_exc(file = sys.stderr)
64                 self.delayed_exceptions.append(sys.exc_info())
65     
66     def waitdone(self):
67         while not self.queue.empty() and not self.done:
68             self.queue.join()
69     
70     def attach(self, queue, rvqueue, delayed_exceptions):
71         if self.isAlive():
72             self.waitdone()
73             oldqueue = self.queue
74         self.queue = queue
75         self.rvqueue = rvqueue
76         self.delayed_exceptions = delayed_exceptions
77         if self.isAlive():
78             oldqueue.put(self.REASSIGNED)
79     
80     def detach(self):
81         if self.isAlive():
82             self.waitdone()
83             self.oldqueue = self.queue
84         self.queue = Queue.Queue()
85         self.rvqueue = None
86         self.delayed_exceptions = []
87     
88     def detach_signal(self):
89         if self.isAlive():
90             self.oldqueue.put(self.REASSIGNED)
91             del self.oldqueue
92         
93     def quit(self):
94         self.queue.put(self.QUIT)
95         self.join()
96
97 class ParallelMap(object):
98     def __init__(self, maxthreads = None, maxqueue = None, results = True):
99         global N_PROCS
100         global THREADCACHE
101         global THREADCACHEPID
102         
103         if maxthreads is None:
104             if N_PROCS is None:
105                 try:
106                     f = open("/proc/cpuinfo")
107                     try:
108                         N_PROCS = sum("processor" in l for l in f)
109                     finally:
110                         f.close()
111                 except:
112                     pass
113             maxthreads = N_PROCS
114         
115         if maxthreads is None:
116             maxthreads = 4
117         
118         self.queue = Queue.Queue(maxqueue or 0)
119
120         self.delayed_exceptions = []
121         
122         if results:
123             self.rvqueue = Queue.Queue()
124         else:
125             self.rvqueue = None
126         
127         # Check threadcache
128         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
129             del THREADCACHE[:]
130             THREADCACHEPID = os.getpid()
131     
132         self.workers = []
133         for x in xrange(maxthreads):
134             t = None
135             if THREADCACHE:
136                 try:
137                     t = THREADCACHE.pop()
138                 except:
139                     pass
140             if t is None:
141                 t = WorkerThread()
142                 t.setDaemon(True)
143             else:
144                 t.waitdone()
145             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
146             self.workers.append(t)
147     
148     def __del__(self):
149         self.destroy()
150     
151     def destroy(self):
152         # Check threadcache
153         global THREADCACHE
154         global THREADCACHEPID
155         if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
156             del THREADCACHE[:]
157             THREADCACHEPID = os.getpid()
158
159         for worker in self.workers:
160             worker.waitdone()
161         for worker in self.workers:
162             worker.detach()
163         for worker in self.workers:
164             worker.detach_signal()
165         THREADCACHE.extend(self.workers)
166         del self.workers[:]
167         
168     def put(self, callable, *args, **kwargs):
169         self.queue.put((callable, args, kwargs))
170     
171     def put_nowait(self, callable, *args, **kwargs):
172         self.queue.put_nowait((callable, args, kwargs))
173
174     def start(self):
175         for thread in self.workers:
176             if not thread.isAlive():
177                 thread.start()
178     
179     def join(self):
180         for thread in self.workers:
181             # That's the sync signal
182             self.queue.put(None)
183             
184         self.queue.join()
185         for thread in self.workers:
186             thread.waitdone()
187         
188         if self.delayed_exceptions:
189             typ,val,loc = self.delayed_exceptions[0]
190             del self.delayed_exceptions[:]
191             raise typ,val,loc
192         
193         self.destroy()
194     
195     def sync(self):
196         self.queue.join()
197         if self.delayed_exceptions:
198             typ,val,loc = self.delayed_exceptions[0]
199             del self.delayed_exceptions[:]
200             raise typ,val,loc
201         
202     def __iter__(self):
203         if self.rvqueue is not None:
204             while True:
205                 try:
206                     yield self.rvqueue.get_nowait()
207                 except Queue.Empty:
208                     self.queue.join()
209                     try:
210                         yield self.rvqueue.get_nowait()
211                     except Queue.Empty:
212                         raise StopIteration
213             
214     
215 class ParallelFilter(ParallelMap):
216     class _FILTERED:
217         pass
218     
219     def __filter(self, x):
220         if self.filter_condition(x):
221             return x
222         else:
223             return self._FILTERED
224     
225     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
226         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
227         self.filter_condition = filter_condition
228
229     def put(self, what):
230         super(ParallelFilter, self).put(self.__filter, what)
231     
232     def put_nowait(self, what):
233         super(ParallelFilter, self).put_nowait(self.__filter, what)
234         
235     def __iter__(self):
236         for rv in super(ParallelFilter, self).__iter__():
237             if rv is not self._FILTERED:
238                 yield rv
239
240 class ParallelRun(ParallelMap):
241     def __run(self, x):
242         fn, args, kwargs = x
243         return fn(*args, **kwargs)
244     
245     def __init__(self, maxthreads = None, maxqueue = None):
246         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
247
248     def put(self, what, *args, **kwargs):
249         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
250     
251     def put_nowait(self, what, *args, **kwargs):
252         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
253
254
255 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
256     mapper = ParallelMap(
257         maxthreads = maxthreads,
258         maxqueue = maxqueue,
259         results = True)
260     mapper.start()
261     for elem in iterable:
262         mapper.put(elem)
263     rv = list(mapper)
264     mapper.join()
265     return rv
266
267 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
268     filtrer = ParallelFilter(
269         condition,
270         maxthreads = maxthreads,
271         maxqueue = maxqueue)
272     filtrer.start()
273     for elem in iterable:
274         filtrer.put(elem)
275     rv = list(filtrer)
276     filtrer.join()
277     return rv
278