renamed sfaticket from util/ to trust/
[sfa.git] / sfa / trust / auth.py
index 50957ea..c3dbc46 100644 (file)
@@ -1,21 +1,18 @@
 #
 # SfaAPI authentication 
 #
-### $Id$
-### $URL$
-#
-
-import time
+import sys
 
+from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
-from sfa.trust.trustedroot import TrustedRootList
-from sfa.trust.rights import RightList
+from sfa.trust.trustedroots import TrustedRoots
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.config import *
-from sfa.util.namespace import *
-from sfa.trust.gid import GID
-from sfa.util.sfaticket import *
+from sfa.util.xrn import get_authority
+from sfa.trust.sfaticket import SfaTicket
+
+from sfa.util.sfalogging import logger
 
 class Auth:
     """
@@ -30,14 +27,38 @@ class Auth:
         self.load_trusted_certs()
 
     def load_trusted_certs(self):
-        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
+
+        
+        
+    def checkCredentials(self, creds, operation, hrn = None):
+        valid = []
+        if not isinstance(creds, list):
+            creds = [creds]
+        logger.debug("Auth.checkCredentials with %d creds"%len(creds))
+        for cred in creds:
+            try:
+                self.check(cred, operation, hrn)
+                valid.append(cred)
+            except:
+                cred_obj=Credential(string=cred)
+                logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
+                error = sys.exc_info()[:2]
+                continue
+            
+        if not len(valid):
+            raise InsufficientRights('Access denied: %s -- %s' % (error[0],error[1]))
+        
+        return valid
+        
         
     def check(self, cred, operation, hrn = None):
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
         HTTPS connection, check if the credential was signed by a 
-        trusted cert and check if the credential is allowd to perform 
+        trusted cert and check if the credential is allowed to perform 
         the specified operation.    
         """
         self.client_cred = Credential(string = cred)
@@ -58,21 +79,18 @@ class Auth:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
-            self.client_cred.verify_chain(self.trusted_cert_list)
-            if self.client_gid:
-                self.client_gid.verify_chain(self.trusted_cert_list)
-            if self.object_gid:
-                self.object_gid.verify_chain(self.trusted_cert_list)
+            self.client_cred.verify(self.trusted_cert_file_list, self.config.SFA_CREDENTIAL_SCHEMA)
         else:
            raise MissingTrustedRoots(self.config.get_trustedroots_dir())
-
+       
         # Make sure the credential's target matches the specified hrn. 
         # This check does not apply to trusted peers 
-        trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list.get_list()]
+        trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list]
         if hrn and self.client_gid.get_hrn() not in trusted_peers:
-            if not hrn == self.object_gid.get_hrn():
+            target_hrn = self.object_gid.get_hrn()
+            if not hrn == target_hrn:
                 raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \
-                                       (self.object_gid.get_hrn(), hrn) )       
+                                       (target_hrn, hrn) )       
         return True
 
     def check_ticket(self, ticket):
@@ -107,13 +125,7 @@ class Auth:
 
     def validateCred(self, cred):
         if self.trusted_cert_list:
-            cred.verify_chain(self.trusted_cert_list)
-            caller_gid = cred.get_gid_caller()
-            object_gid = cred.get_gid_object()
-            if caller_gid:
-                caller_gid.verify_chain(self.trusted_cert_list)
-            if object_gid:
-                object_gid.verify_chain(self.trusted_cert_list)
+            cred.verify(self.trusted_cert_file_list)
 
     def authenticateGid(self, gidStr, argList, requestHash=None):
         gid = GID(string = gidStr)
@@ -226,9 +238,10 @@ class Auth:
         verify_cancreate_credential()
         """
 
-        rl = RightList()
+        rl = Rights()
         type = record['type']
 
+
         if type=="slice":
             researchers = record.get("researcher", [])
             pis = record.get("PI", [])
@@ -243,11 +256,15 @@ class Auth:
             pis = record.get("PI", [])
             operators = record.get("operator", [])
             if (caller_hrn == self.config.SFA_INTERFACE_HRN):
-                rl.add("authority,sa,ma",)
+                rl.add("authority")
+                rl.add("sa")
+                rl.add("ma")
             if (caller_hrn in pis):
-                rl.add("authority,sa")
+                rl.add("authority")
+                rl.add("sa")
             if (caller_hrn in operators):
-                rl.add("authority,ma")
+                rl.add("authority")
+                rl.add("ma")
 
         elif type == "user":
             rl.add("refresh")
@@ -285,3 +302,22 @@ class Auth:
 
     def get_authority(self, hrn):
         return get_authority(hrn)
+
+    def filter_creds_by_caller(self, creds, caller_hrn_list):
+        """
+        Returns a list of creds who's gid caller matches the 
+        specified caller hrn
+        """
+        if not isinstance(creds, list):
+            creds = [creds]
+        creds = []
+        if not isinistance(caller_hrn_list, list):
+            caller_hrn_list = [caller_hrn_list]
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() in [caller_hrn_list]:
+                    creds.append(cred)
+            except: pass
+        return creds
+