ec_shutdown
[nepi.git] / src / nepi / util / parallel.py
index 015c570..6868c4a 100644 (file)
@@ -1,15 +1,33 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
+#
+#    NEPI, a framework to manage network experiments
+#    Copyright (C) 2013 INRIA
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation, either version 3 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Author: Claudio Freire <claudio-daniel.freire@inria.fr>
+#
+
+# A.Q. TODO: BUG FIX THREADCACHE. Not needed!! remove it completely!
 
 import threading
 import Queue
 import traceback
 import sys
+import os
 
 N_PROCS = None
 
-THREADCACHE = []
-
 class WorkerThread(threading.Thread):
     class QUIT:
         pass
@@ -79,7 +97,8 @@ class WorkerThread(threading.Thread):
 class ParallelMap(object):
     def __init__(self, maxthreads = None, maxqueue = None, results = True):
         global N_PROCS
-        
+       
+        # Compute maximum number of threads allowed by the system
         if maxthreads is None:
             if N_PROCS is None:
                 try:
@@ -105,18 +124,16 @@ class ParallelMap(object):
             self.rvqueue = None
     
         self.workers = []
+
+        # initialize workers
         for x in xrange(maxthreads):
             t = None
-            if THREADCACHE:
-                try:
-                    t = THREADCACHE.pop()
-                except:
-                    pass
             if t is None:
                 t = WorkerThread()
                 t.setDaemon(True)
             else:
                 t.waitdone()
+
             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
             self.workers.append(t)
     
@@ -130,7 +147,9 @@ class ParallelMap(object):
             worker.detach()
         for worker in self.workers:
             worker.detach_signal()
-        THREADCACHE.extend(self.workers)
+        for worker in self.workers:
+            worker.quit()
+
         del self.workers[:]
         
     def put(self, callable, *args, **kwargs):
@@ -179,32 +198,6 @@ class ParallelMap(object):
                     except Queue.Empty:
                         raise StopIteration
             
-    
-class ParallelFilter(ParallelMap):
-    class _FILTERED:
-        pass
-    
-    def __filter(self, x):
-        if self.filter_condition(x):
-            return x
-        else:
-            return self._FILTERED
-    
-    def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
-        super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
-        self.filter_condition = filter_condition
-
-    def put(self, what):
-        super(ParallelFilter, self).put(self.__filter, what)
-    
-    def put_nowait(self, what):
-        super(ParallelFilter, self).put_nowait(self.__filter, what)
-        
-    def __iter__(self):
-        for rv in super(ParallelFilter, self).__iter__():
-            if rv is not self._FILTERED:
-                yield rv
-
 class ParallelRun(ParallelMap):
     def __run(self, x):
         fn, args, kwargs = x
@@ -220,27 +213,3 @@ class ParallelRun(ParallelMap):
         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
 
 
-def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
-    mapper = ParallelMap(
-        maxthreads = maxthreads,
-        maxqueue = maxqueue,
-        results = True)
-    mapper.start()
-    for elem in iterable:
-        mapper.put(elem)
-    rv = list(mapper)
-    mapper.join()
-    return rv
-
-def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
-    filtrer = ParallelFilter(
-        condition,
-        maxthreads = maxthreads,
-        maxqueue = maxqueue)
-    filtrer.start()
-    for elem in iterable:
-        filtrer.put(elem)
-    rv = list(filtrer)
-    filtrer.join()
-    return rv
-