X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=src%2Fnepi%2Futil%2Fparallel.py;h=3b6e281bb7119986904413e0c4152d2c57c47439;hb=e55924b6886bd7382a28e1ae235c4810f852e163;hp=b5869b9c10a1df0d333f9eed838c8b4c4b9251e9;hpb=0e53a081db8d0678e12e3d5a29d73efbc221307c;p=nepi.git diff --git a/src/nepi/util/parallel.py b/src/nepi/util/parallel.py index b5869b9c..3b6e281b 100644 --- a/src/nepi/util/parallel.py +++ b/src/nepi/util/parallel.py @@ -1,17 +1,86 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- +# +# NEPI, a framework to manage network experiments +# Copyright (C) 2013 INRIA +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation; +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Author: Claudio Freire +# Alina Quereilhac +# import threading import Queue import traceback import sys +import os N_PROCS = None -class ParallelMap(object): +class WorkerThread(threading.Thread): + class QUIT: + pass + + def run(self): + while True: + task = self.queue.get() + + if task is self.QUIT: + self.queue.task_done() + break + + try: + try: + callable, args, kwargs = task + rv = callable(*args, **kwargs) + + if self.rvqueue is not None: + self.rvqueue.put(rv) + finally: + self.queue.task_done() + except: + traceback.print_exc(file = sys.stderr) + self.delayed_exceptions.append(sys.exc_info()) + + def attach(self, queue, rvqueue, delayed_exceptions): + self.queue = queue + self.rvqueue = rvqueue + self.delayed_exceptions = delayed_exceptions + + def quit(self): + self.queue.put(self.QUIT) + +class ParallelRun(object): def __init__(self, maxthreads = None, maxqueue = None, results = True): - global N_PROCS + self.maxqueue = maxqueue + self.maxthreads = maxthreads + + self.queue = Queue.Queue(self.maxqueue or 0) + + self.delayed_exceptions = [] + if results: + self.rvqueue = Queue.Queue() + else: + self.rvqueue = None + + self.initialize_workers() + + def initialize_workers(self): + global N_PROCS + + maxthreads = self.maxthreads + + # Compute maximum number of threads allowed by the system if maxthreads is None: if N_PROCS is None: try: @@ -26,18 +95,32 @@ class ParallelMap(object): if maxthreads is None: maxthreads = 4 + + self.workers = [] + + # initialize workers + for x in xrange(maxthreads): + worker = WorkerThread() + worker.attach(self.queue, self.rvqueue, self.delayed_exceptions) + worker.setDaemon(True) - self.queue = Queue.Queue(maxqueue or 0) + self.workers.append(worker) - self.workers = [ threading.Thread(target = self.worker) - for x in xrange(maxthreads) ] - - self.delayed_exceptions = [] - - if results: - self.rvqueue = Queue.Queue() - else: - self.rvqueue = None + def __del__(self): + self.destroy() + + def empty(self): + while True: + try: + self.queue.get(block = False) + self.queue.task_done() + except Queue.Empty: + break + + def destroy(self): + self.join() + + del self.workers[:] def put(self, callable, *args, **kwargs): self.queue.put((callable, args, kwargs)) @@ -46,45 +129,26 @@ class ParallelMap(object): self.queue.put_nowait((callable, args, kwargs)) def start(self): - for thread in self.workers: - thread.start() + for worker in self.workers: + if not worker.isAlive(): + worker.start() def join(self): - for thread in self.workers: - # That's the shutdown signal - self.queue.put(None) - + # Wait until all queued tasks have been processed self.queue.join() - for thread in self.workers: - thread.join() - + + for worker in self.workers: + worker.quit() + + for worker in self.workers: + worker.join() + + def sync(self): if self.delayed_exceptions: typ,val,loc = self.delayed_exceptions[0] + del self.delayed_exceptions[:] raise typ,val,loc - - def sync(self): - self.queue.join() - def worker(self): - while True: - task = self.queue.get() - if task is None: - self.queue.task_done() - break - - try: - try: - callable, args, kwargs = task - rv = callable(*args, **kwargs) - - if self.rvqueue is not None: - self.rvqueue.put(rv) - finally: - self.queue.task_done() - except: - traceback.print_exc(file = sys.stderr) - self.delayed_exceptions.append(sys.exc_info()) - def __iter__(self): if self.rvqueue is not None: while True: @@ -97,64 +161,3 @@ class ParallelMap(object): except Queue.Empty: raise StopIteration - -class ParallelFilter(ParallelMap): - class _FILTERED: - pass - - def __filter(self, x): - if self.filter_condition(x): - return x - else: - return self._FILTERED - - def __init__(self, filter_condition, maxthreads = None, maxqueue = None): - super(ParallelFilter, self).__init__(maxthreads, maxqueue, True) - self.filter_condition = filter_condition - - def put(self, what): - super(ParallelFilter, self).put(self.__filter, what) - - def put_nowait(self, what): - super(ParallelFilter, self).put_nowait(self.__filter, what) - - def __iter__(self): - for rv in super(ParallelFilter, self).__iter__(): - if rv is not self._FILTERED: - yield rv - -class ParallelRun(ParallelMap): - def __run(self, x): - fn, args, kwargs = x - return fn(*args, **kwargs) - - def __init__(self, maxthreads = None, maxqueue = None): - super(ParallelRun, self).__init__(maxthreads, maxqueue, True) - - def put(self, what, *args, **kwargs): - super(ParallelRun, self).put(self.__run, (what, args, kwargs)) - - def put_nowait(self, what, *args, **kwargs): - super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs)) - - -def pmap(mapping, iterable, maxthreads = None, maxqueue = None): - mapper = ParallelMap( - maxthreads = maxthreads, - maxqueue = maxqueue, - results = True) - mapper.start() - for elem in iterable: - mapper.put(elem) - return list(mapper) - -def pfilter(condition, iterable, maxthreads = None, maxqueue = None): - filtrer = ParallelFilter( - condition, - maxthreads = maxthreads, - maxqueue = maxqueue) - filtrer.start() - for elem in iterable: - filtrer.put(elem) - return list(filtrer) -