cleanup the component server (no more flavour=plcm, use pl) - needs
[sfa.git] / sfa / util / rspecHelper.py
index e629a86..deaa746 100755 (executable)
 #! /usr/bin/env python
 
 import sys
+
+from copy import deepcopy
 from lxml import etree
 from StringIO import StringIO
 from optparse import OptionParser
 
+from sfa.util.faults import InvalidRSpec
+from sfa.util.sfalogging import logger
+
+def merge_rspecs(rspecs):
+    """
+    Merge merge a list of RSpecs into 1 RSpec, and return the result.
+    rspecs must be a valid RSpec string or list of RSpec strings.
+    """
+    if not rspecs or not isinstance(rspecs, list):
+        return rspecs
+
+    # ugly hack to avoid sending the same info twice, when the call graph has dags
+    known_networks={}
+    def register_network (network):
+        try:
+            known_networks[network.get('name')]=True
+        except:
+            logger.error("merge_rspecs: cannot register network with no name in rspec")
+            pass
+    def is_registered_network (network):
+        try:
+            return network.get('name') in known_networks
+        except:
+            logger.error("merge_rspecs: cannot retrieve network with no name in rspec")
+            return False
+
+    # the resulting tree
+    rspec = None
+    for input_rspec in rspecs:
+        # ignore empty strings as returned with used call_ids
+        if not input_rspec: continue
+        try:
+            tree = etree.parse(StringIO(input_rspec))
+        except etree.XMLSyntaxError:
+            # consider failing silently here
+            logger.log_exc("merge_rspecs, parse error")
+            message = str(sys.exc_info()[1]) + ' with ' + input_rspec
+            raise InvalidRSpec(message)
+
+        root = tree.getroot()
+        if not root.get("type") in ["SFA"]:
+            logger.error("merge_rspecs: unexpected type for rspec root, %s"%root.get('type'))
+            continue
+        if rspec == None:
+            # we scan the first input, register all networks
+            # in addition we remove duplicates - needed until everyone runs 1.0-10
+            rspec = root
+            for network in root.iterfind("./network"):
+                if not is_registered_network(network):
+                    register_network(network)
+                else:
+                    # duplicate in the first input - trash it
+                    root.remove(network)
+        else:
+            for network in root.iterfind("./network"):
+                if not is_registered_network(network):
+                    rspec.append(deepcopy(network))
+                    register_network(network)
+            for request in root.iterfind("./request"):
+                rspec.append(deepcopy(request))
+    return etree.tostring(rspec, xml_declaration=True, pretty_print=True)
+
 class RSpec:
     def __init__(self, xml):
         parser = etree.XMLParser(remove_blank_text=True)
         tree = etree.parse(StringIO(xml), parser)
         self.rspec = tree.getroot()
 
-    def get_node_element(self, hostname):
-        names = self.rspec.iterfind("./network/site/node/hostname")
+        # If there is only one network in the rspec, make it the default
+        self.network = None
+        networks = self.get_network_list()
+        if len(networks) == 1:
+            self.network = networks[0]
+
+    # Thierry : need this to locate hostname even if several networks
+    def get_node_element(self, hostname, network=None):
+        if network == None and self.network:
+            network = self.network
+        if network != None:
+            names = self.rspec.iterfind("./network[@name='%s']/site/node/hostname" % network)
+        else:
+            names = self.rspec.iterfind("./network/site/node/hostname")
         for name in names:
             if name.text == hostname:
                 return name.getparent()
         return None
         
-    def get_node_list(self):
-        result = self.rspec.xpath("./network/site/node/hostname/text()")
+    # Thierry : need this to return all nodes in all networks
+    def get_node_list(self, network=None):
+        if network == None and self.network:
+            network = self.network
+        if network != None:
+            return self.rspec.xpath("./network[@name='%s']/site/node/hostname/text()" % network)
+        else:
+            return self.rspec.xpath("./network/site/node/hostname/text()")
+
+    def get_network_list(self):
+        return self.rspec.xpath("./network[@name]/@name")
+
+    def get_sliver_list(self, network=None):
+        if network == None:
+            network = self.network
+        result = self.rspec.xpath("./network[@name='%s']/site/node[sliver]/hostname/text()" % network)
         return result
 
-    def get_sliver_list(self):
-        result = self.rspec.xpath("./network/site/node[sliver]/hostname/text()")
+    def get_available_node_list(self, network=None):
+        if network == None:
+            network = self.network
+        result = self.rspec.xpath("./network[@name='%s']/site/node[not(sliver)]/hostname/text()" % network)
         return result
 
-    def add_sliver(self, hostname):
-        node = self.get_node_element(hostname)
+    def add_sliver(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         etree.SubElement(node, "sliver")
 
-    def remove_sliver(self, hostname):
-        node = self.get_node_element(hostname)
+    def remove_sliver(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         node.remove(node.find("sliver"))
 
     def attributes_list(self, elem):
@@ -41,12 +137,16 @@ class RSpec:
                 opts.append((e.tag, e.text))
         return opts
 
-    def get_default_sliver_attributes(self):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def get_default_sliver_attributes(self, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
         return self.attributes_list(defaults)
 
-    def get_sliver_attributes(self, hostname):
-        node = self.get_node_element(hostname)
+    def get_sliver_attributes(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         return self.attributes_list(sliver)
 
@@ -54,12 +154,20 @@ class RSpec:
         opt = etree.SubElement(elem, name)
         opt.text = value
 
-    def add_default_sliver_attribute(self, name, value):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def add_default_sliver_attribute(self, name, value, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
+        if defaults is None:
+            defaults = etree.Element("sliver_defaults")
+            network = self.rspec.find("./network[@name='%s']" % network)
+            network.insert(0, defaults)
         self.add_attribute(defaults, name, value)
 
-    def add_sliver_attribute(self, hostname, name, value):
-        node = self.get_node_element(hostname)
+    def add_sliver_attribute(self, hostname, name, value, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         self.add_attribute(sliver, name, value)
 
@@ -71,38 +179,48 @@ class RSpec:
                     if opt.text == value:
                         elem.remove(opt)
 
-    def remove_default_sliver_attribute(self, name, value):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def remove_default_sliver_attribute(self, name, value, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
         self.remove_attribute(defaults, name, value)
 
-    def remove_sliver_attribute(self, hostname, name, value):
-        node = self.get_node_element(hostname)
+    def remove_sliver_attribute(self, hostname, name, value, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         self.remove_attribute(sliver, name, value)
 
-    def get_site_nodes(self, siteid):
-        query = './/site[@id="%s"]/node/hostname/text()' % siteid
+    def get_site_nodes(self, siteid, network=None):
+        if network == None:
+            network = self.network
+        query = './network[@name="%s"]/site[@id="%s"]/node/hostname/text()' % (network, siteid)
         result = self.rspec.xpath(query)
         return result
         
-    def get_link_list(self):
+    def get_link_list(self, network=None):
+        if network == None:
+            network = self.network
         linklist = []
-        links = self.rspec.iterfind(".//link")
+        links = self.rspec.iterfind("./network[@name='%s']/link" % network)
         for link in links:
             (end1, end2) = link.get("endpoints").split()
             name = link.find("description")
             linklist.append((name.text, 
-                             self.get_site_nodes(end1), 
-                             self.get_site_nodes(end2)))
+                             self.get_site_nodes(end1, network), 
+                             self.get_site_nodes(end2, network)))
         return linklist
 
-    def get_vlink_list(self):
+    def get_vlink_list(self, network=None):
+        if network == None:
+            network = self.network
         vlinklist = []
-        vlinks = self.rspec.iterfind(".//vlink")
+        vlinks = self.rspec.iterfind("./network[@name='%s']//vlink" % network)
         for vlink in vlinks:
             endpoints = vlink.get("endpoints")
             (end1, end2) = endpoints.split()
-            query = './/node[@id="%s"]/hostname/text()'
+            query = './network[@name="%s"]//node[@id="%s"]/hostname/text()' % network
             node1 = self.rspec.xpath(query % end1)[0]
             node2 = self.rspec.xpath(query % end2)[0]
             desc = "%s <--> %s" % (node1, node2) 
@@ -110,29 +228,35 @@ class RSpec:
             vlinklist.append((endpoints, desc, kbps.text))
         return vlinklist
 
-    def query_links(self, fromnode, tonode):
+    def query_links(self, fromnode, tonode, network=None):
+        if network == None:
+            network = self.network
         fromsite = fromnode.getparent()
         tosite = tonode.getparent()
         fromid = fromsite.get("id")
         toid = tosite.get("id")
 
-        query = ".//link[@endpoints = '%s %s']" % (fromid, toid)
+        query = "./network[@name='%s']/link[@endpoints = '%s %s']" % (network, fromid, toid)
         results = self.rspec.xpath(query)
         if results == None:
-            query = ".//link[@endpoints = '%s %s']" % (toid, fromid)
+            query = "./network[@name='%s']/link[@endpoints = '%s %s']" % (network, toid, fromid)
             results = self.rspec.xpath(query)
         return results
 
-    def query_vlinks(self, endpoints):
-        query = ".//vlink[@endpoints = '%s']" % endpoints
+    def query_vlinks(self, endpoints, network=None):
+        if network == None:
+            network = self.network
+        query = "./network[@name='%s']//vlink[@endpoints = '%s']" % (network, endpoints)
         results = self.rspec.xpath(query)
         return results
             
     
-    def add_vlink(self, fromhost, tohost, kbps):
-        fromnode = self.get_node_element(fromhost)
-        tonode = self.get_node_element(tohost)
-        links = self.query_links(fromnode, tonode)
+    def add_vlink(self, fromhost, tohost, kbps, network=None):
+        if network == None:
+            network = self.network
+        fromnode = self.get_node_element(fromhost, network)
+        tonode = self.get_node_element(tohost, network)
+        links = self.query_links(fromnode, tonode, network)
 
         for link in links:
             vlink = etree.SubElement(link, "vlink")
@@ -142,8 +266,10 @@ class RSpec:
             self.add_attribute(vlink, "kbps", kbps)
         
 
-    def remove_vlink(self, endpoints):
-        vlinks = self.query_vlinks(endpoints)
+    def remove_vlink(self, endpoints, network=None):
+        if network == None:
+            network = self.network
+        vlinks = self.query_vlinks(endpoints, network)
         for vlink in vlinks:
             vlink.getparent().remove(vlink)