2 # NEPI, a framework to manage network experiments
3 # Copyright (C) 2013 INRIA
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
30 class WorkerThread(threading.Thread):
38 task = self.queue.get()
41 self.queue.task_done()
43 elif task is self.QUIT:
45 self.queue.task_done()
47 elif task is self.REASSIGNED:
54 callable, args, kwargs = task
55 rv = callable(*args, **kwargs)
57 if self.rvqueue is not None:
60 self.queue.task_done()
62 traceback.print_exc(file = sys.stderr)
63 self.delayed_exceptions.append(sys.exc_info())
66 while not self.queue.empty() and not self.done:
69 def attach(self, queue, rvqueue, delayed_exceptions):
74 self.rvqueue = rvqueue
75 self.delayed_exceptions = delayed_exceptions
77 oldqueue.put(self.REASSIGNED)
82 self.oldqueue = self.queue
83 self.queue = Queue.Queue()
85 self.delayed_exceptions = []
87 def detach_signal(self):
89 self.oldqueue.put(self.REASSIGNED)
93 self.queue.put(self.QUIT)
96 class ParallelMap(object):
97 def __init__(self, maxthreads = None, maxqueue = None, results = True):
100 global THREADCACHEPID
102 if maxthreads is None:
105 f = open("/proc/cpuinfo")
107 N_PROCS = sum("processor" in l for l in f)
114 if maxthreads is None:
117 self.queue = Queue.Queue(maxqueue or 0)
119 self.delayed_exceptions = []
122 self.rvqueue = Queue.Queue()
127 if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
129 THREADCACHEPID = os.getpid()
132 for x in xrange(maxthreads):
136 t = THREADCACHE.pop()
144 t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
145 self.workers.append(t)
153 global THREADCACHEPID
154 if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
156 THREADCACHEPID = os.getpid()
158 for worker in self.workers:
160 for worker in self.workers:
162 for worker in self.workers:
163 worker.detach_signal()
164 THREADCACHE.extend(self.workers)
167 def put(self, callable, *args, **kwargs):
168 self.queue.put((callable, args, kwargs))
170 def put_nowait(self, callable, *args, **kwargs):
171 self.queue.put_nowait((callable, args, kwargs))
174 for thread in self.workers:
175 if not thread.isAlive():
179 for thread in self.workers:
180 # That's the sync signal
184 for thread in self.workers:
187 if self.delayed_exceptions:
188 typ,val,loc = self.delayed_exceptions[0]
189 del self.delayed_exceptions[:]
196 if self.delayed_exceptions:
197 typ,val,loc = self.delayed_exceptions[0]
198 del self.delayed_exceptions[:]
202 if self.rvqueue is not None:
205 yield self.rvqueue.get_nowait()
209 yield self.rvqueue.get_nowait()
214 class ParallelFilter(ParallelMap):
218 def __filter(self, x):
219 if self.filter_condition(x):
222 return self._FILTERED
224 def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
225 super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
226 self.filter_condition = filter_condition
229 super(ParallelFilter, self).put(self.__filter, what)
231 def put_nowait(self, what):
232 super(ParallelFilter, self).put_nowait(self.__filter, what)
235 for rv in super(ParallelFilter, self).__iter__():
236 if rv is not self._FILTERED:
239 class ParallelRun(ParallelMap):
242 return fn(*args, **kwargs)
244 def __init__(self, maxthreads = None, maxqueue = None):
245 super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
247 def put(self, what, *args, **kwargs):
248 super(ParallelRun, self).put(self.__run, (what, args, kwargs))
250 def put_nowait(self, what, *args, **kwargs):
251 super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
254 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
255 mapper = ParallelMap(
256 maxthreads = maxthreads,
260 for elem in iterable:
266 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
267 filtrer = ParallelFilter(
269 maxthreads = maxthreads,
272 for elem in iterable: