changes proposed by Aaron Helsinger and GENI
authorThierry Parmentelat <thierry.parmentelat@inria.fr>
Thu, 26 Nov 2015 08:43:23 +0000 (09:43 +0100)
committerThierry Parmentelat <thierry.parmentelat@inria.fr>
Thu, 26 Nov 2015 08:43:23 +0000 (09:43 +0100)
sfa/trust/certificate.py
sfa/trust/credential.py
sfa/util/sfalogging.py

index c3f3790..d24c1a2 100644 (file)
@@ -151,7 +151,7 @@ class Keypair:
 
     def create(self):
         self.key = crypto.PKey()
-        self.key.generate_key(crypto.TYPE_RSA, 1024)
+        self.key.generate_key(crypto.TYPE_RSA, 2048)
 
     ##
     # Save the private key to a file
@@ -300,7 +300,7 @@ class Keypair:
 # whether to save the parent certificates as well.
 
 class Certificate:
-    digest = "md5"
+    digest = "sha256"
 
 #    x509 = None
 #    issuerKey = None
@@ -367,6 +367,10 @@ class Certificate:
         # if it is a chain of multiple certs, then split off the first one and
         # load it (support for the ---parent--- tag as well as normal chained certs)
 
+        if string is None or string.strip() == "":
+            logger.warn("Empty string in load_from_string")
+            return
+
         string = string.strip()
         
         # If it's not in proper PEM format, wrap it
@@ -391,6 +395,9 @@ class Certificate:
 
         self.x509 = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
 
+        if self.x509 is None:
+            logger.warn("Loaded from string but cert is None: %s" % string)
+
         # if there are more certs, then create a parent and let the parent load
         # itself from the remainder of the string
         if len(parts) > 1 and parts[1] != '':
@@ -412,6 +419,9 @@ class Certificate:
     # @param save_parents If save_parents==True, then also save the parent certificates.
 
     def save_to_string(self, save_parents=True):
+        if self.x509 is None:
+            logger.warn("None cert in certificate.save_to_string")
+            return ""
         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.x509)
         if save_parents and self.parent:
             string = string + self.parent.save_to_string(save_parents)
@@ -498,6 +508,7 @@ class Certificate:
     ##
     # Get a pretty-print subject name of the certificate
     # let's try to make this a little more usable as is makes logs hairy
+    # FIXME: Consider adding 'urn:publicid' and 'uuid' back for GENI?
     pretty_fields = ['email']
     def filter_chunk(self, chunk):
         for field in self.pretty_fields:
@@ -601,8 +612,19 @@ class Certificate:
 
     def get_extension(self, name):
 
+        if name is None:
+            return None
+
+        certstr = self.save_to_string()
+        if certstr is None or certstr == "":
+            return None
         # pyOpenSSL does not have a way to get extensions
-        m2x509 = X509.load_cert_string(self.save_to_string())
+        m2x509 = X509.load_cert_string(certstr)
+        if m2x509 is None:
+            logger.warn("No cert loaded in get_extension")
+            return None
+        if m2x509.get_ext(name) is None:
+            return None
         value = m2x509.get_ext(name).get_value()
 
         return value
@@ -734,7 +756,7 @@ class Certificate:
             if self.is_signed_by_cert(trusted_cert):
                 # verify expiration of trusted_cert ?
                 if not trusted_cert.x509.has_expired():
-                    if debug_verify_chain: 
+                    if debug_verify_chain:
                         logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%(
                             self.pretty_cert(), trusted_cert.pretty_cert()))
                     return trusted_cert
@@ -757,11 +779,11 @@ class Certificate:
         if not self.is_signed_by_cert(self.parent):
             if debug_verify_chain:
                 logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%\
-                             (self.pretty_cert(), 
-                              self.parent.pretty_cert(), 
+                             (self.pretty_cert(),
+                              self.parent.pretty_cert(),
                               self.get_issuer()))
             raise CertNotSignedByParent("%s: Parent %s, issuer %s"\
-                                            % (self.pretty_cert(), 
+                                            % (self.pretty_cert(),
                                                self.parent.pretty_cert(),
                                                self.get_issuer()))
 
index 66401f8..6ef5e30 100644 (file)
@@ -51,7 +51,7 @@ from sfa.trust.gid import GID
 from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn
 
 # 31 days, in seconds 
-DEFAULT_CREDENTIAL_LIFETIME = 86400 * 28
+DEFAULT_CREDENTIAL_LIFETIME = 86400 * 31
 
 
 # TODO:
@@ -187,11 +187,32 @@ class Signature(object):
             logger.log_exc ("Failed to parse credential, %s"%self.xml)
             raise
         sig = doc.getElementsByTagName("Signature")[0]
-        self.set_refid(sig.getAttribute("xml:id").strip("Sig_"))
-        keyinfo = sig.getElementsByTagName("X509Data")[0]
-        szgid = getTextNode(keyinfo, "X509Certificate")
-        szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid
-        self.set_issuer_gid(GID(string=szgid))        
+        ## This code until the end of function rewritten by Aaron Helsinger
+        ref_id = sig.getAttribute("xml:id").strip().strip("Sig_")
+        # The xml:id tag is optional, and could be in a 
+        # Reference xml:id or Reference UID sub element instead
+        if not ref_id or ref_id == '':
+            reference = sig.getElementsByTagName('Reference')[0]
+            ref_id = reference.getAttribute('xml:id').strip().strip('Sig_')
+            if not ref_id or ref_id == '':
+                ref_id = reference.getAttribute('URI').strip().strip('#')
+        self.set_refid(ref_id)
+        keyinfos = sig.getElementsByTagName("X509Data")
+        gids = None
+        for keyinfo in keyinfos:
+            certs = keyinfo.getElementsByTagName("X509Certificate")
+            for cert in certs:
+                if len(cert.childNodes) > 0:
+                    szgid = cert.childNodes[0].nodeValue
+                    szgid = szgid.strip()
+                    szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid
+                    if gids is None:
+                        gids = szgid
+                    else:
+                        gids += "\n" + szgid
+        if gids is None:
+            raise CredentialNotVerifiable("Malformed XML: No certificate found in signature")
+        self.set_issuer_gid(GID(string=gids))
         
     def encode(self):
         self.xml = signature_template % (self.get_refid(), self.get_refid())
@@ -221,6 +242,8 @@ def filter_creds_by_caller(creds, caller_hrn_list):
         for cred in creds:
             try:
                 tmp_cred = Credential(string=cred)
+                if tmp_cred.get_cred_type() != Credential.SFA_CREDENTIAL_TYPE:
+                    continue
                 if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list:
                     caller_creds.append(cred)
             except: pass
@@ -228,6 +251,8 @@ def filter_creds_by_caller(creds, caller_hrn_list):
 
 class Credential(object):
 
+    SFA_CREDENTIAL_TYPE = "geni_sfa"
+
     ##
     # Create a Credential object
     #
@@ -248,19 +273,18 @@ class Credential(object):
         self.signature = None
         self.xml = None
         self.refid = None
-        self.type = None
+        self.cred_type = Credential.SFA_CREDENTIAL_TYPE
         self.version = None
 
         if cred:
             if isinstance(cred, StringTypes):
                 string = cred
-                self.type = 'geni_sfa'
-                self.version = '1.0'
+                self.cred_type = Credential.SFA_CREDENTIAL_TYPE
+                self.version = '3'
             elif isinstance(cred, dict):
                 string = cred['geni_value']
-                self.type = cred['geni_type']
+                self.cred_type = cred['geni_type']
                 self.version = cred['geni_version']
-                
 
         if string or filename:
             if string:                
@@ -283,6 +307,16 @@ class Credential(object):
             if os.path.isfile(path + '/' + 'xmlsec1'):
                 self.xmlsec_path = path + '/' + 'xmlsec1'
                 break
+        if not self.xmlsec_path:
+            logger.warn("Could not locate binary for xmlsec1 - SFA will be unable to sign stuff !!")
+
+    def get_cred_type(self): 
+        return self.cred_type
+
+    def get_subject(self):
+        if not self.gidObject:
+            self.decode()
+        return self.gidObject.get_subject()
 
     def pretty_subject(self):
         subject = ""
@@ -434,12 +468,13 @@ class Credential(object):
         # cause those schemas are identical.
         # Also note these PG schemas talk about PG tickets and CM policies.
         signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")
+        # FIXME: See v2 schema at www.geni.net/resources/credential/2/credential.xsd
         signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd")
         signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd")
 
         # PG says for those last 2:
-       #signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")
-       # signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")
+        #        signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")
+        #        signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")
 
         doc.appendChild(signed_cred)  
         
@@ -458,6 +493,9 @@ class Credential(object):
             logger.debug("Creating credential valid for %s s"%DEFAULT_CREDENTIAL_LIFETIME)
             self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
         self.expiration = self.expiration.replace(microsecond=0)
+        if self.expiration.tzinfo is not None and self.expiration.tzinfo.utcoffset(self.expiration) is not None:
+            # TZ aware. Make sure it is UTC - by Aaron Helsinger
+            self.expiration = self.expiration.astimezone(tz.tzutc())
         append_sub(doc, cred, "expires", self.expiration.strftime(SFATIME_FORMAT))
         privileges = doc.createElement("privileges")
         cred.appendChild(privileges)
@@ -536,7 +574,7 @@ class Credential(object):
                 signatures.appendChild(ele)
                 
         # Get the finished product
-        self.xml = doc.toxml()
+        self.xml = doc.toxml("utf-8")
 
 
     def save_to_random_tmp_file(self):       
@@ -615,7 +653,11 @@ class Credential(object):
     # you have loaded an existing signed credential, do not call encode() or sign() on it.
 
     def sign(self):
-        if not self.issuer_privkey or not self.issuer_gid:
+        if not self.issuer_privkey:
+            logger.warn("Cannot sign credential (no private key)")
+            return
+        if not self.issuer_gid:
+            logger.warn("Cannot sign credential (no issuer gid)")
             return
         doc = parseString(self.get_xml())
         sigs = doc.getElementsByTagName("signatures")[0]
@@ -627,7 +669,7 @@ class Credential(object):
         sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)
         sigs.appendChild(sig_ele)
 
-        self.xml = doc.toxml()
+        self.xml = doc.toxml("utf-8")
 
 
         # Split the issuer GID into multiple certificates if it's a chain
@@ -644,8 +686,10 @@ class Credential(object):
         # Call out to xmlsec1 to sign it
         ref = 'Sig_%s' % self.get_refid()
         filename = self.save_to_random_tmp_file()
-        signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \
-                 % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read()
+        command='%s --sign --node-id "%s" --privkey-pem %s,%s %s' \
+            % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)
+#        print 'command',command
+        signed = os.popen(command).read()
         os.remove(filename)
 
         for gid_file in gid_files:
@@ -656,7 +700,7 @@ class Credential(object):
         # Update signatures
         self.decode()       
 
-        
+
     ##
     # Retrieve the attributes of the credential from the XML.
     # This is automatically called by the various get_* methods of
@@ -694,25 +738,28 @@ class Credential(object):
         self.set_refid(cred.getAttribute("xml:id"))
         self.set_expiration(utcparse(getTextNode(cred, "expires")))
         self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
-        self.gidObject = GID(string=getTextNode(cred, "target_gid"))   
+        self.gidObject = GID(string=getTextNode(cred, "target_gid"))
 
 
+        ## This code until the end of function rewritten by Aaron Helsinger
         # Process privileges
-        privs = cred.getElementsByTagName("privileges")[0]
         rlist = Rights()
-        for priv in privs.getElementsByTagName("privilege"):
-            kind = getTextNode(priv, "name")
-            deleg = str2bool(getTextNode(priv, "can_delegate"))
-            if kind == '*':
-                # Convert * into the default privileges for the credential's type
-                # Each inherits the delegatability from the * above
-                _ , type = urn_to_hrn(self.gidObject.get_urn())
-                rl = determine_rights(type, self.gidObject.get_urn())
-                for r in rl.rights:
-                    r.delegate = deleg
-                    rlist.add(r)
-            else:
-                rlist.add(Right(kind.strip(), deleg))
+        priv_nodes = cred.getElementsByTagName("privileges")
+        if len(priv_nodes) > 0:
+            privs = priv_nodes[0]
+            for priv in privs.getElementsByTagName("privilege"):
+                kind = getTextNode(priv, "name")
+                deleg = str2bool(getTextNode(priv, "can_delegate"))
+                if kind == '*':
+                    # Convert * into the default privileges for the credential's type
+                    # Each inherits the delegatability from the * above
+                    _ , type = urn_to_hrn(self.gidObject.get_urn())
+                    rl = determine_rights(type, self.gidObject.get_urn())
+                    for r in rl.rights:
+                        r.delegate = deleg
+                        rlist.add(r)
+                else:
+                    rlist.add(Right(kind.strip(), deleg))
         self.set_privileges(rlist)
 
 
@@ -720,13 +767,15 @@ class Credential(object):
         parent = cred.getElementsByTagName("parent")
         if len(parent) > 0:
             parent_doc = parent[0].getElementsByTagName("credential")[0]
-            parent_xml = parent_doc.toxml()
+            parent_xml = parent_doc.toxml("utf-8")
+            if parent_xml is None or parent_xml.strip() == "":
+                raise CredentialNotVerifiable("Malformed XML: Had parent tag but it is empty")
             self.parent = Credential(string=parent_xml)
             self.updateRefID()
 
         # Assign the signatures to the credentials
         for sig in sigs:
-            Sig = Signature(string=sig.toxml())
+            Sig = Signature(string=sig.toxml("utf-8"))
 
             for cur_cred in self.get_credential_list():
                 if cur_cred.get_refid() == Sig.get_refid():
@@ -891,6 +940,10 @@ class Credential(object):
     def verify_issuer(self, trusted_gids):
         root_cred = self.get_credential_list()[-1]
         root_target_gid = root_cred.get_gid_object()
+        if root_cred.get_signature() is None:
+            # malformed
+            raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred has no signature" % (self.gidCaller.get_urn(), self.gidObject.get_urn()))
+
         root_cred_signer = root_cred.get_signature().get_issuer_gid()
 
         # Case 1:
@@ -1051,13 +1104,13 @@ class Credential(object):
     # only informative
     def get_filename(self):
         return getattr(self,'filename',None)
-    
+
     def actual_caller_hrn (self):
         """a helper method used by some API calls like e.g. Allocate
         to try and find out who really is the original caller
-        
+
         This admittedly is a bit of a hack, please USE IN LAST RESORT
-        
+
         This code uses a heuristic to identify a delegated credential
 
         A first known restriction if for traffic that gets through a slice manager
@@ -1069,7 +1122,7 @@ class Credential(object):
         subject_hrn = self.get_gid_object().get_hrn()
         # if we find that the caller_hrn is an immediate descendant of the issuer, then
         # this seems to be a 'regular' credential
-        if caller_hrn.startswith(issuer_hrn): 
+        if caller_hrn.startswith(issuer_hrn):
             actual_caller_hrn=caller_hrn
         # else this looks like a delegated credential, and the real caller is the issuer
         else:
@@ -1077,7 +1130,7 @@ class Credential(object):
         logger.info("actual_caller_hrn: caller_hrn=%s, issuer_hrn=%s, returning %s"
                     %(caller_hrn,issuer_hrn,actual_caller_hrn))
         return actual_caller_hrn
-            
+
     ##
     # Dump the contents of a credential to stdout in human-readable format
     #
@@ -1085,8 +1138,8 @@ class Credential(object):
     def dump (self, *args, **kwargs):
         print self.dump_string(*args, **kwargs)
 
-    # show_xml is ignored
-    def dump_string(self, dump_parents=False, show_xml=None):
+    # SFA code ignores show_xml and disables printing the cred xml
+    def dump_string(self, dump_parents=False, show_xml=False):
         result=""
         result += "CREDENTIAL %s\n" % self.pretty_subject()
         filename=self.get_filename()
@@ -1095,18 +1148,18 @@ class Credential(object):
         if privileges:
             result += "      privs: %s\n" % privileges.save_to_string()
         else:
-            result += "      privs: \n" 
+            result += "      privs: \n"
         gidCaller = self.get_gid_caller()
         if gidCaller:
             result += "  gidCaller:\n"
             result += gidCaller.dump_string(8, dump_parents)
 
         if self.get_signature():
-            print "  gidIssuer:"
-            self.get_signature().get_issuer_gid().dump(8, dump_parents)
+            result += "  gidIssuer:\n"
+            result += self.get_signature().get_issuer_gid().dump_string(8, dump_parents)
 
         if self.expiration:
-            print "  expiration:", self.expiration.strftime(SFATIME_FORMAT)
+            result += "  expiration: " + self.expiration.strftime(SFATIME_FORMAT) + "\n"
 
         gidObject = self.get_gid_object()
         if gidObject:
@@ -1117,4 +1170,16 @@ class Credential(object):
             result += "\nPARENT"
             result += self.parent.dump_string(True)
 
+        if show_xml and HAVELXML:
+            try:
+                tree = etree.parse(StringIO(self.xml))
+                aside = etree.tostring(tree, pretty_print=True)
+                result += "\nXML:\n\n"
+                result += aside
+                result += "\nEnd XML\n"
+            except:
+                import traceback
+                print "exc. Credential.dump_string / XML"
+                traceback.print_exc()
+
         return result
index 61d76a6..361a243 100644 (file)
@@ -48,9 +48,15 @@ class _SfaLogger:
         try:
             handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5) 
         except IOError:
-            # This is usually a permissions error becaue the file is
+            # This is usually a permissions error because the file is
             # owned by root, but httpd is trying to access it.
-            tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+            tmplogfile=os.path.join(os.getenv("TMPDIR", os.getenv("TMP", os.path.normpath("/tmp"))), os.path.basename(logfile))
+            tmplogfile = os.path.normpath(tmplogfile)
+
+            tmpdir = os.path.dirname(tmplogfile)
+            if tmpdir and tmpdir != "" and not os.path.exists(tmpdir):
+                os.makedirs(tmpdir)
+
             # In strange uses, 2 users on same machine might use same code,
             # meaning they would clobber each others files
             # We could (a) rename the tmplogfile, or (b)