systematic use of context managers for dealing with files instead of open()/close...
[nepi.git] / src / nepi / util / parallel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License version 2 as
7 #    published by the Free Software Foundation;
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17 # Author: Claudio Freire <claudio-daniel.freire@inria.fr>
18 #         Alina Quereilhac <alina.quereilhac@inria.fr>
19 #
20
21 import threading
22 import Queue
23 import traceback
24 import sys
25 import os
26
27 N_PROCS = None
28
29 class WorkerThread(threading.Thread):
30     class QUIT:
31         pass
32
33     def run(self):
34         while True:
35             task = self.queue.get()
36
37             if task is self.QUIT:
38                 self.queue.task_done()
39                 break
40
41             try:
42                 try:
43                     callable, args, kwargs = task
44                     rv = callable(*args, **kwargs)
45                     
46                     if self.rvqueue is not None:
47                         self.rvqueue.put(rv)
48                 finally:
49                     self.queue.task_done()
50             except:
51                 traceback.print_exc(file = sys.stderr)
52                 self.delayed_exceptions.append(sys.exc_info())
53
54     def attach(self, queue, rvqueue, delayed_exceptions):
55         self.queue = queue
56         self.rvqueue = rvqueue
57         self.delayed_exceptions = delayed_exceptions
58    
59     def quit(self):
60         self.queue.put(self.QUIT)
61
62 class ParallelRun(object):
63     def __init__(self, maxthreads = None, maxqueue = None, results = True):
64         self.maxqueue = maxqueue
65         self.maxthreads = maxthreads
66         
67         self.queue = Queue.Queue(self.maxqueue or 0)
68         
69         self.delayed_exceptions = []
70         
71         if results:
72             self.rvqueue = Queue.Queue()
73         else:
74             self.rvqueue = None
75     
76         self.initialize_workers()
77
78     def initialize_workers(self):
79         global N_PROCS
80
81         maxthreads = self.maxthreads
82        
83         # Compute maximum number of threads allowed by the system
84         if maxthreads is None:
85             if N_PROCS is None:
86                 try:
87                     with open("/proc/cpuinfo") as f:
88                         N_PROCS = sum("processor" in l for l in f)
89                 except:
90                     pass
91             maxthreads = N_PROCS
92         
93         if maxthreads is None:
94             maxthreads = 4
95  
96         self.workers = []
97
98         # initialize workers
99         for x in xrange(maxthreads):
100             worker = WorkerThread()
101             worker.attach(self.queue, self.rvqueue, self.delayed_exceptions)
102             worker.setDaemon(True)
103
104             self.workers.append(worker)
105     
106     def __del__(self):
107         self.destroy()
108
109     def empty(self):
110         while True:
111             try:
112                 self.queue.get(block = False)
113                 self.queue.task_done()
114             except Queue.Empty:
115                 break
116   
117     def destroy(self):
118         self.join()
119
120         del self.workers[:]
121         
122     def put(self, callable, *args, **kwargs):
123         self.queue.put((callable, args, kwargs))
124     
125     def put_nowait(self, callable, *args, **kwargs):
126         self.queue.put_nowait((callable, args, kwargs))
127
128     def start(self):
129         for worker in self.workers:
130             if not worker.isAlive():
131                 worker.start()
132     
133     def join(self):
134         # Wait until all queued tasks have been processed
135         self.queue.join()
136
137         for worker in self.workers:
138             worker.quit()
139
140         for worker in self.workers:
141             worker.join()
142     
143     def sync(self):
144         if self.delayed_exceptions:
145             typ,val,loc = self.delayed_exceptions[0]
146             del self.delayed_exceptions[:]
147             raise typ,val,loc
148         
149     def __iter__(self):
150         if self.rvqueue is not None:
151             while True:
152                 try:
153                     yield self.rvqueue.get_nowait()
154                 except Queue.Empty:
155                     self.queue.join()
156                     try:
157                         yield self.rvqueue.get_nowait()
158                     except Queue.Empty:
159                         raise StopIteration
160