acl support for deployments
authorScott Baker <smbaker@gmail.com>
Fri, 6 Jun 2014 21:49:43 +0000 (14:49 -0700)
committerScott Baker <smbaker@gmail.com>
Fri, 6 Jun 2014 21:49:43 +0000 (14:49 -0700)
planetstack/core/acl.py [new file with mode: 0644]
planetstack/core/admin.py
planetstack/core/models/site.py
planetstack/core/models/sliver.py

diff --git a/planetstack/core/acl.py b/planetstack/core/acl.py
new file mode 100644 (file)
index 0000000..7fc6a02
--- /dev/null
@@ -0,0 +1,113 @@
+from fnmatch import fnmatch
+
+class AccessControlList:
+    def __init__(self, aclText=None):
+        self.rules = []
+        if aclText:
+            self.import_text(aclText)
+
+    def import_text(self, aclText):
+        # allow either newline or ';' to separate rules
+        aclText = aclText.replace("\n", ";")
+        for line in aclText.split(";"):
+            line = line.strip()
+            if line.startswith("#"):
+                continue
+
+            if line=="":
+                continue
+
+            parts = line.split()
+
+            if len(parts)==2 and (parts[1]=="all"):
+                # "allow all" has no pattern
+                parts = (parts[0], parts[1], "")
+
+            if len(parts)!=3:
+                raise ACLValidationError(line)
+
+            (action, object, pattern) = parts
+
+            if action not in ["allow", "deny"]:
+                raise ACLValidationError(line)
+
+            if object not in ["site", "user", "all"]:
+                raise ACLValidationError(line)
+
+            self.rules.append( (action, object, pattern) )
+
+    def __str__(self):
+        lines = []
+        for rule in self.rules:
+            lines.append( " ".join(rule) )
+        return ";\n".join(lines)
+
+    def test(self, user):
+        for rule in self.rules:
+            if self.match_rule(rule, user):
+                return rule[0]
+        return "deny"
+
+    def match_rule(self, rule, user):
+        (action, object, pattern) = rule
+
+        if (object == "site"):
+            if fnmatch(user.site.name, pattern):
+                return True
+        elif (object == "user"):
+            if fnmatch(user.email, pattern):
+                return True
+        elif (object == "all"):
+            return True
+
+        return False
+
+
+if __name__ == '__main__':
+    class fakesite:
+        def __init__(self, siteName):
+            self.name = siteName
+
+    class fakeuser:
+        def __init__(self, email, siteName):
+            self.email = email
+            self.site = fakesite(siteName)
+
+    u_scott = fakeuser("scott@onlab.us", "ON.Lab")
+    u_bill = fakeuser("bill@onlab.us", "ON.Lab")
+    u_andy = fakeuser("acb@cs.princeton.edu", "Princeton")
+    u_john = fakeuser("jhh@cs.arizona.edu", "Arizona")
+    u_hacker = fakeuser("somehacker@foo.com", "Not A Real Site")
+
+    # check the "deny all" rule
+    acl = AccessControlList("deny all")
+    assert(acl.test(u_scott) == "deny")
+
+    # a blank ACL results in "deny all"
+    acl = AccessControlList("")
+    assert(acl.test(u_scott) == "deny")
+
+    # check the "allow all" rule
+    acl = AccessControlList("allow all")
+    assert(acl.test(u_scott) == "allow")
+
+    # allow only one site
+    acl = AccessControlList("allow site ON.Lab")
+    assert(acl.test(u_scott) == "allow")
+    assert(acl.test(u_andy) == "deny")
+
+    # some complicated ACL
+    acl = AccessControlList("""allow site Princeton
+                 allow user *@cs.arizona.edu
+                 deny site Arizona
+                 deny user scott@onlab.us
+                 allow site ON.Lab""")
+
+    assert(acl.test(u_scott) == "deny")
+    assert(acl.test(u_bill) == "allow")
+    assert(acl.test(u_andy) == "allow")
+    assert(acl.test(u_john) == "allow")
+    assert(acl.test(u_hacker) == "deny")
+
+    print acl
+
index ea44376..fc12ada 100644 (file)
@@ -478,8 +478,11 @@ class DeploymentAdminForm(forms.ModelForm):
         model = Deployment
 
     def __init__(self, *args, **kwargs):
+      request = kwargs.pop('request', None)
       super(DeploymentAdminForm, self).__init__(*args, **kwargs)
 
+      self.fields['accessControl'].initial = "allow site " + request.user.site.name
+
       if self.instance and self.instance.pk:
         self.fields['sites'].initial = [x.site for x in self.instance.sitedeployments_set.all()]
 
@@ -524,9 +527,8 @@ class SiteAssocInline(PlStackTabularInline):
     suit_classes = 'suit-tab suit-tab-sites'
 
 class DeploymentAdmin(PlanetStackBaseAdmin):
-    #form = DeploymentAdminForm
     model = Deployment
-    fieldList = ['name','sites']
+    fieldList = ['name','sites', 'accessControl']
     fieldsets = [(None, {'fields': fieldList, 'classes':['suit-tab suit-tab-sites']})]
     inlines = [DeploymentPrivilegeInline,NodeInline,TagInline]
 
@@ -540,8 +542,17 @@ class DeploymentAdmin(PlanetStackBaseAdmin):
             kwargs["form"] = DeploymentAdminROForm
         else:
             kwargs["form"] = DeploymentAdminForm
-        return super(DeploymentAdmin,self).get_form(request, obj, **kwargs)
-\r
+        adminForm = super(DeploymentAdmin,self).get_form(request, obj, **kwargs)
+
+        # from stackexchange: pass the request object into the form
+
+        class AdminFormMetaClass(adminForm):
+           def __new__(cls, *args, **kwargs):
+               kwargs['request'] = request
+               return adminForm(*args, **kwargs)
+
+        return AdminFormMetaClass
+
 class ServiceAttrAsTabROInline(ReadOnlyTabularInline):
     model = ServiceAttribute
     fields = ['name','value']
index e675afb..1301ebf 100644 (file)
@@ -1,10 +1,10 @@
 import os
 from django.db import models
 from core.models import PlCoreBase
-#from core.models import Deployment
 from core.models import Tag
 from django.contrib.contenttypes import generic
 from geoposition.fields import GeopositionField
+from core.acl import AccessControlList
 
 class Site(PlCoreBase):
     """
@@ -83,7 +83,42 @@ class SitePrivilege(PlCoreBase):
 
 class Deployment(PlCoreBase):
     name = models.CharField(max_length=200, unique=True, help_text="Name of the Deployment")
-    #sites = models.ManyToManyField('Site', through='SiteDeployments', blank=True)
+
+    # smbaker: the default of 'allow all' is intended for evolutions of existing
+    #    deployments. When new deployments are created via the GUI, they are
+    #    given a default of 'allow site <site_of_creator>'
+    accessControl = models.TextField(max_length=200, blank=False, null=False, default="allow all",
+                                     help_text="Access control list that specifies which sites/users may use nodes in this deployment")
+
+    def get_acl(self):
+        return AccessControlList(self.accessControl)
+
+    def test_acl(self, slice=None, user=None):
+        potential_users=[]
+
+        if user:
+            potential_users.append(user)
+
+        if slice:
+            potential_users.append(slice.creator)
+            for priv in slice.slice_privileges.all():
+                if priv.user not in potential_users:
+                    potential_users.append(priv.user)
+
+        acl = self.get_acl()
+        for user in potential_users:
+            if acl.test(user) == "allow":
+                return True
+
+        return False
+
+    def select_by_acl(self, user):
+        acl = self.get_acl()
+        result = []
+        for deployment in Deployment.objects.all():
+            if acl.test(user):
+                result.append(deployment)
+        return result
 
     def __unicode__(self):  return u'%s' % (self.name)
 
index 6351bd1..0f37bc9 100644 (file)
@@ -37,13 +37,16 @@ class Sliver(PlCoreBase):
         else:
             return u'unsaved-sliver'
 
-
     def save(self, *args, **kwds):
         if not self.name:
             self.name = self.slice.name
         if not self.creator and hasattr(self, 'caller'):
             self.creator = self.caller
         self.deploymentNetwork = self.node.deployment
+
+        if not self.deploymentNetwork.test_acl(slice=self.slice):
+            raise exceptions.ValidationError("Deployment %s's ACL does not allow any of this slice %s's users" % (self.deploymentNetwork.name, self.slice.name))
+
         super(Sliver, self).save(*args, **kwds)
 
     def can_update(self, user):