added sfa.pdf
[sfa.git] / geni / aggregate.py
index 6c3d930..a473345 100644 (file)
@@ -4,8 +4,9 @@ import datetime
 import time
 import xmlrpclib
 
+from types import StringTypes, ListType
 from geni.util.geniserver import GeniServer
-from geni.util.geniclient import *
+from geni.util.geniclient import GeniClient
 from geni.util.cert import Keypair, Certificate
 from geni.util.credential import Credential
 from geni.util.trustedroot import TrustedRootList
@@ -41,6 +42,7 @@ class Aggregate(GeniServer):
 
     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
         GeniServer.__init__(self, ip, port, key_file, cert_file)
+        self.server.interface = 'aggregate'
         self.key_file = key_file
         self.cert_file = cert_file
         self.config = Config(config)
@@ -67,8 +69,8 @@ class Aggregate(GeniServer):
         self.nodes_ttl = 1
 
         self.connectPLC()
-        #self.connectRegistry()
-        #self.loadCredential()
+        self.connectRegistry()
+        self.loadCredential()
 
     def connectRegistry(self):
         """
@@ -77,7 +79,7 @@ class Aggregate(GeniServer):
         # connect to registry using GeniClient
         address = self.config.GENI_REGISTRY_HOSTNAME
         port = self.config.GENI_REGISTRY_PORT
-        url = 'https://%(address)s:%(port)s' % locals()
+        url = 'http://%(address)s:%(port)s' % locals()
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
     
@@ -115,24 +117,30 @@ class Aggregate(GeniServer):
         credential from registry.
         """ 
 
-        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
         ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         
         # see if this file exists
         try:
-            cred = Credential(filename = ma_cred_filename, subject=self.hrn)
-            self.credential = cred.save_to_string()
+            self.credential = Credential(filename = ma_cred_filename)
         except IOError:
-            # get self credential
-            self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            self_credential = Credential(string = self_cred)
-            self_credential.save_to_file(self_cred_filename)
+            self.credential = self.getCredentialFromRegistry()
+
+    def getCredentialFromRegistry(self):
+        """
+        Get our current credential from the registry
+        """
+        # get self credential
+        self_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".cred"
+        self_cred = self.registry.get_credential(None, 'ma', self.hrn)
+        self_cred.save_to_file(self_cred_filename, save_parents = True)
+
+        
+        # get ma credential
+        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
+        ma_cred = self.registry.get_credential(self_cred, 'ma', self.hrn)
+        ma_cred.save_to_file(ma_cred_filename, save_parents=True)
+        return ma_cred        
 
-            # get ma credential
-            ma_cred = self.registry.get_credential(self_cred)
-            ma_credential = Credential(string = ma_cred)
-            ma_credential.save_to_file(ma_cred_filename)
-            self.credential = ma_cred
 
     def hostname_to_hrn(self, login_base, hostname):
         """
@@ -145,8 +153,10 @@ class Aggregate(GeniServer):
         """
         Convert hrn to planetlab name.
         """
-        slicename = slicename.replace("_", ".")
-        return ".".join([self.hrn, slicename])
+        parts = slicename.split("_")
+        slice_hrn = ".".join([self.hrn, parts[0]]) + "." + "_".join(parts[1:])
+          
+        return slice_hrn
 
     def refresh_components(self):
         """
@@ -238,7 +248,7 @@ class Aggregate(GeniServer):
 
         slices = self.shell.GetSlices(self.auth, {}, ['name'])
         slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
-
+        
         return slice_hrns
  
     def get_rspec(self, hrn, type):
@@ -249,11 +259,17 @@ class Aggregate(GeniServer):
         # Get the required nodes
         if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
+            try:  linkspecs = self.shell.GetLinkSpecs() # if call is supported
+            except:  linkspecs = []
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
-            node_ids = slices[0]['node_ids']
-            nodes = self.shell.GetNodes(self.auth, node_ids) 
+            if not slices:
+                nodes = []
+            else:
+                slice = slices[0]     
+                node_ids = slice['node_ids']
+                nodes = self.shell.GetNodes(self.auth, node_ids) 
         
         # Filter out whitelisted nodes
         public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
@@ -289,7 +305,12 @@ class Aggregate(GeniServer):
         duration = end_time - start_time
 
         # create the plc dict
-        networks = [{'nodes': nodes, 'name': self.hrn, 'start_time': start_time, 'duration': duration}] 
+        networks = [{'nodes': nodes,
+                     'name': self.hrn, 
+                     'start_time': start_time, 
+                     'duration': duration}]
+        if type in ['aggregate']:
+            networks[0]['links'] = linkspecs 
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
@@ -334,22 +355,92 @@ class Aggregate(GeniServer):
         # save this instead of the unvalidated rspec the user gave us
         self.slices[slice_hrn] = spec.toxml()
         self.slices.write()
-        
-        # Get slice info
+       
+        # Get the slice record from geni
+        slice = {}
+        records = self.registry.resolve(self.credential, slice_hrn)
+            
+        for record in records:
+            if record.get_type() in ['slice']:
+                slice_info = record.as_dict()
+                slice = slice_info['pl_info']
+        if not slice:
+            raise RecordNotFound(slice_hrn)
+                    
+        # Make sure slice exists at plc, if it doesnt add it
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         if not slices:
-            raise RecordNotFound(slice_hrn)
-        slice = slices[0]
-
+            parts = slicename.split("_")
+            login_base = parts[0]
+            # if site doesnt exist add it
+            sites = self.shell.GetSites(self.auth, [login_base]) 
+            if not sites:
+                authority = get_authority(slice_hrn)
+                site_records = self.registry.resolve(self.credential, authority)
+                site_record = {}
+                if not site_records:
+                    raise RecordNotFound(authority)
+                site_record = site_records[0]     
+                site_info = site_record.as_dict()
+                site = site_info['pl_info'] 
+                
+                # add the site
+                site.pop('site_id') 
+                site_id = self.shell.AddSite(self.auth, site)
+            else:
+                site = sites[0]
+                
+            self.shell.AddSlice(self.auth, slice)
+        
+        # get the list of valid slice users from the registry and make 
+        # they are added to the slice 
+        geni_info = slice_info['geni_info']
+        researchers = geni_info['researcher']
+        for researcher in researchers:
+            person_record = {}
+            person_records = self.registry.resolve(self.credential, researcher)
+            for record in person_records:
+                if record.get_type() in ['user']:
+                    person_record = record
+            if not person_record:
+                pass
+            person_dict = person_record.as_dict()['pl_info']
+            persons = self.shell.GetPersons(self.auth, [person_dict['email']], ['person_id', 'key_ids'])
+            
+            # Create the person record 
+            if not persons:
+                self.shell.AddPerson(self.auth, person_dict)
+                key_ids = []
+            else:
+                key_ids = persons[0]['key_ids']
+           
+            self.shell.AddPersonToSlice(self.auth, person_dict['email'], slicename)
+            
+            # Get this users local keys
+            keylist = self.shell.GetKeys(self.auth, key_ids, ['key'])
+            keys = [key['key'] for key in keylist]
+            
+            # add keys that arent already there 
+            for personkey in person_dict['keys']:
+                if personkey not in keys:
+                    key = {'key_type': 'ssh', 'key': personkey}      
+                    self.shell.AddPersonKey(self.auth, person_dict['email'], key)
         # find out where this slice is currently running
-        nodes = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
-        hostnames = [node['hostname'] for node in nodes]
+        nodelist = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
+        hostnames = [node['hostname'] for node in nodelist]
 
         # get netspec details
         nodespecs = spec.getDictsByTagName('NodeSpec')
-        nodes = [nodespec['name'] for nodespec in nodespecs]    
-       
+        nodes = []
+        for nodespec in nodespecs:
+            if isinstance(nodespec['name'], list):
+                nodes.extend(nodespec['name'])
+            elif isinstance(nodespec['name'], StringTypes):
+                nodes.append(nodespec['name'])
+                
         # save slice state locally
         # we can assume that spec object has been validated so its safer to 
         # save this instead of the unvalidated rspec the user gave us
@@ -364,16 +455,6 @@ class Aggregate(GeniServer):
         self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
         self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
-        for attribute in attributes:
-            type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
-            self.shell.AddSliceAttribute(self.auth, slicename, type, value, node, nodegroup)
-    
-        # contact registry to get slice users and add them to the slice
-        #slice_record = self.registry.resolve(self.credential, slice_hrn)
-        # persons = slice_record['users']
-        
-        #for person in persons:
-        #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
     def updateSlice(self, slice_hrn, rspec, attributes = []):
@@ -459,9 +540,12 @@ class Aggregate(GeniServer):
         self.decode_authentication(cred, 'listslices')
         return self.getSlices()
 
-    def get_resources(self, cred, hrn):
+    def get_resources(self, cred, hrn = None):
         self.decode_authentication(cred, 'listnodes')
-        return self.getResources(hrn)
+        if not hrn: 
+            return self.getNodes()
+        else: 
+            return self.getResources(hrn)
 
     def get_ticket(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'getticket')
@@ -509,4 +593,46 @@ class Aggregate(GeniServer):
         self.server.register_function(self.start_slice)
         self.server.register_function(self.stop_slice)
         self.server.register_function(self.reset_slice)
-              
+
+
+
+
+class Aggregates(dict):
+    
+    def __init__(self, api):
+        dict.__init__(self, {})
+        self.api = api
+        aggregates_file = self.api.server_basedir + os.sep + 'aggregates.xml'
+        connection_dict = {'hrn': '', 'addr': '', 'port': ''}
+        self.aggregate_info = XmlStorage(aggregates_file, {'aggregates': {'aggregate': [connection_dict]}})
+        self.aggregate_info.load()
+        self.connectAggregates()
+
+
+    def connectAggregates(self):
+        """
+        Get connection details for the trusted peer aggregates from file and 
+        create an GeniClient connection to each. 
+        """
+        required_fields = ['hrn', 'addr', 'port']
+        aggregates = self.aggregate_info['aggregates']['aggregate']
+        if isinstance(aggregates, dict):
+            aggregates = [aggregates]
+        if isinstance(aggregates, list):
+            for aggregate in aggregates:
+                # create xmlrpc connection using GeniClient
+                if not set(required_fields).issubset(aggregate.keys()):
+                    continue
+                hrn, address, port = aggregate['hrn'], aggregate['addr'], aggregate['port']
+                if not hrn or not address or not port:
+                    continue
+                url = 'http://%(address)s:%(port)s' % locals()
+                self[hrn] = GeniClient(url, self.api.key_file, self.api.cert_file)
+
+        # set up a connection to the local registry
+        # connect to registry using GeniClient
+        address = self.api.config.GENI_AGGREGATE_HOSTNAME
+        port = self.api.config.GENI_AGGREGATE_PORT
+        url = 'http://%(address)s:%(port)s' % locals()
+        self[self.api.hrn] = GeniClient(url, self.api.key_file, self.api.cert_file)
+