X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=sfa%2Ftrust%2Fcertificate.py;h=f0a2d71c916ff0f625ee0fe28734d80884088e13;hb=57b6a99255d4a88be9c0f910f8524677e34ff4bc;hp=98588a976d49a38cb3836faa28754ac8cbff8273;hpb=dbce495b6f2e7d8dccbfb18c5507907d784c143b;p=sfa.git diff --git a/sfa/trust/certificate.py b/sfa/trust/certificate.py index 98588a97..f0a2d71c 100644 --- a/sfa/trust/certificate.py +++ b/sfa/trust/certificate.py @@ -35,19 +35,52 @@ ## # +import functools import os import tempfile import base64 -import traceback from tempfile import mkstemp from OpenSSL import crypto import M2Crypto from M2Crypto import X509 -from sfa.util.sfalogging import sfa_logger -from sfa.util.namespace import urn_to_hrn -from sfa.util.faults import * +from sfa.util.faults import CertExpired, CertMissingParent, CertNotSignedByParent +from sfa.util.sfalogging import logger + +glo_passphrase_callback = None + +## +# A global callback msy be implemented for requesting passphrases from the +# user. The function will be called with three arguments: +# +# keypair_obj: the keypair object that is calling the passphrase +# string: the string containing the private key that's being loaded +# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto +# +# The callback should return a string containing the passphrase. + +def set_passphrase_callback(callback_func): + global glo_passphrase_callback + + glo_passphrase_callback = callback_func + +## +# Sets a fixed passphrase. + +def set_passphrase(passphrase): + set_passphrase_callback( lambda k,s,x: passphrase ) + +## +# Check to see if a passphrase works for a particular private key string. +# Intended to be used by passphrase callbacks for input validation. + +def test_passphrase(string, passphrase): + try: + crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase)) + return True + except: + return False def convert_public_key(key): keyconvert_path = "/usr/bin/keyconvert.py" @@ -76,7 +109,7 @@ def convert_public_key(key): try: k.load_pubkey_from_file(ssl_fn) except: - sfa_logger().log_exc("convert_public_key caught exception") + logger.log_exc("convert_public_key caught exception") k = None # remove the temporary files @@ -128,16 +161,20 @@ class Keypair: # Load the private key from a file. Implicity the private key includes the public key. def load_from_file(self, filename): + self.filename=filename buffer = open(filename, 'r').read() self.load_from_string(buffer) - self.filename=filename ## # Load the private key from a string. Implicitly the private key includes the public key. def load_from_string(self, string): - self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) - self.m2key = M2Crypto.EVP.load_key_string(string) + if glo_passphrase_callback: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) ) + self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) ) + else: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) + self.m2key = M2Crypto.EVP.load_key_string(string) ## # Load the public key from a string. No private key is loaded. @@ -160,6 +197,9 @@ class Keypair: ASN1.set_time(500) m2x509.set_not_before(ASN1) m2x509.set_not_after(ASN1) + # x509v3 so it can have extensions + # prob not necc since this cert itself is junk but still... + m2x509.set_version(2) junk_key = Keypair(create=True) m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1") @@ -242,7 +282,7 @@ class Keypair: filename=self.get_filename() if filename: result += "Filename %s\n"%filename return result - + ## # The certificate class implements a general purpose X509 certificate, making # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds @@ -262,22 +302,25 @@ class Certificate: issuerKey = None issuerSubject = None parent = None + isCA = None # will be a boolean once set separator="-----parent-----" ## # Create a certificate object. # + # @param lifeDays life of cert in days - default is 1825==5 years # @param create If create==True, then also create a blank X509 certificate. # @param subject If subject!=None, then create a blank certificate and set # it's subject name. # @param string If string!=None, load the certficate from the string. # @param filename If filename!=None, load the certficiate from the file. + # @param isCA If !=None, set whether this cert is for a CA - def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None): + def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None): self.data = {} if create or subject: - self.create() + self.create(lifeDays) if subject: self.set_subject(subject) if string: @@ -285,17 +328,20 @@ class Certificate: if filename: self.load_from_file(filename) - if intermediate: - self.set_intermediate_ca(intermediate) + # Set the CA bit if a value was supplied + if isCA != None: + self.set_is_ca(isCA) - ## # Create a blank X509 certificate and store it in this object. - def create(self): + def create(self, lifeDays=1825): self.cert = crypto.X509() + # FIXME: Use different serial #s self.cert.set_serial_number(3) - self.cert.gmtime_adj_notBefore(0) - self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years + self.cert.gmtime_adj_notBefore(0) # 0 means now + self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default + self.cert.set_version(2) # x509v3 so it can have extensions + ## # Given a pyOpenSSL X509 object, store that object inside of this @@ -312,11 +358,18 @@ class Certificate: # load it (support for the ---parent--- tag as well as normal chained certs) string = string.strip() - - - if not string.startswith('-----'): + + # If it's not in proper PEM format, wrap it + if string.count('-----BEGIN CERTIFICATE') == 0: string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string + # If there is a PEM cert in there, but there is some other text first + # such as the text of the certificate, skip the text + beg = string.find('-----BEGIN CERTIFICATE') + if beg > 0: + # skipping over non cert beginning + string = string[beg:] + parts = [] if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ @@ -424,6 +477,7 @@ class Certificate: else: setattr(subj, "CN", name) self.cert.set_subject(subj) + ## # Get the subject name of the certificate @@ -431,6 +485,13 @@ class Certificate: x = self.cert.get_subject() return getattr(x, which) + ## + # Get a pretty-print subject name of the certificate + + def get_printable_subject(self): + x = self.cert.get_subject() + return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data()) + ## # Get the public key of the certificate. # @@ -452,9 +513,24 @@ class Certificate: return pkey def set_intermediate_ca(self, val): - self.intermediate = val + return self.set_is_ca(val) + + # Set whether this cert is for a CA. All signers and only signers should be CAs. + # The local member starts unset, letting us check that you only set it once + # @param val Boolean indicating whether this cert is for a CA + def set_is_ca(self, val): + if val is None: + return + + if self.isCA != None: + # Can't double set properties + raise Exception, "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val) + + self.isCA = val if val: self.add_extension('basicConstraints', 1, 'CA:TRUE') + else: + self.add_extension('basicConstraints', 1, 'CA:FALSE') @@ -467,6 +543,25 @@ class Certificate: # @param value string containing value of the extension def add_extension(self, name, critical, value): + oldExtVal = None + try: + oldExtVal = self.get_extension(name) + except: + # M2Crypto LookupError when the extension isn't there (yet) + pass + + # This code limits you from adding the extension with the same value + # The method comment says you shouldn't do this with the same name + # But actually it (m2crypto) appears to allow you to do this. + if oldExtVal and oldExtVal == value: + # don't add this extension again + # just do nothing as here + return + # FIXME: What if they are trying to set with a different value? + # Is this ever OK? Or should we raise an exception? +# elif oldExtVal: +# raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value) + ext = crypto.X509Extension (name, critical, value) self.cert.add_extensions([ext]) @@ -478,7 +573,7 @@ class Certificate: # pyOpenSSL does not have a way to get extensions m2x509 = X509.load_cert_string(self.save_to_string()) value = m2x509.get_ext(name).get_value() - + return value ## @@ -513,7 +608,7 @@ class Certificate: # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). def sign(self): - sfa_logger().debug('certificate.sign') + logger.debug('certificate.sign') assert self.cert != None assert self.issuerSubject != None assert self.issuerKey != None @@ -553,7 +648,6 @@ class Certificate: # @param cert certificate object def is_signed_by_cert(self, cert): - print 'is_signed_by_cert' k = cert.get_pubkey() result = self.verify(k) return result @@ -588,6 +682,7 @@ class Certificate: # child. If a parent did not sign a child, then an exception is thrown. If # the bottom of the recursion is reached and the certificate does not match # a trusted root, then an exception is thrown. + # Also require that parents are CAs. # # @param Trusted_certs is a list of certificates that are trusted. # @@ -599,34 +694,53 @@ class Certificate: # verify expiration time if self.cert.has_expired(): - sfa_logger().debug("verify_chain: NO our certificate has expired") - raise CertExpired(self.get_subject(), "client cert") - + logger.debug("verify_chain: NO, Certificate %s has expired" % self.get_printable_subject()) + raise CertExpired(self.get_printable_subject(), "client cert") + # if this cert is signed by a trusted_cert, then we are set for trusted_cert in trusted_certs: if self.is_signed_by_cert(trusted_cert): # verify expiration of trusted_cert ? if not trusted_cert.cert.has_expired(): - sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%( - self.get_subject(), trusted_cert.get_subject())) + logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%( + self.get_printable_subject(), trusted_cert.get_printable_subject())) return trusted_cert else: - sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%( - self.get_subject(),trusted_cert.get_subject())) - raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject()) + logger.debug("verify_chain: NO. Cert %s is signed by trusted_cert %s, but that signer is expired..."%( + self.get_printable_subject(),trusted_cert.get_printable_subject())) + raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject()) # if there is no parent, then no way to verify the chain if not self.parent: - sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject()) - raise CertMissingParent(self.get_subject()) + logger.debug("verify_chain: NO. %s has no parent and issuer %s is not in %d trusted roots"%(self.get_printable_subject(), self.get_issuer(), len(trusted_certs))) + raise CertMissingParent(self.get_printable_subject() + ": Issuer %s not trusted by any of %d trusted roots, and cert has no parent." % (self.get_issuer(), len(trusted_certs))) # if it wasn't signed by the parent... if not self.is_signed_by_cert(self.parent): - sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject()) - return CertNotSignedByParent(self.get_subject()) + logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%\ + (self.get_printable_subject(), + self.parent.get_printable_subject(), + self.get_issuer())) + raise CertNotSignedByParent("%s: Parent %s, issuer %s"\ + % (self.get_printable_subject(), + self.parent.get_printable_subject(), + self.get_issuer())) + + # Confirm that the parent is a CA. Only CAs can be trusted as + # signers. + # Note that trusted roots are not parents, so don't need to be + # CAs. + # Ugly - cert objects aren't parsed so we need to read the + # extension and hope there are no other basicConstraints + if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'): + logger.warn("verify_chain: cert %s's parent %s is not a CA" % \ + (self.get_printable_subject(), self.parent.get_printable_subject())) + raise CertNotSignedByParent("%s: Parent %s not a CA" % (self.get_printable_subject(), + self.parent.get_printable_subject())) # if the parent isn't verified... - sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject()) + logger.debug("verify_chain: .. %s, -> verifying parent %s"%\ + (self.get_printable_subject(),self.parent.get_printable_subject())) self.parent.verify_chain(trusted_certs) return @@ -637,7 +751,7 @@ class Certificate: triples=[] m2x509 = X509.load_cert_string(self.save_to_string()) nb_extensions=m2x509.get_ext_count() - sfa_logger().debug("X509 had %d extensions"%nb_extensions) + logger.debug("X509 had %d extensions"%nb_extensions) for i in range(nb_extensions): ext=m2x509.get_ext_at(i) triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) @@ -648,7 +762,7 @@ class Certificate: def get_all_datas (self): triples=self.get_extensions() - for name in self.get_data_names(): + for name in self.get_data_names(): triples.append( (name,self.get_data(name),'data',) ) return triples @@ -661,7 +775,7 @@ class Certificate: def dump_string (self,show_extensions=False): result = "" - result += "CERTIFICATE for %s\n"%self.get_subject() + result += "CERTIFICATE for %s\n"%self.get_printable_subject() result += "Issued by %s\n"%self.get_issuer() filename=self.get_filename() if filename: result += "Filename %s\n"%filename