really fixed the redundant logging issue this time.
[sfa.git] / sfa / trust / credential.py
index 70544ff..3e1fbcc 100644 (file)
 # Credentials are signed XML files that assign a subject gid privileges to an object gid
 ##
 
-### $Id$
-### $URL$
-
 import os
 import datetime
-from xml.dom.minidom import Document, parseString
 from tempfile import mkstemp
-from sfa.trust.certificate import Keypair
-from sfa.trust.credential_legacy import CredentialLegacy
-from sfa.trust.rights import *
-from sfa.trust.gid import *
-from sfa.util.faults import *
+import dateutil.parser
+from StringIO import StringIO 
+from xml.dom.minidom import Document, parseString
+from lxml import etree
 
+from sfa.util.faults import *
 from sfa.util.sfalogging import logger
-from dateutil.parser import parse
-
-
+from sfa.trust.certificate import Keypair
+from sfa.trust.credential_legacy import CredentialLegacy
+from sfa.trust.rights import Right, Rights
+from sfa.trust.gid import GID
+from sfa.util.xrn import urn_to_hrn
 
-# Two years, in seconds 
-DEFAULT_CREDENTIAL_LIFETIME = 60 * 60 * 24 * 365 * 2
+# 2 weeks, in seconds 
+DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14
 
 
 # TODO:
@@ -83,7 +81,7 @@ signature_template = \
 # Convert a string into a bool
 
 def str2bool(str):
-    if str.lower() in ['yes','true','1']:
+    if str.lower() in ['true','1']:
         return True
     return False
 
@@ -171,6 +169,21 @@ class Signature(object):
 # not be changed else the signature is no longer valid.  So, once
 # you have loaded an existing signed credential, do not call encode() or sign() on it.
 
+def filter_creds_by_caller(creds, caller_hrn):
+        """
+        Returns a list of creds who's gid caller matches the
+        specified caller hrn
+        """
+        if not isinstance(creds, list): creds = [creds]
+        caller_creds = []
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                    caller_creds.append(cred)
+            except: pass
+        return caller_creds
+
 class Credential(object):
 
     ##
@@ -201,6 +214,7 @@ class Credential(object):
                 str = string
             elif filename:
                 str = file(filename).read()
+                self.filename=filename
                 
             if str.strip().startswith("-----"):
                 self.legacy = CredentialLegacy(False,string=str)
@@ -242,10 +256,9 @@ class Credential(object):
         self.gidObject = legacy.get_gid_object()
         lifetime = legacy.get_lifetime()
         if not lifetime:
-            # Default to two years
-            self.set_lifetime(DEFAULT_CREDENTIAL_LIFETIME)
+            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
         else:
-            self.set_lifetime(int(lifetime))
+            self.set_expiration(int(lifetime))
         self.lifeTime = legacy.get_lifetime()
         self.set_privileges(legacy.get_privileges())
         self.get_privileges().delegate_all_privileges(legacy.get_delegate())
@@ -300,43 +313,45 @@ class Credential(object):
             self.decode()
         return self.gidObject
 
+
+            
     ##
-    # set the lifetime of this credential
-    #
-    # @param lifetime lifetime of credential
-    # . if lifeTime is a datetime object, it is used for the expiration time
-    # . if lifeTime is an integer value, it is considered the number of seconds
-    #   remaining before expiration
-
-    def set_lifetime(self, lifeTime):
-        if isinstance(lifeTime, int):
-            self.expiration = datetime.timedelta(seconds=lifeTime) + datetime.datetime.utcnow()
+    # Expiration: an absolute UTC time of expiration (as either an int or datetime)
+    # 
+    def set_expiration(self, expiration):
+        if isinstance(expiration, int):
+            self.expiration = datetime.datetime.fromtimestamp(expiration)
         else:
-            self.expiration = lifeTime
+            self.expiration = expiration
+            
 
     ##
     # get the lifetime of the credential (in datetime format)
 
-    def get_lifetime(self):
+    def get_expiration(self):
         if not self.expiration:
             self.decode()
         return self.expiration
 
+    ##
+    # For legacy sake
+    def get_lifetime(self):
+        return self.get_expiration()
  
     ##
     # set the privileges
     #
-    # @param privs either a comma-separated list of privileges of a RightList object
+    # @param privs either a comma-separated list of privileges of a Rights object
 
     def set_privileges(self, privs):
         if isinstance(privs, str):
-            self.privileges = RightList(string = privs)
+            self.privileges = Rights(string = privs)
         else:
             self.privileges = privs
         
 
     ##
-    # return the privileges as a RightList object
+    # return the privileges as a Rights object
 
     def get_privileges(self):
         if not self.privileges:
@@ -384,7 +399,7 @@ class Credential(object):
         append_sub(doc, cred, "target_urn", self.gidObject.get_urn())
         append_sub(doc, cred, "uuid", "")
         if not self.expiration:
-            self.set_lifetime(DEFAULT_CREDENTIAL_LIFETIME)
+            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
         self.expiration = self.expiration.replace(microsecond=0)
         append_sub(doc, cred, "expires", self.expiration.isoformat())
         privileges = doc.createElement("privileges")
@@ -437,6 +452,7 @@ class Credential(object):
             f = open(filename, "w")
         f.write(self.xml)
         f.close()
+        self.filename=filename
 
     def save_to_string(self, save_parents=True):
         if not self.xml:
@@ -567,14 +583,14 @@ class Credential(object):
         
 
         self.set_refid(cred.getAttribute("xml:id"))
-        self.set_lifetime(parse(getTextNode(cred, "expires")))
+        self.set_expiration(dateutil.parser.parse(getTextNode(cred, "expires")))
         self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
         self.gidObject = GID(string=getTextNode(cred, "target_gid"))   
 
 
         # Process privileges
         privs = cred.getElementsByTagName("privileges")[0]
-        rlist = RightList()
+        rlist = Rights()
         for priv in privs.getElementsByTagName("privilege"):
             kind = getTextNode(priv, "name")
             deleg = str2bool(getTextNode(priv, "can_delegate"))
@@ -632,11 +648,24 @@ class Credential(object):
     #   must be done elsewhere
     #
     # @param trusted_certs: The certificates of trusted CA certificates
-    def verify(self, trusted_certs):
+    # @param schema: The RelaxNG schema to validate the credential against 
+    def verify(self, trusted_certs, schema=None):
         if not self.xml:
             self.decode()        
+        
+        # validate against RelaxNG schema
+        if not self.legacy:
+            if schema and os.path.exists(schema):
+                tree = etree.parse(StringIO(self.xml))
+                schema_doc = etree.parse(schema)
+                xmlschema = etree.XMLSchema(schema_doc)
+                if not xmlschema.validate(tree):
+                    error = xmlschema.error_log.last_error
+                    message = "%s (line %s)" % (error.message, error.line)
+                    raise CredentialNotVerifiable(message) 
+            
 
-#        trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
+#       trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
         trusted_cert_objects = []
         ok_trusted_certs = []
         for f in trusted_certs:
@@ -646,7 +675,7 @@ class Credential(object):
                 trusted_cert_objects.append(GID(filename=f))
                 ok_trusted_certs.append(f)
             except Exception, exc:
-                logger.error("Failed to load trusted cert from %s: %r", f, exc)
+                logger.error("Failed to load trusted cert from %s: %r"%( f, exc))
         trusted_certs = ok_trusted_certs
 
         # Use legacy verification if this is a legacy credential
@@ -657,9 +686,10 @@ class Credential(object):
             if self.legacy.object_gid:
                 self.legacy.object_gid.verify_chain(trusted_cert_objects)
             return True
+
         
         # make sure it is not expired
-        if self.get_lifetime() < datetime.datetime.utcnow():
+        if self.get_expiration() < datetime.datetime.utcnow():
             raise CredentialNotVerifiable("Credential expired at %s" % self.expiration.isoformat())
 
         # Verify the signatures
@@ -767,7 +797,7 @@ class Credential(object):
             raise CredentialNotVerifiable("Target gid not equal between parent and child")
 
         # make sure my expiry time is <= my parent's
-        if not parent_cred.get_lifetime() >= self.get_lifetime():
+        if not parent_cred.get_expiration() >= self.get_expiration():
             raise CredentialNotVerifiable("Delegated credential expires after parent")
 
         # make sure my signer is the parent's caller
@@ -780,58 +810,67 @@ class Credential(object):
             parent_cred.verify_parent(parent_cred.parent)
 
 
-    def delegate(self, delegee_gid, keyfile):
+    def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile):
         """
         Return a delegated copy of this credential, delegated to the 
         specified gid's user.    
         """
         # get the gid of the object we are delegating
         object_gid = self.get_gid_object()
-        object_hrn = self.get_hrn()        
+        object_hrn = object_gid.get_hrn()        
  
         # the hrn of the user who will be delegated to
-        if isinstance(delegee_gid, str):
-            delegee_gid = GID(string=records[0]['gid'])
+        delegee_gid = GID(filename=delegee_gidfile)
         delegee_hrn = delegee_gid.get_hrn()
-   
-        user_key = Keypair(filename=keyfile)
-        user_hrn = self.get_gid_caller().get_hrn()
+  
+        #user_key = Keypair(filename=keyfile)
+        #user_hrn = self.get_gid_caller().get_hrn()
         subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn)
         dcred = Credential(subject=subject_string)
         dcred.set_gid_caller(delegee_gid)
         dcred.set_gid_object(object_gid)
-        privs = self.get_privileges()
+        dcred.set_parent(self)
+        dcred.set_expiration(self.get_expiration())
         dcred.set_privileges(self.get_privileges())
         dcred.get_privileges().delegate_all_privileges(True)
-        dcred.set_pubkey(object_gid.get_pubkey())
-        dcred.set_issuer(user_key, user_hrn)
-        dcred.set_parent(self)
+        #dcred.set_issuer_keys(keyfile, delegee_gidfile)
+        dcred.set_issuer_keys(caller_keyfile, caller_gidfile)
         dcred.encode()
         dcred.sign()
 
         return dcred 
-    ##
-    # Dump the contents of a credential to stdout in human-readable format
-    #
-    # @param dump_parents If true, also dump the parent certificates
-
-    def dump(self, dump_parents=False):
-        print "CREDENTIAL", self.get_subject()
 
-        print "      privs:", self.get_privileges().save_to_string()
+    # only informative
+    def get_filename(self):
+        return getattr(self,'filename',None)
 
-        print "  gidCaller:"
+    # @param dump_parents If true, also dump the parent certificates
+    def dump (self, *args, **kwargs):
+        print self.dump_string(*args, **kwargs)
+
+    def dump_string(self, dump_parents=False):
+        result=""
+        result += "CREDENTIAL %s\n" % self.get_subject() 
+        filename=self.get_filename()
+        if filename: result += "Filename %s\n"%filename
+        result += "      privs: %s\n" % self.get_privileges().save_to_string()
         gidCaller = self.get_gid_caller()
         if gidCaller:
-            gidCaller.dump(8, dump_parents)
+            result += "  gidCaller:\n"
+            result += gidCaller.dump_string(8, dump_parents)
+
+        if self.get_signature():
+            print "  gidIssuer:"
+            self.get_signature().get_issuer_gid().dump(8, dump_parents)
 
-        print "  gidObject:"
         gidObject = self.get_gid_object()
         if gidObject:
-            gidObject.dump(8, dump_parents)
-
+            result += "  gidObject:\n"
+            result += gidObject.dump_string(8, dump_parents)
 
         if self.parent and dump_parents:
-            print "PARENT",
-            self.parent.dump_parents()
+            result += "\nPARENT"
+            result += self.parent.dump(True)
+
+        return result