Update instructions
[sfa.git] / geni / aggregate.py
index 533b365..702c3bb 100644 (file)
@@ -4,6 +4,7 @@ import datetime
 import time
 import xmlrpclib
 
 import time
 import xmlrpclib
 
+from types import StringTypes, ListType
 from geni.util.geniserver import GeniServer
 from geni.util.geniclient import *
 from geni.util.cert import Keypair, Certificate
 from geni.util.geniserver import GeniServer
 from geni.util.geniclient import *
 from geni.util.cert import Keypair, Certificate
@@ -39,7 +40,7 @@ class Aggregate(GeniServer):
     # @param key_file private key filename of registry
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
     # @param key_file private key filename of registry
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
-    def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
+    def __init__(self, ip, port, key_file, cert_file, config = "/etc/planetlab/geni_config"):
         GeniServer.__init__(self, ip, port, key_file, cert_file)
         self.key_file = key_file
         self.cert_file = cert_file
         GeniServer.__init__(self, ip, port, key_file, cert_file)
         self.key_file = key_file
         self.cert_file = cert_file
@@ -63,6 +64,7 @@ class Aggregate(GeniServer):
         timestamp_file = os.sep.join([self.server_basedir, 'agg.' + self.hrn + '.timestamp']) 
         self.timestamp = SimpleStorage(timestamp_file)
 
         timestamp_file = os.sep.join([self.server_basedir, 'agg.' + self.hrn + '.timestamp']) 
         self.timestamp = SimpleStorage(timestamp_file)
 
+        # How long before we refresh nodes cache
         self.nodes_ttl = 1
 
         self.connectPLC()
         self.nodes_ttl = 1
 
         self.connectPLC()
@@ -76,7 +78,7 @@ class Aggregate(GeniServer):
         # connect to registry using GeniClient
         address = self.config.GENI_REGISTRY_HOSTNAME
         port = self.config.GENI_REGISTRY_PORT
         # connect to registry using GeniClient
         address = self.config.GENI_REGISTRY_HOSTNAME
         port = self.config.GENI_REGISTRY_PORT
-        url = 'https://%(address)s:%(port)s' % locals()
+        url = 'http://%(address)s:%(port)s' % locals()
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
     
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
     
@@ -113,28 +115,29 @@ class Aggregate(GeniServer):
         Attempt to load credential from file if it exists. If it doesnt get 
         credential from registry.
         """ 
         Attempt to load credential from file if it exists. If it doesnt get 
         credential from registry.
         """ 
-
-        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
-        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         
         # see if this file exists
         
         # see if this file exists
+        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         try:
         try:
-            cred = Credential(filename = ma_cred_filename)
-            self.credential = cred.save_to_string()
+            self.credential = Credential(filename = ma_cred_filename)
         except IOError:
             # get self credential
         except IOError:
             # get self credential
-            #self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            #self_credential = Credential(string = self_cred)
-            #self_credential.save_to_file(self_cred_filename)
+            self.credential = self.getCredentialFromRegistry()
 
 
-            # get ma credential
-            #ma_cred = self.registry.get_gredential(self_cred)
-            #ma_credential = Credential(string = ma_cred)
-            #ma_credential.save_to_file(ma_cred_filename)
+    def getCredentialFromRegistry(self):
+        """
+        Get our current credential from the registry
+        """
+        # get self credential
+        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
+        self_cred = self.registry.get_credential(None, 'ma', self.hrn)
+        self_cred.save_to_file(self_cred_filename, save_parents = True)
 
 
-            ma_cred = Certificate(filename = self.cert_file)
-            
-            self.credential = ma_cred.save_to_string()
+        # get ma credential
+        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
+        ma_cred = self.registry.get_credential(self_cred, 'ma', self.hrn)
+        ma_cred.save_to_file(ma_cred_filename, save_parents=True)
+        return ma_cred
 
     def hostname_to_hrn(self, login_base, hostname):
         """
 
     def hostname_to_hrn(self, login_base, hostname):
         """
@@ -153,50 +156,35 @@ class Aggregate(GeniServer):
     def refresh_components(self):
         """
         Update the cached list of nodes and save in 4 differnt formats
     def refresh_components(self):
         """
         Update the cached list of nodes and save in 4 differnt formats
-        (rspec, dns, ip, hrn)
+        (rspec, dns, ip)
         """
 
         """
 
-        node_details = {}
         # get node list in rspec format
         rspec = Rspec()
         rspec.parseString(self.get_rspec(self.hrn, 'aggregate'))
         # get node list in rspec format
         rspec = Rspec()
         rspec.parseString(self.get_rspec(self.hrn, 'aggregate'))
+        
         # filter nodes according to policy
         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
         # filter nodes according to policy
         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
-        # extract ifspec info to get ip's
+        
+        # extract ifspecs from rspec to get ip's
         ips = []
         ifspecs = rspec.getDictsByTagName('IfSpec')
         for ifspec in ifspecs:
             if ifspec.has_key('addr') and ifspec['addr']:
                 ips.append(ifspec['addr']) 
 
         ips = []
         ifspecs = rspec.getDictsByTagName('IfSpec')
         for ifspec in ifspecs:
             if ifspec.has_key('addr') and ifspec['addr']:
                 ips.append(ifspec['addr']) 
 
-        # resolve component hostnames 
-        nodes = self.shell.GetNodes(self.auth, {}, ['hostname', 'site_id'])
-    
-        # resolve site login_bases
-        site_ids = [node['site_id'] for node in nodes]
-        sites = self.shell.GetSites(self.auth, site_ids, ['site_id', 'login_base'])
-        site_dict = {}
-        for site in sites:
-            site_dict[site['site_id']] = site['login_base']
-
-        # filter nodes according to policy
-        # convert plc names to geni hrn
-        nodedict = {}
-        for node in nodes:
-            node_hrn = self.hostname_to_hrn(site_dict[node['site_id']], node['hostname'])
-            # apply policy. 
-            # Do not allow nodes found in blacklist, only allow nodes found in whitelist
-            if self.policy['whitelist'] and node_hrn not in self.polciy['whitelist']:
-                continue
-            if self.policy['blacklist'] and node_hrn in self.policy['blacklist']:
-                continue
-            nodedict[node_hrn] = node['hostname']
+        # extract nodespecs from rspec to get dns names
+        hostnames = []
+        nodespecs = rspec.getDictsByTagName('NodeSpec')
+        for nodespec in nodespecs:
+            if nodespec.has_key('name') and nodespec['name']:
+                hostnames.append(nodespec['name'])
 
         
 
         
+        node_details = {}
         node_details['rspec'] = rspec.toxml()
         node_details['rspec'] = rspec.toxml()
-        node_details['hrn'] = nodedict.keys()
-        node_details['dns'] = nodedict.values()
         node_details['ip'] = ips
         node_details['ip'] = ips
+        node_details['dns'] = hostnames
         # save state 
         self.nodes = SimpleStorage(self.nodes.db_filename, node_details)
         self.nodes.write()
         # save state 
         self.nodes = SimpleStorage(self.nodes.db_filename, node_details)
         self.nodes.write()
@@ -228,14 +216,16 @@ class Aggregate(GeniServer):
         self.policy.load()
 
 
         self.policy.load()
 
 
-    def getComponents(self, type = 'rspec'):
+    def getNodes(self, format = 'rspec'):
         """
         Return a list of components at this aggregate.
         """
         """
         Return a list of components at this aggregate.
         """
-        valid_types = ['rspec', 'hrn', 'dns', 'ip']
-        if type not in valid_types:
-            raise Exception, "Invalid type specified, must be one of the following: %s" \
-                             % ", ".join(valid_types)
+        valid_formats = ['rspec', 'hrn', 'dns', 'ip']
+        if not format:
+            format = 'rspec'
+        if format not in valid_formats:
+            raise Exception, "Invalid format specified, must be one of the following: %s" \
+                             % ", ".join(valid_formats)
         
         # Reload components list
         now = datetime.datetime.now()
         
         # Reload components list
         now = datetime.datetime.now()
@@ -244,15 +234,17 @@ class Aggregate(GeniServer):
             self.refresh_components()
         elif now < self.threshold and not self.nodes.keys(): 
             self.load_components()
             self.refresh_components()
         elif now < self.threshold and not self.nodes.keys(): 
             self.load_components()
-        return self.nodes[type]
+        return self.nodes[format]
     
     
-    def getSlices(self, hrn):
+    def getSlices(self):
         """
         Return a list of instnatiated managed by this slice manager.
         """
 
         """
         Return a list of instnatiated managed by this slice manager.
         """
 
-        # XX list only the slices at the specfied hrn
-        return dict(self.slices)
+        slices = self.shell.GetSlices(self.auth, {}, ['name'])
+        slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
+
+        return slice_hrns
  
     def get_rspec(self, hrn, type):
         """
  
     def get_rspec(self, hrn, type):
         """
@@ -262,12 +254,18 @@ class Aggregate(GeniServer):
         # Get the required nodes
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
         # Get the required nodes
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
+            try:  linkspecs = self.shell.GetLinkSpecs() # if call is supported
+            except:  linkspecs = []
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
         
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
         
+        # Filter out whitelisted nodes
+        public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
+        nodes = filter(public_nodes, nodes)
         # Get all network interfaces
         interface_ids = []
         for node in nodes:
         # Get all network interfaces
         interface_ids = []
         for node in nodes:
@@ -298,7 +296,11 @@ class Aggregate(GeniServer):
         duration = end_time - start_time
 
         # create the plc dict
         duration = end_time - start_time
 
         # create the plc dict
-        networks = [{'nodes': nodes, 'name': self.hrn, 'start_time': start_time, 'duration': duration}] 
+        networks = [{'nodes': nodes, 
+                        'name': self.hrn, 
+                        'start_time': start_time, 
+                        'duration': duration, 
+                        'links': linkspecs}] 
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
@@ -316,12 +318,28 @@ class Aggregate(GeniServer):
         
         return rspec
  
         
         return rspec
  
+    
+    def getTicket(self, hrn, rspec):
+        """
+        Retrieve a ticket. This operation is currently implemented on PLC
+        only (see SFA, engineering decisions); it is not implemented on
+        components.
+
+        @param name name of the slice to retrieve a ticket for
+        @param rspec resource specification dictionary
+        @return the string representation of a ticket object
+        """
+        #self.registry.get_ticket(name, rspec)
+
+        return         
+
 
     def createSlice(self, slice_hrn, rspec, attributes = []):
         """
         Instantiate the specified slice according to whats defined in the rspec.
         """
         
 
     def createSlice(self, slice_hrn, rspec, attributes = []):
         """
         Instantiate the specified slice according to whats defined in the rspec.
         """
         
+        spec = Rspec(rspec)
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
@@ -329,21 +347,50 @@ class Aggregate(GeniServer):
         self.slices.write()
         
         # Get slice info
         self.slices.write()
         
         # Get slice info
+        # if slice doesnt exist add it
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
-        if not slice:
-            raise RecordNotFound(slice_hrn)
-        slice = slices[0]
+        if not slices:
+            parts = slicename.split("_")
+            login_base = parts[0]
+            slice_record = self.registry.resolve(self.cred, slice_hrn)
+            slice_info = slice_record.as_dict()
+            slice = slice_info['pl_info']
+            
+            # if site doesnt exist add it
+            sites = self.shell.GetSites(self.auth, [login_base])
+            if not sites:
+                authority = get_authority(slice_hrn)
+                site_record = self.registry.reolve(self.cred, authority) 
+                site_info = site_record.as_dict()
+                site = site_info['pl_info']
+
+                # add the site
+                site.pop('site_id')
+                site_id = self.shell.AddSite(self.auth, site)
+
+            # add the slice
+            self.shell.AddSlice(self.auth, slice_info)
+            
+            # add the slice users
+            
+        else:    
+            slice = slices[0]
+
 
         # find out where this slice is currently running
 
         # find out where this slice is currently running
-        nodes = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
-        hostnames = [node['hostname'] for node in nodes]
+        nodelist = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
+        hostnames = [node['hostname'] for node in nodelist]
 
         # get netspec details
 
         # get netspec details
-        spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
         nodespecs = spec.getDictsByTagName('NodeSpec')
-        nodes = [nodespec['name'] for nodespec in nodespecs]    
-       
+        nodes = []
+        for nodespec in nodespecs:
+            if isinstance(nodespec['name'], list):
+                nodes.extend(nodespec['name'])
+            elif isinstance(nodespec['name'], StringTypes):
+                nodes.append(nodespec['name'])
+                
         # save slice state locally
         # we can assume that spec object has been validated so its safer to 
         # save this instead of the unvalidated rspec the user gave us
         # save slice state locally
         # we can assume that spec object has been validated so its safer to 
         # save this instead of the unvalidated rspec the user gave us
@@ -351,29 +398,24 @@ class Aggregate(GeniServer):
         self.slices.write()
 
         # remove nodes not in rspec
         self.slices.write()
 
         # remove nodes not in rspec
-        delete_nodes = set(hostnames).difference(nodes)
+        deleted_nodes = list(set(hostnames).difference(nodes))
         # add nodes from rspec
         # add nodes from rspec
-        added_nodes = set(nodes).difference(hostnames)
+        added_nodes = list(set(nodes).difference(hostnames))
     
     
-        shell.AddSliceToNodes(self.auth, slicename, added_nodes)
-        shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
+        self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
+        self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
 
-        for attribute in attributes:
-            type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
-            shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
-    
         # contact registry to get slice users and add them to the slice
         # contact registry to get slice users and add them to the slice
-        slice_record = self.registry.resolve(self.credential, slice_hrn)
+        #slice_record = self.registry.resolve(self.credential, slice_hrn)
         # persons = slice_record['users']
         # persons = slice_record['users']
-        
-        #for person in persons:
+        # for perosn in persons:
         #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
         #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
-    def update_slice(self, slice_hrn, rspec, attributes = []):
+    def updateSlice(self, slice_hrn, rspec, attributes = []):
         return self.create_slice(slice_hrn, rspec, attributes)
          
         return self.create_slice(slice_hrn, rspec, attributes)
          
-    def deleteSlice_(self, slice_hrn):
+    def deleteSlice(self, slice_hrn):
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
@@ -383,12 +425,12 @@ class Aggregate(GeniServer):
             self.slices.write()
 
         slicename = hrn_to_pl_slicename(slice_hrn)
             self.slices.write()
 
         slicename = hrn_to_pl_slicename(slice_hrn)
-        slices = shell.GetSlices(self.auth, [slicename])
-        if not slice:
+        slices = self.shell.GetSlices(self.auth, [slicename])
+        if not slices:
             return 1  
         slice = slices[0]
       
             return 1  
         slice = slices[0]
       
-        shell.DeleteSliceFromNodes(self.auth, slicename, slice['node_ids'])
+        self.shell.DeleteSliceFromNodes(self.auth, slicename, slice['node_ids'])
         return 1
 
     def startSlice(self, slice_hrn):
         return 1
 
     def startSlice(self, slice_hrn):
@@ -442,50 +484,58 @@ class Aggregate(GeniServer):
 ## Server methods here for now
 ##############################
 
 ## Server methods here for now
 ##############################
 
-    def list_components(self):
-        return self.getComponents()
 
 
-    def list_slices(self, cred, hrn):
-        self.decode_authentication(cred, 'list')
-        return self.getSlices(hrn)
+    # XX fix rights, should be function name defined in 
+    # privilege_table (from util/rights.py)
+    def list_nodes(self, cred):
+        self.decode_authentication(cred, 'listnodes')
+        return self.getNodes()
+
+    def list_slices(self, cred):
+        self.decode_authentication(cred, 'listslices')
+        return self.getSlices()
 
     def get_resources(self, cred, hrn):
 
     def get_resources(self, cred, hrn):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'listnodes')
         return self.getResources(hrn)
 
         return self.getResources(hrn)
 
+    def get_ticket(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'getticket')
+        return self.getTicket(hrn, rspec)
     def get_policy(self, cred):
     def get_policy(self, cred):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'getpolicy')
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
-        return self.createSlice(hrn)
+        self.decode_authentication(cred, 'createslice')
+        return self.createSlice(hrn, rspec)
 
     def update_slice(self, cred, hrn, rspec):
 
     def update_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'updateslice')
         return self.updateSlice(hrn)    
 
     def delete_slice(self, cred, hrn):
         return self.updateSlice(hrn)    
 
     def delete_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'deleteslice')
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'startslice')
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'stopslice')
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'resetslice')
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
-        self.server.register_function(self.list_components)
+        self.server.register_function(self.list_nodes)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)