Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / trust / auth.py
index 470dc5e..9cb905d 100644 (file)
@@ -6,6 +6,7 @@
 #
 
 
+from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
 from sfa.util.faults import *
@@ -13,8 +14,6 @@ from sfa.trust.hierarchy import Hierarchy
 from sfa.util.config import *
 from sfa.util.namespace import *
 from sfa.util.sfaticket import *
-from sfa.util.sfalogging import logger
-
 import sys
 
 class Auth:
@@ -37,6 +36,8 @@ class Auth:
         
     def checkCredentials(self, creds, operation, hrn = None):
         valid = []
+        if not isinstance(creds, list):
+            creds = [creds]
         for cred in creds:
             try:
                 self.check(cred, operation, hrn)
@@ -85,9 +86,10 @@ class Auth:
         # This check does not apply to trusted peers 
         trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list]
         if hrn and self.client_gid.get_hrn() not in trusted_peers:
-            if not hrn == self.object_gid.get_hrn():
+            target_hrn = self.object_gid.get_hrn()
+            if not hrn == target_hrn:
                 raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \
-                                       (self.object_gid.get_hrn(), hrn) )       
+                                       (target_hrn, hrn) )       
         return True
 
     def check_ticket(self, ticket):
@@ -299,3 +301,20 @@ class Auth:
 
     def get_authority(self, hrn):
         return get_authority(hrn)
+
+    def filter_creds_by_caller(self, creds, caller_hrn):
+        """
+        Returns a list of creds who's gid caller matches the 
+        specified caller hrn
+        """
+        if not isinstance(creds, list):
+            creds = [creds]
+        creds = []
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                    creds.append(cred)
+            except: pass
+        return creds
+