do not depend on types.StringTypes anymore
[sfa.git] / sfa / trust / auth.py
index 2e9ada0..512c58b 100644 (file)
@@ -3,10 +3,12 @@
 #
 import sys
 
 #
 import sys
 
-from sfa.util.faults import InsufficientRights, MissingCallerGID, MissingTrustedRoots, PermissionError, \
-    BadRequestHash, ConnectionKeyGIDMismatch, SfaPermissionDenied, CredentialNotVerifiable, Forbidden, \
-    BadArgs
+from sfa.util.faults import InsufficientRights, MissingCallerGID, \
+    MissingTrustedRoots, PermissionError, BadRequestHash, \
+    ConnectionKeyGIDMismatch, SfaPermissionDenied, CredentialNotVerifiable, \
+    Forbidden, BadArgs
 from sfa.util.sfalogging import logger
 from sfa.util.sfalogging import logger
+from sfa.util.py23 import StringType
 from sfa.util.config import Config
 from sfa.util.xrn import Xrn, get_authority
 
 from sfa.util.config import Config
 from sfa.util.xrn import Xrn, get_authority
 
@@ -33,18 +35,43 @@ class Auth:
         self.load_trusted_certs()
 
     def load_trusted_certs(self):
         self.load_trusted_certs()
 
     def load_trusted_certs(self):
-        self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
-        self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
+        self.trusted_cert_list = \
+            TrustedRoots(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = \
+            TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
+
+    # this convenience methods extracts speaking_for_xrn
+    # from the passed options using 'geni_speaking_for'
+    def checkCredentialsSpeaksFor (self, *args, **kwds):
+        if 'options' not in kwds:
+            logger.error ("checkCredentialsSpeaksFor was not passed options=options")
+            return
+        # remove the options arg
+        options = kwds['options']; del kwds['options']
+        # compute the speaking_for_xrn arg and pass it to checkCredentials
+        if options is None: speaking_for_xrn = None
+        else:               speaking_for_xrn = options.get('geni_speaking_for', None)
+        kwds['speaking_for_xrn'] = speaking_for_xrn
+        return self.checkCredentials(*args, **kwds)
 
     # do not use mutable as default argument 
     # http://docs.python-guide.org/en/latest/writing/gotchas/#mutable-default-arguments
     def checkCredentials(self, creds, operation, xrns=None, 
 
     # do not use mutable as default argument 
     # http://docs.python-guide.org/en/latest/writing/gotchas/#mutable-default-arguments
     def checkCredentials(self, creds, operation, xrns=None, 
-                         check_sliver_callback=None, speaking_for_hrn=None):
-        if xrns is None: xrns=[]
+                         check_sliver_callback=None, 
+                         speaking_for_xrn=None):
+        if xrns is None: xrns = []
+        error = (None, None)
         def log_invalid_cred(cred):
         def log_invalid_cred(cred):
-            cred_obj=Credential(string=cred)
-            logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
-            error = sys.exc_info()[:2]
+            if not isinstance (cred, StringType):
+                logger.info("cannot validate credential %s - expecting a string"%cred)
+                error = ('TypeMismatch',
+                         "checkCredentials: expected a string, received {} -- {}"
+                         .format(type(cred), cred))
+            else:
+                cred_obj = Credential(string=cred)
+                logger.info("failed to validate credential - dump=%s"%\
+                            cred_obj.dump_string(dump_parents=True))
+                error = sys.exc_info()[:2]
             return error
 
         # if xrns are specified they cannot be None or empty string
             return error
 
         # if xrns are specified they cannot be None or empty string
@@ -57,7 +84,7 @@ class Auth:
         if not isinstance(xrns, list):
             xrns = [xrns]
 
         if not isinstance(xrns, list):
             xrns = [xrns]
 
-        slice_xrns = Xrn.filter_type(xrns, 'slice')
+        slice_xrns  = Xrn.filter_type(xrns, 'slice')
         sliver_xrns = Xrn.filter_type(xrns, 'sliver')
 
         # we are not able to validate slivers in the traditional way so 
         sliver_xrns = Xrn.filter_type(xrns, 'sliver')
 
         # we are not able to validate slivers in the traditional way so 
@@ -70,12 +97,9 @@ class Auth:
         # won't work if either creds or hrns is empty - let's make it more explicit
         if not creds: raise Forbidden("no credential provided")
         if not hrns: hrns = [None]
         # won't work if either creds or hrns is empty - let's make it more explicit
         if not creds: raise Forbidden("no credential provided")
         if not hrns: hrns = [None]
-        error=[None,None]
 
 
-        # if speaks for gid matches caller cert then we've found a valid
-        # speaks for credential
-        speaks_for_gid = determine_speaks_for(logger, creds, self.peer_cert, \
-                                              options, self.trusted_cert_list)
+        speaks_for_gid = determine_speaks_for(logger, creds, self.peer_cert,
+                                              speaking_for_xrn, self.trusted_cert_list)
 
         if self.peer_cert and \
            not self.peer_cert.is_pubkey(speaks_for_gid.get_pubkey()):
 
         if self.peer_cert and \
            not self.peer_cert.is_pubkey(speaks_for_gid.get_pubkey()):
@@ -100,15 +124,11 @@ class Auth:
         if not len(valid):
             raise Forbidden("Invalid credential %s -- %s"%(error[0],error[1]))
         
         if not len(valid):
             raise Forbidden("Invalid credential %s -- %s"%(error[0],error[1]))
         
-        if speaking_for_hrn and not speaks_for_cred:
-            raise InsufficientRights('Access denied: "geni_speaking_for" option specified but no valid speaks for credential found: %s -- %s' % (error[0],error[1]))
-        
         return valid
         
         return valid
         
-        
     def check(self, credential, operation, hrn = None):
         """
     def check(self, credential, operation, hrn = None):
         """
-        Check the credential against the peer cert (callerGID included 
+        Check the credential against the peer cert (callerGID) included 
         in the credential matches the caller that is connected to the 
         HTTPS connection, check if the credential was signed by a 
         trusted cert and check if the credential is allowed to perform 
         in the credential matches the caller that is connected to the 
         HTTPS connection, check if the credential was signed by a 
         trusted cert and check if the credential is allowed to perform 
@@ -117,7 +137,7 @@ class Auth:
         cred = Credential(cred=credential)    
         self.client_cred = cred
         logger.debug("Auth.check: handling hrn=%s and credential=%s"%\
         cred = Credential(cred=credential)    
         self.client_cred = cred
         logger.debug("Auth.check: handling hrn=%s and credential=%s"%\
-                         (hrn,cred.get_summary_tostring()))
+                         (hrn,cred.pretty_cred()))
 
         if cred.type not in ['geni_sfa']:
             raise CredentialNotVerifiable(cred.type, "%s not supported" % cred.type)
 
         if cred.type not in ['geni_sfa']:
             raise CredentialNotVerifiable(cred.type, "%s not supported" % cred.type)
@@ -126,7 +146,7 @@ class Auth:
         
         # make sure the client_gid is not blank
         if not self.client_gid:
         
         # make sure the client_gid is not blank
         if not self.client_gid:
-            raise MissingCallerGID(self.client_cred.get_subject())
+            raise MissingCallerGID(self.client_cred.pretty_subject())
        
         # validate the client cert if it exists
         if self.peer_cert:
        
         # validate the client cert if it exists
         if self.peer_cert:
@@ -138,7 +158,8 @@ class Auth:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
-            self.client_cred.verify(self.trusted_cert_file_list, self.config.SFA_CREDENTIAL_SCHEMA)
+            self.client_cred.verify(self.trusted_cert_file_list,
+                                    self.config.SFA_CREDENTIAL_SCHEMA)
         else:
            raise MissingTrustedRoots(self.config.get_trustedroots_dir())
        
         else:
            raise MissingTrustedRoots(self.config.get_trustedroots_dir())
        
@@ -154,7 +175,7 @@ class Auth:
 
     def check_ticket(self, ticket):
         """
 
     def check_ticket(self, ticket):
         """
-        Check if the tickt was signed by a trusted cert
+        Check if the ticket was signed by a trusted cert
         """
         if self.trusted_cert_list:
             client_ticket = SfaTicket(string=ticket)
         """
         if self.trusted_cert_list:
             client_ticket = SfaTicket(string=ticket)
@@ -301,13 +322,14 @@ class Auth:
         rl = Rights()
         type = reg_record.type
 
         rl = Rights()
         type = reg_record.type
 
-        logger.debug("entering determine_user_rights with record %s and caller_hrn %s"%(reg_record, caller_hrn))
+        logger.debug("entering determine_user_rights with record %s and caller_hrn %s"%\
+                     (reg_record, caller_hrn))
 
         if type == 'slice':
             # researchers in the slice are in the DB as-is
             researcher_hrns = [ user.hrn for user in reg_record.reg_researchers ]
             # locating PIs attached to that slice
 
         if type == 'slice':
             # researchers in the slice are in the DB as-is
             researcher_hrns = [ user.hrn for user in reg_record.reg_researchers ]
             # locating PIs attached to that slice
-            slice_pis=reg_record.get_pis()
+            slice_pis = reg_record.get_pis()
             pi_hrns = [ user.hrn for user in slice_pis ]
             if (caller_hrn in researcher_hrns + pi_hrns):
                 rl.add('refresh')
             pi_hrns = [ user.hrn for user in slice_pis ]
             if (caller_hrn in researcher_hrns + pi_hrns):
                 rl.add('refresh')