Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / trust / auth.py
index 6688767..9cb905d 100644 (file)
@@ -5,17 +5,16 @@
 ### $URL$
 #
 
 ### $URL$
 #
 
-import time
 
 
+from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
 from sfa.trust.credential import Credential
 from sfa.trust.trustedroot import TrustedRootList
-from sfa.trust.rights import RightList
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.config import *
 from sfa.util.namespace import *
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.config import *
 from sfa.util.namespace import *
-from sfa.trust.gid import GID
 from sfa.util.sfaticket import *
 from sfa.util.sfaticket import *
+import sys
 
 class Auth:
     """
 
 class Auth:
     """
@@ -27,10 +26,33 @@ class Auth:
         self.hierarchy = Hierarchy()
         if not config:
             self.config = Config()
         self.hierarchy = Hierarchy()
         if not config:
             self.config = Config()
-        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.load_trusted_certs()
 
 
+    def load_trusted_certs(self):
+        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = TrustedRootList(self.config.get_trustedroots_dir()).get_file_list()
 
 
-    def check(self, cred, operation):
+        
+        
+    def checkCredentials(self, creds, operation, hrn = None):
+        valid = []
+        if not isinstance(creds, list):
+            creds = [creds]
+        for cred in creds:
+            try:
+                self.check(cred, operation, hrn)
+                valid.append(cred)
+            except:
+                error = sys.exc_info()[:2]
+                continue
+            
+        if not len(valid):
+            raise InsufficientRights('Access denied: %s -- %s' % (error[0],error[1]))
+        
+        return valid
+        
+        
+    def check(self, cred, operation, hrn = None):
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
         """
         Check the credential against the peer cert (callerGID included 
         in the credential matches the caller that is connected to the 
@@ -56,14 +78,18 @@ class Auth:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
                 raise InsufficientRights(operation)
 
         if self.trusted_cert_list:
-            self.client_cred.verify_chain(self.trusted_cert_list)
-            if self.client_gid:
-                self.client_gid.verify_chain(self.trusted_cert_list)
-            if self.object_gid:
-                self.object_gid.verify_chain(self.trusted_cert_list)
+            self.client_cred.verify(self.trusted_cert_file_list)
         else:
            raise MissingTrustedRoots(self.config.get_trustedroots_dir())
         else:
            raise MissingTrustedRoots(self.config.get_trustedroots_dir())
-
+       
+        # Make sure the credential's target matches the specified hrn. 
+        # This check does not apply to trusted peers 
+        trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list]
+        if hrn and self.client_gid.get_hrn() not in trusted_peers:
+            target_hrn = self.object_gid.get_hrn()
+            if not hrn == target_hrn:
+                raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \
+                                       (target_hrn, hrn) )       
         return True
 
     def check_ticket(self, ticket):
         return True
 
     def check_ticket(self, ticket):
@@ -80,17 +106,8 @@ class Auth:
 
     def verifyPeerCert(self, cert, gid):
         # make sure the client_gid matches client's certificate
 
     def verifyPeerCert(self, cert, gid):
         # make sure the client_gid matches client's certificate
-        if not cert:
-            peer_cert = self.peer_cert
-        else:
-            peer_cert = cert
-
-        if not gid:
-            peer_gid = self.client_gid
-        else:
-            peer_gid = gid
-        if not peer_cert.is_pubkey(peer_gid.get_pubkey()):
-            raise ConnectionKeyGIDMismatch(peer_gid.get_subject())            
+        if not cert.is_pubkey(gid.get_pubkey()):
+            raise ConnectionKeyGIDMismatch(gid.get_subject()+":"+cert.get_subject())            
 
     def verifyGidRequestHash(self, gid, hash, arglist):
         key = gid.get_pubkey()
 
     def verifyGidRequestHash(self, gid, hash, arglist):
         key = gid.get_pubkey()
@@ -107,13 +124,7 @@ class Auth:
 
     def validateCred(self, cred):
         if self.trusted_cert_list:
 
     def validateCred(self, cred):
         if self.trusted_cert_list:
-            cred.verify_chain(self.trusted_cert_list)
-            caller_gid = cred.get_gid_caller()
-            object_gid = cred.get_gid_object()
-            if caller_gid:
-                caller_gid.verify_chain(self.trusted_cert_list)
-            if object_gid:
-                object_gid.verify_chain(self.trusted_cert_list)
+            cred.verify(self.trusted_cert_file_list)
 
     def authenticateGid(self, gidStr, argList, requestHash=None):
         gid = GID(string = gidStr)
 
     def authenticateGid(self, gidStr, argList, requestHash=None):
         gid = GID(string = gidStr)
@@ -229,6 +240,7 @@ class Auth:
         rl = RightList()
         type = record['type']
 
         rl = RightList()
         type = record['type']
 
+
         if type=="slice":
             researchers = record.get("researcher", [])
             pis = record.get("PI", [])
         if type=="slice":
             researchers = record.get("researcher", [])
             pis = record.get("PI", [])
@@ -243,11 +255,15 @@ class Auth:
             pis = record.get("PI", [])
             operators = record.get("operator", [])
             if (caller_hrn == self.config.SFA_INTERFACE_HRN):
             pis = record.get("PI", [])
             operators = record.get("operator", [])
             if (caller_hrn == self.config.SFA_INTERFACE_HRN):
-                rl.add("authority,sa,ma",)
+                rl.add("authority")
+                rl.add("sa")
+                rl.add("ma")
             if (caller_hrn in pis):
             if (caller_hrn in pis):
-                rl.add("authority,sa")
+                rl.add("authority")
+                rl.add("sa")
             if (caller_hrn in operators):
             if (caller_hrn in operators):
-                rl.add("authority,ma")
+                rl.add("authority")
+                rl.add("ma")
 
         elif type == "user":
             rl.add("refresh")
 
         elif type == "user":
             rl.add("refresh")
@@ -285,3 +301,20 @@ class Auth:
 
     def get_authority(self, hrn):
         return get_authority(hrn)
 
     def get_authority(self, hrn):
         return get_authority(hrn)
+
+    def filter_creds_by_caller(self, creds, caller_hrn):
+        """
+        Returns a list of creds who's gid caller matches the 
+        specified caller hrn
+        """
+        if not isinstance(creds, list):
+            creds = [creds]
+        creds = []
+        for cred in creds:
+            try:
+                tmp_cred = Credential(string=cred)
+                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                    creds.append(cred)
+            except: pass
+        return creds
+