fix bug in getResources
[sfa.git] / geni / slicemgr.py
index 87fcc0f..605daf9 100644 (file)
@@ -10,10 +10,10 @@ from geni.util.credential import Credential
 from geni.util.trustedroot import *
 from geni.util.excep import *
 from geni.util.misc import *
-from geni.util.config import Config
+from geni.util.config import *
 from geni.util.rspec import Rspec
 from geni.util.specdict import *
-from geni.util.storage import SimpleStorage
+from geni.util.storage import SimpleStorage, XmlStorage
 
 class SliceMgr(GeniServer):
 
@@ -39,7 +39,7 @@ class SliceMgr(GeniServer):
     # @param key_file private key filename of registry
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
-    def __init__(self, ip, port, key_file, cert_file, config = os.getcwd() + "/geni/util/geni_config"):
+    def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
         GeniServer.__init__(self, ip, port, key_file, cert_file)
         self.key_file = key_file
         self.cert_file = cert_file
@@ -50,18 +50,22 @@ class SliceMgr(GeniServer):
         self.time_format = "%Y-%m-%d %H:%M:%S"
         
         # Get list of aggregates this sm talks to
-        # XX do we use simplestorage to maintain this file manually?
-        aggregates_file = self.server_basedir + os.sep + 'aggregates'
-        self.aggregates = SimpleStorage(aggregates_file)
+        aggregates_file = self.server_basedir + os.sep + 'aggregates.xml'
+        connection_dict = {'hrn': '', 'addr': '', 'port': ''}
+        self.aggregate_info = XmlStorage(aggregates_file, {'aggregates': {'aggregate': [connection_dict]}} )
+        self.aggregate_info.load()
         
+        # Get cached list of nodes (rspec) 
         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
         self.nodes = SimpleStorage(nodes_file)
         self.nodes.load()
         
+        # Get cacheds slice states
         slices_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.slices'])
         self.slices = SimpleStorage(slices_file)
         self.slices.load()
 
+        # Get the policy
         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
         self.policy.load()
@@ -71,7 +75,7 @@ class SliceMgr(GeniServer):
 
         self.connectRegistry()
         self.loadCredential()
-        self.connectAggregates(aggregates_file)
+        self.connectAggregates()
 
 
     def loadCredential(self):
@@ -85,7 +89,7 @@ class SliceMgr(GeniServer):
         try:
             self.credential = Credential(filename = ma_cred_filename)
         except IOError:
-            self.credential = self.getCrednetialFromRegistry()
+            self.credential = self.getCredentialFromRegistry()
             
         
     def getCredentialFromRegistry(self):
@@ -113,33 +117,23 @@ class SliceMgr(GeniServer):
         url = 'http://%(address)s:%(port)s' % locals()
         self.registry = GeniClient(url, self.key_file, self.cert_file)
 
-    def connectAggregates(self, aggregates_file):
+    def connectAggregates(self):
         """
-        Get info about the aggregates available to us from file and create 
-        an xmlrpc connection to each. If any info is invalid, skip it. 
+        Get connection details for the trusted peer aggregates from file and 
+        create a GeniClient connection to each.      
         """
-        lines = []
-        try:
-            f = open(aggregates_file, 'r')
-            lines = f.readlines()
-            f.close()
-        except: raise 
-        
-        for line in lines:
-            # Skip comments
-            if line.strip().startswith("#"):
-                continue
-            line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
-            agg_info = line.split(" ")
-        
-            # skip invalid info
-            if len(agg_info) != 3:
-                continue
-
-            # create xmlrpc connection using GeniClient
-            hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
-            url = 'http://%(address)s:%(port)s' % locals()
-            self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
+        self.aggregates = {} 
+        aggregates = self.aggregate_info['aggregates']['aggregate']
+        if isinstance(aggregates, dict):
+            aggregates = [aggregates]
+        if isinstance(aggregates, list):
+            for aggregate in aggregates:         
+                # create xmlrpc connection using GeniClient
+                hrn, address, port = aggregate['hrn'], aggregate['addr'], aggregate['port']
+                if not hrn or not address or not port:
+                    continue
+                url = 'http://%(address)s:%(port)s' % locals()
+                self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
 
     def item_hrns(self, items):
         """
@@ -291,7 +285,7 @@ class SliceMgr(GeniServer):
             # check if the slice has resources at this hrn
             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
             rspec.parseString(slice_resources)
-            networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
+            networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
             
         # merge all these rspecs into one
         start_time = int(datetime.datetime.now().strftime("%s"))
@@ -316,34 +310,37 @@ class SliceMgr(GeniServer):
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
-        rspec = Rspec()
+        spec = Rspec()
         tempspec = Rspec()
-        rspec.parseString(rspec)
+        spec.parseString(rspec)
 
-        self.slices[slice_hrn] = rspec.toxml()
+        self.slices[slice_hrn] = spec.toxml()
         self.slices.write()
 
         # extract network list from the rspec and create a separate
         # rspec for each network
-        slicename = self.hrn_to_plcslicename(slice_hrn)
-        specDict = rspec.toDict()
-        start_time = specDict['start_time']
-        end_time = specDict['end_time']
-
+        slicename = hrn_to_pl_slicename(slice_hrn)
+        specDict = spec.toDict()
+        if specDict.has_key('Rspec'): specDict = specDict['Rspec']
+        if specDict.has_key('start_time'): start_time = specDict['start_time']
+        else: start_time = 0
+        if specDict.has_key('end_time'): end_time = specDict['end_time']
+        else: end_time = 0
+   
         rspecs = {}
         # only attempt to extract information about the aggregates we know about
         for hrn in self.aggregates.keys():
             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
             if netspec:
                 # creat a plc dict 
-                resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
+                resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
                 resourceDict = {'Rspec': resources}
                 tempspec.parseDict(resourceDict)
                 rspecs[hrn] = tempspec.toxml()
 
         # notify the aggregates
-        for hrn in self.rspecs.keys():
-            self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
+        for hrn in rspecs.keys():
+            self.aggregates[hrn].create_slice(self.credential, slice_hrn, rspecs[hrn])
             
         return 1
 
@@ -417,9 +414,12 @@ class SliceMgr(GeniServer):
         self.decode_authentication(cred, 'listslices')
         return self.getSlices()
 
-    def get_resources(self, cred, hrn):
+    def get_resources(self, cred, hrn=None):
         self.decode_authentication(cred, 'listnodes')
-        return self.getResources(hrn)
+        if not hrn: 
+            return self.getNodes()
+        else:
+            return self.getResources(hrn)
 
     def get_ticket(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'getticket')
@@ -430,7 +430,7 @@ class SliceMgr(GeniServer):
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'creatslice')
+        self.decode_authentication(cred, 'createslice')
         return self.createSlice(cred, hrn, rspec)
 
     def delete_slice(self, cred, hrn):