15e1bb2a2badc22415f28839325f69432c9076a4
[nepi.git] / src / nepi / util / parallel.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import threading
5 import Queue
6 import traceback
7 import sys
8
9 N_PROCS = None
10
11 class ParallelMap(object):
12     def __init__(self, maxthreads = None, maxqueue = None, results = True):
13         global N_PROCS
14         
15         if maxthreads is None:
16             if N_PROCS is None:
17                 try:
18                     f = open("/proc/cpuinfo")
19                     try:
20                         N_PROCS = sum("processor" in l for l in f)
21                     finally:
22                         f.close()
23                 except:
24                     pass
25             maxthreads = N_PROCS
26         
27         if maxthreads is None:
28             maxthreads = 4
29
30         self.queue = Queue.Queue(maxqueue or 0)
31     
32         self.workers = [ threading.Thread(target = self.worker) 
33                          for x in xrange(maxthreads) ]
34         
35         self.delayed_exceptions = []
36         
37         if results:
38             self.rvqueue = Queue.Queue()
39         else:
40             self.rvqueue = None
41         
42     def put(self, callable, *args, **kwargs):
43         self.queue.put((callable, args, kwargs))
44     
45     def put_nowait(self, callable, *args, **kwargs):
46         self.queue.put_nowait((callable, args, kwargs))
47
48     def start(self):
49         for thread in self.workers:
50             thread.start()
51     
52     def join(self):
53         for thread in self.workers:
54             # That's the shutdown signal
55             self.queue.put(None)
56             
57         self.queue.join()
58         for thread in self.workers:
59             thread.join()
60         
61         if self.delayed_exceptions:
62             typ,val,loc = self.delayed_exceptions[0]
63             raise typ,val,loc
64         
65     def worker(self):
66         while True:
67             task = self.queue.get()
68             if task is None:
69                 self.queue.task_done()
70                 break
71             
72             try:
73                 try:
74                     callable, args, kwargs = task
75                     rv = callable(*args, **kwargs)
76                     
77                     if self.rvqueue is not None:
78                         self.rvqueue.put(rv)
79                 finally:
80                     self.queue.task_done()
81             except:
82                 traceback.print_exc(file = sys.stderr)
83                 self.delayed_exceptions.append(sys.exc_info())
84
85     def __iter__(self):
86         if self.rvqueue is not None:
87             while True:
88                 try:
89                     yield self.rvqueue.get_nowait()
90                 except Queue.Empty:
91                     self.queue.join()
92                     try:
93                         yield self.rvqueue.get_nowait()
94                     except Queue.Empty:
95                         raise StopIteration
96             
97     
98 class ParallelFilter(ParallelMap):
99     class _FILTERED:
100         pass
101     
102     def __filter(self, x):
103         if self.filter_condition(x):
104             return x
105         else:
106             return self._FILTERED
107     
108     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
109         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
110         self.filter_condition = filter_condition
111
112     def put(self, what):
113         super(ParallelFilter, self).put(self.__filter, what)
114     
115     def put_nowait(self, what):
116         super(ParallelFilter, self).put_nowait(self.__filter, what)
117         
118     def __iter__(self):
119         for rv in super(ParallelFilter, self).__iter__():
120             if rv is not self._FILTERED:
121                 yield rv
122
123 class ParallelRun(ParallelMap):
124     def __run(self, x):
125         fn, args, kwargs = x
126         return fn(*args, **kwargs)
127     
128     def __init__(self, maxthreads = None, maxqueue = None):
129         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
130
131     def put(self, what, *args, **kwargs):
132         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
133     
134     def put_nowait(self, what, *args, **kwargs):
135         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
136
137
138 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
139     mapper = ParallelMap(
140         maxthreads = maxthreads,
141         maxqueue = maxqueue,
142         results = True)
143     mapper.start()
144     for elem in iterable:
145         mapper.put(elem)
146     return list(mapper)
147
148 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
149     filtrer = ParallelFilter(
150         condition,
151         maxthreads = maxthreads,
152         maxqueue = maxqueue)
153     filtrer.start()
154     for elem in iterable:
155         filtrer.put(elem)
156     return list(filtrer)
157