check xrns arg at beginning of method
[sfa.git] / sfa / trust / auth.py
index 0c03279..4eedb09 100644 (file)
@@ -4,10 +4,11 @@
 import sys
 
 from sfa.util.faults import InsufficientRights, MissingCallerGID, MissingTrustedRoots, PermissionError, \
 import sys
 
 from sfa.util.faults import InsufficientRights, MissingCallerGID, MissingTrustedRoots, PermissionError, \
-    BadRequestHash, ConnectionKeyGIDMismatch, SfaPermissionDenied
+    BadRequestHash, ConnectionKeyGIDMismatch, SfaPermissionDenied, CredentialNotVerifiable, Forbidden, \
+    BadArgs
 from sfa.util.sfalogging import logger
 from sfa.util.config import Config
 from sfa.util.sfalogging import logger
 from sfa.util.config import Config
-from sfa.util.xrn import get_authority
+from sfa.util.xrn import Xrn, get_authority
 
 from sfa.trust.gid import GID
 from sfa.trust.rights import Rights
 
 from sfa.trust.gid import GID
 from sfa.trust.rights import Rights
@@ -34,30 +35,56 @@ class Auth:
         self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
         self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
 
         self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
         self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
 
+    def checkCredentials(self, creds, operation, xrns=[], check_sliver_callback=None):
+        # if xrns are specified they cannot be None or empty string
+        if xrns:
+            for xrn in xrns:
+                if not xrn:
+                    raise BadArgs("Invalid urn or hrn: %s" % hrn)
+
         
         
-        
-    def checkCredentials(self, creds, operation, hrn = None):
+        if not isinstance(xrns, list):
+            xrns = [xrns]
+
+        slice_xrns = Xrn.filter_type(xrns, 'slice')
+        sliver_xrns = Xrn.filter_type(xrns, 'sliver')
+
+        # we are not able to validate slivers in the traditional way so 
+        # we make sure not to include sliver urns/hrns in the core validation loop
+        hrns = [Xrn(xrn).hrn for xrn in xrns if xrn not in sliver_xrns] 
         valid = []
         if not isinstance(creds, list):
             creds = [creds]
         valid = []
         if not isinstance(creds, list):
             creds = [creds]
-        logger.debug("Auth.checkCredentials with %d creds"%len(creds))
+        logger.debug("Auth.checkCredentials with %d creds on hrns=%s"%(len(creds),hrns))
+        # won't work if either creds or hrns is empty - let's make it more explicit
+        if not creds: raise Forbidden("no credential provided")
+        if not hrns: hrns = [None]
         for cred in creds:
         for cred in creds:
-            try:
-                self.check(cred, operation, hrn)
-                valid.append(cred)
-            except:
-                cred_obj=Credential(string=cred)
-                logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
-                error = sys.exc_info()[:2]
-                continue
-            
+            for hrn in hrns:
+                try:
+                    self.check(cred, operation, hrn)
+                    valid.append(cred)
+                except:
+                    cred_obj=Credential(cred=cred)
+                    logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
+                    error = sys.exc_info()[:2]
+                    continue
+        
+        # make sure all sliver xrns are validated against the valid credentials
+        if sliver_xrns:
+            if not check_sliver_callback:
+                msg = "sliver verification callback method not found." 
+                msg += " Unable to validate sliver xrns: %s" % sliver_xrns
+                raise Forbidden(msg)
+            check_sliver_callback(valid, sliver_xrns)
+                
         if not len(valid):
         if not len(valid):
-            raise InsufficientRights('Access denied: %s -- %s' % (error[0],error[1]))
+            raise Forbidden("Invalid credential")
         
         return valid
         
         
         
         return valid
         
         
-    def check(self, cred, operation, hrn = None):
+    def check(self, credential, operation, hrn = None):
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
@@ -65,7 +92,13 @@ class Auth:
         trusted cert and check if the credential is allowed to perform 
         the specified operation.    
         """
         trusted cert and check if the credential is allowed to perform 
         the specified operation.    
         """
-        self.client_cred = Credential(string = cred)
+        cred = Credential(cred=credential)    
+        self.client_cred = cred
+        logger.debug("Auth.check: handling hrn=%s and credential=%s"%\
+                         (hrn,cred.get_summary_tostring()))
+
+        if cred.type not in ['geni_sfa']:
+            raise CredentialNotVerifiable(cred.type, "%s not supported" % cred.type)
         self.client_gid = self.client_cred.get_gid_caller()
         self.object_gid = self.client_cred.get_gid_object()
         
         self.client_gid = self.client_cred.get_gid_caller()
         self.object_gid = self.client_cred.get_gid_object()