Merge default
[nepi.git] / src / nepi / util / parallel.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import threading
5 import Queue
6 import traceback
7 import sys
8
9 N_PROCS = None
10
11 class ParallelMap(object):
12     def __init__(self, maxthreads = None, maxqueue = None, results = True):
13         global N_PROCS
14         
15         if maxthreads is None:
16             if N_PROCS is None:
17                 try:
18                     f = open("/proc/cpuinfo")
19                     try:
20                         N_PROCS = sum("processor" in l for l in f)
21                     finally:
22                         f.close()
23                 except:
24                     pass
25             maxthreads = N_PROCS
26         
27         if maxthreads is None:
28             maxthreads = 4
29
30         self.queue = Queue.Queue(maxqueue or 0)
31     
32         self.workers = [ threading.Thread(target = self.worker) 
33                          for x in xrange(maxthreads) ]
34         
35         self.delayed_exceptions = []
36         
37         if results:
38             self.rvqueue = Queue.Queue()
39         else:
40             self.rvqueue = None
41         
42     def put(self, callable, *args, **kwargs):
43         self.queue.put((callable, args, kwargs))
44     
45     def put_nowait(self, callable, *args, **kwargs):
46         self.queue.put_nowait((callable, args, kwargs))
47
48     def start(self):
49         for thread in self.workers:
50             thread.start()
51     
52     def join(self):
53         for thread in self.workers:
54             # That's the shutdown signal
55             self.queue.put(None)
56             
57         self.queue.join()
58         for thread in self.workers:
59             thread.join()
60         
61         if self.delayed_exceptions:
62             typ,val,loc = self.delayed_exceptions[0]
63             raise typ,val,loc
64     
65     def sync(self):
66         self.queue.join()
67         
68     def worker(self):
69         while True:
70             task = self.queue.get()
71             if task is None:
72                 self.queue.task_done()
73                 break
74             
75             try:
76                 try:
77                     callable, args, kwargs = task
78                     rv = callable(*args, **kwargs)
79                     
80                     if self.rvqueue is not None:
81                         self.rvqueue.put(rv)
82                 finally:
83                     self.queue.task_done()
84             except:
85                 traceback.print_exc(file = sys.stderr)
86                 self.delayed_exceptions.append(sys.exc_info())
87
88     def __iter__(self):
89         if self.rvqueue is not None:
90             while True:
91                 try:
92                     yield self.rvqueue.get_nowait()
93                 except Queue.Empty:
94                     self.queue.join()
95                     try:
96                         yield self.rvqueue.get_nowait()
97                     except Queue.Empty:
98                         raise StopIteration
99             
100     
101 class ParallelFilter(ParallelMap):
102     class _FILTERED:
103         pass
104     
105     def __filter(self, x):
106         if self.filter_condition(x):
107             return x
108         else:
109             return self._FILTERED
110     
111     def __init__(self, filter_condition, maxthreads = None, maxqueue = None):
112         super(ParallelFilter, self).__init__(maxthreads, maxqueue, True)
113         self.filter_condition = filter_condition
114
115     def put(self, what):
116         super(ParallelFilter, self).put(self.__filter, what)
117     
118     def put_nowait(self, what):
119         super(ParallelFilter, self).put_nowait(self.__filter, what)
120         
121     def __iter__(self):
122         for rv in super(ParallelFilter, self).__iter__():
123             if rv is not self._FILTERED:
124                 yield rv
125
126 class ParallelRun(ParallelMap):
127     def __run(self, x):
128         fn, args, kwargs = x
129         return fn(*args, **kwargs)
130     
131     def __init__(self, maxthreads = None, maxqueue = None):
132         super(ParallelRun, self).__init__(maxthreads, maxqueue, True)
133
134     def put(self, what, *args, **kwargs):
135         super(ParallelRun, self).put(self.__run, (what, args, kwargs))
136     
137     def put_nowait(self, what, *args, **kwargs):
138         super(ParallelRun, self).put_nowait(self.__filter, (what, args, kwargs))
139
140
141 def pmap(mapping, iterable, maxthreads = None, maxqueue = None):
142     mapper = ParallelMap(
143         maxthreads = maxthreads,
144         maxqueue = maxqueue,
145         results = True)
146     mapper.start()
147     for elem in iterable:
148         mapper.put(elem)
149     return list(mapper)
150
151 def pfilter(condition, iterable, maxthreads = None, maxqueue = None):
152     filtrer = ParallelFilter(
153         condition,
154         maxthreads = maxthreads,
155         maxqueue = maxqueue)
156     filtrer.start()
157     for elem in iterable:
158         filtrer.put(elem)
159     return list(filtrer)
160