really fixed the redundant logging issue this time.
[sfa.git] / sfa / trust / certificate.py
index f12c86b..959b763 100644 (file)
 # This module exports two classes: Keypair and Certificate.
 ##
 #
-### $Id$
-### $URL$
-#
 
+import functools
 import os
 import tempfile
 import base64
@@ -48,9 +46,44 @@ from OpenSSL import crypto
 import M2Crypto
 from M2Crypto import X509
 
-from sfa.util.sfalogging import sfa_logger
-from sfa.util.namespace import urn_to_hrn
+from sfa.util.sfalogging import logger
+from sfa.util.xrn import urn_to_hrn
 from sfa.util.faults import *
+from sfa.util.sfalogging import logger
+
+glo_passphrase_callback = None
+
+##
+# A global callback msy be implemented for requesting passphrases from the
+# user. The function will be called with three arguments:
+#
+#    keypair_obj: the keypair object that is calling the passphrase
+#    string: the string containing the private key that's being loaded
+#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto
+#
+# The callback should return a string containing the passphrase.
+
+def set_passphrase_callback(callback_func):
+    global glo_passphrase_callback
+
+    glo_passphrase_callback = callback_func
+
+##
+# Sets a fixed passphrase.
+
+def set_passphrase(passphrase):
+    set_passphrase_callback( lambda k,s,x: passphrase )
+
+##
+# Check to see if a passphrase works for a particular private key string.
+# Intended to be used by passphrase callbacks for input validation.
+
+def test_passphrase(string, passphrase):
+    try:
+        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))
+        return True
+    except:
+        return False
 
 def convert_public_key(key):
     keyconvert_path = "/usr/bin/keyconvert.py"
@@ -79,7 +112,7 @@ def convert_public_key(key):
     try:
         k.load_pubkey_from_file(ssl_fn)
     except:
-        sfa_logger().log_exc("convert_public_key caught exception")
+        logger.log_exc("convert_public_key caught exception")
         k = None
 
     # remove the temporary files
@@ -125,11 +158,13 @@ class Keypair:
 
     def save_to_file(self, filename):
         open(filename, 'w').write(self.as_pem())
+        self.filename=filename
 
     ##
     # Load the private key from a file. Implicity the private key includes the public key.
 
     def load_from_file(self, filename):
+        self.filename=filename
         buffer = open(filename, 'r').read()
         self.load_from_string(buffer)
 
@@ -137,8 +172,12 @@ class Keypair:
     # Load the private key from a string. Implicitly the private key includes the public key.
 
     def load_from_string(self, string):
-        self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
-        self.m2key = M2Crypto.EVP.load_key_string(string)
+        if glo_passphrase_callback:
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
+            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
+        else:
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
+            self.m2key = M2Crypto.EVP.load_key_string(string)
 
     ##
     #  Load the public key from a string. No private key is loaded.
@@ -161,6 +200,9 @@ class Keypair:
         ASN1.set_time(500)
         m2x509.set_not_before(ASN1)
         m2x509.set_not_after(ASN1)
+        # x509v3 so it can have extensions
+        # prob not necc since this cert itself is junk but still...
+        m2x509.set_version(2)
         junk_key = Keypair(create=True)
         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
 
@@ -170,6 +212,7 @@ class Keypair:
 
         # get the pyopenssl pkey from the pyopenssl x509
         self.key = pyx509.get_pubkey()
+        self.filename=filename
 
     ##
     # Load the public key from a string. No private key is loaded.
@@ -208,7 +251,6 @@ class Keypair:
     def get_openssl_pkey(self):
         return self.key
 
-
     ##
     # Given another Keypair object, return TRUE if the two keys are the same.
 
@@ -230,6 +272,20 @@ class Keypair:
     def compute_hash(self, value):
         return self.sign_string(str(value))
 
+    # only informative
+    def get_filename(self):
+        return getattr(self,'filename',None)
+
+    def dump (self, *args, **kwargs):
+        print self.dump_string(*args, **kwargs)
+
+    def dump_string (self):
+        result=""
+        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
+        filename=self.get_filename()
+        if filename: result += "Filename %s\n"%filename
+        return result
+    
 ##
 # The certificate class implements a general purpose X509 certificate, making
 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
@@ -275,7 +331,6 @@ class Certificate:
         if intermediate:
             self.set_intermediate_ca(intermediate)
 
-    ##
     # Create a blank X509 certificate and store it in this object.
 
     def create(self):
@@ -283,6 +338,8 @@ class Certificate:
         self.cert.set_serial_number(3)
         self.cert.gmtime_adj_notBefore(0)
         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
+        self.cert.set_version(2) # x509v3 so it can have extensions        
+
 
     ##
     # Given a pyOpenSSL X509 object, store that object inside of this
@@ -299,11 +356,18 @@ class Certificate:
         # load it (support for the ---parent--- tag as well as normal chained certs)
 
         string = string.strip()
-
-
-        if not string.startswith('-----'):
+        
+        # If it's not in proper PEM format, wrap it
+        if string.count('-----BEGIN CERTIFICATE') == 0:
             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
 
+        # If there is a PEM cert in there, but there is some other text first
+        # such as the text of the certificate, skip the text
+        beg = string.find('-----BEGIN CERTIFICATE')
+        if beg > 0:
+            # skipping over non cert beginning                                                                                                              
+            string = string[beg:]
+
         parts = []
 
         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
@@ -328,6 +392,7 @@ class Certificate:
         file = open(filename)
         string = file.read()
         self.load_from_string(string)
+        self.filename=filename
 
     ##
     # Save the certificate to a string.
@@ -352,6 +417,7 @@ class Certificate:
             f = open(filename, 'w')
         f.write(string)
         f.close()
+        self.filename=filename
 
     ##
     # Save the certificate to a random file in /tmp/
@@ -498,7 +564,7 @@ class Certificate:
     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
 
     def sign(self):
-        sfa_logger().debug('certificate.sign')
+        logger.debug('certificate.sign')
         assert self.cert != None
         assert self.issuerSubject != None
         assert self.issuerKey != None
@@ -538,7 +604,6 @@ class Certificate:
     # @param cert certificate object
 
     def is_signed_by_cert(self, cert):
-        print 'is_signed_by_cert'
         k = cert.get_pubkey()
         result = self.verify(k)
         return result
@@ -584,7 +649,7 @@ class Certificate:
 
         # verify expiration time
         if self.cert.has_expired():
-            sfa_logger().debug("verify_chain: NO our certificate has expired")
+            logger.debug("verify_chain: NO our certificate has expired")
             raise CertExpired(self.get_subject(), "client cert")   
         
         # if this cert is signed by a trusted_cert, then we are set
@@ -592,26 +657,26 @@ class Certificate:
             if self.is_signed_by_cert(trusted_cert):
                 # verify expiration of trusted_cert ?
                 if not trusted_cert.cert.has_expired():
-                    sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
+                    logger.debug("verify_chain: YES cert %s signed by trusted cert %s"%(
                             self.get_subject(), trusted_cert.get_subject()))
                     return trusted_cert
                 else:
-                    sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
+                    logger.debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
                             self.get_subject(),trusted_cert.get_subject()))
                     raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
 
         # if there is no parent, then no way to verify the chain
         if not self.parent:
-            sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
+            logger.debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
             raise CertMissingParent(self.get_subject())
 
         # if it wasn't signed by the parent...
         if not self.is_signed_by_cert(self.parent):
-            sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
+            logger.debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
             return CertNotSignedByParent(self.get_subject())
 
         # if the parent isn't verified...
-        sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject())
+        logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
         self.parent.verify_chain(trusted_certs)
 
         return
@@ -622,7 +687,7 @@ class Certificate:
         triples=[]
         m2x509 = X509.load_cert_string(self.save_to_string())
         nb_extensions=m2x509.get_ext_count()
-        sfa_logger().debug("X509 had %d extensions"%nb_extensions)
+        logger.debug("X509 had %d extensions"%nb_extensions)
         for i in range(nb_extensions):
             ext=m2x509.get_ext_at(i)
             triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
@@ -637,16 +702,25 @@ class Certificate:
             triples.append( (name,self.get_data(name),'data',) )
         return triples
 
+    # only informative
+    def get_filename(self):
+        return getattr(self,'filename',None)
+
     def dump (self, *args, **kwargs):
         print self.dump_string(*args, **kwargs)
 
-    def dump_string (self):
+    def dump_string (self,show_extensions=False):
         result = ""
-        result += "Certificate for %s\n"%self.get_subject()
+        result += "CERTIFICATE for %s\n"%self.get_subject()
         result += "Issued by %s\n"%self.get_issuer()
-        for (n,v,c) in self.get_all_datas():
-            if c=='data':
-                result += "   data: %s=%s\n"%(n,v)
-            else:
-                result += "    ext: %s=%s (%s)\n"%(n,v,c)
+        filename=self.get_filename()
+        if filename: result += "Filename %s\n"%filename
+        if show_extensions:
+            all_datas=self.get_all_datas()
+            result += " has %d extensions/data attached"%len(all_datas)
+            for (n,v,c) in all_datas:
+                if c=='data':
+                    result += "   data: %s=%s\n"%(n,v)
+                else:
+                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
         return result