fix bugs in refresh_components, getResources
[sfa.git] / geni / aggregate.py
index acaf16e..7c84d5f 100644 (file)
@@ -63,11 +63,12 @@ class Aggregate(GeniServer):
         timestamp_file = os.sep.join([self.server_basedir, 'agg.' + self.hrn + '.timestamp']) 
         self.timestamp = SimpleStorage(timestamp_file)
 
         timestamp_file = os.sep.join([self.server_basedir, 'agg.' + self.hrn + '.timestamp']) 
         self.timestamp = SimpleStorage(timestamp_file)
 
+        # How long before we refresh nodes cache
         self.nodes_ttl = 1
 
         self.connectPLC()
         self.connectRegistry()
         self.nodes_ttl = 1
 
         self.connectPLC()
         self.connectRegistry()
-        self.loadCredential()
+        #self.loadCredential()
 
     def connectRegistry(self):
         """
 
     def connectRegistry(self):
         """
@@ -119,22 +120,19 @@ class Aggregate(GeniServer):
         
         # see if this file exists
         try:
         
         # see if this file exists
         try:
-            cred = Credential(filename = ma_cred_filename)
+            cred = Credential(filename = ma_cred_filename, subject=self.hrn)
             self.credential = cred.save_to_string()
         except IOError:
             # get self credential
             self.credential = cred.save_to_string()
         except IOError:
             # get self credential
-            #self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            #self_credential = Credential(string = self_cred)
-            #self_credential.save_to_file(self_cred_filename)
+            self_cred = self.registry.get_credential(None, 'ma', self.hrn)
+            self_credential = Credential(string = self_cred)
+            self_credential.save_to_file(self_cred_filename)
 
             # get ma credential
 
             # get ma credential
-            #ma_cred = self.registry.get_gredential(self_cred)
-            #ma_credential = Credential(string = ma_cred)
-            #ma_credential.save_to_file(ma_cred_filename)
-
-            ma_cred = Certificate(filename = self.cert_file)
-            
-            self.credential = ma_cred.save_to_string()
+            ma_cred = self.registry.get_credential(self_cred)
+            ma_credential = Credential(string = ma_cred)
+            ma_credential.save_to_file(ma_cred_filename)
+            self.credential = ma_cred
 
     def hostname_to_hrn(self, login_base, hostname):
         """
 
     def hostname_to_hrn(self, login_base, hostname):
         """
@@ -153,50 +151,35 @@ class Aggregate(GeniServer):
     def refresh_components(self):
         """
         Update the cached list of nodes and save in 4 differnt formats
     def refresh_components(self):
         """
         Update the cached list of nodes and save in 4 differnt formats
-        (rspec, dns, ip, hrn)
+        (rspec, dns, ip)
         """
 
         """
 
-        node_details = {}
         # get node list in rspec format
         rspec = Rspec()
         rspec.parseString(self.get_rspec(self.hrn, 'aggregate'))
         # get node list in rspec format
         rspec = Rspec()
         rspec.parseString(self.get_rspec(self.hrn, 'aggregate'))
+        
         # filter nodes according to policy
         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
         # filter nodes according to policy
         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
-        # extract ifspec info to get ip's
+        
+        # extract ifspecs from rspec to get ip's
         ips = []
         ifspecs = rspec.getDictsByTagName('IfSpec')
         for ifspec in ifspecs:
             if ifspec.has_key('addr') and ifspec['addr']:
                 ips.append(ifspec['addr']) 
 
         ips = []
         ifspecs = rspec.getDictsByTagName('IfSpec')
         for ifspec in ifspecs:
             if ifspec.has_key('addr') and ifspec['addr']:
                 ips.append(ifspec['addr']) 
 
-        # resolve component hostnames 
-        nodes = self.shell.GetNodes(self.auth, {}, ['hostname', 'site_id'])
-    
-        # resolve site login_bases
-        site_ids = [node['site_id'] for node in nodes]
-        sites = self.shell.GetSites(self.auth, site_ids, ['site_id', 'login_base'])
-        site_dict = {}
-        for site in sites:
-            site_dict[site['site_id']] = site['login_base']
-
-        # filter nodes according to policy
-        # convert plc names to geni hrn
-        nodedict = {}
-        for node in nodes:
-            node_hrn = self.hostname_to_hrn(site_dict[node['site_id']], node['hostname'])
-            # apply policy. 
-            # Do not allow nodes found in blacklist, only allow nodes found in whitelist
-            if self.policy['whitelist'] and node_hrn not in self.polciy['whitelist']:
-                continue
-            if self.policy['blacklist'] and node_hrn in self.policy['blacklist']:
-                continue
-            nodedict[node_hrn] = node['hostname']
+        # extract nodespecs from rspec to get dns names
+        hostnames = []
+        nodespecs = rspec.getDictsByTagName('NodeSpec')
+        for nodespec in nodespecs:
+            if nodespec.has_key('name') and nodespec['name']:
+                hostnames.append(nodespec['name'])
 
         
 
         
+        node_details = {}
         node_details['rspec'] = rspec.toxml()
         node_details['rspec'] = rspec.toxml()
-        node_details['hrn'] = nodedict.keys()
-        node_details['dns'] = nodedict.values()
         node_details['ip'] = ips
         node_details['ip'] = ips
+        node_details['dns'] = hostnames
         # save state 
         self.nodes = SimpleStorage(self.nodes.db_filename, node_details)
         self.nodes.write()
         # save state 
         self.nodes = SimpleStorage(self.nodes.db_filename, node_details)
         self.nodes.write()
@@ -228,14 +211,16 @@ class Aggregate(GeniServer):
         self.policy.load()
 
 
         self.policy.load()
 
 
-    def getComponents(self, type = 'rspec'):
+    def getNodes(self, format = 'rspec'):
         """
         Return a list of components at this aggregate.
         """
         """
         Return a list of components at this aggregate.
         """
-        valid_types = ['rspec', 'hrn', 'dns', 'ip']
-        if type not in valid_types:
-            raise Exception, "Invalid type specified, must be one of the following: %s" \
-                             % ", ".join(valid_types)
+        valid_formats = ['rspec', 'hrn', 'dns', 'ip']
+        if not format:
+            format = 'rspec'
+        if format not in valid_formats:
+            raise Exception, "Invalid format specified, must be one of the following: %s" \
+                             % ", ".join(valid_formats)
         
         # Reload components list
         now = datetime.datetime.now()
         
         # Reload components list
         now = datetime.datetime.now()
@@ -244,15 +229,17 @@ class Aggregate(GeniServer):
             self.refresh_components()
         elif now < self.threshold and not self.nodes.keys(): 
             self.load_components()
             self.refresh_components()
         elif now < self.threshold and not self.nodes.keys(): 
             self.load_components()
-        return self.nodes.keys()
+        return self.nodes[format]
     
     
-    def getSlices(self, hrn):
+    def getSlices(self):
         """
         Return a list of instnatiated managed by this slice manager.
         """
 
         """
         Return a list of instnatiated managed by this slice manager.
         """
 
-        # XX list only the slices at the specfied hrn
-        return dict(self.slices)
+        slices = self.shell.GetSlices(self.auth, {}, ['name'])
+        slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
+
+        return slice_hrns
  
     def get_rspec(self, hrn, type):
         """
  
     def get_rspec(self, hrn, type):
         """
@@ -263,12 +250,15 @@ class Aggregate(GeniServer):
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
         elif type in ['slice']:
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
         elif type in ['slice']:
-            print hrn
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
         
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
         
+        # Filter out whitelisted nodes
+        public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
+        nodes = filter(public_nodes, nodes)
         # Get all network interfaces
         interface_ids = []
         for node in nodes:
         # Get all network interfaces
         interface_ids = []
         for node in nodes:
@@ -317,12 +307,28 @@ class Aggregate(GeniServer):
         
         return rspec
  
         
         return rspec
  
+    
+    def getTicket(self, hrn, rspec):
+        """
+        Retrieve a ticket. This operation is currently implemented on PLC
+        only (see SFA, engineering decisions); it is not implemented on
+        components.
+
+        @param name name of the slice to retrieve a ticket for
+        @param rspec resource specification dictionary
+        @return the string representation of a ticket object
+        """
+        #self.registry.get_ticket(name, rspec)
+
+        return         
+
 
     def createSlice(self, slice_hrn, rspec, attributes = []):
         """
         Instantiate the specified slice according to whats defined in the rspec.
         """
         
 
     def createSlice(self, slice_hrn, rspec, attributes = []):
         """
         Instantiate the specified slice according to whats defined in the rspec.
         """
         
+        spec = Rspec(rspec)
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
@@ -332,7 +338,7 @@ class Aggregate(GeniServer):
         # Get slice info
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         # Get slice info
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
-        if not slice:
+        if not slices:
             raise RecordNotFound(slice_hrn)
         slice = slices[0]
 
             raise RecordNotFound(slice_hrn)
         slice = slices[0]
 
@@ -341,7 +347,6 @@ class Aggregate(GeniServer):
         hostnames = [node['hostname'] for node in nodes]
 
         # get netspec details
         hostnames = [node['hostname'] for node in nodes]
 
         # get netspec details
-        spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
         nodes = [nodespec['name'] for nodespec in nodespecs]    
        
         nodespecs = spec.getDictsByTagName('NodeSpec')
         nodes = [nodespec['name'] for nodespec in nodespecs]    
        
@@ -352,29 +357,29 @@ class Aggregate(GeniServer):
         self.slices.write()
 
         # remove nodes not in rspec
         self.slices.write()
 
         # remove nodes not in rspec
-        delete_nodes = set(hostnames).difference(nodes)
+        deleted_nodes = list(set(hostnames).difference(nodes))
         # add nodes from rspec
         # add nodes from rspec
-        added_nodes = set(nodes).difference(hostnames)
+        added_nodes = list(set(nodes).difference(hostnames))
     
     
-        shell.AddSliceToNodes(self.auth, slicename, added_nodes)
-        shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
+        self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
+        self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
         for attribute in attributes:
             type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
 
         for attribute in attributes:
             type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
-            shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
+            self.shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
     
         # contact registry to get slice users and add them to the slice
     
         # contact registry to get slice users and add them to the slice
-        slice_record = self.registry.resolve(self.credential, slice_hrn)
+        #slice_record = self.registry.resolve(self.credential, slice_hrn)
         # persons = slice_record['users']
         
         #for person in persons:
         #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
         # persons = slice_record['users']
         
         #for person in persons:
         #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
-    def update_slice(self, slice_hrn, rspec, attributes = []):
+    def updateSlice(self, slice_hrn, rspec, attributes = []):
         return self.create_slice(slice_hrn, rspec, attributes)
          
         return self.create_slice(slice_hrn, rspec, attributes)
          
-    def deleteSlice_(self, slice_hrn):
+    def deleteSlice(self, slice_hrn):
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
@@ -384,12 +389,12 @@ class Aggregate(GeniServer):
             self.slices.write()
 
         slicename = hrn_to_pl_slicename(slice_hrn)
             self.slices.write()
 
         slicename = hrn_to_pl_slicename(slice_hrn)
-        slices = shell.GetSlices(self.auth, [slicename])
-        if not slice:
+        slices = self.shell.GetSlices(self.auth, [slicename])
+        if not slices:
             return 1  
         slice = slices[0]
       
             return 1  
         slice = slices[0]
       
-        shell.DeleteSliceFromNodes(self.auth, slicename, slice['node_ids'])
+        self.shell.DeleteSliceFromNodes(self.auth, slicename, slice['node_ids'])
         return 1
 
     def startSlice(self, slice_hrn):
         return 1
 
     def startSlice(self, slice_hrn):
@@ -443,50 +448,58 @@ class Aggregate(GeniServer):
 ## Server methods here for now
 ##############################
 
 ## Server methods here for now
 ##############################
 
-    def list_components(self):
-        return self.getComponents()
 
 
-    def list_slices(self, cred, hrn):
-        self.decode_authentication(cred, 'list')
-        return self.getSlices(hrn)
+    # XX fix rights, should be function name defined in 
+    # privilege_table (from util/rights.py)
+    def list_nodes(self, cred):
+        self.decode_authentication(cred, 'listnodes')
+        return self.getNodes()
+
+    def list_slices(self, cred):
+        self.decode_authentication(cred, 'listslices')
+        return self.getSlices()
 
     def get_resources(self, cred, hrn):
 
     def get_resources(self, cred, hrn):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'listnodes')
         return self.getResources(hrn)
 
         return self.getResources(hrn)
 
+    def get_ticket(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'getticket')
+        return self.getTicket(hrn, rspec)
     def get_policy(self, cred):
     def get_policy(self, cred):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'getpolicy')
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
-        return self.createSlice(hrn)
+        self.decode_authentication(cred, 'createslice')
+        return self.createSlice(hrn, rspec)
 
     def update_slice(self, cred, hrn, rspec):
 
     def update_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'updateslice')
         return self.updateSlice(hrn)    
 
     def delete_slice(self, cred, hrn):
         return self.updateSlice(hrn)    
 
     def delete_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'deleteslice')
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'startslice')
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'stopslice')
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'resetslice')
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
-        self.server.register_function(self.list_components)
+        self.server.register_function(self.list_nodes)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)