simple_ssl_context() is now a helper exposed in module sfa.util.ssl
[sfa.git] / sfa / trust / certificate.py
index 936352f..a0e3a70 100644 (file)
@@ -45,7 +45,7 @@
 #
 
 
-from __future__ import print_function
+
 
 import functools
 import os
@@ -57,14 +57,12 @@ import OpenSSL
 # M2Crypto is imported on the fly to minimize crashes
 # import M2Crypto
 
-from sfa.util.py23 import PY3
-
 from sfa.util.faults import (CertExpired, CertMissingParent,
                              CertNotSignedByParent)
 from sfa.util.sfalogging import logger
 
 # this tends to generate quite some logs for little or no value
-debug_verify_chain = False
+debug_verify_chain = True
 
 glo_passphrase_callback = None
 
@@ -117,7 +115,7 @@ def convert_public_key(key):
 
     (ssh_f, ssh_fn) = tempfile.mkstemp()
     ssl_fn = tempfile.mktemp()
-    os.write(ssh_f, key)
+    os.write(ssh_f, key.encode())
     os.close(ssh_f)
 
     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
@@ -182,7 +180,8 @@ class Keypair:
     # @param filename name of file to store the keypair in
 
     def save_to_file(self, filename):
-        open(filename, 'w').write(self.as_pem())
+        with open(filename, 'wb') as output:
+            output.write(self.as_pem())
         self.filename = filename
 
     ##
@@ -205,12 +204,13 @@ class Keypair:
                 OpenSSL.crypto.FILETYPE_PEM, string,
                 functools.partial(glo_passphrase_callback, self, string))
             self.m2key = M2Crypto.EVP.load_key_string(
-                string, functools.partial(glo_passphrase_callback,
-                                          self, string))
+                string.encode(encoding="utf-8"),
+                functools.partial(glo_passphrase_callback, self, string))
         else:
             self.key = OpenSSL.crypto.load_privatekey(
                 OpenSSL.crypto.FILETYPE_PEM, string)
-            self.m2key = M2Crypto.EVP.load_key_string(string)
+            self.m2key = M2Crypto.EVP.load_key_string(
+                string.encode(encoding="utf-8"))
 
     ##
     #  Load the public key from a string. No private key is loaded.
@@ -357,8 +357,8 @@ class Certificate:
     # @param create If create==True, then also create a blank X509 certificate.
     # @param subject If subject!=None, then create a blank certificate and set
     #     it's subject name.
-    # @param string If string!=None, load the certficate from the string.
-    # @param filename If filename!=None, load the certficiate from the file.
+    # @param string If string!=None, load the certificate from the string.
+    # @param filename If filename!=None, load the certificate from the file.
     # @param isCA If !=None, set whether this cert is for a CA
 
     def __init__(self, lifeDays=1825, create=False, subject=None, string=None,
@@ -410,7 +410,7 @@ class Certificate:
         # certs)
 
         if string is None or string.strip() == "":
-            logger.warn("Empty string in load_from_string")
+            logger.warning("Empty string in load_from_string")
             return
 
         string = string.strip()
@@ -441,7 +441,7 @@ class Certificate:
             OpenSSL.crypto.FILETYPE_PEM, parts[0])
 
         if self.x509 is None:
-            logger.warn(
+            logger.warning(
                 "Loaded from string but cert is None: {}".format(string))
 
         # if there are more certs, then create a parent and let the parent load
@@ -467,11 +467,11 @@ class Certificate:
 
     def save_to_string(self, save_parents=True):
         if self.x509 is None:
-            logger.warn("None cert in certificate.save_to_string")
+            logger.warning("None cert in certificate.save_to_string")
             return ""
         string = OpenSSL.crypto.dump_certificate(
             OpenSSL.crypto.FILETYPE_PEM, self.x509)
-        if PY3 and isinstance(string, bytes):
+        if isinstance(string, bytes):
             string = string.decode()
         if save_parents and self.parent:
             string = string + self.parent.save_to_string(save_parents)
@@ -488,7 +488,7 @@ class Certificate:
             f = filep
         else:
             f = open(filename, 'w')
-        if PY3 and isinstance(string, bytes):
+        if isinstance(string, bytes):
             string = string.decode()
         f.write(string)
         f.close()
@@ -520,7 +520,7 @@ class Certificate:
                 req = OpenSSL.crypto.X509Req()
                 reqSubject = req.get_subject()
                 if isinstance(subject, dict):
-                    for key in reqSubject.keys():
+                    for key in list(reqSubject.keys()):
                         setattr(reqSubject, key, subject[key])
                 else:
                     setattr(reqSubject, "CN", subject)
@@ -547,7 +547,7 @@ class Certificate:
         req = OpenSSL.crypto.X509Req()
         subj = req.get_subject()
         if isinstance(name, dict):
-            for key in name.keys():
+            for key in list(name.keys()):
                 setattr(subj, key, name[key])
         else:
             setattr(subj, "CN", name)
@@ -583,7 +583,6 @@ class Certificate:
         data = self.get_data(field='subjectAltName')
         if data:
             message += " SubjectAltName:"
-            counter = 0
             filtered = [self.filter_chunk(chunk) for chunk in data.split()]
             message += " ".join([f for f in filtered if f])
             omitted = len([f for f in filtered if not f])
@@ -676,6 +675,11 @@ class Certificate:
 #            raise "Cannot add extension {} which had val {} with new val {}"\
 #                  .format(name, oldExtVal, value)
 
+        if isinstance(name, str):
+            name = name.encode()
+        if isinstance(value, str):
+            value = value.encode()
+
         ext = OpenSSL.crypto.X509Extension(name, critical, value)
         self.x509.add_extensions([ext])
 
@@ -694,7 +698,7 @@ class Certificate:
         # pyOpenSSL does not have a way to get extensions
         m2x509 = M2Crypto.X509.load_cert_string(certstr)
         if m2x509 is None:
-            logger.warn("No cert loaded in get_extension")
+            logger.warning("No cert loaded in get_extension")
             return None
         if m2x509.get_ext(name) is None:
             return None
@@ -714,7 +718,9 @@ class Certificate:
         if field in self.data:
             raise Exception("Cannot set {} more than once".format(field))
         self.data[field] = string
-        self.add_extension(field, 0, string)
+        # call str() because we've seen unicode there
+        # and the underlying C code doesn't like it
+        self.add_extension(field, 0, str(string))
 
     ##
     # Return the data string that was previously set with set_data
@@ -789,13 +795,15 @@ class Certificate:
     # @param cert certificate object
 
     def is_signed_by_cert(self, cert):
-        logger.debug("Certificate.is_signed_by_cert -> invoking verify")
-        k = cert.get_pubkey()
-        result = self.verify(k)
+        key = cert.get_pubkey()
+        logger.debug("Certificate.is_signed_by_cert -> verify on {}\n"
+                     "with pubkey {}"
+                     .format(self, key))
+        result = self.verify(key)
         return result
 
     ##
-    # Set the parent certficiate.
+    # Set the parent certificate.
     #
     # @param p certificate object.
 
@@ -834,7 +842,6 @@ class Certificate:
         # the public key contained in it's parent. The chain is recursed
         # until a certificate is found that is signed by a trusted root.
 
-        logger.debug("Certificate.verify_chain {}".format(self.pretty_name()))
         # verify expiration time
         if self.x509.has_expired():
             if debug_verify_chain:
@@ -844,7 +851,8 @@ class Certificate:
 
         # if this cert is signed by a trusted_cert, then we are set
         for i, trusted_cert in enumerate(trusted_certs, 1):
-            logger.debug("Certificate.verify_chain - trying trusted #{} : {}"
+            logger.debug(5*'-' +
+                         " Certificate.verify_chain - trying trusted #{} : {}"
                          .format(i, trusted_cert.pretty_name()))
             if self.is_signed_by_cert(trusted_cert):
                 # verify expiration of trusted_cert ?
@@ -867,7 +875,7 @@ class Certificate:
                                               trusted_cert.pretty_name()))
             else:
                 logger.debug("verify_chain: not a direct"
-                             " descendant of a trusted root")
+                             " descendant of trusted root #{}".format(i))
 
         # if there is no parent, then no way to verify the chain
         if not self.parent:
@@ -903,8 +911,8 @@ class Certificate:
         # extension and hope there are no other basicConstraints
         if not self.parent.isCA and not (
                 self.parent.get_extension('basicConstraints') == 'CA:TRUE'):
-            logger.warn("verify_chain: cert {}'s parent {} is not a CA"
-                        .format(self.pretty_name(), self.parent.pretty_name()))
+            logger.warning("verify_chain: cert {}'s parent {} is not a CA"
+                            .format(self.pretty_name(), self.parent.pretty_name()))
             raise CertNotSignedByParent("{}: Parent {} not a CA"
                                         .format(self.pretty_name(),
                                                 self.parent.pretty_name()))
@@ -933,7 +941,7 @@ class Certificate:
         return triples
 
     def get_data_names(self):
-        return self.data.keys()
+        return list(self.data.keys())
 
     def get_all_datas(self):
         triples = self.get_extensions()