big drastic change: use SSLContext.wrap_socket()
[sfa.git] / sfa / server / threadedserver.py
1 ##
2 # This module implements a general-purpose server layer for sfa.
3 # The same basic server should be usable on the registry, component, or
4 # other interfaces.
5 #
6 # TODO: investigate ways to combine this with existing PLC server?
7 ##
8
9 import sys
10 import socket
11 import traceback
12 import threading
13 from queue import Queue
14 import socketserver
15 import ssl
16 import http.server
17 import xmlrpc.server
18 import xmlrpc.client
19
20 from sfa.util.sfalogging import logger
21 from sfa.util.config import Config
22 from sfa.util.cache import Cache
23 from sfa.trust.certificate import Certificate
24 from sfa.trust.trustedroots import TrustedRoots
25
26 # don't hard code an api class anymore here
27 from sfa.generic import Generic
28
29 ##
30 # Verification callback for pyOpenSSL. We do our own checking of keys because
31 # we have our own authentication spec. Thus we disable several of the normal
32 # prohibitions that OpenSSL places on certificates
33
34
35 def verify_callback(conn, x509, err, depth, preverify):
36     # if the cert has been preverified, then it is ok
37     if preverify:
38         # print "  preverified"
39         return True
40
41     # the certificate verification done by openssl checks a number of things
42     # that we aren't interested in, so we look out for those error messages
43     # and ignore them
44
45     # XXX SMBAKER: I don't know what this error is, but it's being returned
46     # xxx thierry: this most likely means the cert
47     #     has a validity range in the future
48     # by newer pl nodes.
49     if err == 9:
50         # print "  X509_V_ERR_CERT_NOT_YET_VALID"
51         return True
52
53     # allow self-signed certificates
54     if err == 18:
55         # print "  X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT"
56         return False
57
58     # allow certs that don't have an issuer
59     if err == 20:
60         # print "  X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY"
61         return False
62
63     # allow chained certs with self-signed roots
64     if err == 19:
65         return False
66
67     # allow certs that are untrusted
68     if err == 21:
69         # print "  X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE"
70         return False
71
72     # allow certs that are untrusted
73     if err == 27:
74         # print "  X509_V_ERR_CERT_UNTRUSTED"
75         return False
76
77     # ignore X509_V_ERR_CERT_SIGNATURE_FAILURE
78     if err == 7:
79         return False
80
81     logger.debug("  unhandled error %s in verify_callback" % err)
82
83     return False
84
85 ##
86 # taken from the web (XXX find reference). Implements HTTPS xmlrpc request
87 # handler
88
89 # python-2.7 http://code.activestate.com/recipes/442473-simple-http-server-supporting-ssl-secure-communica/
90 # python-3.3 https://gist.github.com/ubershmekel/6194556
91 class SecureXMLRpcRequestHandler(xmlrpc.server.SimpleXMLRPCRequestHandler):
92     """
93     Secure XML-RPC request handler class.
94
95     It it very similar to SimpleXMLRPCRequestHandler
96     but it uses HTTPS for transporting XML data.
97     """
98
99     # porting to python3
100     # setup() no longer needed
101
102     def do_POST(self):
103         """
104         Handles the HTTPS POST request.
105
106         It was copied out from SimpleXMLRPCServer.py and modified to shutdown
107         the socket cleanly.
108         """
109         try:
110             peer_cert = Certificate()
111             peer_cert.load_from_pyopenssl_x509(
112                 self.connection.getpeercert())
113             generic = Generic.the_flavour()
114             self.api = generic.make_api(peer_cert=peer_cert,
115                                         interface=self.server.interface,
116                                         key_file=self.server.key_file,
117                                         cert_file=self.server.cert_file,
118                                         cache=self.cache)
119             # logger.info("SecureXMLRpcRequestHandler.do_POST:")
120             # logger.info("interface=%s"%self.server.interface)
121             # logger.info("key_file=%s"%self.server.key_file)
122             # logger.info("api=%s"%self.api)
123             # logger.info("server=%s"%self.server)
124             # logger.info("handler=%s"%self)
125             # get arguments
126             request = self.rfile.read(int(self.headers["content-length"]))
127             remote_addr = (
128                 remote_ip, remote_port) = self.connection.getpeername()
129             self.api.remote_addr = remote_addr
130             response = self.api.handle(
131                 remote_addr, request, self.server.method_map)
132         except Exception as fault:
133             # This should only happen if the module is buggy
134             # internal error, report as HTTP server error
135             logger.log_exc("server.do_POST")
136             response = self.api.prepare_response(fault)
137             # self.send_response(500)
138             # self.end_headers()
139
140         # avoid session/connection leaks : do this no matter what
141         finally:
142             self.send_response(200)
143             self.send_header("Content-type", "text/xml")
144             self.send_header("Content-length", str(len(response)))
145             self.end_headers()
146             self.wfile.write(response.encode())
147             self.wfile.flush()
148             # close db connection
149             self.api.close_dbsession()
150             # shut down the connection
151             self.connection.shutdown(socket.SHUT_RDWR)  # Modified here!
152
153 ##
154 # Taken from the web (XXX find reference). Implements an HTTPS xmlrpc server
155
156
157 class SecureXMLRPCServer(http.server.HTTPServer,
158                          xmlrpc.server.SimpleXMLRPCDispatcher):
159
160     def __init__(self, server_address, HandlerClass,
161                  key_file, cert_file, logRequests=True):
162         """
163         Secure XML-RPC server.
164
165         It it very similar to SimpleXMLRPCServer
166         but it uses HTTPS for transporting XML data.
167         """
168         logger.debug(
169             f"SecureXMLRPCServer.__init__, server_address={server_address}, "
170             f"cert_file={cert_file}, key_file={key_file}")
171         self.logRequests = logRequests
172         self.interface = None
173         self.key_file = key_file
174         self.cert_file = cert_file
175         self.method_map = {}
176         # add cache to the request handler
177         HandlerClass.cache = Cache()
178         xmlrpc.server.SimpleXMLRPCDispatcher.__init__(self, True, None)
179         socketserver.BaseServer.__init__(self, server_address, HandlerClass)
180         ssl_context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
181         ssl_context.load_cert_chain(cert_file, key_file)
182         # If you wanted to verify certs against known CAs..
183         # this is how you would do it
184         # ssl_context.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid')
185         config = Config()
186         trusted_cert_files = TrustedRoots(
187             config.get_trustedroots_dir()).get_file_list()
188         cadata = ""
189         for cert_file in trusted_cert_files:
190             with open(cert_file) as cafile:
191                 cadata += cafile.read()
192         ssl_context.load_verify_locations(cadata=cadata)
193 #        ctx.set_verify(SSL.VERIFY_PEER |
194 #                       SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback)
195 #        ctx.set_verify_depth(5)
196 #        ctx.set_app_data(self)
197         # with python3 we use standard library SSLContext.wrap_socket()
198         # instead of an OpenSSL.SSL.Connection object
199         self.socket = ssl_context.wrap_socket(
200             socket.socket(self.address_family, self.socket_type),
201             server_side=True)
202         self.server_bind()
203         self.server_activate()
204
205     # _dispatch
206     #
207     # Convert an exception on the server to a full stack trace and send it to
208     # the client.
209
210     def _dispatch(self, method, params):
211         logger.debug("SecureXMLRPCServer._dispatch, method=%s" % method)
212         try:
213             return xmlrpc.server.SimpleXMLRPCDispatcher._dispatch(
214                 self, method, params)
215         except:
216             # can't use format_exc() as it is not available in jython yet
217             # (even in trunk).
218             type, value, tb = sys.exc_info()
219             raise xmlrpc.client.Fault(1, ''.join(
220                 traceback.format_exception(type, value, tb)))
221
222     # porting to python3
223     # shutdown_request() no longer needed
224
225
226 # From Active State code: http://code.activestate.com/recipes/574454/
227 # This is intended as a drop-in replacement for the ThreadingMixIn class in
228 # module SocketServer of the standard lib. Instead of spawning a new thread
229 # for each request, requests are processed by of pool of reusable threads.
230
231 class ThreadPoolMixIn(socketserver.ThreadingMixIn):
232     """
233     use a thread pool instead of a new thread on every request
234     """
235     # XX TODO: Make this configurable
236     # config = Config()
237     # numThreads = config.SFA_SERVER_NUM_THREADS
238     numThreads = 25
239     allow_reuse_address = True  # seems to fix socket.error on server restart
240
241     def serve_forever(self):
242         """
243         Handle one request at a time until doomsday.
244         """
245         # set up the threadpool
246         self.requests = Queue()
247
248         for _ in range(self.numThreads):
249             thread = threading.Thread(target=self.process_request_thread)
250             thread.setDaemon(1)
251             thread.start()
252
253         # server main loop
254         while True:
255             self.handle_request()
256
257         self.server_close()
258
259     def process_request_thread(self):
260         """
261         obtain request from queue instead of directly from server socket
262         """
263         while True:
264             socketserver.ThreadingMixIn.process_request_thread(
265                 self, *self.requests.get())
266
267     def handle_request(self):
268         """
269         simply collect requests and put them on the queue for the workers.
270         """
271         try:
272             request, client_address = self.get_request()
273         except socket.error:
274             return
275         if self.verify_request(request, client_address):
276             self.requests.put((request, client_address))
277
278
279 class ThreadedServer(ThreadPoolMixIn, SecureXMLRPCServer):
280     pass