check for the correct rights. fixed many other bugs
authorTony Mack <tmack@cs.princeton.edu>
Wed, 4 Mar 2009 17:43:38 +0000 (17:43 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Wed, 4 Mar 2009 17:43:38 +0000 (17:43 +0000)
geni/slicemgr.py

index ed25063..a0c8726 100644 (file)
@@ -40,39 +40,40 @@ class SliceMgr(GeniServer):
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
-        GeniServer.__init__(ip, port, key_file, cert_file)
+        GeniServer.__init__(self, ip, port, key_file, cert_file)
         self.key_file = key_file
         self.cert_file = cert_file
         self.key_file = key_file
         self.cert_file = cert_file
-        self.conf = Config(config)
-        basedir = self.conf.GENI_BASE_DIR + os.sep
-        server_basedir = basedir + os.sep + "geni" + os.sep
-        self.hrn = conf.GENI_INTERFACE_HRN    
+        self.config = Config(config)
+        self.basedir = self.config.GENI_BASE_DIR + os.sep
+        self.server_basedir = self.basedir + os.sep + "geni" + os.sep
+        self.hrn = self.config.GENI_INTERFACE_HRN    
 
         # Get list of aggregates this sm talks to
         # XX do we use simplestorage to maintain this file manually?
 
         # Get list of aggregates this sm talks to
         # XX do we use simplestorage to maintain this file manually?
-        aggregates_file = server_basedir + os.sep + 'aggregates'
+        aggregates_file = self.server_basedir + os.sep + 'aggregates'
         self.aggregates = SimpleStorage(aggregates_file)
         self.aggregates = SimpleStorage(aggregates_file)
-        self.connect_aggregates(aggregates_file) 
         
         
-        nodes_file = os.sep.join([server_basedir, 'smgr.' + self.hrn + '.components'])
+        nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
         self.nodes = SimpleStorage(nodes_file)
         self.nodes.load()
         
         self.nodes = SimpleStorage(nodes_file)
         self.nodes.load()
         
-        slices_file = os.sep.join([server_basedir, 'smgr' + self.hrn + '.slices'])
+        slices_file = os.sep.join([self.server_basedir, 'smgr' + self.hrn + '.slices'])
         self.slices = SimpleStorage(slices_file)
         self.slices.load()
 
         self.slices = SimpleStorage(slices_file)
         self.slices.load()
 
-        policy_file = os.sep.join([server_basedir, 'smgr.' + self.hrn + '.policy'])
+        policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
         self.policy = SimpleStorage(policy_file)
         self.policy.load()
 
         self.policy = SimpleStorage(policy_file)
         self.policy.load()
 
-        timestamp_file = os.sep.join([server_basedir, 'smgr.' + self.hrn + '.timestamp'])
+        timestamp_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.timestamp'])
         self.timestamp = SimpleStorage(timestamp_file)
         self.timestamp = SimpleStorage(timestamp_file)
+        
+        # How long before we refresh nodes cache  
         self.nodes_ttl = 1
         self.nodes_ttl = 1
-        self.connectAggregates()
+
         self.connectRegistry()
         self.loadCredential()
         self.connectRegistry()
         self.loadCredential()
+        self.connectAggregates(aggregates_file)
 
 
     def loadCredential(self):
 
 
     def loadCredential(self):
@@ -83,24 +84,31 @@ class SliceMgr(GeniServer):
 
         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
 
         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
-
+        
         # see if this file exists
         try:
         # see if this file exists
         try:
-            cred = Credential(filename = ma_cred_filename)
-            self.credential = cred.save_to_string()
+            self.credential = Credential(filename = ma_cred_filename)
         except IOError:
             # get self credential
             self_cred = self.registry.get_credential(None, 'ma', self.hrn)
         except IOError:
             # get self credential
             self_cred = self.registry.get_credential(None, 'ma', self.hrn)
-            self_credential = Credential(string = self_cred)
-            self_credential.save_to_file(self_cred_filename)
+            self_cred.save_to_file(self_cred_filename, save_parents=True)
 
             # get ma credential
 
             # get ma credential
-            ma_cred = self.registry.get_gredential(self_cred)
-            ma_credential = Credential(string = ma_cred)
-            ma_credential.save_to_file(ma_cred_filename)
-            self.credential = ma_cred        
+            ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
+            ma_cred.save_to_file(ma_cred_filename, save_parents=True)
+            self.credential = ma_cred
+
+    def connectRegistry(self):
+        """
+        Connect to the registry
+        """
+        # connect to registry using GeniClient
+        address = self.config.GENI_REGISTRY_HOSTNAME
+        port = self.config.GENI_REGISTRY_PORT
+        url = 'http://%(address)s:%(port)s' % locals()
+        self.registry = GeniClient(url, self.key_file, self.cert_file)
 
 
-    def connect_aggregates(self, aggregates_file):
+    def connectAggregates(self, aggregates_file):
         """
         Get info about the aggregates available to us from file and create 
         an xmlrpc connection to each. If any info is invalid, skip it. 
         """
         Get info about the aggregates available to us from file and create 
         an xmlrpc connection to each. If any info is invalid, skip it. 
@@ -111,12 +119,13 @@ class SliceMgr(GeniServer):
             lines = f.readlines()
             f.close()
         except: raise 
             lines = f.readlines()
             f.close()
         except: raise 
-    
+        
         for line in lines:
             # Skip comments
         for line in lines:
             # Skip comments
-            if line.strip.startswith("#"):
+            if line.strip().startswith("#"):
                 continue
                 continue
-            agg_info = line.split("\t").split(" ")
+            line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
+            agg_info = line.split(" ")
         
             # skip invalid info
             if len(agg_info) != 3:
         
             # skip invalid info
             if len(agg_info) != 3:
@@ -124,9 +133,9 @@ class SliceMgr(GeniServer):
 
             # create xmlrpc connection using GeniClient
             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
 
             # create xmlrpc connection using GeniClient
             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
-            url = 'https://%(address)s:%(port)s' % locals()
+            url = 'http://%(address)s:%(port)s' % locals()
             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
-
+            self.aggregates[hrn].list_nodes(self.credential)
 
     def item_hrns(self, items):
         """
 
     def item_hrns(self, items):
         """
@@ -164,43 +173,74 @@ class SliceMgr(GeniServer):
         """
         Update the cached list of nodes.
         """
         """
         Update the cached list of nodes.
         """
-        print "refreshing"
     
     
+         # convert and threshold to ints
+        if self.timestamp.has_key('timestamp') and self.timestamp['timestamp']:
+            timestamp = self.timestamp['timestamp']
+            threshold = self.threshold
+        else:
+            timestamp = datetime.datetime.now()
+            delta = datetime.timedelta(hours=self.nodes_ttl)
+            threshold = timestamp + delta
+
+        start_time = int(timestamp.strftime("%s"))
+        end_time = int(threshold.strftime("%s"))
+        duration = end_time - start_time
+
         aggregates = self.aggregates.keys()
         aggregates = self.aggregates.keys()
-        all_nodes = []
-        nodedict = {}
+        rspecs = {}
         for aggregate in aggregates:
             try:
         for aggregate in aggregates:
             try:
-                # resolve components hostnames
-                nodes = self.aggregates[aggregate].get_components()
-                all_nodes.extend(nodes)    
+                # get the rspec from the aggregate
+                agg_server = self.aggregates[aggregate]
+                nodes = self.aggregates[aggregate].list_nodes(self.credential)
+                rspecs[aggregate] = nodes
+                
+                # XX apply policy whitelist, balcklist here
             except:
                 # XX print out to some error log
             except:
                 # XX print out to some error log
-                pass    
+                print "Error calling list nodes at aggregate %s" % aggregate
+                raise    
    
    
-        for node in all_nodes:
-            if self.polciy['whitelist'] and node not in self.polciy['whitelist']:
-                continue
-            if self.polciy['blacklist'] and node in self.policy['blacklist']:
-                continue
-
-            nodedict[node] = node
-
-        self.nodes = SimpleStorate(self.nodes.db_filename, nodedict)
+        # extract the netspec from each aggregates rspec
+        networks = []
+        for rs in rspecs:
+            r = Rspec()
+            r.parseString(rspecs[rs])
+            networks.extend(r.getDictsByTagName('NetSpec'))
+        
+        # create the plc dict
+        resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
+        
+        # convert plc dict to rspec dict
+        resourceDict = RspecDict(resources)
+        
+        # convert rspec dict to xml
+        rspec = Rspec()
+        rspec.parseDict(resourceDict)
+        
+        #for node in all_nodes:
+        #    if self.polciy['whitelist'] and node not in self.polciy['whitelist']:
+        #        continue
+        #    if self.polciy['blacklist'] and node in self.policy['blacklist']:
+        #        continue
+        #    nodedict[node] = node
+
+        nodedict = {'rspec': rspec.toxml()}
+        self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
         self.nodes.write()
 
         # update timestamp and threshold
         self.timestamp['timestamp'] = datetime.datetime.now()
         self.nodes.write()
 
         # update timestamp and threshold
         self.timestamp['timestamp'] = datetime.datetime.now()
-        delta = datetime.timedelta(hours=self.nodes_tt1)
+        delta = datetime.timedelta(hours=self.nodes_ttl)
         self.threshold = self.timestamp['timestamp'] + delta
         self.timestamp.write()
         
  
     def load_components(self):
         """
         self.threshold = self.timestamp['timestamp'] + delta
         self.timestamp.write()
         
  
     def load_components(self):
         """
-        Read cached list of nodes and slices.
+        Read cached list of nodes.
         """
         """
-        print "loading nodes"
         # Read component list from cached file 
         self.nodes.load()
         self.timestamp.load()
         # Read component list from cached file 
         self.nodes.load()
         self.timestamp.load()
@@ -220,11 +260,10 @@ class SliceMgr(GeniServer):
         """
         Read current slice instantiation states.
         """
         """
         Read current slice instantiation states.
         """
-        print "loading slices"
         self.slices.load()
 
 
         self.slices.load()
 
 
-    def getComponents(self):
+    def getNodes(self, format = 'rspec'):
         """
         Return a list of components managed by this slice manager.
         """
         """
         Return a list of components managed by this slice manager.
         """
@@ -233,9 +272,9 @@ class SliceMgr(GeniServer):
         #self.load_components()
         if not self.threshold or not self.timestamp or now > self.threshold:
             self.refresh_components()
         #self.load_components()
         if not self.threshold or not self.timestamp or now > self.threshold:
             self.refresh_components()
-        elif now < self.threshold and not self.components: 
+        elif now < self.threshold and not self.nodes: 
             self.load_components()
             self.load_components()
-        return self.nodes.keys()
+        return self.nodes[format]
    
      
     def getSlices(self):
    
      
     def getSlices(self):
@@ -249,7 +288,6 @@ class SliceMgr(GeniServer):
         """
         Return the current rspec for the specified slice.
         """
         """
         Return the current rspec for the specified slice.
         """
-        cred = self.credential
 
         if slice_hrn in self.slices.keys():
             # check if we alreay have this slices state saved
 
         if slice_hrn in self.slices.keys():
             # check if we alreay have this slices state saved
@@ -260,7 +298,7 @@ class SliceMgr(GeniServer):
             for hrn in self.aggregates.keys():
                 # XX need to use the right credentials for this call
                 # check if the slice has resources at this hrn
             for hrn in self.aggregates.keys():
                 # XX need to use the right credentials for this call
                 # check if the slice has resources at this hrn
-                tempresources = self.aggregates[hrn].resources(cred, slice_hrn)
+                tempresources = self.aggregates[hrn].resources(self.credential, slice_hrn)
                 temprspec = Rspec()
                 temprspec.parseString(temprspec)
                 if temprspec.getDictsByTagName('NodeSpec'):
                 temprspec = Rspec()
                 temprspec.parseString(temprspec)
                 if temprspec.getDictsByTagName('NodeSpec'):
@@ -390,46 +428,51 @@ class SliceMgr(GeniServer):
 ## Server methods here for now
 ##############################
 
 ## Server methods here for now
 ##############################
 
-    def list_components(self):
-        return self.getComponents()
+    def list_nodes(self, cred):
+        self.decode_authentication(cred, 'listnodes')
+        return self.getNodes()
 
     def list_slices(self, cred, hrn):
 
     def list_slices(self, cred, hrn):
-        self.decode_authentication(cred, 'list')
+        self.decode_authentication(cred, 'listslices')
         return self.getSlices(hrn)
 
     def get_resources(self, cred, hrn):
         return self.getSlices(hrn)
 
     def get_resources(self, cred, hrn):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'listnodes')
         return self.getResources(hrn)
 
         return self.getResources(hrn)
 
+    def get_ticket(self, cred, hrn, rspec):
+        self.decode_authentication(cred, 'getticket')
+        return self.getTicket(hrn, rspec)
+
     def get_policy(self, cred):
     def get_policy(self, cred):
-        self.decode_authentication(cred, 'info')
+        self.decode_authentication(cred, 'getpolicy')
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
         return self.getPolicy()
 
     def create_slice(self, cred, hrn, rspec):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'creatslice')
         return self.createSlice(hrn)
 
     def delete_slice(self, cred, hrn):
         return self.createSlice(hrn)
 
     def delete_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'embed')
+        self.decode_authentication(cred, 'deleteslice')
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
         return self.deleteSlice(hrn)
 
     def start_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'startslice')
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
         return self.startSlice(hrn)
 
     def stop_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'stopslice')
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
         return self.stopSlice(hrn)
 
     def reset_slice(self, cred, hrn):
-        self.decode_authentication(cred, 'control')
+        self.decode_authentication(cred, 'resetslice')
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
         return self.resetSlice(hrn)
 
     def register_functions(self):
         GeniServer.register_functions(self)
 
         # Aggregate interface methods
-        self.server.register_function(self.list_components)
+        self.server.register_function(self.list_nodes)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)
         self.server.register_function(self.list_slices)
         self.server.register_function(self.get_resources)
         self.server.register_function(self.get_policy)