fix bugs in refresh_components, getResources
authorTony Mack <tmack@cs.princeton.edu>
Tue, 10 Mar 2009 14:27:31 +0000 (14:27 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Tue, 10 Mar 2009 14:27:31 +0000 (14:27 +0000)
geni/slicemgr.py

index a0c8726..95b6497 100644 (file)
@@ -47,7 +47,8 @@ class SliceMgr(GeniServer):
         self.basedir = self.config.GENI_BASE_DIR + os.sep
         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
         self.hrn = self.config.GENI_INTERFACE_HRN    
-
+        self.time_format = "%Y-%m-%d %H:%M:%S"
+        
         # Get list of aggregates this sm talks to
         # XX do we use simplestorage to maintain this file manually?
         aggregates_file = self.server_basedir + os.sep + 'aggregates'
@@ -62,12 +63,9 @@ class SliceMgr(GeniServer):
         self.slices.load()
 
         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
-        self.policy = SimpleStorage(policy_file)
+        self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
         self.policy.load()
 
-        timestamp_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.timestamp'])
-        self.timestamp = SimpleStorage(timestamp_file)
-        
         # How long before we refresh nodes cache  
         self.nodes_ttl = 1
 
@@ -135,7 +133,6 @@ class SliceMgr(GeniServer):
             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
             url = 'http://%(address)s:%(port)s' % locals()
             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
-            self.aggregates[hrn].list_nodes(self.credential)
 
     def item_hrns(self, items):
         """
@@ -174,14 +171,18 @@ class SliceMgr(GeniServer):
         Update the cached list of nodes.
         """
     
-         # convert and threshold to ints
-        if self.timestamp.has_key('timestamp') and self.timestamp['timestamp']:
-            timestamp = self.timestamp['timestamp']
-            threshold = self.threshold
+        # convert and threshold to ints
+        if self.nodes.has_key('timestamp') and self.nodes['timestamp']:
+            hr_timestamp = self.nodes['timestamp']
+            timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.time_format)))
+            hr_threshold = self.nodes['threshold']
+            threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.time_format)))
         else:
             timestamp = datetime.datetime.now()
+            hr_timestamp = timestamp.strftime(self.time_format)
             delta = datetime.timedelta(hours=self.nodes_ttl)
             threshold = timestamp + delta
+            hr_threshold = threshold.strftime(self.time_format)
 
         start_time = int(timestamp.strftime("%s"))
         end_time = int(threshold.strftime("%s"))
@@ -189,66 +190,44 @@ class SliceMgr(GeniServer):
 
         aggregates = self.aggregates.keys()
         rspecs = {}
+        networks = []
+        rspec = Rspec()
         for aggregate in aggregates:
             try:
                 # get the rspec from the aggregate
                 agg_server = self.aggregates[aggregate]
-                nodes = self.aggregates[aggregate].list_nodes(self.credential)
-                rspecs[aggregate] = nodes
+                agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
                 
-                # XX apply policy whitelist, balcklist here
+                # extract the netspec from each aggregates rspec
+                rspec.parseString(agg_rspec)
+                networks.extend(rspec.getDictsByTagName('NetSpec'))
             except:
                 # XX print out to some error log
                 print "Error calling list nodes at aggregate %s" % aggregate
                 raise    
    
-        # extract the netspec from each aggregates rspec
-        networks = []
-        for rs in rspecs:
-            r = Rspec()
-            r.parseString(rspecs[rs])
-            networks.extend(r.getDictsByTagName('NetSpec'))
-        
-        # create the plc dict
+        # create the rspec dict
         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
-        
-        # convert plc dict to rspec dict
-        resourceDict = RspecDict(resources)
-        
+        resourceDict = {'Rspec': resources} 
         # convert rspec dict to xml
-        rspec = Rspec()
         rspec.parseDict(resourceDict)
-        
-        #for node in all_nodes:
-        #    if self.polciy['whitelist'] and node not in self.polciy['whitelist']:
-        #        continue
-        #    if self.polciy['blacklist'] and node in self.policy['blacklist']:
-        #        continue
-        #    nodedict[node] = node
-
-        nodedict = {'rspec': rspec.toxml()}
-        self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
-        self.nodes.write()
+       
+        # filter according to policy
+        rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
 
         # update timestamp and threshold
-        self.timestamp['timestamp'] = datetime.datetime.now()
+        timestamp = datetime.datetime.now()
+        hr_timestamp = timestamp.strftime(self.time_format)
         delta = datetime.timedelta(hours=self.nodes_ttl)
-        self.threshold = self.timestamp['timestamp'] + delta
-        self.timestamp.write()
+        threshold = timestamp + delta
+        hr_threshold = threshold.strftime(self.time_format)
         
-    def load_components(self):
-        """
-        Read cached list of nodes.
-        """
-        # Read component list from cached file 
-        self.nodes.load()
-        self.timestamp.load()
-        time_format = "%Y-%m-%d %H:%M:%S"
-        timestamp = self.timestamp['timestamp']
-        self.timestamp['timestamp'] = datetime.datetime.fromtimestamp(time.mktime(time.strptime(timestamp, time_format)))
-        delta = datetime.timedelta(hours=self.nodes_ttl)
-        self.threshold = self.timestamp['timestamp'] + delta
+        nodedict = {'rspec': rspec.toxml(),
+                    'timestamp': hr_timestamp,
+                    'threshold':  hr_threshold}
+
+        self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
+        self.nodes.write()
 
     def load_policy(self):
         """
@@ -269,11 +248,8 @@ class SliceMgr(GeniServer):
         """
         # Reload components list
         now = datetime.datetime.now()
-        #self.load_components()
-        if not self.threshold or not self.timestamp or now > self.threshold:
+        if not self.nodes.has_key('threshold') or not self.nodes['threshold'] or not self.nodes.has_key('timestamp') or not self.nodes['timestamp'] or now > self.nodes['threshold']:
             self.refresh_components()
-        elif now < self.threshold and not self.nodes: 
-            self.load_components()
         return self.nodes[format]
    
      
@@ -281,8 +257,17 @@ class SliceMgr(GeniServer):
         """
         Return a list of instnatiated managed by this slice manager.
         """
-        # XX return only the slices at the specified hrn
-        return dict(self.slices)
+        slice_hrns = []
+        for aggregate in self.aggregates:
+            try:
+                slices = self.aggregates[aggregate].list_slices(self.credential)
+                slice_hrns.extend(slices)
+            except:
+                raise
+                # print to some error log
+                pass
+
+        return slice_hrns
 
     def getResources(self, slice_hrn):
         """
@@ -291,78 +276,69 @@ class SliceMgr(GeniServer):
 
         if slice_hrn in self.slices.keys():
             # check if we alreay have this slices state saved
-            rspec = self.slices[slice_hrn]
+            return  self.slices[slice_hrn]
         else:
-            # request this slices state from all  known aggregates
+            # request this slices state from all known aggregates
+            rspec = Rspec()
             rspecdicts = []
+            networks = []
             for hrn in self.aggregates.keys():
-                # XX need to use the right credentials for this call
                 # check if the slice has resources at this hrn
-                tempresources = self.aggregates[hrn].resources(self.credential, slice_hrn)
-                temprspec = Rspec()
-                temprspec.parseString(temprspec)
-                if temprspec.getDictsByTagName('NodeSpec'):
-                    # append this rspec to the list of rspecs
-                    rspecdicts.append(temprspec.toDict())
-                
+                slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
+                rspec.parseString(slice_resources)
+                networks.extend(rspec.getDictsByTagName('NetSpec'))
+            
             # merge all these rspecs into one
-            start_time = int(self.timestamp['timestamp'].strftime("%s"))
-            end_time = int(self.duration.strftime("%s"))
+            start_time = int(datetime.datetime.now().strftime("%s"))
+            end_time = start_time
             duration = end_time - start_time
-                
-            # create a plc dict 
-            networks = [rspecdict['networks'][0] for rspecdict in rspecdicts]
+    
             resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
-            # convert the plc dict to an rspec dict
-            resourceDict = RspecDict(resources)
-            resourceSpec = Rspec()
-            resourceSpec.parseDict(resourceDict)
-            rspec = resourceSpec.toxml() 
+            resourceDict = {'Rspec': resources}
+            # convert rspec dict to xml
+            rspec.parseDict(resourceDict)
             # save this slices resources
-            self.slices[slice_hrn] = rspec
+            self.slices[slice_hrn] = rspec.toxml()
             self.slices.write()
          
-        return rspec
+            return rspec.toxml()
  
     def createSlice(self, slice_hrn, rspec, attributes):
         """
         Instantiate the specified slice according to whats defined in the rspec.
         """
-        # XX need to gget the correct credentials
-        cred = self.credential
 
         # save slice state locally
         # we can assume that spec object has been validated so its safer to
         # save this instead of the unvalidated rspec the user gave us
-        self.slices[slice_hrn] = spec.toxml()
+        rspec = Rspec()
+        tempspec = Rspec()
+        rspec.parseString(rspec)
+
+        self.slices[slice_hrn] = rspec.toxml()
         self.slices.write()
 
         # extract network list from the rspec and create a separate
         # rspec for each network
         slicename = self.hrn_to_plcslicename(slice_hrn)
-        spec = Rspec()
-        spec.parseString(rspec)
-        specDict = spec.toDict()
+        specDict = rspec.toDict()
         start_time = specDict['start_time']
         end_time = specDict['end_time']
 
         rspecs = {}
         # only attempt to extract information about the aggregates we know about
         for hrn in self.aggregates.keys():
-            netspec = spec.getDictByTagNameValue('NetSpec', 'hrn')
+            netspec = spec.getDictByTagNameValue('NetSpec', hrn)
             if netspec:
                 # creat a plc dict 
-                tempdict = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
-                #convert the plc dict to rpsec dict
-                resourceDict = RspecDict(tempdict)
-                # parse rspec dict
-                tempspec = Rspec()
+                resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
+                resourceDict = {'Rspec': resources}
                 tempspec.parseDict(resourceDict)
                 rspecs[hrn] = tempspec.toxml()
 
         # notify the aggregates
         for hrn in self.rspecs.keys():
-            self.aggregates[hrn].createSlice(cred, rspecs[hrn])
+            self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
             
         return 1
 
@@ -432,9 +408,9 @@ class SliceMgr(GeniServer):
         self.decode_authentication(cred, 'listnodes')
         return self.getNodes()
 
-    def list_slices(self, cred, hrn):
+    def list_slices(self, cred):
         self.decode_authentication(cred, 'listslices')
-        return self.getSlices(hrn)
+        return self.getSlices()
 
     def get_resources(self, cred, hrn):
         self.decode_authentication(cred, 'listnodes')