fix bug in get_rspec when trying to get the rspec for a slice that doesnt exist
[sfa.git] / geni / aggregate.py
index 343f1eb..4b9f8a3 100644 (file)
@@ -115,13 +115,13 @@ class Aggregate(GeniServer):
         Attempt to load credential from file if it exists. If it doesnt get 
         credential from registry.
         """ 
         Attempt to load credential from file if it exists. If it doesnt get 
         credential from registry.
         """ 
+
+        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         
         # see if this file exists
         
         # see if this file exists
-        ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         try:
             self.credential = Credential(filename = ma_cred_filename)
         except IOError:
         try:
             self.credential = Credential(filename = ma_cred_filename)
         except IOError:
-            # get self credential
             self.credential = self.getCredentialFromRegistry()
 
     def getCredentialFromRegistry(self):
             self.credential = self.getCredentialFromRegistry()
 
     def getCredentialFromRegistry(self):
@@ -133,11 +133,13 @@ class Aggregate(GeniServer):
         self_cred = self.registry.get_credential(None, 'ma', self.hrn)
         self_cred.save_to_file(self_cred_filename, save_parents = True)
 
         self_cred = self.registry.get_credential(None, 'ma', self.hrn)
         self_cred.save_to_file(self_cred_filename, save_parents = True)
 
+        
         # get ma credential
         ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         ma_cred = self.registry.get_credential(self_cred, 'ma', self.hrn)
         ma_cred.save_to_file(ma_cred_filename, save_parents=True)
         # get ma credential
         ma_cred_filename = self.server_basedir + os.sep + "agg." + self.hrn + ".ma.cred"
         ma_cred = self.registry.get_credential(self_cred, 'ma', self.hrn)
         ma_cred.save_to_file(ma_cred_filename, save_parents=True)
-        return ma_cred
+        return ma_cred        
+
 
     def hostname_to_hrn(self, login_base, hostname):
         """
 
     def hostname_to_hrn(self, login_base, hostname):
         """
@@ -150,8 +152,10 @@ class Aggregate(GeniServer):
         """
         Convert hrn to planetlab name.
         """
         """
         Convert hrn to planetlab name.
         """
-        slicename = slicename.replace("_", ".")
-        return ".".join([self.hrn, slicename])
+        parts = slicename.split("_")
+        slice_hrn = ".".join([self.hrn, parts[0]]) + "." + "_".join(parts[1:])
+          
+        return slice_hrn
 
     def refresh_components(self):
         """
 
     def refresh_components(self):
         """
@@ -243,7 +247,7 @@ class Aggregate(GeniServer):
 
         slices = self.shell.GetSlices(self.auth, {}, ['name'])
         slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
 
         slices = self.shell.GetSlices(self.auth, {}, ['name'])
         slice_hrns = [self.slicename_to_hrn(slice['name']) for slice in slices]  
-
+        
         return slice_hrns
  
     def get_rspec(self, hrn, type):
         return slice_hrns
  
     def get_rspec(self, hrn, type):
@@ -259,8 +263,12 @@ class Aggregate(GeniServer):
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
-            node_ids = slices[0]['node_ids']
-            nodes = self.shell.GetNodes(self.auth, node_ids) 
+            if not slices:
+                nodes = []
+            else:
+                slice = slices[0]     
+                node_ids = slice['node_ids']
+                nodes = self.shell.GetNodes(self.auth, node_ids) 
         
         # Filter out whitelisted nodes
         public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
         
         # Filter out whitelisted nodes
         public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
@@ -296,11 +304,12 @@ class Aggregate(GeniServer):
         duration = end_time - start_time
 
         # create the plc dict
         duration = end_time - start_time
 
         # create the plc dict
-        networks = [{'nodes': nodes, 
-                        'name': self.hrn, 
-                        'start_time': start_time, 
-                        'duration': duration, 
-                        'links': linkspecs}] 
+        networks = [{'nodes': nodes,
+                     'name': self.hrn, 
+                     'start_time': start_time, 
+                     'duration': duration}]
+        if type in ['aggregate']:
+            networks[0]['links'] = linkspecs 
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
 
         # convert the plc dict to an rspec dict
@@ -345,39 +354,79 @@ class Aggregate(GeniServer):
         # save this instead of the unvalidated rspec the user gave us
         self.slices[slice_hrn] = spec.toxml()
         self.slices.write()
         # save this instead of the unvalidated rspec the user gave us
         self.slices[slice_hrn] = spec.toxml()
         self.slices.write()
-        
-        # Get slice info
-        # if slice doesnt exist add it
+       
+        # Get the slice record from geni
+        slice = {}
+        records = self.registry.resolve(self.credential, slice_hrn)
+            
+        for record in records:
+            if record.get_type() in ['slice']:
+                slice_info = record.as_dict()
+                slice = slice_info['pl_info']
+        if not slice:
+            raise RecordNotFound(slice_hrn)
+                    
+        # Make sure slice exists at plc, if it doesnt add it
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         if not slices:
             parts = slicename.split("_")
             login_base = parts[0]
         slicename = hrn_to_pl_slicename(slice_hrn)
         slices = self.shell.GetSlices(self.auth, [slicename], ['node_ids'])
         if not slices:
             parts = slicename.split("_")
             login_base = parts[0]
-            slice_record = self.registry.resolve(self.cred, slice_hrn)
-            slice_info = slice_record.as_dict()
-            slice = slice_info['pl_info']
-            
             # if site doesnt exist add it
             # if site doesnt exist add it
-            sites = self.shell.GetSites(self.auth, [login_base])
+            sites = self.shell.GetSites(self.auth, [login_base]) 
             if not sites:
                 authority = get_authority(slice_hrn)
             if not sites:
                 authority = get_authority(slice_hrn)
-                site_record = self.registry.reolve(self.cred, authority) 
+                site_records = self.registry.resolve(self.credential, authority)
+                site_record = {}
+                if not site_records:
+                    raise RecordNotFound(authority)
+                site_record = site_records[0]     
                 site_info = site_record.as_dict()
                 site_info = site_record.as_dict()
-                site = site_info['pl_info']
-
+                site = site_info['pl_info'] 
+                
                 # add the site
                 # add the site
-                site.pop('site_id')
+                site.pop('site_id') 
                 site_id = self.shell.AddSite(self.auth, site)
                 site_id = self.shell.AddSite(self.auth, site)
-
-            # add the slice
-            self.shell.AddSlice(self.auth, slice_info)
+            else:
+                site = sites[0]
+                
+            self.shell.AddSlice(self.auth, slice)
+        
+        # get the list of valid slice users from the registry and make 
+        # they are added to the slice 
+        geni_info = slice_info['geni_info']
+        researchers = geni_info['researcher']
+        for researcher in researchers:
+            person_record = {}
+            person_records = self.registry.resolve(self.credential, researcher)
+            for record in person_records:
+                if record.get_type() in ['user']:
+                    person_record = record
+            if not person_record:
+                pass
+            person_dict = person_record.as_dict()['pl_info']
+            persons = self.shell.GetPersons(self.auth, [person_dict['email']], ['person_id', 'key_ids'])
             
             
-            # add the slice users
+            # Create the person record 
+            if not persons:
+                self.shell.AddPerson(self.auth, person_dict)
+                key_ids = []
+            else:
+                key_ids = persons[0]['key_ids']
+           
+            self.shell.AddPersonToSlice(self.auth, person_dict['email'], slicename)
             
             
-        else:    
-            slice = slices[0]
-
-
+            # Get this users local keys
+            keylist = self.shell.GetKeys(self.auth, key_ids, ['key'])
+            keys = [key['key'] for key in keylist]
+            
+            # add keys that arent already there 
+            for personkey in person_dict['keys']:
+                if personkey not in keys:
+                    key = {'key_type': 'ssh', 'key': personkey}      
+                    self.shell.AddPersonKey(self.auth, person_dict['email'], key)
         # find out where this slice is currently running
         nodelist = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
         hostnames = [node['hostname'] for node in nodelist]
         # find out where this slice is currently running
         nodelist = self.shell.GetNodes(self.auth, slice['node_ids'], ['hostname'])
         hostnames = [node['hostname'] for node in nodelist]
@@ -405,11 +454,6 @@ class Aggregate(GeniServer):
         self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
         self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
         self.shell.AddSliceToNodes(self.auth, slicename, added_nodes)
         self.shell.DeleteSliceFromNodes(self.auth, slicename, deleted_nodes)
 
-        # contact registry to get slice users and add them to the slice
-        #slice_record = self.registry.resolve(self.credential, slice_hrn)
-        # persons = slice_record['users']
-        # for perosn in persons:
-        #    shell.AddPersonToSlice(person['email'], slice_name)
         return 1
 
     def updateSlice(self, slice_hrn, rspec, attributes = []):
         return 1
 
     def updateSlice(self, slice_hrn, rspec, attributes = []):
@@ -495,9 +539,12 @@ class Aggregate(GeniServer):
         self.decode_authentication(cred, 'listslices')
         return self.getSlices()
 
         self.decode_authentication(cred, 'listslices')
         return self.getSlices()
 
-    def get_resources(self, cred, hrn):
+    def get_resources(self, cred, hrn = None):
         self.decode_authentication(cred, 'listnodes')
         self.decode_authentication(cred, 'listnodes')
-        return self.getResources(hrn)
+        if not hrn: 
+            return self.getNodes()
+        else: 
+            return self.getResources(hrn)
 
     def get_ticket(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'getticket')
 
     def get_ticket(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'getticket')