enforce rights subsets in credentials
authorScott Baker <bakers@cs.arizona.edu>
Wed, 24 Sep 2008 22:20:10 +0000 (22:20 +0000)
committerScott Baker <bakers@cs.arizona.edu>
Wed, 24 Sep 2008 22:20:10 +0000 (22:20 +0000)
tests/testRights.py
util/credential.py
util/excep.py
util/rights.py

index 3670e5d..a9c29a0 100644 (file)
@@ -14,6 +14,22 @@ class TestRight(unittest.TestCase):
       self.assert_(right.can_perform("getticket"))
       self.assert_(not right.can_perform("resolve"))
 
+   def testIsSuperset(self):
+      pright = Right("sa")
+      cright = Right("embed")
+      self.assert_(pright.is_superset(cright))
+      self.assert_(not cright.is_superset(pright))
+
+      pright = Right("embed")
+      cright = Right("embed")
+      self.assert_(pright.is_superset(cright))
+      self.assert_(cright.is_superset(pright))
+
+      pright = Right("control")
+      cright = Right("embed")
+      self.assert_(not pright.is_superset(cright))
+      self.assert_(not cright.is_superset(pright))
+
 class TestRightList(unittest.TestCase):
     def setUp(self):
         pass
@@ -44,6 +60,27 @@ class TestRightList(unittest.TestCase):
         self.assert_(rightList.can_perform("getticket"))
         self.assert_(rightList.can_perform("resolve"))
 
+    def testIsSuperset(self):
+        pRightList = RightList(string="sa")
+        cRightList = RightList(string="embed")
+        self.assert_(pRightList.is_superset(cRightList))
+        self.assert_(not cRightList.is_superset(pRightList))
+
+        pRightList = RightList(string="embed")
+        cRightList = RightList(string="embed")
+        self.assert_(pRightList.is_superset(cRightList))
+        self.assert_(cRightList.is_superset(pRightList))
+
+        pRightList = RightList(string="control")
+        cRightList = RightList(string="embed")
+        self.assert_(not pRightList.is_superset(cRightList))
+        self.assert_(not cRightList.is_superset(pRightList))
+
+        pRightList = RightList(string="control,sa")
+        cRightList = RightList(string="embed")
+        self.assert_(pRightList.is_superset(cRightList))
+        self.assert_(not cRightList.is_superset(pRightList))
+
 
 if __name__ == "__main__":
     unittest.main()
index c4e16ea..0daaf42 100644 (file)
@@ -1,6 +1,6 @@
-# credential.py
+##
 #
-# implements GENI credentials
+# Implements Geni Credentials
 #
 # Credentials are layered on top of certificates, and are essentially a
 # certificate that stores a tuple of parameters.
@@ -10,6 +10,7 @@ from rights import *
 from gid import *
 import xmlrpclib
 
+##
 # Credential is a tuple:
 #     (GIDCaller, GIDObject, LifeTime, Privileges, Delegate)
 #
@@ -24,61 +25,117 @@ class Credential(Certificate):
     privileges = None
     delegate = False
 
+    ##
+    # Create a Credential object
+    #
+    # @param create If true, create a blank x509 certificate
+    # @param subject If subject!=None, create an x509 cert with the subject name
+    # @param string If string!=None, load the credential from the string
+    # @param filename If filename!=None, load the credential from the file
+
     def __init__(self, create=False, subject=None, string=None, filename=None):
         Certificate.__init__(self, create, subject, string, filename)
 
-    def create_similar(self):
-        return Credential()
+    ##
+    # set the GID of the caller
+    #
+    # @param gid GID object of the caller
 
     def set_gid_caller(self, gid):
         self.gidCaller = gid
 
+    ##
+    # get the GID of the object
+
     def get_gid_caller(self):
         if not self.gidCaller:
             self.decode()
         return self.gidCaller
 
+    ##
+    # set the GID of the object
+    #
+    # @param gid GID object of the object
+
     def set_gid_object(self, gid):
         self.gidObject = gid
 
+    ##
+    # get the GID of the object
+
     def get_gid_object(self):
         if not self.gidObject:
             self.decode()
         return self.gidObject
 
+    ##
+    # set the lifetime of this credential
+    #
+    # @param lifetime lifetime of credential
+
     def set_lifetime(self, lifeTime):
         self.lifeTime = lifeTime
 
+    ##
+    # get the lifetime of the credential
+
     def get_lifetime(self):
         if not self.lifeTime:
             self.decode()
         return self.lifeTime
 
+    ##
+    # set the delegate bit
+    #
+    # @param delegate boolean (True or False)
+
     def set_delegate(self, delegate):
         self.delegate = delegate
 
+    ##
+    # get the delegate bit
+
     def get_delegate(self):
         if not self.delegate:
             self.decode()
         return self.delegate
 
+    ##
+    # set the privileges
+    #
+    # @param privs either a comma-separated list of privileges of a RightList object
+
     def set_privileges(self, privs):
         if isinstance(privs, str):
             self.privileges = RightList(string = privs)
         else:
             self.privileges = privs
 
+    ##
+    # return the privileges as a RightList object
+
     def get_privileges(self):
         if not self.privileges:
             self.decode()
         return self.privileges
 
+    ##
+    # determine whether the credential allows a particular operation to be
+    # performed
+    #
+    # @param op_name string specifying name of operation ("lookup", "update", etc)
+
     def can_perform(self, op_name):
         rights = self.get_privileges()
         if not rights:
             return False
         return rights.can_perform(op_name)
 
+    ##
+    # Encode the attributes of the credential into a string and store that
+    # string in the alt-subject-name field of the X509 object. This should be
+    # done immediately before signing the credential.
+
     def encode(self):
         dict = {"gidCaller": None,
                 "gidObject": None,
@@ -94,6 +151,11 @@ class Credential(Certificate):
         str = xmlrpclib.dumps((dict,), allow_none=True)
         self.set_data(str)
 
+    ##
+    # Retrieve the attributes of the credential from the alt-subject-name field
+    # of the X509 certificate. This is automatically done by the various
+    # get_* methods of this class and should not need to be called explicitly.
+
     def decode(self):
         data = self.get_data()
         if data:
@@ -122,6 +184,12 @@ class Credential(Certificate):
         else:
             self.gidObject = None
 
+    ##
+    # Verify that a chain of credentials is valid (see cert.py:verify). In
+    # addition to the checks for ordinary certificates, verification also
+    # ensures that the delegate bit was set by each parent in the chain. If
+    # a delegate bit was not set, then an exception is thrown.
+
     def verify_chain(self, trusted_certs = None):
         # do the normal certificate verification stuff
         Certificate.verify_chain(self, trusted_certs)
@@ -131,10 +199,18 @@ class Credential(Certificate):
             if not self.parent.get_delegate():
                 raise MissingDelegateBit(self.parent.get_subject())
 
-            # XXX todo: make sure child rights are a subset of parent rights
+            # make sure the rights given to the child are a subset of the
+            # parents rights
+            if not self.parent.get_privileges().is_superset(self.get_privileges()):
+                raise ChildRightsNotSubsetOfParent(self.get_subject())
 
         return
 
+    ##
+    # Dump the contents of a credential to stdout in human-readable format
+    #
+    # @param dump_parents If true, also dump the parent certificates
+
     def dump(self, dump_parents=False):
         print "CREDENTIAL", self.get_subject()
 
@@ -156,7 +232,3 @@ class Credential(Certificate):
            print "PARENT",
            self.parent.dump(dump_parents)
 
-
-
-
-
index 0555709..5a6c99f 100644 (file)
@@ -97,14 +97,33 @@ class MissingDelegateBit(Exception):
     def __str__(self):
         return repr(self.value)
 
-class MissingParent(Exception):
+class ChildRightsNotSubsetOfParent(Exception):
     def __init__(self, value):
         self.value = value
     def __str__(self):
         return repr(self.value)
 
-class NotSignedByParent(Exception):
+class CertMissingParent(Exception):
     def __init__(self, value):
         self.value = value
     def __str__(self):
         return repr(self.value)
+
+class CertNotSignedByParent(Exception):
+    def __init__(self, value):
+        self.value = value
+    def __str__(self):
+        return repr(self.value)
+
+class GidInvalidParentHrn(Exception):
+    def __init__(self, value):
+        self.value = value
+    def __str__(self):
+        return repr(self.value)
+
+class SliverDoesNotExist(Exception):
+    def __init__(self, value):
+        self.value = value
+    def __str__(self):
+        return repr(self.value)
+
index b66c35e..ac5e0fc 100644 (file)
@@ -12,7 +12,7 @@ privilege_table = {"authority": ["*"],
                    "sa": ["*"],
                    "embed": ["getticket", "createslice", "deleteslice", "updateslice"],
                    "bind": ["getticket", "loanresources"],
-                   "control": ["updateslice", "stopslice", "startslice", "deleteslice"],
+                   "control": ["updateslice", "stopslice", "startslice", "deleteslice", "resetslice"],
                    "info": ["listslices", "listcomponentresources", "getsliceresources"],
                    "ma": ["*"]}
 
@@ -33,6 +33,19 @@ class Right:
 
       return (op_name.lower() in allowed_ops)
 
+   def is_superset(self, child):
+      my_allowed_ops = privilege_table.get(self.kind.lower(), None)
+      child_allowed_ops = privilege_table.get(child.kind.lower(), None)
+
+      if "*" in my_allowed_ops:
+          return True
+
+      for right in child_allowed_ops:
+          if not right in my_allowed_ops:
+              return False
+
+      return True
+
 # a "RightList" is a list of privileges
 
 class RightList:
@@ -41,6 +54,11 @@ class RightList:
         if string:
             self.load_from_string(string)
 
+    def add(self, right):
+        if isinstance(right, str):
+            right = Right(kind = right)
+        self.rights.append(right)
+
     def load_from_string(self, string):
         self.rights = []
 
@@ -65,4 +83,13 @@ class RightList:
                 return True
         return False
 
+    def is_superset(self, child):
+        for child_right in child.rights:
+            allowed = False
+            for my_right in self.rights:
+                if my_right.is_superset(child_right):
+                    allowed = True
+            if not allowed:
+                return False
+        return True