fix bug in getResources
[sfa.git] / geni / slicemgr.py
index da49039..605daf9 100644 (file)
@@ -10,10 +10,10 @@ from geni.util.credential import Credential
 from geni.util.trustedroot import *
 from geni.util.excep import *
 from geni.util.misc import *
-from geni.util.config import Config
+from geni.util.config import *
 from geni.util.rspec import Rspec
 from geni.util.specdict import *
-from geni.util.storage import SimpleStorage
+from geni.util.storage import SimpleStorage, XmlStorage
 
 class SliceMgr(GeniServer):
 
@@ -50,18 +50,22 @@ class SliceMgr(GeniServer):
         self.time_format = "%Y-%m-%d %H:%M:%S"
         
         # Get list of aggregates this sm talks to
-        # XX do we use simplestorage to maintain this file manually?
-        aggregates_file = self.server_basedir + os.sep + 'aggregates'
-        self.aggregates = SimpleStorage(aggregates_file)
+        aggregates_file = self.server_basedir + os.sep + 'aggregates.xml'
+        connection_dict = {'hrn': '', 'addr': '', 'port': ''}
+        self.aggregate_info = XmlStorage(aggregates_file, {'aggregates': {'aggregate': [connection_dict]}} )
+        self.aggregate_info.load()
         
+        # Get cached list of nodes (rspec) 
         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
         self.nodes = SimpleStorage(nodes_file)
         self.nodes.load()
         
-        slices_file = os.sep.join([self.server_basedir, 'smgr' + self.hrn + '.slices'])
+        # Get cacheds slice states
+        slices_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.slices'])
         self.slices = SimpleStorage(slices_file)
         self.slices.load()
 
+        # Get the policy
         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
         self.policy.load()
@@ -71,7 +75,7 @@ class SliceMgr(GeniServer):
 
         self.connectRegistry()
         self.loadCredential()
-        self.connectAggregates(aggregates_file)
+        self.connectAggregates()
 
 
     def loadCredential(self):
@@ -79,22 +83,29 @@ class SliceMgr(GeniServer):
         Attempt to load credential from file if it exists. If it doesnt get
         credential from registry.
         """
-
-        self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
-        ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
         
         # see if this file exists
+        ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
         try:
             self.credential = Credential(filename = ma_cred_filename)
         except IOError:
-            # get self credential
-            self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            self_cred.save_to_file(self_cred_filename, save_parents=True)
+            self.credential = self.getCredentialFromRegistry()
+            
+        
+    def getCredentialFromRegistry(self):
+        """
+        Get our current credential from the registry.
+        """
+        # get self credential
+        self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
+        self_cred = self.registry.get_credential(None, 'ma', self.hrn)
+        self_cred.save_to_file(self_cred_filename, save_parents=True)
 
-            # get ma credential
-            ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
-            ma_cred.save_to_file(ma_cred_filename, save_parents=True)
-            self.credential = ma_cred
+        # get ma credential
+        ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
+        ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
+        ma_cred.save_to_file(ma_cred_filename, save_parents=True)
+        return ma_cred        
 
     def connectRegistry(self):
         """
@@ -106,33 +117,23 @@ class SliceMgr(GeniServer):
         url = 'http://%(address)s:%(port)s' % locals()
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
-    def connectAggregates(self, aggregates_file):
+    def connectAggregates(self):
         """
-        Get info about the aggregates available to us from file and create 
-        an xmlrpc connection to each. If any info is invalid, skip it. 
+        Get connection details for the trusted peer aggregates from file and 
+        create a GeniClient connection to each.      
         """
-        lines = []
-        try:
-            f = open(aggregates_file, 'r')
-            lines = f.readlines()
-            f.close()
-        except: raise 
-        
-        for line in lines:
-            # Skip comments
-            if line.strip().startswith("#"):
-                continue
-            line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
-            agg_info = line.split(" ")
-        
-            # skip invalid info
-            if len(agg_info) != 3:
-                continue
-
-            # create xmlrpc connection using GeniClient
-            hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
-            url = 'http://%(address)s:%(port)s' % locals()
-            self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
+        self.aggregates = {} 
+        aggregates = self.aggregate_info['aggregates']['aggregate']
+        if isinstance(aggregates, dict):
+            aggregates = [aggregates]
+        if isinstance(aggregates, list):
+            for aggregate in aggregates:         
+                # create xmlrpc connection using GeniClient
+                hrn, address, port = aggregate['hrn'], aggregate['addr'], aggregate['port']
+                if not hrn or not address or not port:
+                    continue
+                url = 'http://%(address)s:%(port)s' % locals()
+                self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
 
     def item_hrns(self, items):
         """
@@ -195,17 +196,15 @@ class SliceMgr(GeniServer):
         for aggregate in aggregates:
             try:
                 # get the rspec from the aggregate
-                agg_server = self.aggregates[aggregate]
                 agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
-                
                 # extract the netspec from each aggregates rspec
                 rspec.parseString(agg_rspec)
-                networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
+                networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
             except:
                 # XX print out to some error log
                 print "Error calling list nodes at aggregate %s" % aggregate
                 raise    
-   
+  
         # create the rspec dict
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
         resourceDict = {'Rspec': resources} 
@@ -286,7 +285,7 @@ class SliceMgr(GeniServer):
             # check if the slice has resources at this hrn
             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
             rspec.parseString(slice_resources)
-            networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
+            networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
             
         # merge all these rspecs into one
         start_time = int(datetime.datetime.now().strftime("%s"))
@@ -311,34 +310,37 @@ class SliceMgr(GeniServer):
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
-        rspec = Rspec()
+        spec = Rspec()
         tempspec = Rspec()
-        rspec.parseString(rspec)
+        spec.parseString(rspec)
 
-        self.slices[slice_hrn] = rspec.toxml()
+        self.slices[slice_hrn] = spec.toxml()
         self.slices.write()
 
         # extract network list from the rspec and create a separate
         # rspec for each network
-        slicename = self.hrn_to_plcslicename(slice_hrn)
-        specDict = rspec.toDict()
-        start_time = specDict['start_time']
-        end_time = specDict['end_time']
-
+        slicename = hrn_to_pl_slicename(slice_hrn)
+        specDict = spec.toDict()
+        if specDict.has_key('Rspec'): specDict = specDict['Rspec']
+        if specDict.has_key('start_time'): start_time = specDict['start_time']
+        else: start_time = 0
+        if specDict.has_key('end_time'): end_time = specDict['end_time']
+        else: end_time = 0
+   
         rspecs = {}
         # only attempt to extract information about the aggregates we know about
         for hrn in self.aggregates.keys():
             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
             if netspec:
                 # creat a plc dict 
-                resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
+                resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
                 resourceDict = {'Rspec': resources}
                 tempspec.parseDict(resourceDict)
                 rspecs[hrn] = tempspec.toxml()
 
         # notify the aggregates
-        for hrn in self.rspecs.keys():
-            self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
+        for hrn in rspecs.keys():
+            self.aggregates[hrn].create_slice(self.credential, slice_hrn, rspecs[hrn])
             
         return 1
 
@@ -412,9 +414,12 @@ class SliceMgr(GeniServer):
         self.decode_authentication(cred, 'listslices')
         return self.getSlices()
 
-    def get_resources(self, cred, hrn):
+    def get_resources(self, cred, hrn=None):
         self.decode_authentication(cred, 'listnodes')
-        return self.getResources(hrn)
+        if not hrn: 
+            return self.getNodes()
+        else:
+            return self.getResources(hrn)
 
     def get_ticket(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'getticket')
@@ -425,7 +430,7 @@ class SliceMgr(GeniServer):
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'creatslice')
+        self.decode_authentication(cred, 'createslice')
         return self.createSlice(cred, hrn, rspec)
 
     def delete_slice(self, cred, hrn):