Fixed up URNs in GID creation
[sfa.git] / sfa / trust / certificate.py
index 54d98eb..6a5ee2d 100644 (file)
@@ -1,5 +1,5 @@
 ##
-# Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
+# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
 # the necessary crypto functionality. Ideally just one of these libraries
 # would be used, but unfortunately each of these libraries is independently
 # lacking. The pyOpenSSL library is missing many necessary functions, and
@@ -22,7 +22,8 @@ import traceback
 from OpenSSL import crypto
 import M2Crypto
 from M2Crypto import X509
-from M2Crypto import EVP
+from tempfile import mkstemp
+from sfa.util.sfalogging import logger
 
 from sfa.util.faults import *
 
@@ -33,7 +34,6 @@ def convert_public_key(key):
 
     # we can only convert rsa keys 
     if "ssh-dss" in key:
-        print "XXX: DSA key encountered, ignoring"
         return None
     
     (ssh_f, ssh_fn) = tempfile.mkstemp()
@@ -43,19 +43,17 @@ def convert_public_key(key):
 
     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
     os.system(cmd)
-
+    
     # this check leaves the temporary file containing the public key so
     # that it can be expected to see why it failed.
     # TODO: for production, cleanup the temporary files
     if not os.path.exists(ssl_fn):
-        report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
         return None
-
+    
     k = Keypair()
     try:
         k.load_pubkey_from_file(ssl_fn)
     except:
-        print "XXX: Error while converting key: ", key
         traceback.print_exc()
         k = None
 
@@ -73,7 +71,7 @@ def convert_public_key(key):
 class Keypair:
    key = None       # public/private keypair
    m2key = None     # public key (m2crypto format)
-
+   
    ##
    # Creates a Keypair object
    # @param create If create==True, creates a new public/private key and
@@ -221,7 +219,6 @@ class Keypair:
 class Certificate:
    digest = "md5"
 
-   data = None
    cert = None
    issuerKey = None
    issuerSubject = None
@@ -239,6 +236,7 @@ class Certificate:
    # @param filename If filename!=None, load the certficiate from the file.
 
    def __init__(self, create=False, subject=None, string=None, filename=None):
+       self.data = {}
        if create or subject:
            self.create()
        if subject:
@@ -269,20 +267,35 @@ class Certificate:
 
    def load_from_string(self, string):
        # if it is a chain of multiple certs, then split off the first one and
-       # load it
-       parts = string.split(Certificate.separator, 1)
+       # load it (support for the ---parent--- tag as well as normal chained certs)       
+
+       string = string.strip()       
+       
+       
+       if not string.startswith('-----'):
+           string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
+           
+       parts = []
+       
+       if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
+              string.count(Certificate.separator) == 0:
+           parts = string.split('-----END CERTIFICATE-----',1)
+           parts[0] += '-----END CERTIFICATE-----'
+       else:
+           parts = string.split(Certificate.separator, 1)
+       
        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
 
        # if there are more certs, then create a parent and let the parent load
        # itself from the remainder of the string
-       if len(parts) > 1:
+       if len(parts) > 1 and parts[1] != '':
            self.parent = self.__class__()
            self.parent.load_from_string(parts[1])
 
    ##
    # Load the certificate from a file
 
-   def load_from_file(self, filename):
+   def load_from_file(self, filename):     
        file = open(filename)
        string = file.read()
        self.load_from_string(string)
@@ -292,19 +305,33 @@ class Certificate:
    #
    # @param save_parents If save_parents==True, then also save the parent certificates.
 
-   def save_to_string(self, save_parents=False):
+   def save_to_string(self, save_parents=True):
        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
        if save_parents and self.parent:
-          string = string + Certificate.separator + self.parent.save_to_string(save_parents)
+          string = string + self.parent.save_to_string(save_parents)
        return string
 
    ##
    # Save the certificate to a file.
    # @param save_parents If save_parents==True, then also save the parent certificates.
 
-   def save_to_file(self, filename, save_parents=False):
+   def save_to_file(self, filename, save_parents=True, filep=None):
        string = self.save_to_string(save_parents=save_parents)
-       open(filename, 'w').write(string)
+       if filep:
+           f = filep
+       else:
+           f = open(filename, 'w')
+       f.write(string)
+       f.close()
+
+   ##
+   # Save the certificate to a random file in /tmp/
+   # @param save_parents If save_parents==True, then also save the parent certificates.  
+   def save_to_random_tmp_file(self, save_parents=True):       
+       fp, filename = mkstemp(suffix='cert', text=True)
+       fp = os.fdopen(fp, "w")
+       self.save_to_file(filename, save_parents=True, filep=fp)
+       return filename   
 
    ##
    # Sets the issuer private key and name
@@ -406,31 +433,28 @@ class Certificate:
    # the X509 subject_alt_name extension. Set_data can only be called once, due
    # to limitations in the underlying library.
 
-   def set_data(self, str):
+   def set_data(self, str, field='subjectAltName'):
        # pyOpenSSL only allows us to add extensions, so if we try to set the
        # same extension more than once, it will not work
-       if self.data != None:
-          raise "cannot set subjectAltName more than once"
-       self.data = str
-       self.add_extension("subjectAltName", 0, "URI:http://" + str)
+       if self.data.has_key(field):
+          raise "cannot set ", field, " more than once"
+       self.data[field] = str
+       self.add_extension(field, 0, str)
 
    ##
    # Return the data string that was previously set with set_data
 
-   def get_data(self):
-       if self.data:
-           return self.data
+   def get_data(self, field='subjectAltName'):
+       if self.data.has_key(field):
+           return self.data[field]
 
        try:
-           uri = self.get_extension("subjectAltName")
+           uri = self.get_extension(field)
+           self.data[field] = uri           
        except LookupError:
-           self.data = None
-           return self.data
-
-       if not uri.startswith("URI:http://"):
-           raise "bad encoding in subjectAltName"
-       self.data = uri[11:]
-       return self.data
+           return None
+       
+       return self.data[field]
 
    ##
    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
@@ -528,6 +552,10 @@ class Certificate:
             #print "TRUSTED CERT", trusted_cert.dump()
             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
             if self.is_signed_by_cert(trusted_cert):
+                # make sure sure the trusted cert's hrn is a prefix of the
+                # signed cert's hrn
+                if not self.get_subject().startswith(trusted_cert.get_subject()):
+                    raise GidParentHrn(trusted_cert.get_subject()) 
                 #print self.get_subject(), "is signed by a root"
                 return