Flushing scheduler before shutdown
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
19 #         Alina Quereilhac <alina.quereilhac@inria.fr>
20 #
21
22 import threading
23 import Queue
24 import traceback
25 import sys
26 import os
27
28 N_PROCS = None
29
30 class WorkerThread(threading.Thread):
31     class QUIT:
32         pass
33     class REASSIGNED:
34         pass
35     
36     def run(self):
37         while True:
38             task = self.queue.get()
39             if task is None:
40                 self.done = True
41                 self.queue.task_done()
42                 continue
43             elif task is self.QUIT:
44                 self.done = True
45                 self.queue.task_done()
46                 break
47             elif task is self.REASSIGNED:
48                 continue
49             else:
50                 self.done = False
51             
52             try:
53                 try:
54                     callable, args, kwargs = task
55                     rv = callable(*args, **kwargs)
56                     
57                     if self.rvqueue is not None:
58                         self.rvqueue.put(rv)
59                 finally:
60                     self.queue.task_done()
61             except:
62                 traceback.print_exc(file = sys.stderr)
63                 self.delayed_exceptions.append(sys.exc_info())
64     
65     def waitdone(self):
66         while not self.queue.empty() and not self.done:
67             self.queue.join()
68     
69     def attach(self, queue, rvqueue, delayed_exceptions):
70         if self.isAlive():
71             self.waitdone()
72             oldqueue = self.queue
73         self.queue = queue
74         self.rvqueue = rvqueue
75         self.delayed_exceptions = delayed_exceptions
76         if self.isAlive():
77             oldqueue.put(self.REASSIGNED)
78     
79     def detach(self):
80         if self.isAlive():
81             self.waitdone()
82             self.oldqueue = self.queue
83         self.queue = Queue.Queue()
84         self.rvqueue = None
85         self.delayed_exceptions = []
86     
87     def detach_signal(self):
88         if self.isAlive():
89             self.oldqueue.put(self.REASSIGNED)
90             del self.oldqueue
91         
92     def quit(self):
93         self.queue.put(self.QUIT)
94         self.join()
95
96 class ParallelMap(object):
97     def __init__(self, maxthreads = None, maxqueue = None, results = True):
98         global N_PROCS
99        
100         # Compute maximum number of threads allowed by the system
101         if maxthreads is None:
102             if N_PROCS is None:
103                 try:
104                     f = open("/proc/cpuinfo")
105                     try:
106                         N_PROCS = sum("processor" in l for l in f)
107                     finally:
108                         f.close()
109                 except:
110                     pass
111             maxthreads = N_PROCS
112         
113         if maxthreads is None:
114             maxthreads = 4
115         
116         self.queue = Queue.Queue(maxqueue or 0)
117
118         self.delayed_exceptions = []
119         
120         if results:
121             self.rvqueue = Queue.Queue()
122         else:
123             self.rvqueue = None
124     
125         self.workers = []
126
127         # initialize workers
128         for x in xrange(maxthreads):
129             t = None
130             if t is None:
131                 t = WorkerThread()
132                 t.setDaemon(True)
133             else:
134                 t.waitdone()
135
136             t.attach(self.queue, self.rvqueue, self.delayed_exceptions)
137             self.workers.append(t)
138     
139     def __del__(self):
140         self.destroy()
141     
142     def destroy(self):
143         for worker in self.workers:
144             worker.waitdone()
145         for worker in self.workers:
146             worker.detach()
147         for worker in self.workers:
148             worker.detach_signal()
149         for worker in self.workers:
150             worker.quit()
151
152         del self.workers[:]
153         
154     def put(self, callable, *args, **kwargs):
155         self.queue.put((callable, args, kwargs))
156     
157     def put_nowait(self, callable, *args, **kwargs):
158         self.queue.put_nowait((callable, args, kwargs))
159
160     def start(self):
161         for thread in self.workers:
162             if not thread.isAlive():
163                 thread.start()
164     
165     def join(self):
166         for thread in self.workers:
167             # That's the sync signal
168             self.queue.put(None)
169             
170         self.queue.join()
171         for thread in self.workers:
172             thread.waitdone()
173         
174         if self.delayed_exceptions:
175             typ,val,loc = self.delayed_exceptions[0]
176             del self.delayed_exceptions[:]
177             raise typ,val,loc
178         
179         self.destroy()
180     
181     def sync(self):
182         self.queue.join()
183         if self.delayed_exceptions:
184             typ,val,loc = self.delayed_exceptions[0]
185             del self.delayed_exceptions[:]
186             raise typ,val,loc
187         
188     def __iter__(self):
189         if self.rvqueue is not None:
190             while True:
191                 try:
192                     yield self.rvqueue.get_nowait()
193                 except Queue.Empty:
194                     self.queue.join()
195                     try:
196                         yield self.rvqueue.get_nowait()
197                     except Queue.Empty:
198                         raise StopIteration
199             
200 class ParallelRun(ParallelMap):
201     def __run(self, x):
202         fn, args, kwargs = x
203         return fn(*args, **kwargs)
204     
205     def __init__(self, maxthreads = None, maxqueue = None):
206         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
207
208     def put(self, what, *args, **kwargs):
209         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
210     
211     def put_nowait(self, what, *args, **kwargs):
212         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
213
214