b7caeac3fbb79f6118fbca8249a220437071599c
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
19 #         Alina Quereilhac <alina.quereilhac@inria.fr>
20 #
21
22 import threading
23 import Queue
24 import traceback
25 import sys
26 import os
27
28 N_PROCS = None
29
30 class WorkerThread(threading.Thread):
31     class QUIT:
32         pass
33
34     def run(self):
35         while True:
36             task = self.queue.get()
37
38             if task is self.QUIT:
39                 self.queue.task_done()
40                 break
41
42             try:
43                 try:
44                     callable, args, kwargs = task
45                     rv = callable(*args, **kwargs)
46                     
47                     if self.rvqueue is not None:
48                         self.rvqueue.put(rv)
49                 finally:
50                     self.queue.task_done()
51             except:
52                 traceback.print_exc(file = sys.stderr)
53                 self.delayed_exceptions.append(sys.exc_info())
54
55     def attach(self, queue, rvqueue, delayed_exceptions):
56         self.queue = queue
57         self.rvqueue = rvqueue
58         self.delayed_exceptions = delayed_exceptions
59    
60     def quit(self):
61         self.queue.put(self.QUIT)
62
63 class ParallelRun(object):
64     def __init__(self, maxthreads = None, maxqueue = None, results = True):
65         self.maxqueue = maxqueue
66         self.maxthreads = maxthreads
67         
68         self.queue = Queue.Queue(self.maxqueue or 0)
69         
70         self.delayed_exceptions = []
71         
72         if results:
73             self.rvqueue = Queue.Queue()
74         else:
75             self.rvqueue = None
76     
77         self.initialize_workers()
78
79     def initialize_workers(self):
80         global N_PROCS
81
82         maxthreads = self.maxthreads
83        
84         # Compute maximum number of threads allowed by the system
85         if maxthreads is None:
86             if N_PROCS is None:
87                 try:
88                     f = open("/proc/cpuinfo")
89                     try:
90                         N_PROCS = sum("processor" in l for l in f)
91                     finally:
92                         f.close()
93                 except:
94                     pass
95             maxthreads = N_PROCS
96         
97         if maxthreads is None:
98             maxthreads = 4
99  
100         self.workers = []
101
102         # initialize workers
103         for x in xrange(maxthreads):
104             worker = WorkerThread()
105             worker.attach(self.queue, self.rvqueue, self.delayed_exceptions)
106             worker.setDaemon(True)
107
108             self.workers.append(worker)
109     
110     def __del__(self):
111         self.destroy()
112
113     def empty(self):
114         while True:
115             try:
116                 self.queue.get(block = False)
117                 self.queue.task_done()
118             except Queue.Empty:
119                 break
120   
121     def destroy(self):
122         self.join()
123
124         del self.workers[:]
125         
126     def put(self, callable, *args, **kwargs):
127         self.queue.put((callable, args, kwargs))
128     
129     def put_nowait(self, callable, *args, **kwargs):
130         self.queue.put_nowait((callable, args, kwargs))
131
132     def start(self):
133         for worker in self.workers:
134             if not worker.isAlive():
135                 worker.start()
136     
137     def join(self):
138         # Wait until all queued tasks have been processed
139         self.queue.join()
140
141         for worker in self.workers:
142             worker.quit()
143
144         for worker in self.workers:
145             worker.join()
146     
147     def sync(self):
148         if self.delayed_exceptions:
149             typ,val,loc = self.delayed_exceptions[0]
150             del self.delayed_exceptions[:]
151             raise typ,val,loc
152         
153     def __iter__(self):
154         if self.rvqueue is not None:
155             while True:
156                 try:
157                     yield self.rvqueue.get_nowait()
158                 except Queue.Empty:
159                     self.queue.join()
160                     try:
161                         yield self.rvqueue.get_nowait()
162                     except Queue.Empty:
163                         raise StopIteration
164