X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=sfa%2Ftrust%2Fcertificate.py;h=959b76380c288184ecce5fb7b2ec9b4f21729424;hb=db091e73c33c373b7f6c2c96bd2caf6a2acf0178;hp=f12c86b8ba60d116bf51e435701d040f8708480f;hpb=ea995a055eba04aedff577e86652abaaa5e881aa;p=sfa.git diff --git a/sfa/trust/certificate.py b/sfa/trust/certificate.py index f12c86b8..959b7638 100644 --- a/sfa/trust/certificate.py +++ b/sfa/trust/certificate.py @@ -34,10 +34,8 @@ # This module exports two classes: Keypair and Certificate. ## # -### $Id$ -### $URL$ -# +import functools import os import tempfile import base64 @@ -48,9 +46,44 @@ from OpenSSL import crypto import M2Crypto from M2Crypto import X509 -from sfa.util.sfalogging import sfa_logger -from sfa.util.namespace import urn_to_hrn +from sfa.util.sfalogging import logger +from sfa.util.xrn import urn_to_hrn from sfa.util.faults import * +from sfa.util.sfalogging import logger + +glo_passphrase_callback = None + +## +# A global callback msy be implemented for requesting passphrases from the +# user. The function will be called with three arguments: +# +# keypair_obj: the keypair object that is calling the passphrase +# string: the string containing the private key that's being loaded +# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto +# +# The callback should return a string containing the passphrase. + +def set_passphrase_callback(callback_func): + global glo_passphrase_callback + + glo_passphrase_callback = callback_func + +## +# Sets a fixed passphrase. + +def set_passphrase(passphrase): + set_passphrase_callback( lambda k,s,x: passphrase ) + +## +# Check to see if a passphrase works for a particular private key string. +# Intended to be used by passphrase callbacks for input validation. + +def test_passphrase(string, passphrase): + try: + crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase)) + return True + except: + return False def convert_public_key(key): keyconvert_path = "/usr/bin/keyconvert.py" @@ -79,7 +112,7 @@ def convert_public_key(key): try: k.load_pubkey_from_file(ssl_fn) except: - sfa_logger().log_exc("convert_public_key caught exception") + logger.log_exc("convert_public_key caught exception") k = None # remove the temporary files @@ -125,11 +158,13 @@ class Keypair: def save_to_file(self, filename): open(filename, 'w').write(self.as_pem()) + self.filename=filename ## # Load the private key from a file. Implicity the private key includes the public key. def load_from_file(self, filename): + self.filename=filename buffer = open(filename, 'r').read() self.load_from_string(buffer) @@ -137,8 +172,12 @@ class Keypair: # Load the private key from a string. Implicitly the private key includes the public key. def load_from_string(self, string): - self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) - self.m2key = M2Crypto.EVP.load_key_string(string) + if glo_passphrase_callback: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) ) + self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) ) + else: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) + self.m2key = M2Crypto.EVP.load_key_string(string) ## # Load the public key from a string. No private key is loaded. @@ -161,6 +200,9 @@ class Keypair: ASN1.set_time(500) m2x509.set_not_before(ASN1) m2x509.set_not_after(ASN1) + # x509v3 so it can have extensions + # prob not necc since this cert itself is junk but still... + m2x509.set_version(2) junk_key = Keypair(create=True) m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1") @@ -170,6 +212,7 @@ class Keypair: # get the pyopenssl pkey from the pyopenssl x509 self.key = pyx509.get_pubkey() + self.filename=filename ## # Load the public key from a string. No private key is loaded. @@ -208,7 +251,6 @@ class Keypair: def get_openssl_pkey(self): return self.key - ## # Given another Keypair object, return TRUE if the two keys are the same. @@ -230,6 +272,20 @@ class Keypair: def compute_hash(self, value): return self.sign_string(str(value)) + # only informative + def get_filename(self): + return getattr(self,'filename',None) + + def dump (self, *args, **kwargs): + print self.dump_string(*args, **kwargs) + + def dump_string (self): + result="" + result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string() + filename=self.get_filename() + if filename: result += "Filename %s\n"%filename + return result + ## # The certificate class implements a general purpose X509 certificate, making # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds @@ -275,7 +331,6 @@ class Certificate: if intermediate: self.set_intermediate_ca(intermediate) - ## # Create a blank X509 certificate and store it in this object. def create(self): @@ -283,6 +338,8 @@ class Certificate: self.cert.set_serial_number(3) self.cert.gmtime_adj_notBefore(0) self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years + self.cert.set_version(2) # x509v3 so it can have extensions + ## # Given a pyOpenSSL X509 object, store that object inside of this @@ -299,11 +356,18 @@ class Certificate: # load it (support for the ---parent--- tag as well as normal chained certs) string = string.strip() - - - if not string.startswith('-----'): + + # If it's not in proper PEM format, wrap it + if string.count('-----BEGIN CERTIFICATE') == 0: string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string + # If there is a PEM cert in there, but there is some other text first + # such as the text of the certificate, skip the text + beg = string.find('-----BEGIN CERTIFICATE') + if beg > 0: + # skipping over non cert beginning + string = string[beg:] + parts = [] if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ @@ -328,6 +392,7 @@ class Certificate: file = open(filename) string = file.read() self.load_from_string(string) + self.filename=filename ## # Save the certificate to a string. @@ -352,6 +417,7 @@ class Certificate: f = open(filename, 'w') f.write(string) f.close() + self.filename=filename ## # Save the certificate to a random file in /tmp/ @@ -498,7 +564,7 @@ class Certificate: # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). def sign(self): - sfa_logger().debug('certificate.sign') + logger.debug('certificate.sign') assert self.cert != None assert self.issuerSubject != None assert self.issuerKey != None @@ -538,7 +604,6 @@ class Certificate: # @param cert certificate object def is_signed_by_cert(self, cert): - print 'is_signed_by_cert' k = cert.get_pubkey() result = self.verify(k) return result @@ -584,7 +649,7 @@ class Certificate: # verify expiration time if self.cert.has_expired(): - sfa_logger().debug("verify_chain: NO our certificate has expired") + logger.debug("verify_chain: NO our certificate has expired") raise CertExpired(self.get_subject(), "client cert") # if this cert is signed by a trusted_cert, then we are set @@ -592,26 +657,26 @@ class Certificate: if self.is_signed_by_cert(trusted_cert): # verify expiration of trusted_cert ? if not trusted_cert.cert.has_expired(): - sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%( + logger.debug("verify_chain: YES cert %s signed by trusted cert %s"%( self.get_subject(), trusted_cert.get_subject())) return trusted_cert else: - sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%( + logger.debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%( self.get_subject(),trusted_cert.get_subject())) raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject()) # if there is no parent, then no way to verify the chain if not self.parent: - sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject()) + logger.debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject()) raise CertMissingParent(self.get_subject()) # if it wasn't signed by the parent... if not self.is_signed_by_cert(self.parent): - sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject()) + logger.debug("verify_chain: NO %s is not signed by parent"%self.get_subject()) return CertNotSignedByParent(self.get_subject()) # if the parent isn't verified... - sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject()) + logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject())) self.parent.verify_chain(trusted_certs) return @@ -622,7 +687,7 @@ class Certificate: triples=[] m2x509 = X509.load_cert_string(self.save_to_string()) nb_extensions=m2x509.get_ext_count() - sfa_logger().debug("X509 had %d extensions"%nb_extensions) + logger.debug("X509 had %d extensions"%nb_extensions) for i in range(nb_extensions): ext=m2x509.get_ext_at(i) triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) @@ -637,16 +702,25 @@ class Certificate: triples.append( (name,self.get_data(name),'data',) ) return triples + # only informative + def get_filename(self): + return getattr(self,'filename',None) + def dump (self, *args, **kwargs): print self.dump_string(*args, **kwargs) - def dump_string (self): + def dump_string (self,show_extensions=False): result = "" - result += "Certificate for %s\n"%self.get_subject() + result += "CERTIFICATE for %s\n"%self.get_subject() result += "Issued by %s\n"%self.get_issuer() - for (n,v,c) in self.get_all_datas(): - if c=='data': - result += " data: %s=%s\n"%(n,v) - else: - result += " ext: %s=%s (%s)\n"%(n,v,c) + filename=self.get_filename() + if filename: result += "Filename %s\n"%filename + if show_extensions: + all_datas=self.get_all_datas() + result += " has %d extensions/data attached"%len(all_datas) + for (n,v,c) in all_datas: + if c=='data': + result += " data: %s=%s\n"%(n,v) + else: + result += " ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v) return result