ec_shutdown
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
19 #
20
21 # A.Q. TODO: BUG FIX THREADCACHE. Not needed!! remove it completely!
22
23 import threading
24 import Queue
25 import traceback
26 import sys
27 import os
28
29 N_PROCS = None
30
31 class WorkerThread(threading.Thread):
32     class QUIT:
33         pass
34     class REASSIGNED:
35         pass
36     
37     def run(self):
38         while True:
39             task = self.queue.get()
40             if task is None:
41                 self.done = True
42                 self.queue.task_done()
43                 continue
44             elif task is self.QUIT:
45                 self.done = True
46                 self.queue.task_done()
47                 break
48             elif task is self.REASSIGNED:
49                 continue
50             else:
51                 self.done = False
52             
53             try:
54                 try:
55                     callable, args, kwargs = task
56                     rv = callable(*args, **kwargs)
57                     
58                     if self.rvqueue is not None:
59                         self.rvqueue.put(rv)
60                 finally:
61                     self.queue.task_done()
62             except:
63                 traceback.print_exc(file = sys.stderr)
64                 self.delayed_exceptions.append(sys.exc_info())
65     
66     def waitdone(self):
67         while not self.queue.empty() and not self.done:
68             self.queue.join()
69     
70     def attach(self, queue, rvqueue, delayed_exceptions):
71         if self.isAlive():
72             self.waitdone()
73             oldqueue = self.queue
74         self.queue = queue
75         self.rvqueue = rvqueue
76         self.delayed_exceptions = delayed_exceptions
77         if self.isAlive():
78             oldqueue.put(self.REASSIGNED)
79     
80     def detach(self):
81         if self.isAlive():
82             self.waitdone()
83             self.oldqueue = self.queue
84         self.queue = Queue.Queue()
85         self.rvqueue = None
86         self.delayed_exceptions = []
87     
88     def detach_signal(self):
89         if self.isAlive():
90             self.oldqueue.put(self.REASSIGNED)
91             del self.oldqueue
92         
93     def quit(self):
94         self.queue.put(self.QUIT)
95         self.join()
96
97 class ParallelMap(object):
98     def __init__(self, maxthreads = None, maxqueue = None, results = True):
99         global N_PROCS
100        
101         # Compute maximum number of threads allowed by the system
102         if maxthreads is None:
103             if N_PROCS is None:
104                 try:
105                     f = open("/proc/cpuinfo")
106                     try:
107                         N_PROCS = sum("processor" in l for l in f)
108                     finally:
109                         f.close()
110                 except:
111                     pass
112             maxthreads = N_PROCS
113         
114         if maxthreads is None:
115             maxthreads = 4
116         
117         self.queue = Queue.Queue(maxqueue or 0)
118
119         self.delayed_exceptions = []
120         
121         if results:
122             self.rvqueue = Queue.Queue()
123         else:
124             self.rvqueue = None
125     
126         self.workers = []
127
128         # initialize workers
129         for x in xrange(maxthreads):
130             t = None
131             if t is None:
132                 t = WorkerThread()
133                 t.setDaemon(True)
134             else:
135                 t.waitdone()
136
137             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
138             self.workers.append(t)
139     
140     def __del__(self):
141         self.destroy()
142     
143     def destroy(self):
144         for worker in self.workers:
145             worker.waitdone()
146         for worker in self.workers:
147             worker.detach()
148         for worker in self.workers:
149             worker.detach_signal()
150         for worker in self.workers:
151             worker.quit()
152
153         del self.workers[:]
154         
155     def put(self, callable, *args, **kwargs):
156         self.queue.put((callable, args, kwargs))
157     
158     def put_nowait(self, callable, *args, **kwargs):
159         self.queue.put_nowait((callable, args, kwargs))
160
161     def start(self):
162         for thread in self.workers:
163             if not thread.isAlive():
164                 thread.start()
165     
166     def join(self):
167         for thread in self.workers:
168             # That's the sync signal
169             self.queue.put(None)
170             
171         self.queue.join()
172         for thread in self.workers:
173             thread.waitdone()
174         
175         if self.delayed_exceptions:
176             typ,val,loc = self.delayed_exceptions[0]
177             del self.delayed_exceptions[:]
178             raise typ,val,loc
179         
180         self.destroy()
181     
182     def sync(self):
183         self.queue.join()
184         if self.delayed_exceptions:
185             typ,val,loc = self.delayed_exceptions[0]
186             del self.delayed_exceptions[:]
187             raise typ,val,loc
188         
189     def __iter__(self):
190         if self.rvqueue is not None:
191             while True:
192                 try:
193                     yield self.rvqueue.get_nowait()
194                 except Queue.Empty:
195                     self.queue.join()
196                     try:
197                         yield self.rvqueue.get_nowait()
198                     except Queue.Empty:
199                         raise StopIteration
200             
201 class ParallelRun(ParallelMap):
202     def __run(self, x):
203         fn, args, kwargs = x
204         return fn(*args, **kwargs)
205     
206     def __init__(self, maxthreads = None, maxqueue = None):
207         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
208
209     def put(self, what, *args, **kwargs):
210         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
211     
212     def put_nowait(self, what, *args, **kwargs):
213         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
214
215