ec_shutdown
[nepi.git] / src / nepi / util / parallel.py
index 8dc39a7..6868c4a 100644 (file)
@@ -1,3 +1,25 @@
+#
+#    NEPI, a framework to manage network experiments
+#    Copyright (C) 2013 INRIA
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation, either version 3 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Author: Claudio Freire <claudio-daniel.freire@inria.fr>
+#
+
+# A.Q. TODO: BUG FIX THREADCACHE. Not needed!! remove it completely!
+
 import threading
 import Queue
 import traceback
@@ -6,9 +28,6 @@ import os
 
 N_PROCS = None
 
-THREADCACHE = []
-THREADCACHEPID = None
-
 class WorkerThread(threading.Thread):
     class QUIT:
         pass
@@ -78,9 +97,8 @@ class WorkerThread(threading.Thread):
 class ParallelMap(object):
     def __init__(self, maxthreads = None, maxqueue = None, results = True):
         global N_PROCS
-        global THREADCACHE
-        global THREADCACHEPID
-        
+       
+        # Compute maximum number of threads allowed by the system
         if maxthreads is None:
             if N_PROCS is None:
                 try:
@@ -104,25 +122,18 @@ class ParallelMap(object):
             self.rvqueue = Queue.Queue()
         else:
             self.rvqueue = None
-        
-        # Check threadcache
-        if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
-            del THREADCACHE[:]
-            THREADCACHEPID = os.getpid()
     
         self.workers = []
+
+        # initialize workers
         for x in xrange(maxthreads):
             t = None
-            if THREADCACHE:
-                try:
-                    t = THREADCACHE.pop()
-                except:
-                    pass
             if t is None:
                 t = WorkerThread()
                 t.setDaemon(True)
             else:
                 t.waitdone()
+
             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
             self.workers.append(t)
     
@@ -130,20 +141,15 @@ class ParallelMap(object):
         self.destroy()
     
     def destroy(self):
-        # Check threadcache
-        global THREADCACHE
-        global THREADCACHEPID
-        if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
-            del THREADCACHE[:]
-            THREADCACHEPID = os.getpid()
-
         for worker in self.workers:
             worker.waitdone()
         for worker in self.workers:
             worker.detach()
         for worker in self.workers:
             worker.detach_signal()
-        THREADCACHE.extend(self.workers)
+        for worker in self.workers:
+            worker.quit()
+
         del self.workers[:]
         
     def put(self, callable, *args, **kwargs):
@@ -192,32 +198,6 @@ class ParallelMap(object):
                     except Queue.Empty:
                         raise StopIteration
             
-    
-class ParallelFilter(ParallelMap):
-    class _FILTERED:
-        pass
-    
-    def __filter(self, x):
-        if self.filter_condition(x):
-            return x
-        else:
-            return self._FILTERED
-    
-    def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
-        super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
-        self.filter_condition = filter_condition
-
-    def put(self, what):
-        super(ParallelFilter, self).put(self.__filter, what)
-    
-    def put_nowait(self, what):
-        super(ParallelFilter, self).put_nowait(self.__filter, what)
-        
-    def __iter__(self):
-        for rv in super(ParallelFilter, self).__iter__():
-            if rv is not self._FILTERED:
-                yield rv
-
 class ParallelRun(ParallelMap):
     def __run(self, x):
         fn, args, kwargs = x
@@ -233,27 +213,3 @@ class ParallelRun(ParallelMap):
         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
 
 
-def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
-    mapper = ParallelMap(
-        maxthreads = maxthreads,
-        maxqueue = maxqueue,
-        results = True)
-    mapper.start()
-    for elem in iterable:
-        mapper.put(elem)
-    rv = list(mapper)
-    mapper.join()
-    return rv
-
-def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
-    filtrer = ParallelFilter(
-        condition,
-        maxthreads = maxthreads,
-        maxqueue = maxqueue)
-    filtrer.start()
-    for elem in iterable:
-        filtrer.put(elem)
-    rv = list(filtrer)
-    filtrer.join()
-    return rv
-