attempt to make api-creation code more robust and leak-proof
[sfa.git] / sfa / server / threadedserver.py
1 ##
2 # This module implements a general-purpose server layer for sfa.
3 # The same basic server should be usable on the registry, component, or
4 # other interfaces.
5 #
6 # TODO: investigate ways to combine this with existing PLC server?
7 ##
8
9 import sys
10 import socket
11 import traceback
12 import threading
13 from Queue import Queue
14 import xmlrpclib
15 import SocketServer
16 import BaseHTTPServer
17 import SimpleXMLRPCServer
18 from OpenSSL import SSL
19
20 from sfa.util.sfalogging import logger
21 from sfa.util.config import Config
22 from sfa.util.cache import Cache 
23 from sfa.trust.certificate import Certificate
24 from sfa.trust.trustedroots import TrustedRoots
25
26 # don't hard code an api class anymore here
27 from sfa.generic import Generic
28
29 ##
30 # Verification callback for pyOpenSSL. We do our own checking of keys because
31 # we have our own authentication spec. Thus we disable several of the normal
32 # prohibitions that OpenSSL places on certificates
33
34 def verify_callback(conn, x509, err, depth, preverify):
35     # if the cert has been preverified, then it is ok
36     if preverify:
37        #print "  preverified"
38        return 1
39
40
41     # the certificate verification done by openssl checks a number of things
42     # that we aren't interested in, so we look out for those error messages
43     # and ignore them
44
45     # XXX SMBAKER: I don't know what this error is, but it's being returned
46     # xxx thierry: this most likely means the cert has a validity range in the future
47     # by newer pl nodes.
48     if err == 9:
49        #print "  X509_V_ERR_CERT_NOT_YET_VALID"
50        return 1
51
52     # allow self-signed certificates
53     if err == 18:
54        #print "  X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT"
55        return 1
56
57     # allow certs that don't have an issuer
58     if err == 20:
59        #print "  X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY"
60        return 1
61
62     # allow chained certs with self-signed roots
63     if err == 19:
64         return 1
65     
66     # allow certs that are untrusted
67     if err == 21:
68        #print "  X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE"
69        return 1
70
71     # allow certs that are untrusted
72     if err == 27:
73        #print "  X509_V_ERR_CERT_UNTRUSTED"
74        return 1
75
76     # ignore X509_V_ERR_CERT_SIGNATURE_FAILURE
77     if err == 7:
78        return 1         
79
80     logger.debug("  error %s in verify_callback"%err)
81
82     return 0
83
84 ##
85 # taken from the web (XXX find reference). Implements HTTPS xmlrpc request handler
86 class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
87     """Secure XML-RPC request handler class.
88
89     It it very similar to SimpleXMLRPCRequestHandler but it uses HTTPS for transporting XML data.
90     """
91     def setup(self):
92         self.connection = self.request
93         self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
94         self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
95
96     def do_POST(self):
97         """Handles the HTTPS POST request.
98
99         It was copied out from SimpleXMLRPCServer.py and modified to shutdown 
100         the socket cleanly.
101         """
102         try:
103             peer_cert = Certificate()
104             peer_cert.load_from_pyopenssl_x509(self.connection.get_peer_certificate())
105             generic=Generic.the_flavour()
106             self.api = generic.make_api (peer_cert = peer_cert, 
107                                          interface = self.server.interface, 
108                                          key_file = self.server.key_file, 
109                                          cert_file = self.server.cert_file,
110                                          cache = self.cache)
111             #logger.info("SecureXMLRpcRequestHandler.do_POST:")
112             #logger.info("interface=%s"%self.server.interface)
113             #logger.info("key_file=%s"%self.server.key_file)
114             #logger.info("api=%s"%self.api)
115             #logger.info("server=%s"%self.server)
116             #logger.info("handler=%s"%self)
117             # get arguments
118             request = self.rfile.read(int(self.headers["content-length"]))
119             remote_addr = (remote_ip, remote_port) = self.connection.getpeername()
120             self.api.remote_addr = remote_addr            
121             response = self.api.handle(remote_addr, request, self.server.method_map)
122         except Exception, fault:
123             # This should only happen if the module is buggy
124             # internal error, report as HTTP server error
125             logger.log_exc("server.do_POST")
126             response = self.api.prepare_response(fault)
127             #self.send_response(500)
128             #self.end_headers()
129        
130         # avoid session/connection leaks : do this no matter what 
131         finally:
132             self.send_response(200)
133             self.send_header("Content-type", "text/xml")
134             self.send_header("Content-length", str(len(response)))
135             self.end_headers()
136             self.wfile.write(response)
137             self.wfile.flush()
138             # close db connection
139             self.api.close_dbsession()
140             # shut down the connection
141             self.connection.shutdown() # Modified here!
142
143 ##
144 # Taken from the web (XXX find reference). Implements an HTTPS xmlrpc server
145 class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
146
147     def __init__(self, server_address, HandlerClass, key_file, cert_file, logRequests=True):
148         """Secure XML-RPC server.
149
150         It it very similar to SimpleXMLRPCServer but it uses HTTPS for transporting XML data.
151         """
152         logger.debug("SecureXMLRPCServer.__init__, server_address=%s, cert_file=%s, key_file=%s"%(server_address,cert_file,key_file))
153         self.logRequests = logRequests
154         self.interface = None
155         self.key_file = key_file
156         self.cert_file = cert_file
157         self.method_map = {}
158         # add cache to the request handler
159         HandlerClass.cache = Cache()
160         #for compatibility with python 2.4 (centos53)
161         if sys.version_info < (2, 5):
162             SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self)
163         else:
164            SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, True, None)
165         SocketServer.BaseServer.__init__(self, server_address, HandlerClass)
166         ctx = SSL.Context(SSL.SSLv23_METHOD)
167         ctx.use_privatekey_file(key_file)        
168         ctx.use_certificate_file(cert_file)
169         # If you wanted to verify certs against known CAs.. this is how you would do it
170         #ctx.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid')
171         config = Config()
172         trusted_cert_files = TrustedRoots(config.get_trustedroots_dir()).get_file_list()
173         for cert_file in trusted_cert_files:
174             ctx.load_verify_locations(cert_file)
175         ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback)
176         ctx.set_verify_depth(5)
177         ctx.set_app_data(self)
178         self.socket = SSL.Connection(ctx, socket.socket(self.address_family,
179                                                         self.socket_type))
180         self.server_bind()
181         self.server_activate()
182
183     # _dispatch
184     #
185     # Convert an exception on the server to a full stack trace and send it to
186     # the client.
187
188     def _dispatch(self, method, params):
189         logger.debug("SecureXMLRPCServer._dispatch, method=%s"%method)
190         try:
191             return SimpleXMLRPCServer.SimpleXMLRPCDispatcher._dispatch(self, method, params)
192         except:
193             # can't use format_exc() as it is not available in jython yet
194             # (even in trunk).
195             type, value, tb = sys.exc_info()
196             raise xmlrpclib.Fault(1,''.join(traceback.format_exception(type, value, tb)))
197
198     # override this one from the python 2.7 code
199     # originally defined in class TCPServer
200     def shutdown_request(self, request):
201         """Called to shutdown and close an individual request."""
202         # ---------- 
203         # the std python 2.7 code just attempts a request.shutdown(socket.SHUT_WR)
204         # this works fine with regular sockets
205         # However we are dealing with an instance of OpenSSL.SSL.Connection instead
206         # This one only supports shutdown(), and in addition this does not
207         # always perform as expected
208         # ---------- std python 2.7 code
209         try:
210             #explicitly shutdown.  socket.close() merely releases
211             #the socket and waits for GC to perform the actual close.
212             request.shutdown(socket.SHUT_WR)
213         except socket.error:
214             pass #some platforms may raise ENOTCONN here
215         # ----------
216         except TypeError:
217             # we are dealing with an OpenSSL.Connection object, 
218             # try to shut it down but never mind if that fails
219             try: request.shutdown()
220             except: pass
221         # ----------
222         self.close_request(request)
223
224 ## From Active State code: http://code.activestate.com/recipes/574454/
225 # This is intended as a drop-in replacement for the ThreadingMixIn class in 
226 # module SocketServer of the standard lib. Instead of spawning a new thread 
227 # for each request, requests are processed by of pool of reusable threads.
228 class ThreadPoolMixIn(SocketServer.ThreadingMixIn):
229     """
230     use a thread pool instead of a new thread on every request
231     """
232     # XX TODO: Make this configurable
233     # config = Config()
234     # numThreads = config.SFA_SERVER_NUM_THREADS
235     numThreads = 25
236     allow_reuse_address = True  # seems to fix socket.error on server restart
237
238     def serve_forever(self):
239         """
240         Handle one request at a time until doomsday.
241         """
242         # set up the threadpool
243         self.requests = Queue()
244
245         for x in range(self.numThreads):
246             t = threading.Thread(target = self.process_request_thread)
247             t.setDaemon(1)
248             t.start()
249
250         # server main loop
251         while True:
252             self.handle_request()
253             
254         self.server_close()
255
256     
257     def process_request_thread(self):
258         """
259         obtain request from queue instead of directly from server socket
260         """
261         while True:
262             SocketServer.ThreadingMixIn.process_request_thread(self, *self.requests.get())
263
264     
265     def handle_request(self):
266         """
267         simply collect requests and put them on the queue for the workers.
268         """
269         try:
270             request, client_address = self.get_request()
271         except socket.error:
272             return
273         if self.verify_request(request, client_address):
274             self.requests.put((request, client_address))
275
276 class ThreadedServer(ThreadPoolMixIn, SecureXMLRPCServer):
277     pass