updates
[sfa.git] / geni / aggregate.py
index 46cab18..df0f672 100644 (file)
@@ -4,23 +4,23 @@ import datetime
 import time
 import xmlrpclib
 
 import time
 import xmlrpclib
 
-from geni.util.geniserver import *
+from geni.util.geniserver import GeniServer
 from geni.util.geniclient import *
 from geni.util.geniclient import *
-from geni.util.cert import *
-from geni.util.trustedroot import *
+from geni.util.cert import Keypair, Certificate
+from geni.util.trustedroot import TrustedRootList
 from geni.util.excep import *
 from geni.util.misc import *
 from geni.util.config import Config
 from geni.util.rspec import Rspec
 from geni.util.excep import *
 from geni.util.misc import *
 from geni.util.config import Config
 from geni.util.rspec import Rspec
+from geni.util.specdict import *
+from geni.util.storage import SimpleStorage
 
 class Aggregate(GeniServer):
 
     hrn = None
 
 class Aggregate(GeniServer):
 
     hrn = None
-    components_file = None
-    components_ttl = None
-    components = []
-    whitelist_file = None
-    blacklist_file = None    
+    nodes_ttl = None
+    nodes = {}
+    slices = {} 
     policy = {}
     timestamp = None
     threshold = None    
     policy = {}
     timestamp = None
     threshold = None    
@@ -36,18 +36,26 @@ class Aggregate(GeniServer):
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/util/geni_config"):
     # @param cert_file certificate filename containing public key (could be a GID file)     
 
     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/util/geni_config"):
-        GeniServer.__init__(ip, port, key_file, cert_file)
-        conf = Config(config)
-        basedir = conf.GENI_BASE_DIR + os.sep
-        server_basedir = basedir + os.sep + "plc" + os.sep
-        self.hrn = conf.GENI_INTERFACE_HRN
-        self.components_file = os.sep.join([server_basedir, 'components', hrn + '.comp'])
-        self.whitelist_file = os.sep.join([server_basedir, 'policy', 'whitelist'])
-        self.blacklist_file = os.sep.join([server_basedir, 'policy', 'blacklist'])
-        self.timestamp_file = os.sep.join([server_basedir, 'components', hrn + '.timestamp']) 
-        self.components_ttl = components_ttl
-        self.policy['whitelist'] = []
-        self.policy['blacklist'] = []
+        GeniServer.__init__(self, ip, port, key_file, cert_file)
+        self.conf = Config(config)
+        basedir = self.conf.GENI_BASE_DIR + os.sep
+        server_basedir = basedir + os.sep + "geni" + os.sep
+        self.hrn = self.conf.GENI_INTERFACE_HRN
+        
+        nodes_file = os.sep.join([server_basedir, 'agg.' + self.hrn + '.components'])
+        self.nodes = SimpleStorage(nodes_file)
+       
+        node_slices_file = os.sep.join([server_basedir, 'agg.' + self.hrn + '.slices'])
+        self.slices = SimpleStorage(node_slices_file)
+        self.slices.load()
+        policy_file = os.sep.join([server_basedir, 'policy'])
+        self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
+        
+        timestamp_file = os.sep.join([server_basedir, 'components', self.hrn + '.timestamp']) 
+        self.timestamp = SimpleStorage(timestamp_file)
+
+        self.nodes_ttl = 1
         self.connectPLC()
         self.connectRegistry()
 
         self.connectPLC()
         self.connectRegistry()
 
@@ -62,25 +70,25 @@ class Aggregate(GeniServer):
         Connect to the plc api interface. First attempt to impor thte shell, if that fails
         try to connect to the xmlrpc server.
         """
         Connect to the plc api interface. First attempt to impor thte shell, if that fails
         try to connect to the xmlrpc server.
         """
-        self.auth = {'Username': conf.GENI_PLC_USER,
+        self.auth = {'Username': self.conf.GENI_PLC_USER,
                      'AuthMethod': 'password',
                      'AuthMethod': 'password',
-                     'AuthString': conf.GENI_PLC_PASSWORD}
+                     'AuthString': self.conf.GENI_PLC_PASSWORD}
 
         try:
            # try to import PLC.Shell directly
 
         try:
            # try to import PLC.Shell directly
-            sys.path.append(conf.GENI_PLC_SHELL_PATH) 
+            sys.path.append(self.conf.GENI_PLC_SHELL_PATH) 
             import PLC.Shell
             self.shell = PLC.Shell.Shell(globals())
             self.shell.AuthCheck()
         except ImportError:
             # connect to plc api via xmlrpc
             import PLC.Shell
             self.shell = PLC.Shell.Shell(globals())
             self.shell.AuthCheck()
         except ImportError:
             # connect to plc api via xmlrpc
-            plc_host = conf.GENI_PLC_HOST
-            plc_port = conf.GENI_PLC_PORT
-            plc_api_path = conf.GENI_PLC_API_PATH                 
+            plc_host = self.conf.GENI_PLC_HOST
+            plc_port = self.conf.GENI_PLC_PORT
+            plc_api_path = self.conf.GENI_PLC_API_PATH                 
             url = "https://%(plc_host)s:%(plc_port)s/%(plc_api_path)s/" % locals()
             url = "https://%(plc_host)s:%(plc_port)s/%(plc_api_path)s/" % locals()
-            self.auth = {'Username': conf.GENI_PLC_USER,
+            self.auth = {'Username': self.conf.GENI_PLC_USER,
                  'AuthMethod': 'password',
                  'AuthMethod': 'password',
-                 'AuthString': conf.GENI_PLC_PASSWORD} 
+                 'AuthString': self.conf.GENI_PLC_PASSWORD} 
 
             self.shell = xmlrpclib.Server(url, verbose = 0, allow_none = True) 
             self.shell.AuthCheck(self.auth) 
 
             self.shell = xmlrpclib.Server(url, verbose = 0, allow_none = True) 
             self.shell.AuthCheck(self.auth) 
@@ -114,73 +122,45 @@ class Aggregate(GeniServer):
             site_dict[site['site_id']] = site['login_base']
 
         # convert plc names to geni hrn
             site_dict[site['site_id']] = site['login_base']
 
         # convert plc names to geni hrn
-        self.components = [self.hostname_to_hrn(site_dict[node['site_id']], node['hostname']) for node in nodes]
-
-        # apply policy. Do not allow nodes found in blacklist, only allow nodes found in whitelist
-        whitelist_policy = lambda node: node in self.policy['whitelist']
-        blacklist_policy = lambda node: node not in self.policy['blacklist']
+        nodedict = {}
+        for node in nodes:
+            node_hrn = self.hostname_to_hrn(site_dict[node['site_id']], node['hostname'])
+            # apply policy. 
+            # Do not allow nodes found in blacklist, only allow nodes found in whitelist
+            if self.polciy['whitelist'] and node_hrn not in self.polciy['whitelist']:
+                continue
+            if self.polciy['blacklist'] and node_hrn in self.policy['blacklist']:
+                continue
+            nodedict[node_hrn] = node['hostname']
+        
+        self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
+        self.nodes.write()
 
 
-        if self.policy['blacklist']:
-            self.components = blacklist_policy(self.components)
-        if self.policy['whitelist']:
-            self.components = whitelist_policy(self.components)
-            
         # update timestamp and threshold
         # update timestamp and threshold
-        self.timestamp = datetime.datetime.now()
-        delta = datetime.timedelta(hours=self.components_ttl)
-        self.threshold = self.timestamp + delta 
-    
-        f = open(self.components_file, 'w')
-        f.write(str(self.components))
-        f.close()
-        f = open(self.timestamp_file, 'w')
-        f.write(str(self.threshold))
-        f.close()
+        self.timestamp['timestamp'] =  datetime.datetime.now()
+        delta = datetime.timedelta(hours=self.nodes_ttl)
+        self.threshold = self.timestamp['timestamp'] + delta 
+        self.timestamp.write()        
  
     def load_components(self):
         """
         Read cached list of nodes.
         """
         # Read component list from cached file 
  
     def load_components(self):
         """
         Read cached list of nodes.
         """
         # Read component list from cached file 
-        if os.path.exists(self.components_file):
-            f = open(self.components_file, 'r')
-            self.components = eval(f.read())
-            f.close()
-    
+        self.nodes.load()
+        self.timestamp.load() 
         time_format = "%Y-%m-%d %H:%M:%S"
         time_format = "%Y-%m-%d %H:%M:%S"
-        if os.path.exists(self.timestamp_file):
-            f = open(self.timestamp_file, 'r')
-            timestamp = str(f.read()).split(".")[0]
-            self.timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(timestamp, time_format)))
-            delta = datetime.timedelta(hours=self.components_ttl)
-            self.threshold = self.timestamp + delta
-            f.close()    
+        timestamp = self.timestamp['timestamp']
+        self.timestamp['timestamp'] = datetime.datetime.fromtimestamp(time.mktime(time.strptime(timestamp, time_format)))
+        delta = datetime.timedelta(hours=self.nodes_ttl)
+        self.threshold = self.timestamp['timestamp'] + delta
 
     def load_policy(self):
         """
         Read the list of blacklisted and whitelisted nodes.
         """
 
     def load_policy(self):
         """
         Read the list of blacklisted and whitelisted nodes.
         """
-        whitelist = []
-        blacklist = []
-        if os.path.exists(self.whitelist_file):
-            f = open(self.whitelist_file, 'r')
-            lines = f.readlines()
-            f.close()
-        for line in lines:
-            line = line.strip().replace(" ", "").replace("\n", "")
-            whitelist.extend(line.split(","))
-            
-    
-        if os.path.exists(self.blacklist_file):
-            f = open(self.blacklist_file, 'r')
-            lines = f.readlines()
-            f.close()
-        for line in lines:
-            line = line.strip().replace(" ", "").replace("\n", "")
-            blacklist.extend(line.split(","))
+        self.policy.load()
 
 
-        self.policy['whitelist'] = whitelist
-        self.policy['blacklist'] = blacklist
 
     def get_components(self):
         """
 
     def get_components(self):
         """
@@ -189,29 +169,55 @@ class Aggregate(GeniServer):
         # Reload components list
         now = datetime.datetime.now()
         #self.load_components()
         # Reload components list
         now = datetime.datetime.now()
         #self.load_components()
-        if not self.threshold or not self.timestamp or now > self.threshold:
+        if not self.threshold or not self.timestamp['timestamp'] or now > self.threshold:
             self.refresh_components()
             self.refresh_components()
-        elif now < self.threshold and not self.components
+        elif now < self.threshold and not self.nodes.keys()
             self.load_components()
             self.load_components()
-        return self.components
+        return self.nodes.keys()
      
     def get_rspec(self, hrn, type):
      
     def get_rspec(self, hrn, type):
-        rspec = Rspec()
-        rspec['nodespec'] = {'name': self.conf.GENI_INTERFACE_HRN}
-        rsepc['nodespec']['nodes'] = []
-        if type in ['node']:
+        
+        # Get the required nodes
+        if type in ['aggregate']:
             nodes = self.shell.GetNodes(self.auth)
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
             nodes = self.shell.GetNodes(self.auth)
         elif type in ['slice']:
             slicename = hrn_to_pl_slicename(hrn)
             slices = self.shell.GetSlices(self.auth, [slicename])
             node_ids = slices[0]['node_ids']
             nodes = self.shell.GetNodes(self.auth, node_ids) 
-            for node in nodes:
-                nodespec = {'name': node['hostname'], 'type': 'std'}
-        elif type in ['aggregate']:
-            pass
-
-        return rspec
+        
+        # Get all network interfaces
+        interface_ids = []
+        for node in nodes:
+            interface_ids.extend(node['nodenetwork_ids'])
+        interfaces = self.shell.GetNodeNetworks(self.auth, interface_ids)
+        interface_dict = {}
+        for interface in interfaces:
+            interface_dict[interface['nodenetwork_id']] = interface
+        
+        # join nodes with thier interfaces
+        for node in nodes:
+            node['interfaces'] = []
+            for nodenetwork_id in node['nodenetwork_ids']:
+                node['interfaces'].append(interface_dict[nodenetwork_id])
+
+        # convert and threshold to ints
+        timestamp = self.timestamp['timestamp']
+        start_time = int(self.timestamp['timestamp'].strftime("%s"))
+        end_time = int(self.duration.strftime("%s"))
+        duration = end_time - start_time
+
+        # create the plc dict
+        networks = {'nodes': nodes, 'name': self.hrn, 'start_time': start_time, 'duration': duration} 
+        resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
+
+        # convert the plc dict to an rspec dict
+        resouceDict = RspecDict(resources)
+
+        # convert the rspec dict to xml
+        rspec = Rspec()
+        rspec.parseDict(resourceDict)
+        return rspec.toxml()
 
     def get_resources(self, slice_hrn):
         """
 
     def get_resources(self, slice_hrn):
         """
@@ -227,9 +233,19 @@ class Aggregate(GeniServer):
         Instantiate the specified slice according to whats defined in the rspec.
         """
         slicename = self.hrn_to_plcslicename(slice_hrn)
         Instantiate the specified slice according to whats defined in the rspec.
         """
         slicename = self.hrn_to_plcslicename(slice_hrn)
+        
+        # extract node list from rspec
         spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
         spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
-        nodes = [nodespec['name'] for nodespec in nodespecs]    
+        nodes = [nodespec['name'] for nodespec in nodespecs]
+
+        # save slice state locally
+        # we can assume that spec object has been validated so its safer to 
+        # save this instead of the unvalidated rspec the user gave us
+        self.slices[slice_hrn] = spec.toxml()
+        self.slices.write()
+
+        # add slice to nodes at plc    
         self.shell.AddSliceToNodes(self.auth, slicename, nodes)
         for attribute in attributes:
             type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
         self.shell.AddSliceToNodes(self.auth, slicename, nodes)
         for attribute in attributes:
             type, value, node, nodegroup = attribute['type'], attribute['value'], attribute['node'], attribute['nodegroup']
@@ -265,6 +281,13 @@ class Aggregate(GeniServer):
         spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
         nodes = [nodespec['name'] for nodespec in nodespecs]    
         spec = Rspec(rspec)
         nodespecs = spec.getDictsByTagName('NodeSpec')
         nodes = [nodespec['name'] for nodespec in nodespecs]    
+       
+        # save slice state locally
+        # we can assume that spec object has been validated so its safer to 
+        # save this instead of the unvalidated rspec the user gave us
+        self.slices[slice_hrn] = spec.toxml()
+        self.slices.write()
+
         # remove nodes not in rspec
         delete_nodes = set(hostnames).difference(nodes)
         # add nodes from rspec
         # remove nodes not in rspec
         delete_nodes = set(hostnames).difference(nodes)
         # add nodes from rspec
@@ -282,12 +305,18 @@ class Aggregate(GeniServer):
         # persons = slice_record['users']
         
         #for person in persons:
         # persons = slice_record['users']
         
         #for person in persons:
-        #    shell.AddPersonToSlice(person['email'], slice_name) 
+        #    shell.AddPersonToSlice(person['email'], slice_name)
+
+         
     def delete_slice_(self, slice_hrn):
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
         """
     def delete_slice_(self, slice_hrn):
         """
         Remove this slice from all components it was previouly associated with and 
         free up the resources it was using.
         """
+        if self.slices.has_key(slice_hrn):
+            self.slices.pop(slice_hrn)
+            self.slices.write()
+
         slicename = self.hrn_to_plcslicename(slice_hrn)
         slices = shell.GetSlices(self.auth, [slicename])
         if not slice:
         slicename = self.hrn_to_plcslicename(slice_hrn)
         slices = shell.GetSlices(self.auth, [slicename])
         if not slice:
@@ -346,7 +375,7 @@ class Aggregate(GeniServer):
 ## Server methods here for now
 ##############################
 
 ## Server methods here for now
 ##############################
 
-    def nodes(self):
+    def components(self):
         return self.get_components()
 
     #def slices(self):
         return self.get_components()
 
     #def slices(self):
@@ -358,30 +387,30 @@ class Aggregate(GeniServer):
 
         return self.get_resources(hrn)
 
 
         return self.get_resources(hrn)
 
-    def create(self, cred, hrn, rspec):
+    def createSlice(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
-        return self.create(hrn)
+        return self.create_slice(hrn)
 
 
-    def update(self, cred, hrn, rspec):
+    def updateSlice(self, cred, hrn, rspec):
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
-        return self.update(hrn)    
+        return self.update_slice(hrn)    
 
 
-    def delete(self, cred, hrn):
+    def deleteSlice(self, cred, hrn):
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
         return self.delete_slice(hrn)
 
         self.decode_authentication(cred, 'embed')
         self.verify_object_belongs_to_me(hrn)
         return self.delete_slice(hrn)
 
-    def start(self, cred, hrn):
+    def startSlice(self, cred, hrn):
         self.decode_authentication(cred, 'control')
         self.decode_authentication(cred, 'control')
-        return self.start(hrn)
+        return self.start_slice(hrn)
 
 
-    def stop(self, cred, hrn):
+    def stopSlice(self, cred, hrn):
         self.decode_authentication(cred, 'control')
         return self.stop(hrn)
 
         self.decode_authentication(cred, 'control')
         return self.stop(hrn)
 
-    def reset(self, cred, hrn):
+    def resetSlice(self, cred, hrn):
         self.decode_authentication(cred, 'control')
         return self.reset(hrn)
 
         self.decode_authentication(cred, 'control')
         return self.reset(hrn)
 
@@ -396,10 +425,10 @@ class Aggregate(GeniServer):
         self.server.register_function(self.components)
         #self.server.register_function(self.slices)
         self.server.register_function(self.resources)
         self.server.register_function(self.components)
         #self.server.register_function(self.slices)
         self.server.register_function(self.resources)
-        self.server.register_function(self.create)
-        self.server.register_function(self.delete)
-        self.server.register_function(self.start)
-        self.server.register_function(self.stop)
-        self.server.register_function(self.reset)
+        self.server.register_function(self.createSlice)
+        self.server.register_function(self.deleteSlice)
+        self.server.register_function(self.startSlice)
+        self.server.register_function(self.stopSlice)
+        self.server.register_function(self.resetSlice)
         self.server.register_function(self.policy)
               
         self.server.register_function(self.policy)