X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=sfa%2Ftrust%2Fcertificate.py;h=9e0f82b880e21fdefecac4914653b8a34e71053f;hb=0a9902d2a55a0a9ac03601345c4284293669012b;hp=7fe977ffc8772a2520bf18c2b2dc831715e4e664;hpb=9194805e82cb8775d1257ba45b27a368aef9d8b7;p=sfa.git diff --git a/sfa/trust/certificate.py b/sfa/trust/certificate.py index 7fe977ff..9e0f82b8 100644 --- a/sfa/trust/certificate.py +++ b/sfa/trust/certificate.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- @@ -47,6 +47,8 @@ import OpenSSL # M2Crypto is imported on the fly to minimize crashes #import M2Crypto +from sfa.util.py23 import PY3 + from sfa.util.faults import CertExpired, CertMissingParent, CertNotSignedByParent from sfa.util.sfalogging import logger @@ -65,6 +67,7 @@ glo_passphrase_callback = None # # The callback should return a string containing the passphrase. + def set_passphrase_callback(callback_func): global glo_passphrase_callback @@ -73,24 +76,29 @@ def set_passphrase_callback(callback_func): ## # Sets a fixed passphrase. + def set_passphrase(passphrase): - set_passphrase_callback( lambda k,s,x: passphrase ) + set_passphrase_callback(lambda k, s, x: passphrase) ## # Check to see if a passphrase works for a particular private key string. # Intended to be used by passphrase callbacks for input validation. + def test_passphrase(string, passphrase): try: - OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, string, (lambda x: passphrase)) + OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, string, (lambda x: passphrase)) return True except: return False + def convert_public_key(key): keyconvert_path = "/usr/bin/keyconvert.py" if not os.path.isfile(keyconvert_path): - raise IOError("Could not find keyconvert in {}".format(keyconvert_path)) + raise IOError( + "Could not find keyconvert in {}".format(keyconvert_path)) # we can only convert rsa keys if "ssh-dss" in key: @@ -108,7 +116,8 @@ def convert_public_key(key): # that it can be expected to see why it failed. # TODO: for production, cleanup the temporary files if not os.path.exists(ssl_fn): - raise Exception("keyconvert: generated certificate not found. keyconvert may have failed.") + raise Exception( + "keyconvert: generated certificate not found. keyconvert may have failed.") k = Keypair() try: @@ -129,6 +138,7 @@ def convert_public_key(key): # A Keypair object may represent both a public and private key pair, or it # may represent only a public key (this usage is consistent with OpenSSL). + class Keypair: key = None # public/private keypair m2key = None # public key (m2crypto format) @@ -149,7 +159,8 @@ class Keypair: self.load_from_file(filename) ## - # Create a RSA public/private key pair and store it inside the keypair object + # Create a RSA public/private key pair and store it inside the keypair + # object def create(self): self.key = OpenSSL.crypto.PKey() @@ -164,7 +175,8 @@ class Keypair: self.filename = filename ## - # Load the private key from a file. Implicity the private key includes the public key. + # Load the private key from a file. Implicity the private key includes the + # public key. def load_from_file(self, filename): self.filename = filename @@ -172,7 +184,8 @@ class Keypair: self.load_from_string(buffer) ## - # Load the private key from a string. Implicitly the private key includes the public key. + # Load the private key from a string. Implicitly the private key includes + # the public key. def load_from_string(self, string): import M2Crypto @@ -182,7 +195,8 @@ class Keypair: self.m2key = M2Crypto.EVP.load_key_string( string, functools.partial(glo_passphrase_callback, self, string)) else: - self.key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, string) + self.key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, string) self.m2key = M2Crypto.EVP.load_key_string(string) ## @@ -197,7 +211,8 @@ class Keypair: # create an m2 x509 cert m2name = M2Crypto.X509.X509_Name() - m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0) + m2name.add_entry_by_txt(field="CN", type=0x1001, + entry="junk", len=-1, loc=-1, set=0) m2x509 = M2Crypto.X509.X509() m2x509.set_pubkey(self.m2key) m2x509.set_serial_number(0) @@ -215,7 +230,8 @@ class Keypair: # convert the m2 x509 cert to a pyopenssl x509 m2pem = m2x509.as_pem() - pyx509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, m2pem) + pyx509 = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, m2pem) # get the pyopenssl pkey from the pyopenssl x509 self.key = pyx509.get_pubkey() @@ -283,16 +299,17 @@ class Keypair: # only informative def get_filename(self): - return getattr(self,'filename',None) + return getattr(self, 'filename', None) def dump(self, *args, **kwargs): print(self.dump_string(*args, **kwargs)) def dump_string(self): - result = "" + result = "" result += "KEYPAIR: pubkey={:>40}...".format(self.get_pubkey_string()) filename = self.get_filename() - if filename: result += "Filename {}\n".format(filename) + if filename: + result += "Filename {}\n".format(filename) return result ## @@ -307,6 +324,7 @@ class Keypair: # When saving a certificate to a file or a string, the caller can choose # whether to save the parent certificates as well. + class Certificate: digest = "sha256" @@ -314,7 +332,7 @@ class Certificate: # issuerKey = None # issuerSubject = None # parent = None - isCA = None # will be a boolean once set + isCA = None # will be a boolean once set separator = "-----parent-----" @@ -356,10 +374,10 @@ class Certificate: self.x509 = OpenSSL.crypto.X509() # FIXME: Use different serial #s self.x509.set_serial_number(3) - self.x509.gmtime_adj_notBefore(0) # 0 means now - self.x509.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default - self.x509.set_version(2) # x509v3 so it can have extensions - + self.x509.gmtime_adj_notBefore(0) # 0 means now + self.x509.gmtime_adj_notAfter( + lifeDays * 60 * 60 * 24) # five years is default + self.x509.set_version(2) # x509v3 so it can have extensions ## # Given a pyOpenSSL X509 object, store that object inside of this @@ -373,14 +391,15 @@ class Certificate: def load_from_string(self, string): # if it is a chain of multiple certs, then split off the first one and - # load it (support for the ---parent--- tag as well as normal chained certs) + # load it (support for the ---parent--- tag as well as normal chained + # certs) if string is None or string.strip() == "": logger.warn("Empty string in load_from_string") return string = string.strip() - + # If it's not in proper PEM format, wrap it if string.count('-----BEGIN CERTIFICATE') == 0: string = '-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----'\ @@ -390,22 +409,24 @@ class Certificate: # such as the text of the certificate, skip the text beg = string.find('-----BEGIN CERTIFICATE') if beg > 0: - # skipping over non cert beginning + # skipping over non cert beginning string = string[beg:] parts = [] if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ - string.count(Certificate.separator) == 0: - parts = string.split('-----END CERTIFICATE-----',1) + string.count(Certificate.separator) == 0: + parts = string.split('-----END CERTIFICATE-----', 1) parts[0] += '-----END CERTIFICATE-----' else: parts = string.split(Certificate.separator, 1) - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, parts[0]) + self.x509 = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, parts[0]) if self.x509 is None: - logger.warn("Loaded from string but cert is None: {}".format(string)) + logger.warn( + "Loaded from string but cert is None: {}".format(string)) # if there are more certs, then create a parent and let the parent load # itself from the remainder of the string @@ -431,7 +452,10 @@ class Certificate: if self.x509 is None: logger.warn("None cert in certificate.save_to_string") return "" - string = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + string = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, self.x509) + if PY3 and isinstance(string, bytes): + string = string.decode() if save_parents and self.parent: string = string + self.parent.save_to_string(save_parents) return string @@ -446,6 +470,8 @@ class Certificate: f = filep else: f = open(filename, 'w') + if PY3 and isinstance(string, bytes): + string = string.decode() f.write(string) f.close() self.filename = filename @@ -519,30 +545,44 @@ class Certificate: # let's try to make this a little more usable as is makes logs hairy # FIXME: Consider adding 'urn:publicid' and 'uuid' back for GENI? pretty_fields = ['email'] + def filter_chunk(self, chunk): for field in self.pretty_fields: if field in chunk: - return " "+chunk + return " " + chunk def pretty_cert(self): message = "[Cert." x = self.x509.get_subject() ou = getattr(x, "OU") - if ou: message += " OU: {}".format(ou) + if ou: + message += " OU: {}".format(ou) cn = getattr(x, "CN") - if cn: message += " CN: {}".format(cn) + if cn: + message += " CN: {}".format(cn) data = self.get_data(field='subjectAltName') if data: message += " SubjectAltName:" counter = 0 filtered = [self.filter_chunk(chunk) for chunk in data.split()] - message += " ".join( [f for f in filtered if f]) + message += " ".join([f for f in filtered if f]) omitted = len([f for f in filtered if not f]) if omitted: message += "..+{} omitted".format(omitted) message += "]" return message + def pretty_chain(self): + message = "{}".format(self.x509.get_subject()) + parent = self.parent + while parent: + message += " -> {}".format(parent.x509.get_subject()) + parent = parent.parent + return message + + def pretty_name(self): + return self.get_filename() or self.pretty_chain() + ## # Get the public key of the certificate. # @@ -585,8 +625,6 @@ class Certificate: else: self.add_extension('basicConstraints', 1, 'CA:FALSE') - - ## # Add an X509 extension to the certificate. Add_extension can only be called # once for a particular extension name, due to limitations in the underlying @@ -596,7 +634,6 @@ class Certificate: # @param value string containing value of the extension def add_extension(self, name, critical, value): - import M2Crypto oldExtVal = None try: oldExtVal = self.get_extension(name) @@ -671,7 +708,8 @@ class Certificate: return self.data[field] ## - # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). + # Sign the certificate using the issuer private key and issuer subject + # previous set with set_issuer(). def sign(self): logger.debug('certificate.sign') @@ -692,9 +730,17 @@ class Certificate: m2x509 = M2Crypto.X509.load_cert_string(self.save_to_string()) m2pubkey = pubkey.get_m2_pubkey() # verify it - # verify returns -1 or 0 on failure depending on how serious the - # error conditions are - return m2x509.verify(m2pubkey) == 1 + # https://www.openssl.org/docs/man1.1.0/crypto/X509_verify.html + # verify returns + # 1 if it checks out + # 0 if if does not + # -1 if it could not be checked 'for some reason' + m2result = m2x509.verify(m2pubkey) + result = m2result == 1 + if debug_verify_chain: + logger.debug("Certificate.verify: <- {} (m2={}) ({} x {})" + .format(result, m2result, self.pretty_cert(), m2pubkey)) + return result # XXX alternatively, if openssl has been patched, do the much simpler: # try: @@ -717,6 +763,7 @@ class Certificate: # @param cert certificate object def is_signed_by_cert(self, cert): + logger.debug("Certificate.is_signed_by_cert -> invoking verify") k = cert.get_pubkey() result = self.verify(k) return result @@ -756,11 +803,12 @@ class Certificate: # @param Trusted_certs is a list of certificates that are trusted. # - def verify_chain(self, trusted_certs = None): + def verify_chain(self, trusted_certs=None): # Verify a chain of certificates. Each certificate must be signed by # the public key contained in it's parent. The chain is recursed # until a certificate is found that is signed by a trusted root. + logger.debug("Certificate.verify_chain {}".format(self.pretty_name())) # verify expiration time if self.x509.has_expired(): if debug_verify_chain: @@ -769,42 +817,47 @@ class Certificate: raise CertExpired(self.pretty_cert(), "client cert") # if this cert is signed by a trusted_cert, then we are set - for trusted_cert in trusted_certs: + for i, trusted_cert in enumerate(trusted_certs, 1): + logger.debug("Certificate.verify_chain - trying trusted #{} : {}" + .format(i, trusted_cert.pretty_name())) if self.is_signed_by_cert(trusted_cert): # verify expiration of trusted_cert ? if not trusted_cert.x509.has_expired(): if debug_verify_chain: logger.debug("verify_chain: YES. Cert {} signed by trusted cert {}" - .format(self.pretty_cert(), trusted_cert.pretty_cert())) + .format(self.pretty_name(), trusted_cert.pretty_name())) return trusted_cert else: if debug_verify_chain: logger.debug("verify_chain: NO. Cert {} is signed by trusted_cert {}, " "but that signer is expired..." - .format(self.pretty_cert(),trusted_cert.pretty_cert())) + .format(self.pretty_cert(), trusted_cert.pretty_cert())) raise CertExpired("{} signer trusted_cert {}" - .format(self.pretty_cert(), trusted_cert.pretty_cert())) + .format(self.pretty_name(), trusted_cert.pretty_name())) + else: + logger.debug("verify_chain: not a direct descendant of a trusted root". + format(self.pretty_name(), trusted_cert)) # if there is no parent, then no way to verify the chain if not self.parent: if debug_verify_chain: logger.debug("verify_chain: NO. {} has no parent " "and issuer {} is not in {} trusted roots" - .format(self.pretty_cert(), self.get_issuer(), len(trusted_certs))) + .format(self.pretty_name(), self.get_issuer(), len(trusted_certs))) raise CertMissingParent("{}: Issuer {} is not one of the {} trusted roots, " "and cert has no parent." - .format(self.pretty_cert(), self.get_issuer(), len(trusted_certs))) + .format(self.pretty_name(), self.get_issuer(), len(trusted_certs))) # if it wasn't signed by the parent... if not self.is_signed_by_cert(self.parent): if debug_verify_chain: - logger.debug("verify_chain: NO. {} is not signed by parent {}, but by {}" - .format(self.pretty_cert(), - self.parent.pretty_cert(), - self.get_issuer())) + logger.debug("verify_chain: NO. {} is not signed by parent {}" + .format(self.pretty_name(), + self.parent.pretty_name())) + self.save_to_file("/tmp/xxx-capture.pem", save_parents=True) raise CertNotSignedByParent("{}: Parent {}, issuer {}" - .format(self.pretty_cert(), - self.parent.pretty_cert(), + .format(self.pretty_name(), + self.parent.pretty_name(), self.get_issuer())) # Confirm that the parent is a CA. Only CAs can be trusted as @@ -815,19 +868,19 @@ class Certificate: # extension and hope there are no other basicConstraints if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'): logger.warn("verify_chain: cert {}'s parent {} is not a CA" - .format(self.pretty_cert(), self.parent.pretty_cert())) + .format(self.pretty_name(), self.parent.pretty_name())) raise CertNotSignedByParent("{}: Parent {} not a CA" - .format(self.pretty_cert(), self.parent.pretty_cert())) + .format(self.pretty_name(), self.parent.pretty_name())) # if the parent isn't verified... if debug_verify_chain: logger.debug("verify_chain: .. {}, -> verifying parent {}" - .format(self.pretty_cert(),self.parent.pretty_cert())) + .format(self.pretty_name(),self.parent.pretty_name())) self.parent.verify_chain(trusted_certs) return - ### more introspection + # more introspection def get_extensions(self): import M2Crypto # pyOpenSSL does not have a way to get extensions @@ -837,7 +890,8 @@ class Certificate: logger.debug("X509 had {} extensions".format(nb_extensions)) for i in range(nb_extensions): ext = m2x509.get_ext_at(i) - triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) + triples.append( + (ext.get_name(), ext.get_value(), ext.get_critical(),)) return triples def get_data_names(self): @@ -846,12 +900,12 @@ class Certificate: def get_all_datas(self): triples = self.get_extensions() for name in self.get_data_names(): - triples.append( (name,self.get_data(name),'data',) ) + triples.append((name, self.get_data(name), 'data',)) return triples # only informative def get_filename(self): - return getattr(self,'filename',None) + return getattr(self, 'filename', None) def dump(self, *args, **kwargs): print(self.dump_string(*args, **kwargs))