Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / trust / certificate.py
index 98588a9..9a2e862 100644 (file)
@@ -35,6 +35,7 @@
 ##
 #
 
+import functools
 import os
 import tempfile
 import base64
@@ -46,9 +47,43 @@ import M2Crypto
 from M2Crypto import X509
 
 from sfa.util.sfalogging import sfa_logger
-from sfa.util.namespace import urn_to_hrn
+from sfa.util.xrn import urn_to_hrn
 from sfa.util.faults import *
 
+glo_passphrase_callback = None
+
+##
+# A global callback msy be implemented for requesting passphrases from the
+# user. The function will be called with three arguments:
+#
+#    keypair_obj: the keypair object that is calling the passphrase
+#    string: the string containing the private key that's being loaded
+#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto
+#
+# The callback should return a string containing the passphrase.
+
+def set_passphrase_callback(callback_func):
+    global glo_passphrase_callback
+
+    glo_passphrase_callback = callback_func
+
+##
+# Sets a fixed passphrase.
+
+def set_passphrase(passphrase):
+    set_passphrase_callback( lambda k,s,x: passphrase )
+
+##
+# Check to see if a passphrase works for a particular private key string.
+# Intended to be used by passphrase callbacks for input validation.
+
+def test_passphrase(string, passphrase):
+    try:
+        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))
+        return True
+    except:
+        return False
+
 def convert_public_key(key):
     keyconvert_path = "/usr/bin/keyconvert.py"
     if not os.path.isfile(keyconvert_path):
@@ -128,16 +163,20 @@ class Keypair:
     # Load the private key from a file. Implicity the private key includes the public key.
 
     def load_from_file(self, filename):
+        self.filename=filename
         buffer = open(filename, 'r').read()
         self.load_from_string(buffer)
-        self.filename=filename
 
     ##
     # Load the private key from a string. Implicitly the private key includes the public key.
 
     def load_from_string(self, string):
-        self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
-        self.m2key = M2Crypto.EVP.load_key_string(string)
+        if glo_passphrase_callback:
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
+            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
+        else:
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
+            self.m2key = M2Crypto.EVP.load_key_string(string)
 
     ##
     #  Load the public key from a string. No private key is loaded.
@@ -160,6 +199,9 @@ class Keypair:
         ASN1.set_time(500)
         m2x509.set_not_before(ASN1)
         m2x509.set_not_after(ASN1)
+        # x509v3 so it can have extensions
+        # prob not necc since this cert itself is junk but still...
+        m2x509.set_version(2)
         junk_key = Keypair(create=True)
         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
 
@@ -296,6 +338,8 @@ class Certificate:
         self.cert.set_serial_number(3)
         self.cert.gmtime_adj_notBefore(0)
         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
+        self.cert.set_version(2) # x509v3 so it can have extensions        
+
 
     ##
     # Given a pyOpenSSL X509 object, store that object inside of this
@@ -312,11 +356,18 @@ class Certificate:
         # load it (support for the ---parent--- tag as well as normal chained certs)
 
         string = string.strip()
-
-
-        if not string.startswith('-----'):
+        
+        # If it's not in proper PEM format, wrap it
+        if string.count('-----BEGIN CERTIFICATE') == 0:
             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
 
+        # If there is a PEM cert in there, but there is some other text first
+        # such as the text of the certificate, skip the text
+        beg = string.find('-----BEGIN CERTIFICATE')
+        if beg > 0:
+            # skipping over non cert beginning                                                                                                              
+            string = string[beg:]
+
         parts = []
 
         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
@@ -553,7 +604,6 @@ class Certificate:
     # @param cert certificate object
 
     def is_signed_by_cert(self, cert):
-        print 'is_signed_by_cert'
         k = cert.get_pubkey()
         result = self.verify(k)
         return result
@@ -626,7 +676,7 @@ class Certificate:
             return CertNotSignedByParent(self.get_subject())
 
         # if the parent isn't verified...
-        sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject())
+        sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
         self.parent.verify_chain(trusted_certs)
 
         return