namespace module is gone, plxrn provides PL-specific translations
[sfa.git] / sfa / trust / certificate.py
index 9fd58ed..839d1df 100644 (file)
@@ -1,3 +1,26 @@
+#----------------------------------------------------------------------
+# Copyright (c) 2008 Board of Trustees, Princeton University
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and/or hardware specification (the "Work") to
+# deal in the Work without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Work, and to permit persons to whom the Work
+# is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Work.
+#
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
+# IN THE WORK.
+#----------------------------------------------------------------------
+
 ##
 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
 # the necessary crypto functionality. Ideally just one of these libraries
 # This module exports two classes: Keypair and Certificate.
 ##
 #
-### $Id$
-### $URL$
-#
 
 import os
 import tempfile
 import base64
 import traceback
+from tempfile import mkstemp
+
 from OpenSSL import crypto
 import M2Crypto
 from M2Crypto import X509
-from tempfile import mkstemp
-from sfa.util.sfalogging import logger
-from sfa.util.namespace import urn_to_hrn
+
+from sfa.util.sfalogging import sfa_logger
+from sfa.util.xrn import urn_to_hrn
 from sfa.util.faults import *
 
 def convert_public_key(key):
-    keyconvert_path = "/usr/bin/keyconvert"
+    keyconvert_path = "/usr/bin/keyconvert.py"
     if not os.path.isfile(keyconvert_path):
         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
 
@@ -54,7 +76,7 @@ def convert_public_key(key):
     try:
         k.load_pubkey_from_file(ssl_fn)
     except:
-        traceback.print_exc()
+        sfa_logger().log_exc("convert_public_key caught exception")
         k = None
 
     # remove the temporary files
@@ -100,6 +122,7 @@ class Keypair:
 
     def save_to_file(self, filename):
         open(filename, 'w').write(self.as_pem())
+        self.filename=filename
 
     ##
     # Load the private key from a file. Implicity the private key includes the public key.
@@ -107,6 +130,7 @@ class Keypair:
     def load_from_file(self, filename):
         buffer = open(filename, 'r').read()
         self.load_from_string(buffer)
+        self.filename=filename
 
     ##
     # Load the private key from a string. Implicitly the private key includes the public key.
@@ -145,6 +169,7 @@ class Keypair:
 
         # get the pyopenssl pkey from the pyopenssl x509
         self.key = pyx509.get_pubkey()
+        self.filename=filename
 
     ##
     # Load the public key from a string. No private key is loaded.
@@ -183,7 +208,6 @@ class Keypair:
     def get_openssl_pkey(self):
         return self.key
 
-
     ##
     # Given another Keypair object, return TRUE if the two keys are the same.
 
@@ -205,6 +229,20 @@ class Keypair:
     def compute_hash(self, value):
         return self.sign_string(str(value))
 
+    # only informative
+    def get_filename(self):
+        return getattr(self,'filename',None)
+
+    def dump (self, *args, **kwargs):
+        print self.dump_string(*args, **kwargs)
+
+    def dump_string (self):
+        result=""
+        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
+        filename=self.get_filename()
+        if filename: result += "Filename %s\n"%filename
+        return result
+    
 ##
 # The certificate class implements a general purpose X509 certificate, making
 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
@@ -274,11 +312,18 @@ class Certificate:
         # load it (support for the ---parent--- tag as well as normal chained certs)
 
         string = string.strip()
-
-
-        if not string.startswith('-----'):
+        
+        # If it's not in proper PEM format, wrap it
+        if string.count('-----BEGIN CERTIFICATE') == 0:
             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
 
+        # If there is a PEM cert in there, but there is some other text first
+        # such as the text of the certificate, skip the text
+        beg = string.find('-----BEGIN CERTIFICATE')
+        if beg > 0:
+            # skipping over non cert beginning                                                                                                              
+            string = string[beg:]
+
         parts = []
 
         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
@@ -303,6 +348,7 @@ class Certificate:
         file = open(filename)
         string = file.read()
         self.load_from_string(string)
+        self.filename=filename
 
     ##
     # Save the certificate to a string.
@@ -327,6 +373,7 @@ class Certificate:
             f = open(filename, 'w')
         f.write(string)
         f.close()
+        self.filename=filename
 
     ##
     # Save the certificate to a random file in /tmp/
@@ -434,9 +481,11 @@ class Certificate:
     # Get an X509 extension from the certificate
 
     def get_extension(self, name):
+
         # pyOpenSSL does not have a way to get extensions
         m2x509 = X509.load_cert_string(self.save_to_string())
         value = m2x509.get_ext(name).get_value()
+        
         return value
 
     ##
@@ -448,7 +497,7 @@ class Certificate:
         # pyOpenSSL only allows us to add extensions, so if we try to set the
         # same extension more than once, it will not work
         if self.data.has_key(field):
-            raise "cannot set ", field, " more than once"
+            raise "Cannot set ", field, " more than once"
         self.data[field] = str
         self.add_extension(field, 0, str)
 
@@ -471,6 +520,7 @@ class Certificate:
     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
 
     def sign(self):
+        sfa_logger().debug('certificate.sign')
         assert self.cert != None
         assert self.issuerSubject != None
         assert self.issuerKey != None
@@ -510,6 +560,7 @@ class Certificate:
     # @param cert certificate object
 
     def is_signed_by_cert(self, cert):
+        print 'is_signed_by_cert'
         k = cert.get_pubkey()
         result = self.verify(k)
         return result
@@ -553,29 +604,80 @@ class Certificate:
         # the public key contained in it's parent. The chain is recursed
         # until a certificate is found that is signed by a trusted root.
 
-        # TODO: verify expiration time
-        #print "====Verify Chain====="
+        # verify expiration time
+        if self.cert.has_expired():
+            sfa_logger().debug("verify_chain: NO our certificate has expired")
+            raise CertExpired(self.get_subject(), "client cert")   
+        
         # if this cert is signed by a trusted_cert, then we are set
         for trusted_cert in trusted_certs:
-            #print "***************"
-            # TODO: verify expiration of trusted_cert ?
-            #print "CLIENT CERT", self.dump()
-            #print "TRUSTED CERT", trusted_cert.dump()
-            #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
             if self.is_signed_by_cert(trusted_cert):
-                return trusted_cert
+                # verify expiration of trusted_cert ?
+                if not trusted_cert.cert.has_expired():
+                    sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
+                            self.get_subject(), trusted_cert.get_subject()))
+                    return trusted_cert
+                else:
+                    sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
+                            self.get_subject(),trusted_cert.get_subject()))
+                    raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
 
         # if there is no parent, then no way to verify the chain
         if not self.parent:
-            #print self.get_subject(), "has no parent"
+            sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
             raise CertMissingParent(self.get_subject())
 
         # if it wasn't signed by the parent...
         if not self.is_signed_by_cert(self.parent):
-            #print self.get_subject(), "is not signed by parent"
+            sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
             return CertNotSignedByParent(self.get_subject())
 
         # if the parent isn't verified...
+        sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject())
         self.parent.verify_chain(trusted_certs)
 
         return
+
+    ### more introspection
+    def get_extensions(self):
+        # pyOpenSSL does not have a way to get extensions
+        triples=[]
+        m2x509 = X509.load_cert_string(self.save_to_string())
+        nb_extensions=m2x509.get_ext_count()
+        sfa_logger().debug("X509 had %d extensions"%nb_extensions)
+        for i in range(nb_extensions):
+            ext=m2x509.get_ext_at(i)
+            triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
+        return triples
+
+    def get_data_names(self):
+        return self.data.keys()
+
+    def get_all_datas (self):
+        triples=self.get_extensions()
+        for name in self.get_data_names(): 
+            triples.append( (name,self.get_data(name),'data',) )
+        return triples
+
+    # only informative
+    def get_filename(self):
+        return getattr(self,'filename',None)
+
+    def dump (self, *args, **kwargs):
+        print self.dump_string(*args, **kwargs)
+
+    def dump_string (self,show_extensions=False):
+        result = ""
+        result += "CERTIFICATE for %s\n"%self.get_subject()
+        result += "Issued by %s\n"%self.get_issuer()
+        filename=self.get_filename()
+        if filename: result += "Filename %s\n"%filename
+        if show_extensions:
+            all_datas=self.get_all_datas()
+            result += " has %d extensions/data attached"%len(all_datas)
+            for (n,v,c) in all_datas:
+                if c=='data':
+                    result += "   data: %s=%s\n"%(n,v)
+                else:
+                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
+        return result