* tried to put some sense in the way things get logged, at least on server-side for now
[sfa.git] / sfa / trust / auth.py
index 3b0af02..9cb905d 100644 (file)
@@ -1,20 +1,20 @@
 #
 #
-# GeniAPI authentication 
+# SfaAPI authentication 
 #
 ### $Id$
 ### $URL$
 #
 
 #
 ### $Id$
 ### $URL$
 #
 
-import time
 
 
+from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
-from sfa.trust.rights import RightList
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
-from sfa.util.genitable import GeniTable
 from sfa.util.config import *
 from sfa.util.config import *
-from sfa.util.misc import *
+from sfa.util.namespace import *
+from sfa.util.sfaticket import *
+import sys
 
 class Auth:
     """
 
 class Auth:
     """
@@ -24,12 +24,35 @@ class Auth:
     def __init__(self, peer_cert = None, config = None ):
         self.peer_cert = peer_cert
         self.hierarchy = Hierarchy()
     def __init__(self, peer_cert = None, config = None ):
         self.peer_cert = peer_cert
         self.hierarchy = Hierarchy()
-        self.trusted_cert_list = TrustedRootList().get_list() 
         if not config:
         if not config:
-            self.config = Config() 
-    
+            self.config = Config()
+        self.load_trusted_certs()
+
+    def load_trusted_certs(self):
+        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = TrustedRootList(self.config.get_trustedroots_dir()).get_file_list()
 
 
-    def check(self, cred, operation):
+        
+        
+    def checkCredentials(self, creds, operation, hrn = None):
+        valid = []
+        if not isinstance(creds, list):
+            creds = [creds]
+        for cred in creds:
+            try:
+                self.check(cred, operation, hrn)
+                valid.append(cred)
+            except:
+                error = sys.exc_info()[:2]
+                continue
+            
+        if not len(valid):
+            raise InsufficientRights('Access denied: %s -- %s' % (error[0],error[1]))
+        
+        return valid
+        
+        
+    def check(self, cred, operation, hrn = None):
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
@@ -44,11 +67,10 @@ class Auth:
         # make sure the client_gid is not blank
         if not self.client_gid:
             raise MissingCallerGID(self.client_cred.get_subject())
         # make sure the client_gid is not blank
         if not self.client_gid:
             raise MissingCallerGID(self.client_cred.get_subject())
-
-        # make sure the client_gid matches client's certificate
-        peer_cert = self.peer_cert
-        if not peer_cert.is_pubkey(self.client_gid.get_pubkey()):
-            raise ConnectionKeyGIDMismatch(self.client_gid.get_subject())
+       
+        # validate the client cert if it exists
+        if self.peer_cert:
+            self.verifyPeerCert(self.peer_cert, self.client_gid)                   
 
         # make sure the client is allowed to perform the operation
         if operation:
 
         # make sure the client is allowed to perform the operation
         if operation:
@@ -56,14 +78,91 @@ class Auth:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
-            self.client_cred.verify_chain(self.trusted_cert_list)
-            if self.client_gid:
-                self.client_gid.verify_chain(self.trusted_cert_list)
-            if self.object_gid:
-                self.object_gid.verify_chain(self.trusted_cert_list)
-
+            self.client_cred.verify(self.trusted_cert_file_list)
+        else:
+           raise MissingTrustedRoots(self.config.get_trustedroots_dir())
+       
+        # Make sure the credential's target matches the specified hrn. 
+        # This check does not apply to trusted peers 
+        trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list]
+        if hrn and self.client_gid.get_hrn() not in trusted_peers:
+            target_hrn = self.object_gid.get_hrn()
+            if not hrn == target_hrn:
+                raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \
+                                       (target_hrn, hrn) )       
         return True
 
         return True
 
+    def check_ticket(self, ticket):
+        """
+        Check if the tickt was signed by a trusted cert
+        """
+        if self.trusted_cert_list:
+            client_ticket = SfaTicket(string=ticket)
+            client_ticket.verify_chain(self.trusted_cert_list)
+        else:
+           raise MissingTrustedRoots(self.config.get_trustedroots_dir())
+
+        return True 
+
+    def verifyPeerCert(self, cert, gid):
+        # make sure the client_gid matches client's certificate
+        if not cert.is_pubkey(gid.get_pubkey()):
+            raise ConnectionKeyGIDMismatch(gid.get_subject()+":"+cert.get_subject())            
+
+    def verifyGidRequestHash(self, gid, hash, arglist):
+        key = gid.get_pubkey()
+        if not key.verify_string(str(arglist), hash):
+            raise BadRequestHash(hash)
+
+    def verifyCredRequestHash(self, cred, hash, arglist):
+        gid = cred.get_gid_caller()
+        self.verifyGidRequestHash(gid, hash, arglist)
+
+    def validateGid(self, gid):
+        if self.trusted_cert_list:
+            gid.verify_chain(self.trusted_cert_list)
+
+    def validateCred(self, cred):
+        if self.trusted_cert_list:
+            cred.verify(self.trusted_cert_file_list)
+
+    def authenticateGid(self, gidStr, argList, requestHash=None):
+        gid = GID(string = gidStr)
+        self.validateGid(gid)
+        # request_hash is optional
+        if requestHash:
+            self.verifyGidRequestHash(gid, requestHash, argList)
+        return gid
+
+    def authenticateCred(self, credStr, argList, requestHash=None):
+        cred = Credential(string = credStr)
+        self.validateCred(cred)
+        # request hash is optional
+        if requestHash:
+            self.verifyCredRequestHash(cred, requestHash, argList)
+        return cred
+
+    def authenticateCert(self, certStr, requestHash):
+        cert = Certificate(string=certStr)
+        self.validateCert(self, cert)   
+
+    def gidNoop(self, gidStr, value, requestHash):
+        self.authenticateGid(gidStr, [gidStr, value], requestHash)
+        return value
+
+    def credNoop(self, credStr, value, requestHash):
+        self.authenticateCred(credStr, [credStr, value], requestHash)
+        return value
+
+    def verify_cred_is_me(self, credential):
+        is_me = False 
+        cred = Credential(string=credential)
+        caller_gid = cred.get_gid_caller()
+        caller_hrn = caller_gid.get_hrn()
+        if caller_hrn != self.config.SFA_INTERFACE_HRN:
+            raise SfaPermissionDenied(self.config.SFA_INTEFACE_HRN)
+
+        return   
         
     def get_auth_info(self, auth_hrn):
         """
         
     def get_auth_info(self, auth_hrn):
         """
@@ -76,26 +175,6 @@ class Auth:
         return self.hierarchy.get_auth_info(auth_hrn)
 
 
         return self.hierarchy.get_auth_info(auth_hrn)
 
 
-    def get_auth_table(self, auth_name):
-        """
-        Given an authority name, return the database table for that authority.
-        If the databse table does not exist, then one will be automatically
-        created.
-
-        @param auth_name human readable name of authority
-        """
-        auth_info = self.get_auth_info(auth_name)
-        table = GeniTable(hrn=auth_name,
-                          cninfo=auth_info.get_dbinfo())
-        # if the table doesn't exist, then it means we haven't put any records
-        # into this authority yet.
-
-        if not table.exists():
-            print >> log, "Registry: creating table for authority", auth_name
-            table.create()
-    
-        return table
-
     def veriry_auth_belongs_to_me(self, name):
         """
         Verify that an authority belongs to our hierarchy. 
     def veriry_auth_belongs_to_me(self, name):
         """
         Verify that an authority belongs to our hierarchy. 
@@ -106,6 +185,7 @@ class Auth:
         @param auth_name human readable name of authority
         """
 
         @param auth_name human readable name of authority
         """
 
+        # get auth info will throw an exception if the authority doesnt exist
         self.get_auth_info(name)
 
 
         self.get_auth_info(name)
 
 
@@ -119,8 +199,8 @@ class Auth:
         """
         auth_name = self.get_authority(name)
         if not auth_name:
         """
         auth_name = self.get_authority(name)
         if not auth_name:
-            # the root authority belongs to the registry by default?
-            # TODO: is this true?
+            auth_name = name 
+        if name == self.config.SFA_INTERFACE_HRN:
             return
         self.verify_auth_belongs_to_me(auth_name) 
              
             return
         self.verify_auth_belongs_to_me(auth_name) 
              
@@ -143,39 +223,28 @@ class Auth:
             return
         if name.startswith(object_hrn + "."):
             return
             return
         if name.startswith(object_hrn + "."):
             return
-        if name.startswith(get_authority(name)):
-            return
+        #if name.startswith(get_authority(name)):
+            #return
     
         raise PermissionError(name)
 
     
         raise PermissionError(name)
 
-    def determine_user_rights(self, src_cred, record):
+    def determine_user_rights(self, caller_hrn, record):
         """
         Given a user credential and a record, determine what set of rights the
         user should have to that record.
         """
         Given a user credential and a record, determine what set of rights the
         user should have to that record.
-
-        Src_cred can be None when obtaining a user credential, but should be
-        set to a valid user credential when obtaining a slice or authority
-        credential.
-
+        
         This is intended to replace determine_rights() and
         verify_cancreate_credential()
         """
 
         This is intended to replace determine_rights() and
         verify_cancreate_credential()
         """
 
-        type = record.get_type()
-        if src_cred:
-            cred_object_hrn = src_cred.get_gid_object().get_hrn()
-        else:
-            # supplying src_cred==None is only valid when obtaining user
-            # credentials.
-            #assert(type == "user")
-            
-            cred_object_hrn = None
-
         rl = RightList()
         rl = RightList()
+        type = record['type']
+
 
         if type=="slice":
             researchers = record.get("researcher", [])
 
         if type=="slice":
             researchers = record.get("researcher", [])
-            if (cred_object_hrn in researchers):
+            pis = record.get("PI", [])
+            if (caller_hrn in researchers + pis):
                 rl.add("refresh")
                 rl.add("embed")
                 rl.add("bind")
                 rl.add("refresh")
                 rl.add("embed")
                 rl.add("bind")
@@ -183,12 +252,17 @@ class Auth:
                 rl.add("info")
 
         elif type == "authority":
                 rl.add("info")
 
         elif type == "authority":
-            pis = record.get("pi", [])
+            pis = record.get("PI", [])
             operators = record.get("operator", [])
             operators = record.get("operator", [])
-            rl.add("authority,sa,ma")
-            if (cred_object_hrn in pis):
+            if (caller_hrn == self.config.SFA_INTERFACE_HRN):
+                rl.add("authority")
+                rl.add("sa")
+                rl.add("ma")
+            if (caller_hrn in pis):
+                rl.add("authority")
                 rl.add("sa")
                 rl.add("sa")
-            if (cred_object_hrn in operators):
+            if (caller_hrn in operators):
+                rl.add("authority")
                 rl.add("ma")
 
         elif type == "user":
                 rl.add("ma")
 
         elif type == "user":
@@ -196,6 +270,9 @@ class Auth:
             rl.add("resolve")
             rl.add("info")
 
             rl.add("resolve")
             rl.add("info")
 
+        elif type == "node":
+            rl.add("operator")
+
         return rl
 
     def verify_cancreate_credential(self, src_cred, record):
         return rl
 
     def verify_cancreate_credential(self, src_cred, record):
@@ -224,3 +301,20 @@ class Auth:
 
     def get_authority(self, hrn):
         return get_authority(hrn)
 
     def get_authority(self, hrn):
         return get_authority(hrn)
+
+    def filter_creds_by_caller(self, creds, caller_hrn):
+        """
+        Returns a list of creds who's gid caller matches the 
+        specified caller hrn
+        """
+        if not isinstance(creds, list):
+            creds = [creds]
+        creds = []
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                    creds.append(cred)
+            except: pass
+        return creds
+