2 NEPI, a framework to manage network experiments
3 Copyright (C) 2013 INRIA
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
31 class WorkerThread(threading.Thread):
39 task = self.queue.get()
42 self.queue.task_done()
44 elif task is self.QUIT:
46 self.queue.task_done()
48 elif task is self.REASSIGNED:
55 callable, args, kwargs = task
56 rv = callable(*args, **kwargs)
58 if self.rvqueue is not None:
61 self.queue.task_done()
63 traceback.print_exc(file = sys.stderr)
64 self.delayed_exceptions.append(sys.exc_info())
67 while not self.queue.empty() and not self.done:
70 def attach(self, queue, rvqueue, delayed_exceptions):
75 self.rvqueue = rvqueue
76 self.delayed_exceptions = delayed_exceptions
78 oldqueue.put(self.REASSIGNED)
83 self.oldqueue = self.queue
84 self.queue = Queue.Queue()
86 self.delayed_exceptions = []
88 def detach_signal(self):
90 self.oldqueue.put(self.REASSIGNED)
94 self.queue.put(self.QUIT)
97 class ParallelMap(object):
98 def __init__(self, maxthreads = None, maxqueue = None, results = True):
101 global THREADCACHEPID
103 if maxthreads is None:
106 f = open("/proc/cpuinfo")
108 N_PROCS = sum("processor" in l for l in f)
115 if maxthreads is None:
118 self.queue = Queue.Queue(maxqueue or 0)
120 self.delayed_exceptions = []
123 self.rvqueue = Queue.Queue()
128 if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
130 THREADCACHEPID = os.getpid()
133 for x in xrange(maxthreads):
137 t = THREADCACHE.pop()
145 t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
146 self.workers.append(t)
154 global THREADCACHEPID
155 if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
157 THREADCACHEPID = os.getpid()
159 for worker in self.workers:
161 for worker in self.workers:
163 for worker in self.workers:
164 worker.detach_signal()
165 THREADCACHE.extend(self.workers)
168 def put(self, callable, *args, **kwargs):
169 self.queue.put((callable, args, kwargs))
171 def put_nowait(self, callable, *args, **kwargs):
172 self.queue.put_nowait((callable, args, kwargs))
175 for thread in self.workers:
176 if not thread.isAlive():
180 for thread in self.workers:
181 # That's the sync signal
185 for thread in self.workers:
188 if self.delayed_exceptions:
189 typ,val,loc = self.delayed_exceptions[0]
190 del self.delayed_exceptions[:]
197 if self.delayed_exceptions:
198 typ,val,loc = self.delayed_exceptions[0]
199 del self.delayed_exceptions[:]
203 if self.rvqueue is not None:
206 yield self.rvqueue.get_nowait()
210 yield self.rvqueue.get_nowait()
215 class ParallelFilter(ParallelMap):
219 def __filter(self, x):
220 if self.filter_condition(x):
223 return self._FILTERED
225 def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
226 super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
227 self.filter_condition = filter_condition
230 super(ParallelFilter, self).put(self.__filter, what)
232 def put_nowait(self, what):
233 super(ParallelFilter, self).put_nowait(self.__filter, what)
236 for rv in super(ParallelFilter, self).__iter__():
237 if rv is not self._FILTERED:
240 class ParallelRun(ParallelMap):
243 return fn(*args, **kwargs)
245 def __init__(self, maxthreads = None, maxqueue = None):
246 super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
248 def put(self, what, *args, **kwargs):
249 super(ParallelRun, self).put(self.__run, (what, args, kwargs))
251 def put_nowait(self, what, *args, **kwargs):
252 super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
255 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
256 mapper = ParallelMap(
257 maxthreads = maxthreads,
261 for elem in iterable:
267 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
268 filtrer = ParallelFilter(
270 maxthreads = maxthreads,
273 for elem in iterable: