Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / trust / auth.py
index c6e0d9d..9cb905d 100644 (file)
@@ -1,21 +1,20 @@
 #
 #
-# GeniAPI authentication 
+# SfaAPI authentication 
 #
 ### $Id$
 ### $URL$
 #
 
 #
 ### $Id$
 ### $URL$
 #
 
-import time
 
 
+from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
-from sfa.trust.rights import RightList
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
-from sfa.util.genitable import GeniTable
 from sfa.util.config import *
 from sfa.util.config import *
-from sfa.util.misc import *
-from sfa.trust.gid import GID
+from sfa.util.namespace import *
+from sfa.util.sfaticket import *
+import sys
 
 class Auth:
     """
 
 class Auth:
     """
@@ -27,10 +26,33 @@ class Auth:
         self.hierarchy = Hierarchy()
         if not config:
             self.config = Config()
         self.hierarchy = Hierarchy()
         if not config:
             self.config = Config()
-        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.load_trusted_certs()
 
 
+    def load_trusted_certs(self):
+        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = TrustedRootList(self.config.get_trustedroots_dir()).get_file_list()
 
 
-    def check(self, cred, operation):
+        
+        
+    def checkCredentials(self, creds, operation, hrn = None):
+        valid = []
+        if not isinstance(creds, list):
+            creds = [creds]
+        for cred in creds:
+            try:
+                self.check(cred, operation, hrn)
+                valid.append(cred)
+            except:
+                error = sys.exc_info()[:2]
+                continue
+            
+        if not len(valid):
+            raise InsufficientRights('Access denied: %s -- %s' % (error[0],error[1]))
+        
+        return valid
+        
+        
+    def check(self, cred, operation, hrn = None):
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
@@ -56,27 +78,36 @@ class Auth:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
-            self.client_cred.verify_chain(self.trusted_cert_list)
-            if self.client_gid:
-                self.client_gid.verify_chain(self.trusted_cert_list)
-            if self.object_gid:
-                self.object_gid.verify_chain(self.trusted_cert_list)
-
+            self.client_cred.verify(self.trusted_cert_file_list)
+        else:
+           raise MissingTrustedRoots(self.config.get_trustedroots_dir())
+       
+        # Make sure the credential's target matches the specified hrn. 
+        # This check does not apply to trusted peers 
+        trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list]
+        if hrn and self.client_gid.get_hrn() not in trusted_peers:
+            target_hrn = self.object_gid.get_hrn()
+            if not hrn == target_hrn:
+                raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \
+                                       (target_hrn, hrn) )       
         return True
 
         return True
 
-    def verifyPeerCert(self, cert, gid):
-        # make sure the client_gid matches client's certificate
-        if not cert:
-            peer_cert = self.peer_cert
+    def check_ticket(self, ticket):
+        """
+        Check if the tickt was signed by a trusted cert
+        """
+        if self.trusted_cert_list:
+            client_ticket = SfaTicket(string=ticket)
+            client_ticket.verify_chain(self.trusted_cert_list)
         else:
         else:
-            peer_cert = cert
+           raise MissingTrustedRoots(self.config.get_trustedroots_dir())
 
 
-        if not gid:
-            peer_gid = self.client_gid
-        else:
-            peer_gid = gid
-        if not peer_cert.is_pubkey(peer_gid.get_pubkey()):
-            raise ConnectionKeyGIDMismatch(peer_gid.get_subject())            
+        return True 
+
+    def verifyPeerCert(self, cert, gid):
+        # make sure the client_gid matches client's certificate
+        if not cert.is_pubkey(gid.get_pubkey()):
+            raise ConnectionKeyGIDMismatch(gid.get_subject()+":"+cert.get_subject())            
 
     def verifyGidRequestHash(self, gid, hash, arglist):
         key = gid.get_pubkey()
 
     def verifyGidRequestHash(self, gid, hash, arglist):
         key = gid.get_pubkey()
@@ -93,13 +124,7 @@ class Auth:
 
     def validateCred(self, cred):
         if self.trusted_cert_list:
 
     def validateCred(self, cred):
         if self.trusted_cert_list:
-            cred.verify_chain(self.trusted_cert_list)
-            caller_gid = cred.get_gid_caller()
-            object_gid = cred.get_gid_object()
-            if caller_gid:
-                caller_gid.verify_chain(self.trusted_cert_list)
-            if object_gid:
-                object_gid.verify_chain(self.trusted_cert_list)
+            cred.verify(self.trusted_cert_file_list)
 
     def authenticateGid(self, gidStr, argList, requestHash=None):
         gid = GID(string = gidStr)
 
     def authenticateGid(self, gidStr, argList, requestHash=None):
         gid = GID(string = gidStr)
@@ -135,7 +160,7 @@ class Auth:
         caller_gid = cred.get_gid_caller()
         caller_hrn = caller_gid.get_hrn()
         if caller_hrn != self.config.SFA_INTERFACE_HRN:
         caller_gid = cred.get_gid_caller()
         caller_hrn = caller_gid.get_hrn()
         if caller_hrn != self.config.SFA_INTERFACE_HRN:
-            raise GeniPermissionError(self.config.SFA_INTEFACE_HRN)
+            raise SfaPermissionDenied(self.config.SFA_INTEFACE_HRN)
 
         return   
         
 
         return   
         
@@ -203,34 +228,23 @@ class Auth:
     
         raise PermissionError(name)
 
     
         raise PermissionError(name)
 
-    def determine_user_rights(self, src_cred, record):
+    def determine_user_rights(self, caller_hrn, record):
         """
         Given a user credential and a record, determine what set of rights the
         user should have to that record.
         """
         Given a user credential and a record, determine what set of rights the
         user should have to that record.
-
-        Src_cred can be None when obtaining a user credential, but should be
-        set to a valid user credential when obtaining a slice or authority
-        credential.
-
+        
         This is intended to replace determine_rights() and
         verify_cancreate_credential()
         """
 
         This is intended to replace determine_rights() and
         verify_cancreate_credential()
         """
 
+        rl = RightList()
         type = record['type']
         type = record['type']
-        if src_cred:
-            cred_object_hrn = src_cred.get_gid_object().get_hrn()
-        else:
-            # supplying src_cred==None is only valid when obtaining user
-            # credentials.
-            #assert(type == "user")
-            
-            cred_object_hrn = None
 
 
-        rl = RightList()
 
         if type=="slice":
             researchers = record.get("researcher", [])
 
         if type=="slice":
             researchers = record.get("researcher", [])
-            if (cred_object_hrn in researchers):
+            pis = record.get("PI", [])
+            if (caller_hrn in researchers + pis):
                 rl.add("refresh")
                 rl.add("embed")
                 rl.add("bind")
                 rl.add("refresh")
                 rl.add("embed")
                 rl.add("bind")
@@ -238,19 +252,27 @@ class Auth:
                 rl.add("info")
 
         elif type == "authority":
                 rl.add("info")
 
         elif type == "authority":
-            rl.add("authority")
-            pis = record.get("pi", [])
+            pis = record.get("PI", [])
             operators = record.get("operator", [])
             operators = record.get("operator", [])
-            if (cred_object_hrn in pis):
-                rl.add("authority,sa")
-            if (cred_object_hrn in operators):
-                rl.add("authority,ma")
+            if (caller_hrn == self.config.SFA_INTERFACE_HRN):
+                rl.add("authority")
+                rl.add("sa")
+                rl.add("ma")
+            if (caller_hrn in pis):
+                rl.add("authority")
+                rl.add("sa")
+            if (caller_hrn in operators):
+                rl.add("authority")
+                rl.add("ma")
 
         elif type == "user":
             rl.add("refresh")
             rl.add("resolve")
             rl.add("info")
 
 
         elif type == "user":
             rl.add("refresh")
             rl.add("resolve")
             rl.add("info")
 
+        elif type == "node":
+            rl.add("operator")
+
         return rl
 
     def verify_cancreate_credential(self, src_cred, record):
         return rl
 
     def verify_cancreate_credential(self, src_cred, record):
@@ -279,3 +301,20 @@ class Auth:
 
     def get_authority(self, hrn):
         return get_authority(hrn)
 
     def get_authority(self, hrn):
         return get_authority(hrn)
+
+    def filter_creds_by_caller(self, creds, caller_hrn):
+        """
+        Returns a list of creds who's gid caller matches the 
+        specified caller hrn
+        """
+        if not isinstance(creds, list):
+            creds = [creds]
+        creds = []
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                    creds.append(cred)
+            except: pass
+        return creds
+