list nodes now collapsed into list resources
[sfa.git] / geni / aggregate.py
index 9ec0a6e..0ba4a0f 100644 (file)
@@ -4,6 +4,7 @@ import datetime
 import time
 import xmlrpclib
 
+from types import StringTypes, ListType
 from geni.util.geniserver import GeniServer
 from geni.util.geniclient import *
 from geni.util.cert import Keypair, Certificate
@@ -63,6 +64,7 @@ class Aggregate(GeniServer):
         timestamp_file = os.sep.join([self.server_basedir, 'agg.' + self.hrn + '.timestamp']) 
         self.timestamp = SimpleStorage(timestamp_file)
 
+        # How long before we refresh nodes cache
         self.nodes_ttl = 1
 
         self.connectPLC()
@@ -76,7 +78,7 @@ class Aggregate(GeniServer):
         # connect to registry using GeniClient
         address = self.config.GENI_REGISTRY_HOSTNAME
         port = self.config.GENI_REGISTRY_PORT
-        url = 'https://%(address)s:%(port)s' % locals()
+        url = 'http://%(address)s:%(port)s' % locals()
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
     
@@ -114,24 +116,30 @@ class Aggregate(GeniServer):
         credential from registry.
         """ 
 
-        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
         ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         
         # see if this file exists
         try:
-            cred = Credential(filename = ma_cred_filename, subject=self.hrn)
-            self.credential = cred.save_to_string()
+            self.credential = Credential(filename = ma_cred_filename)
         except IOError:
-            # get self credential
-            self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            self_credential = Credential(string = self_cred)
-            self_credential.save_to_file(self_cred_filename)
+            self.credential = self.getCredentialFromRegistry()
+
+    def getCredentialFromRegistry(self):
+        """
+        Get our current credential from the registry
+        """
+        # get self credential
+        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
+        self_cred = self.registry.get_credential(None, 'ma', self.hrn)
+        self_cred.save_to_file(self_cred_filename, save_parents = True)
+
+        
+        # get ma credential
+        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
+        ma_cred = self.registry.get_credential(self_cred, 'ma', self.hrn)
+        ma_cred.save_to_file(ma_cred_filename, save_parents=True)
+        return ma_cred        
 
-            # get ma credential
-            ma_cred = self.registry.get_credential(self_cred)
-            ma_credential = Credential(string = ma_cred)
-            ma_credential.save_to_file(ma_cred_filename)
-            self.credential = ma_cred
 
     def hostname_to_hrn(self, login_base, hostname):
         """
@@ -144,56 +152,43 @@ class Aggregate(GeniServer):
         """
         Convert hrn to planetlab name.
         """
-        slicename = slicename.replace("_", ".")
-        return ".".join([self.hrn, slicename])
+        parts = slicename.split("_")
+        slice_hrn = ".".join([self.hrn, parts[0]]) + "." + "_".join(parts[1:])
+          
+        return slice_hrn
 
     def refresh_components(self):
         """
         Update the cached list of nodes and save in 4 differnt formats
-        (rspec, dns, ip, hrn)
+        (rspec, dns, ip)
         """
 
-        node_details = {}
         # get node list in rspec format
         rspec = Rspec()
         rspec.parseString(self.get_rspec(self.hrn, 'aggregate'))
+        
         # filter nodes according to policy
         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
-        # extract ifspec info to get ip's
+        
+        # extract ifspecs from rspec to get ip's
         ips = []
         ifspecs = rspec.getDictsByTagName('IfSpec')
         for ifspec in ifspecs:
             if ifspec.has_key('addr') and ifspec['addr']:
                 ips.append(ifspec['addr']) 
 
-        # resolve component hostnames 
-        nodes = self.shell.GetNodes(self.auth, {}, ['hostname', 'site_id'])
-    
-        # resolve site login_bases
-        site_ids = [node['site_id'] for node in nodes]
-        sites = self.shell.GetSites(self.auth, site_ids, ['site_id', 'login_base'])
-        site_dict = {}
-        for site in sites:
-            site_dict[site['site_id']] = site['login_base']
-
-        # filter nodes according to policy
-        # convert plc names to geni hrn
-        nodedict = {}
-        for node in nodes:
-            node_hrn = self.hostname_to_hrn(site_dict[node['site_id']], node['hostname'])
-            # apply policy. 
-            # Do not allow nodes found in blacklist, only allow nodes found in whitelist
-            if self.policy['whitelist'] and node_hrn not in self.polciy['whitelist']:
-                continue
-            if self.policy['blacklist'] and node_hrn in self.policy['blacklist']:
-                continue
-            nodedict[node_hrn] = node['hostname']
+        # extract nodespecs from rspec to get dns names
+        hostnames = []
+        nodespecs = rspec.getDictsByTagName('NodeSpec')
+        for nodespec in nodespecs:
+            if nodespec.has_key('name') and nodespec['name']:
+                hostnames.append(nodespec['name'])
 
         
+        node_details = {}
         node_details['rspec'] = rspec.toxml()
-        node_details['hrn'] = nodedict.keys()
-        node_details['dns'] = nodedict.values()
         node_details['ip'] = ips
+        node_details['dns'] = hostnames
         # save state 
         self.nodes = SimpleStorage(self.nodes.db_filename, node_details)
         self.nodes.write()
@@ -225,14 +220,16 @@ class Aggregate(GeniServer):
         self.policy.load()
 
 
-    def getComponents(self, type = 'rspec'):
+    def getNodes(self, format = 'rspec'):
         """
         Return a list of components at this aggregate.
         """
-        valid_types = ['rspec', 'hrn', 'dns', 'ip']
-        if type not in valid_types:
-            raise Exception, "Invalid type specified, must be one of the following: %s" \
-                             % ", ".join(valid_types)
+        valid_formats = ['rspec', 'hrn', 'dns', 'ip']
+        if not format:
+            format = 'rspec'
+        if format not in valid_formats:
+            raise Exception, "Invalid format specified, must be one of the following: %s" \
+                             % ", ".join(valid_formats)
         
         # Reload components list
         now = datetime.datetime.now()
@@ -241,15 +238,17 @@ class Aggregate(GeniServer):
             self.refresh_components()
         elif now < self.threshold and not self.nodes.keys(): 
             self.load_components()
-        return self.nodes[type]
+        return self.nodes[format]
     
-    def getSlices(self, hrn):
+    def getSlices(self):
         """
         Return a list of instnatiated managed by this slice manager.
         """
 
-        # XX list only the slices at the specfied hrn
-        return dict(self.slices)
+        slices = self.shell.GetSlices(self.auth, {}, ['name'])
+        slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
+        
+        return slice_hrns
  
     def get_rspec(self, hrn, type):
         """
@@ -259,12 +258,18 @@ class Aggregate(GeniServer):
         # Get the required nodes
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
+            try:  linkspecs = self.shell.GetLinkSpecs() # if call is supported
+            except:  linkspecs = []
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
         
+        # Filter out whitelisted nodes
+        public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
+        nodes = filter(public_nodes, nodes)
         # Get all network interfaces
         interface_ids = []
         for node in nodes:
@@ -295,7 +300,11 @@ class Aggregate(GeniServer):
         duration = end_time - start_time
 
         # create the plc dict
-        networks = [{'nodes': nodes, 'name': self.hrn, 'start_time': start_time, 'duration': duration}] 
+        networks = [{'nodes': nodes,
+                     'links': linkspecs, 
+                     'name': self.hrn, 
+                     'start_time': start_time, 
+                     'duration': duration}] 
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
@@ -313,6 +322,21 @@ class Aggregate(GeniServer):
         
         return rspec
  
+    
+    def getTicket(self, hrn, rspec):
+        """
+        Retrieve a ticket. This operation is currently implemented on PLC
+        only (see SFA, engineering decisions); it is not implemented on
+        components.
+
+        @param name name of the slice to retrieve a ticket for
+        @param rspec resource specification dictionary
+        @return the string representation of a ticket object
+        """
+        #self.registry.get_ticket(name, rspec)
+
+        return         
+
 
     def createSlice(self, slice_hrn, rspec, attributes = []):
         """
@@ -330,17 +354,70 @@ class Aggregate(GeniServer):
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         if not slices:
-            raise RecordNotFound(slice_hrn)
-        slice = slices[0]
+            parts = slicename.split("_")
+            login_base = parts[0]
+            slice_record = self.registry.resolve(self.cred, slice_hrn)
+            slice_info = slice_record.as_dict()
+            slice = slice_info['pl_info']
+
+            # if site doesnt exist add it
+            sites = self.shell.GetSites(self.auth, [login_base]) 
+            if not sites:
+                authority = get_authority(slice_hrn)
+                site_record = self.registry.reolve(self.cred, authority)
+                site_info = site_record.as_dict()
+                site = site_info['pl_info'] 
+                
+                # add the site
+                site.pop('site_id') 
+                site_id = self.shell.AddSite(self.auth, site)
+            else:
+                site = sites[0]
+                
+            self.shell.AddSlice(self.auth, slice_info)
+        else:
+            slice = slices[0]
 
+        
+        # get the list of valid slice users from the registry and make 
+        # they are added to the slice 
+        slice_records = self.registry.resolve(self.credential, slice_hrn)
+        if not slice_records:
+            raise Error, "record for %s not found" % slice_hrn
+        slice_record = slice_records[0]
+        slice_record_dict = slice_record.as_dict()
+        geni_info = slice_record_dict['geni_info']
+        researchers = geni_info['researcher']
+        for researcher in researchers:
+            person_records = self.registry.resolve(self.credential, researcher)
+            if not person_records:
+                pass
+            person_record = person_records[0]
+            person_dict = person_record.as_dict()['plc_info']
+            persons = self.shell.GetPersons(self.auth, [person_dict['email']], ['person_id', 'key_ids'])
+            
+            # Create the person record 
+            if not persons:
+                self.shell.AddPerson(self.auth, person_dict)
+            self.shell.AddPersonToSlice(self.auth, person_dict['email'], login_base)
+            # Add this person's public keys
+            for personkey in person_dict['keys']:
+                key = {'type': 'ssh', 'key': personkey}      
+                self.shellAddPersonKey(self.auth, person_dict['email'], key)
         # find out where this slice is currently running
-        nodes = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
-        hostnames = [node['hostname'] for node in nodes]
+        nodelist = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
+        hostnames = [node['hostname'] for node in nodelist]
 
         # get netspec details
         nodespecs = spec.getDictsByTagName('NodeSpec')
-        nodes = [nodespec['name'] for nodespec in nodespecs]    
-       
+        nodes = []
+        for nodespec in nodespecs:
+            if isinstance(nodespec['name'], list):
+                nodes.extend(nodespec['name'])
+            elif isinstance(nodespec['name'], StringTypes):
+                nodes.append(nodespec['name'])
+                
         # save slice state locally
         # we can assume that spec object has been validated so its safer to 
         # save this instead of the unvalidated rspec the user gave us
@@ -355,16 +432,6 @@ class Aggregate(GeniServer):
         self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
         self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
-        for attribute in attributes:
-            type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
-            self.shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
-    
-        # contact registry to get slice users and add them to the slice
-        slice_record = self.registry.resolve(self.credential, slice_hrn)
-        # persons = slice_record['users']
-        
-        #for person in persons:
-        #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
     def updateSlice(self, slice_hrn, rspec, attributes = []):
@@ -439,50 +506,61 @@ class Aggregate(GeniServer):
 ## Server methods here for now
 ##############################
 
-    def list_components(self):
-        return self.getComponents()
 
-    def list_slices(self, cred, hrn):
-        self.decode_authentication(cred, 'list')
-        return self.getSlices(hrn)
+    # XX fix rights, should be function name defined in 
+    # privilege_table (from util/rights.py)
+    def list_nodes(self, cred):
+        self.decode_authentication(cred, 'listnodes')
+        return self.getNodes()
 
-    def get_resources(self, cred, hrn):
-        self.decode_authentication(cred, 'info')
-        return self.getResources(hrn)
+    def list_slices(self, cred):
+        self.decode_authentication(cred, 'listslices')
+        return self.getSlices()
 
+    def get_resources(self, cred, hrn = None):
+        self.decode_authentication(cred, 'listnodes')
+        if not hrn: 
+            return self.getNodes()
+        else: 
+            return self.getResources(hrn)
+
+    def get_ticket(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'getticket')
+        return self.getTicket(hrn, rspec)
     def get_policy(self, cred):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'getpolicy')
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'createslice')
         return self.createSlice(hrn, rspec)
 
     def update_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'updateslice')
         return self.updateSlice(hrn)    
 
     def delete_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'deleteslice')
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'startslice')
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'stopslice')
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'resetslice')
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
-        self.server.register_function(self.list_components)
+        self.server.register_function(self.list_nodes)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)