Fixed nasty concurrency bug in EC
[nepi.git] / src / nepi / util / parallel.py
index 015c570..fffdea5 100644 (file)
@@ -1,14 +1,33 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
+#
+#    NEPI, a framework to manage network experiments
+#    Copyright (C) 2013 INRIA
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation, either version 3 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Author: Claudio Freire <claudio-daniel.freire@inria.fr>
+#
 
 import threading
 import Queue
 import traceback
 import sys
+import os
 
 N_PROCS = None
 
 THREADCACHE = []
+THREADCACHEPID = None
 
 class WorkerThread(threading.Thread):
     class QUIT:
@@ -79,6 +98,8 @@ class WorkerThread(threading.Thread):
 class ParallelMap(object):
     def __init__(self, maxthreads = None, maxqueue = None, results = True):
         global N_PROCS
+        global THREADCACHE
+        global THREADCACHEPID
         
         if maxthreads is None:
             if N_PROCS is None:
@@ -103,6 +124,11 @@ class ParallelMap(object):
             self.rvqueue = Queue.Queue()
         else:
             self.rvqueue = None
+        
+        # Check threadcache
+        if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
+            del THREADCACHE[:]
+            THREADCACHEPID = os.getpid()
     
         self.workers = []
         for x in xrange(maxthreads):
@@ -124,6 +150,13 @@ class ParallelMap(object):
         self.destroy()
     
     def destroy(self):
+        # Check threadcache
+        global THREADCACHE
+        global THREADCACHEPID
+        if THREADCACHEPID is None or THREADCACHEPID != os.getpid():
+            del THREADCACHE[:]
+            THREADCACHEPID = os.getpid()
+
         for worker in self.workers:
             worker.waitdone()
         for worker in self.workers: