fix bugs
[sfa.git] / plc / aggregate.py
index a5c061b..cf4b257 100644 (file)
@@ -4,37 +4,49 @@ import datetime
 import time
 import xmlrpclib
 
-from geniserver import *
-from excep import *
-from misc import *
-from config import Config
+from util.geniserver import *
+from util.cert import *
+from util.trustedroot import *
+from util.excep import *
+from util.misc import *
+from util.config import Config
 
-conf = Config()
-basedir = conf.GENI_BASE_DIR + os.sep 
-server_basedir = basedir + os.sep + "plc" + os.sep 
-agg_hrn = conf.GENI_INTERFACE_HRN
-
-class Aggregate:
+class Aggregate(GeniServer):
 
     hrn = None
-    aggregate_file = None
     components_file = None
-    slices_file = None 
     components_ttl = None
     components = []
-    slices = []        
-    policies = {}
+    whitelist_file = None
+    blacklist_file = None      
+    policy = {}
     timestamp = None
     threshold = None   
-    server = None
-     
+    shell = None
+  
+    ##
+    # Create a new aggregate object.
+    #
+    # @param ip the ip address to listen on
+    # @param port the port to listen on
+    # @param key_file private key filename of registry
+    # @param cert_file certificate filename containing public key (could be a GID file)     
 
-    def __init__(self, hrn = agg_hrn, components_ttl = 1):
-        self.hrn = hrn
+    def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/util/geni_config"):
+        GeniServer.__init__(ip, port, keyfile, cert_file)
+       
+       conf = Config(config)
+        basedir = conf.GENI_BASE_DIR + os.sep
+        server_basedir = basedir + os.sep + "plc" + os.sep
+       self.hrn = conf.GENI_INTERFACE_HRN
        self.components_file = os.sep.join([server_basedir, 'components', hrn + '.comp'])
-       self.slices_file = os.sep.join([server_basedir, 'components', hrn + '.slices'])
+       self.whitelist_file = os.sep.join([server_basedir, 'policy', 'whitelist'])
+       self.blacklist_file = os.sep.join([server_basedir, 'policy', 'blacklist'])
        self.timestamp_file = os.sep.join([server_basedir, 'components', hrn + '.timestamp']) 
        self.components_ttl = components_ttl
+       self.policy['whitelist'] = []
+       self.policy['blacklist'] = []
+       self.connect()
 
     def connect(self):
        """
@@ -80,15 +92,12 @@ class Aggregate:
 
     def refresh_components(self):
        """
-       Update the cached list of nodes and slices.
+       Update the cached list of nodes.
        """
        print "refreshing"      
        # resolve component hostnames 
        nodes = self.shell.GetNodes(self.auth, {}, ['hostname', 'site_id'])
        
-       # resolve slices
-       slices = self.shell.GetSlices(self.auth, {}, ['name', 'site_id'])
-   
        # resolve site login_bases
        site_ids = [node['site_id'] for node in nodes]
        sites = self.shell.GetSites(self.auth, site_ids, ['site_id', 'login_base'])
@@ -98,8 +107,16 @@ class Aggregate:
 
        # convert plc names to geni hrn
        self.components = [self.hostname_to_hrn(site_dict[node['site_id']], node['hostname']) for node in nodes]
-       self.slices = [self.slicename_to_hrn(slice['name']) for slice in slices]
-               
+
+       # apply policy. Do not allow nodes found in blacklist, only allow nodes found in whitelist
+       whitelist_policy = lambda node: node in self.policy['whitelist']
+       blacklist_policy = lambda node: node not in self.policy['blacklist']
+
+       if self.policy['blacklist']:
+           self.components = blacklist_policy(self.components)
+       if self.policy['whitelist']:
+           self.components = whitelist_policy(self.components)
+               
        # update timestamp and threshold
        self.timestamp = datetime.datetime.now()
        delta = datetime.timedelta(hours=self.components_ttl)
@@ -108,29 +125,21 @@ class Aggregate:
        f = open(self.components_file, 'w')
        f.write(str(self.components))
        f.close()
-       f = open(self.slices_file, 'w')
-       f.write(str(self.slices))
-       f.close()
        f = open(self.timestamp_file, 'w')
        f.write(str(self.threshold))
        f.close()
  
     def load_components(self):
        """
-       Read cached list of nodes and slices.
+       Read cached list of nodes.
        """
-       print "loading"
+       print "loading components"
        # Read component list from cached file 
        if os.path.exists(self.components_file):
            f = open(self.components_file, 'r')
            self.components = eval(f.read())
            f.close()
        
-       if os.path.exists(self.slices_file):
-            f = open(self.components_file, 'r')
-            self.slices = eval(f.read())
-            f.close()
-
        time_format = "%Y-%m-%d %H:%M:%S"
        if os.path.exists(self.timestamp_file):
            f = open(self.timestamp_file, 'r')
@@ -140,6 +149,32 @@ class Aggregate:
             self.threshold = self.timestamp + delta
            f.close()   
 
+    def load_policy(self):
+       """
+       Read the list of blacklisted and whitelisted nodes.
+       """
+       whitelist = []
+       blacklist = []
+       if os.path.exists(self.whitelist_file):
+            f = open(self.whitelist_file, 'r')
+           lines = f.readlines()
+            f.close()
+           for line in lines:
+               line = line.strip().replace(" ", "").replace("\n", "")
+               whitelist.extend(line.split(","))
+                       
+       
+       if os.path.exists(self.blacklist_file):
+            f = open(self.blacklist_file, 'r')
+            lines = f.readlines()
+            f.close()
+           for line in lines:
+                line = line.strip().replace(" ", "").replace("\n", "")
+                blacklist.extend(line.split(","))
+
+       self.policy['whitelist'] = whitelist
+       self.policy['blacklist'] = blacklist
+
     def get_components(self):
        """
        Return a list of components at this aggregate.
@@ -152,47 +187,82 @@ class Aggregate:
        elif now < self.threshold and not self.components: 
            self.load_components()
        return self.components
-   
      
-    def get_slices(self):
-       """
-       Return a list of instnatiated slices at this aggregate.
-       """
-       now = datetime.datetime.now()
-       #self.load_components()
-       if not self.threshold or not self.timestamp or now > self.threshold:
-           self.refresh_components()
-       elif now < self.threshold and not self.slices:
-           self.load_components()
-       return self.slices
-
     def get_rspec(self, hrn, type):
        #rspec = Rspec()
+       #rspec['nodespec'] = {'name': self.conf.GENI_INTERFACE_HRN}
+       #rsepc['nodespec']['nodes'] = []
        if type in ['node']:
            nodes = self.shell.GetNodes(self.auth)
        elif type in ['slice']:
-           slices = self.shell.GetSlices(self.auth)
+           slicename = hrn_to_pl_slicename(hrn)
+           slices = self.shell.GetSlices(self.auth, [slicename])
+           node_ids = slices[0]['node_ids']
+           nodes = self.shell.GetNodes(self.auth, node_ids) 
+           for node in nodes:
+               nodespec = {'name': node['hostname'], 'type': 'std'}
+           #   rspec['nodespec']['nodes'].append(nodespec)
+               
        elif type in ['aggregate']:
            pass
 
+       #return rspec
+
     def get_resources(self, slice_hrn):
        """
        Return the current rspec for the specified slice.
        """
        slicename = hrn_to_plcslicename(slice_hrn)
-       rspec = self.get_rspec(slicenamem, 'slice' )
+       rspec = self.get_rspec(slicenamem, 'slice')
         
        return rspec
  
-    def create_slice(self, slice_hrn, rspec):
+    def create_slice(self, slice_hrn, rspec, attributes):
        """
        Instantiate the specified slice according to whats defined in the rspec.
        """
        slicename = self.hrn_to_plcslicename(slice_hrn)
        #spec = Rspec(rspec)
-       #components = spec.components()
-       #shell.AddSliceToNodes(self.auth, slicename, components)
+       #nodespec = spec['networks']['nodes']
+       #nodes = [nspec['name'] for nspec in nodespec]
+       #self.shell.AddSliceToNodes(self.auth, slicename, nodes)
+       #for attribute in attributes:
+           #type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
+           #shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
+
        return 1
+
+    def update_slice(self, slice_hrn, rspec, attributes):
+       """
+       Update the specified slice.
+       """
+       # Get slice info
+       slicename = self.hrn_to_plcslicename(slice_hrn)
+        slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
+        if not slice:
+            raise RecordNotFound(slice_hrn)
+        slice = slices[0]
+
+       # find out where this slice is currently running
+        nodes = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
+        hostnames = [node['hostname'] for node in nodes]
+
+       # get netspec details
+       #spec = Rspec(rspec)
+        #nodespec = spec['networks']['nodes']
+        #nodes = [nspec['name'] for nspec in nodespec]
+
+       # remove nodes not in rspec
+       #delete_nodes = set(hostnames).difference(nodes)
+       # add nodes from rspec
+       #added_nodes = set(nodes).difference(hostnames)
+       
+        #shell.AddSliceToNodes(self.auth, slicename, added_nodes)
+       #shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
+
+        #for attribute in attributes:
+            #type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
+            #shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
        
     def delete_slice_(self, slice_hrn):
        """
@@ -200,9 +270,12 @@ class Aggregate:
        free up the resources it was using.
        """
        slicename = self.hrn_to_plcslicename(slice_hrn)
-       rspec = self.get_resources(slice_hrn)
-       components = rspec.components()
-       shell.DeleteSliceFromNodes(self.auth, slicename, components)
+       slices = shell.GetSlices(self.auth, [slicename])
+       if not slice:
+           raise RecordNotFound(slice_hrn)
+       slice = slices[0]
+         
+       shell.DeleteSliceFromNodes(self.auth, slicename, slice['node_ids'])
        return 1
 
     def start_slice(self, slice_hrn):
@@ -233,6 +306,7 @@ class Aggregate:
        self.shell.UpdateSliceAttribute(self.auth, attribute_id, "0")
        return 1
 
+
     def reset_slice(self, slice_hrn):
        """
        Reset the slice
@@ -242,9 +316,71 @@ class Aggregate:
 
     def get_policy(self):
        """
-       Return this aggregates policy as an rspec
+       Return this aggregates policy.
        """
-       rspec = self.get_rspec(self.hrn, 'aggregate')
-       return rspec
+       
+       return self.policy
        
-                       
+       
+
+##############################
+## Server methods here for now
+##############################
+
+    def nodes(self):
+        return self..get_components()
+
+    #def slices(self):
+    #    return self.get_slices()
+
+    def resources(self, cred, hrn):
+        self.decode_authentication(cred, 'info')
+        self.verify_object_belongs_to_me(hrn)
+
+        return self.get_resources(hrn)
+
+    def create(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'embed')
+        self.verify_object_belongs_to_me(hrn)
+        return self.create(hrn)
+
+    def update(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'embed')
+        self.verify_object_belongs_to_me(hrn)
+        return self.update(hrn)        
+
+    def delete(self, cred, hrn):
+        self.decode_authentication(cred, 'embed')
+        self.verify_object_belongs_to_me(hrn)
+        return self.delete_slice(hrn)
+
+    def start(self, cred, hrn):
+        self.decode_authentication(cred, 'control')
+        return self.start(hrn)
+
+    def stop(self, cred, hrn):
+        self.decode_authentication(cred, 'control')
+        return self.stop(hrn)
+
+    def reset(self, cred, hrn):
+        self.decode_authentication(cred, 'control')
+        return self.reset(hrn)
+
+    def policy(self, cred):
+        self.decode_authentication(cred, 'info')
+        return self.get_policy()
+
+    def register_functions(self):
+        GeniServer.register_functions(self)
+
+        # Aggregate interface methods
+        self.server.register_function(self.components)
+        #self.server.register_function(self.slices)
+        self.server.register_function(self.resources)
+        self.server.register_function(self.create)
+        self.server.register_function(self.delete)
+        self.server.register_function(self.start)
+        self.server.register_function(self.stop)
+        self.server.register_function(self.reset)
+        self.server.register_function(self.policy)
+