2to3 -f raise
[sfa.git] / sfa / trust / certificate.py
index c3f3790..6f3ecc6 100644 (file)
@@ -35,6 +35,8 @@
 ##
 #
 
+from __future__ import print_function
+
 import functools
 import os
 import tempfile
@@ -88,11 +90,11 @@ def test_passphrase(string, passphrase):
 def convert_public_key(key):
     keyconvert_path = "/usr/bin/keyconvert.py"
     if not os.path.isfile(keyconvert_path):
-        raise IOError, "Could not find keyconvert in %s" % keyconvert_path
+        raise IOError("Could not find keyconvert in %s" % keyconvert_path)
 
     # we can only convert rsa keys
     if "ssh-dss" in key:
-        raise Exception, "keyconvert: dss keys are not supported"
+        raise Exception("keyconvert: dss keys are not supported")
 
     (ssh_f, ssh_fn) = tempfile.mkstemp()
     ssl_fn = tempfile.mktemp()
@@ -106,7 +108,7 @@ def convert_public_key(key):
     # that it can be expected to see why it failed.
     # TODO: for production, cleanup the temporary files
     if not os.path.exists(ssl_fn):
-        raise Exception, "keyconvert: generated certificate not found. keyconvert may have failed."
+        raise Exception("keyconvert: generated certificate not found. keyconvert may have failed.")
 
     k = Keypair()
     try:
@@ -151,7 +153,7 @@ class Keypair:
 
     def create(self):
         self.key = crypto.PKey()
-        self.key.generate_key(crypto.TYPE_RSA, 1024)
+        self.key.generate_key(crypto.TYPE_RSA, 2048)
 
     ##
     # Save the private key to a file
@@ -174,8 +176,10 @@ class Keypair:
 
     def load_from_string(self, string):
         if glo_passphrase_callback:
-            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
-            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
+            self.key = crypto.load_privatekey(
+                crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string))
+            self.m2key = M2Crypto.EVP.load_key_string(
+                string, functools.partial(glo_passphrase_callback, self, string))
         else:
             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
             self.m2key = M2Crypto.EVP.load_key_string(string)
@@ -205,7 +209,7 @@ class Keypair:
         # prob not necc since this cert itself is junk but still...
         m2x509.set_version(2)
         junk_key = Keypair(create=True)
-        m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
+        m2x509.sign(pkey=junk_key.get_m2_pubkey(), md="sha1")
 
         # convert the m2 x509 cert to a pyopenssl x509
         m2pem = m2x509.as_pem()
@@ -234,7 +238,7 @@ class Keypair:
     ##
     # Return an M2Crypto key object
 
-    def get_m2_pkey(self):
+    def get_m2_pubkey(self):
         if not self.m2key:
             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
         return self.m2key
@@ -243,7 +247,7 @@ class Keypair:
     # Returns a string containing the public key represented by this object.
 
     def get_pubkey_string(self):
-        m2pkey = self.get_m2_pkey()
+        m2pkey = self.get_m2_pubkey()
         return base64.b64encode(m2pkey.as_der())
 
     ##
@@ -259,13 +263,13 @@ class Keypair:
         return self.as_pem() == pkey.as_pem()
 
     def sign_string(self, data):
-        k = self.get_m2_pkey()
+        k = self.get_m2_pubkey()
         k.sign_init()
         k.sign_update(data)
         return base64.b64encode(k.sign_final())
 
     def verify_string(self, data, sig):
-        k = self.get_m2_pkey()
+        k = self.get_m2_pubkey()
         k.verify_init()
         k.verify_update(data)
         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
@@ -278,7 +282,7 @@ class Keypair:
         return getattr(self,'filename',None)
 
     def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
+        print(self.dump_string(*args, **kwargs))
 
     def dump_string (self):
         result=""
@@ -300,7 +304,7 @@ class Keypair:
 # whether to save the parent certificates as well.
 
 class Certificate:
-    digest = "md5"
+    digest = "sha256"
 
 #    x509 = None
 #    issuerKey = None
@@ -367,6 +371,10 @@ class Certificate:
         # if it is a chain of multiple certs, then split off the first one and
         # load it (support for the ---parent--- tag as well as normal chained certs)
 
+        if string is None or string.strip() == "":
+            logger.warn("Empty string in load_from_string")
+            return
+
         string = string.strip()
         
         # If it's not in proper PEM format, wrap it
@@ -391,6 +399,9 @@ class Certificate:
 
         self.x509 = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
 
+        if self.x509 is None:
+            logger.warn("Loaded from string but cert is None: %s" % string)
+
         # if there are more certs, then create a parent and let the parent load
         # itself from the remainder of the string
         if len(parts) > 1 and parts[1] != '':
@@ -412,6 +423,9 @@ class Certificate:
     # @param save_parents If save_parents==True, then also save the parent certificates.
 
     def save_to_string(self, save_parents=True):
+        if self.x509 is None:
+            logger.warn("None cert in certificate.save_to_string")
+            return ""
         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.x509)
         if save_parents and self.parent:
             string = string + self.parent.save_to_string(save_parents)
@@ -498,6 +512,7 @@ class Certificate:
     ##
     # Get a pretty-print subject name of the certificate
     # let's try to make this a little more usable as is makes logs hairy
+    # FIXME: Consider adding 'urn:publicid' and 'uuid' back for GENI?
     pretty_fields = ['email']
     def filter_chunk(self, chunk):
         for field in self.pretty_fields:
@@ -555,7 +570,7 @@ class Certificate:
 
         if self.isCA != None:
             # Can't double set properties
-            raise Exception, "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val)
+            raise Exception("Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val))
 
         self.isCA = val
         if val:
@@ -601,8 +616,19 @@ class Certificate:
 
     def get_extension(self, name):
 
+        if name is None:
+            return None
+
+        certstr = self.save_to_string()
+        if certstr is None or certstr == "":
+            return None
         # pyOpenSSL does not have a way to get extensions
-        m2x509 = X509.load_cert_string(self.save_to_string())
+        m2x509 = X509.load_cert_string(certstr)
+        if m2x509 is None:
+            logger.warn("No cert loaded in get_extension")
+            return None
+        if m2x509.get_ext(name) is None:
+            return None
         value = m2x509.get_ext(name).get_value()
 
         return value
@@ -651,12 +677,14 @@ class Certificate:
     # @param pkey is a Keypair object representing a public key. If Pkey
     #     did not sign the certificate, then an exception will be thrown.
 
-    def verify(self, pkey):
+    def verify(self, pubkey):
         # pyOpenSSL does not have a way to verify signatures
         m2x509 = X509.load_cert_string(self.save_to_string())
-        m2pkey = pkey.get_m2_pkey()
+        m2pubkey = pubkey.get_m2_pubkey()
         # verify it
-        return m2x509.verify(m2pkey)
+        # verify returns -1 or 0 on failure depending on how serious the
+        # error conditions are
+        return m2x509.verify(m2pubkey) == 1
 
         # XXX alternatively, if openssl has been patched, do the much simpler:
         # try:
@@ -734,7 +762,7 @@ class Certificate:
             if self.is_signed_by_cert(trusted_cert):
                 # verify expiration of trusted_cert ?
                 if not trusted_cert.x509.has_expired():
-                    if debug_verify_chain: 
+                    if debug_verify_chain:
                         logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%(
                             self.pretty_cert(), trusted_cert.pretty_cert()))
                     return trusted_cert
@@ -757,11 +785,11 @@ class Certificate:
         if not self.is_signed_by_cert(self.parent):
             if debug_verify_chain:
                 logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%\
-                             (self.pretty_cert(), 
-                              self.parent.pretty_cert(), 
+                             (self.pretty_cert(),
+                              self.parent.pretty_cert(),
                               self.get_issuer()))
             raise CertNotSignedByParent("%s: Parent %s, issuer %s"\
-                                            % (self.pretty_cert(), 
+                                            % (self.pretty_cert(),
                                                self.parent.pretty_cert(),
                                                self.get_issuer()))
 
@@ -811,7 +839,7 @@ class Certificate:
         return getattr(self,'filename',None)
 
     def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
+        print(self.dump_string(*args, **kwargs))
 
     def dump_string (self,show_extensions=False):
         result = ""