rename src/nepi/ into just nepi/
[nepi.git] / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License version 2 as
7 #    published by the Free Software Foundation;
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
18 #         Alina Quereilhac <alina.quereilhac@inria.fr>
19 #
20
21 import threading
22 import traceback
23 import sys
24 import os
25
26 from six.moves import queue
27
28 N_PROCS = None
29
30 class WorkerThread(threading.Thread):
31     class QUIT:
32         pass
33
34     def run(self):
35         while True:
36             task = self.queue.get()
37
38             if task is self.QUIT:
39                 self.queue.task_done()
40                 break
41
42             try:
43                 try:
44                     callable, args, kwargs = task
45                     rv = callable(*args, **kwargs)
46                     
47                     if self.rvqueue is not None:
48                         self.rvqueue.put(rv)
49                 finally:
50                     self.queue.task_done()
51             except:
52                 traceback.print_exc(file = sys.stderr)
53                 self.delayed_exceptions.append(sys.exc_info())
54
55     def attach(self, queue, rvqueue, delayed_exceptions):
56         self.queue = queue
57         self.rvqueue = rvqueue
58         self.delayed_exceptions = delayed_exceptions
59    
60     def quit(self):
61         self.queue.put(self.QUIT)
62
63 class ParallelRun(object):
64     def __init__(self, maxthreads = None, maxqueue = None, results = True):
65         self.maxqueue = maxqueue
66         self.maxthreads = maxthreads
67         
68         self.queue = queue.Queue(self.maxqueue or 0)
69         
70         self.delayed_exceptions = []
71         
72         if results:
73             self.rvqueue = queue.Queue()
74         else:
75             self.rvqueue = None
76     
77         self.initialize_workers()
78
79     def initialize_workers(self):
80         global N_PROCS
81
82         maxthreads = self.maxthreads
83        
84         # Compute maximum number of threads allowed by the system
85         if maxthreads is None:
86             if N_PROCS is None:
87                 try:
88                     with open("/proc/cpuinfo") as f:
89                         N_PROCS = sum("processor" in l for l in f)
90                 except:
91                     pass
92             maxthreads = N_PROCS
93         
94         if maxthreads is None:
95             maxthreads = 4
96  
97         self.workers = []
98
99         # initialize workers
100         for x in range(maxthreads):
101             worker = WorkerThread()
102             worker.attach(self.queue, self.rvqueue, self.delayed_exceptions)
103             worker.setDaemon(True)
104
105             self.workers.append(worker)
106     
107     def __del__(self):
108         self.destroy()
109
110     def empty(self):
111         while True:
112             try:
113                 self.queue.get(block = False)
114                 self.queue.task_done()
115             except queue.Empty:
116                 break
117   
118     def destroy(self):
119         self.join()
120
121         del self.workers[:]
122         
123     def put(self, callable, *args, **kwargs):
124         self.queue.put((callable, args, kwargs))
125     
126     def put_nowait(self, callable, *args, **kwargs):
127         self.queue.put_nowait((callable, args, kwargs))
128
129     def start(self):
130         for worker in self.workers:
131             if not worker.isAlive():
132                 worker.start()
133     
134     def join(self):
135         # Wait until all queued tasks have been processed
136         self.queue.join()
137
138         for worker in self.workers:
139             worker.quit()
140
141         for worker in self.workers:
142             worker.join()
143     
144     def sync(self):
145         if self.delayed_exceptions:
146             typ,val,loc = self.delayed_exceptions[0]
147             del self.delayed_exceptions[:]
148             raise typ(val).with_traceback(loc)
149         
150     def __iter__(self):
151         if self.rvqueue is not None:
152             while True:
153                 try:
154                     yield self.rvqueue.get_nowait()
155                 except queue.Empty:
156                     self.queue.join()
157                     try:
158                         yield self.rvqueue.get_nowait()
159                     except queue.Empty:
160                         raise StopIteration
161