Fixed up URNs in GID creation
[sfa.git] / sfa / trust / certificate.py
index 9b48835..6a5ee2d 100644 (file)
@@ -22,7 +22,8 @@ import traceback
 from OpenSSL import crypto
 import M2Crypto
 from M2Crypto import X509
-from M2Crypto import EVP
+from tempfile import mkstemp
+from sfa.util.sfalogging import logger
 
 from sfa.util.faults import *
 
@@ -70,7 +71,7 @@ def convert_public_key(key):
 class Keypair:
    key = None       # public/private keypair
    m2key = None     # public key (m2crypto format)
-
+   
    ##
    # Creates a Keypair object
    # @param create If create==True, creates a new public/private key and
@@ -218,7 +219,6 @@ class Keypair:
 class Certificate:
    digest = "md5"
 
-   data = None
    cert = None
    issuerKey = None
    issuerSubject = None
@@ -236,6 +236,7 @@ class Certificate:
    # @param filename If filename!=None, load the certficiate from the file.
 
    def __init__(self, create=False, subject=None, string=None, filename=None):
+       self.data = {}
        if create or subject:
            self.create()
        if subject:
@@ -266,20 +267,35 @@ class Certificate:
 
    def load_from_string(self, string):
        # if it is a chain of multiple certs, then split off the first one and
-       # load it
-       parts = string.split(Certificate.separator, 1)
+       # load it (support for the ---parent--- tag as well as normal chained certs)       
+
+       string = string.strip()       
+       
+       
+       if not string.startswith('-----'):
+           string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
+           
+       parts = []
+       
+       if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
+              string.count(Certificate.separator) == 0:
+           parts = string.split('-----END CERTIFICATE-----',1)
+           parts[0] += '-----END CERTIFICATE-----'
+       else:
+           parts = string.split(Certificate.separator, 1)
+       
        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
 
        # if there are more certs, then create a parent and let the parent load
        # itself from the remainder of the string
-       if len(parts) > 1:
+       if len(parts) > 1 and parts[1] != '':
            self.parent = self.__class__()
            self.parent.load_from_string(parts[1])
 
    ##
    # Load the certificate from a file
 
-   def load_from_file(self, filename):
+   def load_from_file(self, filename):     
        file = open(filename)
        string = file.read()
        self.load_from_string(string)
@@ -289,19 +305,33 @@ class Certificate:
    #
    # @param save_parents If save_parents==True, then also save the parent certificates.
 
-   def save_to_string(self, save_parents=False):
+   def save_to_string(self, save_parents=True):
        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
        if save_parents and self.parent:
-          string = string + Certificate.separator + self.parent.save_to_string(save_parents)
+          string = string + self.parent.save_to_string(save_parents)
        return string
 
    ##
    # Save the certificate to a file.
    # @param save_parents If save_parents==True, then also save the parent certificates.
 
-   def save_to_file(self, filename, save_parents=False):
+   def save_to_file(self, filename, save_parents=True, filep=None):
        string = self.save_to_string(save_parents=save_parents)
-       open(filename, 'w').write(string)
+       if filep:
+           f = filep
+       else:
+           f = open(filename, 'w')
+       f.write(string)
+       f.close()
+
+   ##
+   # Save the certificate to a random file in /tmp/
+   # @param save_parents If save_parents==True, then also save the parent certificates.  
+   def save_to_random_tmp_file(self, save_parents=True):       
+       fp, filename = mkstemp(suffix='cert', text=True)
+       fp = os.fdopen(fp, "w")
+       self.save_to_file(filename, save_parents=True, filep=fp)
+       return filename   
 
    ##
    # Sets the issuer private key and name
@@ -403,31 +433,28 @@ class Certificate:
    # the X509 subject_alt_name extension. Set_data can only be called once, due
    # to limitations in the underlying library.
 
-   def set_data(self, str):
+   def set_data(self, str, field='subjectAltName'):
        # pyOpenSSL only allows us to add extensions, so if we try to set the
        # same extension more than once, it will not work
-       if self.data != None:
-          raise "cannot set subjectAltName more than once"
-       self.data = str
-       self.add_extension("subjectAltName", 0, "URI:http://" + str)
+       if self.data.has_key(field):
+          raise "cannot set ", field, " more than once"
+       self.data[field] = str
+       self.add_extension(field, 0, str)
 
    ##
    # Return the data string that was previously set with set_data
 
-   def get_data(self):
-       if self.data:
-           return self.data
+   def get_data(self, field='subjectAltName'):
+       if self.data.has_key(field):
+           return self.data[field]
 
        try:
-           uri = self.get_extension("subjectAltName")
+           uri = self.get_extension(field)
+           self.data[field] = uri           
        except LookupError:
-           self.data = None
-           return self.data
-
-       if not uri.startswith("URI:http://"):
-           raise "bad encoding in subjectAltName"
-       self.data = uri[11:]
-       return self.data
+           return None
+       
+       return self.data[field]
 
    ##
    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
@@ -525,6 +552,10 @@ class Certificate:
             #print "TRUSTED CERT", trusted_cert.dump()
             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
             if self.is_signed_by_cert(trusted_cert):
+                # make sure sure the trusted cert's hrn is a prefix of the
+                # signed cert's hrn
+                if not self.get_subject().startswith(trusted_cert.get_subject()):
+                    raise GidParentHrn(trusted_cert.get_subject()) 
                 #print self.get_subject(), "is signed by a root"
                 return