From 5d52c3f960d8d120e5d3b4f5bf63ff6f2bb226cc Mon Sep 17 00:00:00 2001 From: Claudio-Daniel Freire Date: Wed, 17 Aug 2011 14:07:37 +0200 Subject: [PATCH] Thread cache, to re-use threads. It so happens that the per-user process limit includes DEAD threads, which is nonsense, but it does. Maybe python doesn't clean up properly, but whatever the reason, there is a limit even in nonconcurrent threads created. To mitigate that, a thread cache is used in the parallel module to avoid creating thousands of threads and reaching that limit. --- src/nepi/util/parallel.py | 146 ++++++++++++++++++++++++++++++-------- 1 file changed, 116 insertions(+), 30 deletions(-) diff --git a/src/nepi/util/parallel.py b/src/nepi/util/parallel.py index b5869b9c..015c5704 100644 --- a/src/nepi/util/parallel.py +++ b/src/nepi/util/parallel.py @@ -8,6 +8,74 @@ import sys N_PROCS = None +THREADCACHE = [] + +class WorkerThread(threading.Thread): + class QUIT: + pass + class REASSIGNED: + pass + + def run(self): + while True: + task = self.queue.get() + if task is None: + self.done = True + self.queue.task_done() + continue + elif task is self.QUIT: + self.done = True + self.queue.task_done() + break + elif task is self.REASSIGNED: + continue + else: + self.done = False + + try: + try: + callable, args, kwargs = task + rv = callable(*args, **kwargs) + + if self.rvqueue is not None: + self.rvqueue.put(rv) + finally: + self.queue.task_done() + except: + traceback.print_exc(file = sys.stderr) + self.delayed_exceptions.append(sys.exc_info()) + + def waitdone(self): + while not self.queue.empty() and not self.done: + self.queue.join() + + def attach(self, queue, rvqueue, delayed_exceptions): + if self.isAlive(): + self.waitdone() + oldqueue = self.queue + self.queue = queue + self.rvqueue = rvqueue + self.delayed_exceptions = delayed_exceptions + if self.isAlive(): + oldqueue.put(self.REASSIGNED) + + def detach(self): + if self.isAlive(): + self.waitdone() + self.oldqueue = self.queue + self.queue = Queue.Queue() + self.rvqueue = None + self.delayed_exceptions = [] + + def detach_signal(self): + if self.isAlive(): + self.oldqueue.put(self.REASSIGNED) + del self.oldqueue + + def quit(self): + self.queue.put(self.QUIT) + self.join() + class ParallelMap(object): def __init__(self, maxthreads = None, maxqueue = None, results = True): global N_PROCS @@ -26,18 +94,44 @@ class ParallelMap(object): if maxthreads is None: maxthreads = 4 - - self.queue = Queue.Queue(maxqueue or 0) - - self.workers = [ threading.Thread(target = self.worker) - for x in xrange(maxthreads) ] + self.queue = Queue.Queue(maxqueue or 0) + self.delayed_exceptions = [] if results: self.rvqueue = Queue.Queue() else: self.rvqueue = None + + self.workers = [] + for x in xrange(maxthreads): + t = None + if THREADCACHE: + try: + t = THREADCACHE.pop() + except: + pass + if t is None: + t = WorkerThread() + t.setDaemon(True) + else: + t.waitdone() + t.attach(self.queue, self.rvqueue, self.delayed_exceptions) + self.workers.append(t) + + def __del__(self): + self.destroy() + + def destroy(self): + for worker in self.workers: + worker.waitdone() + for worker in self.workers: + worker.detach() + for worker in self.workers: + worker.detach_signal() + THREADCACHE.extend(self.workers) + del self.workers[:] def put(self, callable, *args, **kwargs): self.queue.put((callable, args, kwargs)) @@ -47,44 +141,32 @@ class ParallelMap(object): def start(self): for thread in self.workers: - thread.start() + if not thread.isAlive(): + thread.start() def join(self): for thread in self.workers: - # That's the shutdown signal + # That's the sync signal self.queue.put(None) self.queue.join() for thread in self.workers: - thread.join() + thread.waitdone() if self.delayed_exceptions: typ,val,loc = self.delayed_exceptions[0] + del self.delayed_exceptions[:] raise typ,val,loc + + self.destroy() def sync(self): self.queue.join() + if self.delayed_exceptions: + typ,val,loc = self.delayed_exceptions[0] + del self.delayed_exceptions[:] + raise typ,val,loc - def worker(self): - while True: - task = self.queue.get() - if task is None: - self.queue.task_done() - break - - try: - try: - callable, args, kwargs = task - rv = callable(*args, **kwargs) - - if self.rvqueue is not None: - self.rvqueue.put(rv) - finally: - self.queue.task_done() - except: - traceback.print_exc(file = sys.stderr) - self.delayed_exceptions.append(sys.exc_info()) - def __iter__(self): if self.rvqueue is not None: while True: @@ -146,7 +228,9 @@ def pmap(mapping, iterable, maxthreads = None, maxqueue = None): mapper.start() for elem in iterable: mapper.put(elem) - return list(mapper) + rv = list(mapper) + mapper.join() + return rv def pfilter(condition, iterable, maxthreads = None, maxqueue = None): filtrer = ParallelFilter( @@ -156,5 +240,7 @@ def pfilter(condition, iterable, maxthreads = None, maxqueue = None): filtrer.start() for elem in iterable: filtrer.put(elem) - return list(filtrer) + rv = list(filtrer) + filtrer.join() + return rv -- 2.47.0