3b6e281bb7119986904413e0c4152d2c57c47439
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License version 2 as
7 #    published by the Free Software Foundation;
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
18 #         Alina Quereilhac <alina.quereilhac@inria.fr>
19 #
20
21 import threading
22 import Queue
23 import traceback
24 import sys
25 import os
26
27 N_PROCS = None
28
29 class WorkerThread(threading.Thread):
30     class QUIT:
31         pass
32
33     def run(self):
34         while True:
35             task = self.queue.get()
36
37             if task is self.QUIT:
38                 self.queue.task_done()
39                 break
40
41             try:
42                 try:
43                     callable, args, kwargs = task
44                     rv = callable(*args, **kwargs)
45                     
46                     if self.rvqueue is not None:
47                         self.rvqueue.put(rv)
48                 finally:
49                     self.queue.task_done()
50             except:
51                 traceback.print_exc(file = sys.stderr)
52                 self.delayed_exceptions.append(sys.exc_info())
53
54     def attach(self, queue, rvqueue, delayed_exceptions):
55         self.queue = queue
56         self.rvqueue = rvqueue
57         self.delayed_exceptions = delayed_exceptions
58    
59     def quit(self):
60         self.queue.put(self.QUIT)
61
62 class ParallelRun(object):
63     def __init__(self, maxthreads = None, maxqueue = None, results = True):
64         self.maxqueue = maxqueue
65         self.maxthreads = maxthreads
66         
67         self.queue = Queue.Queue(self.maxqueue or 0)
68         
69         self.delayed_exceptions = []
70         
71         if results:
72             self.rvqueue = Queue.Queue()
73         else:
74             self.rvqueue = None
75     
76         self.initialize_workers()
77
78     def initialize_workers(self):
79         global N_PROCS
80
81         maxthreads = self.maxthreads
82        
83         # Compute maximum number of threads allowed by the system
84         if maxthreads is None:
85             if N_PROCS is None:
86                 try:
87                     f = open("/proc/cpuinfo")
88                     try:
89                         N_PROCS = sum("processor" in l for l in f)
90                     finally:
91                         f.close()
92                 except:
93                     pass
94             maxthreads = N_PROCS
95         
96         if maxthreads is None:
97             maxthreads = 4
98  
99         self.workers = []
100
101         # initialize workers
102         for x in xrange(maxthreads):
103             worker = WorkerThread()
104             worker.attach(self.queue, self.rvqueue, self.delayed_exceptions)
105             worker.setDaemon(True)
106
107             self.workers.append(worker)
108     
109     def __del__(self):
110         self.destroy()
111
112     def empty(self):
113         while True:
114             try:
115                 self.queue.get(block = False)
116                 self.queue.task_done()
117             except Queue.Empty:
118                 break
119   
120     def destroy(self):
121         self.join()
122
123         del self.workers[:]
124         
125     def put(self, callable, *args, **kwargs):
126         self.queue.put((callable, args, kwargs))
127     
128     def put_nowait(self, callable, *args, **kwargs):
129         self.queue.put_nowait((callable, args, kwargs))
130
131     def start(self):
132         for worker in self.workers:
133             if not worker.isAlive():
134                 worker.start()
135     
136     def join(self):
137         # Wait until all queued tasks have been processed
138         self.queue.join()
139
140         for worker in self.workers:
141             worker.quit()
142
143         for worker in self.workers:
144             worker.join()
145     
146     def sync(self):
147         if self.delayed_exceptions:
148             typ,val,loc = self.delayed_exceptions[0]
149             del self.delayed_exceptions[:]
150             raise typ,val,loc
151         
152     def __iter__(self):
153         if self.rvqueue is not None:
154             while True:
155                 try:
156                     yield self.rvqueue.get_nowait()
157                 except Queue.Empty:
158                     self.queue.join()
159                     try:
160                         yield self.rvqueue.get_nowait()
161                     except Queue.Empty:
162                         raise StopIteration
163