Add network argument to rspec calls
authorAndy Bavier <acb@acb-imac.cs.princeton.edu>
Tue, 8 Feb 2011 15:11:30 +0000 (10:11 -0500)
committerAndy Bavier <acb@acb-imac.cs.princeton.edu>
Tue, 8 Feb 2011 15:11:30 +0000 (10:11 -0500)
Added an optional network parameter to the rspec helper functions, to make them
useful on rspecs that have multiple <network> elements.

sfa/util/rspecHelper.py

index 8174e6b..6cac3b3 100755 (executable)
@@ -73,33 +73,52 @@ class RSpec:
         tree = etree.parse(StringIO(xml), parser)
         self.rspec = tree.getroot()
 
-    def get_node_element(self, hostname):
-        names = self.rspec.iterfind("./network/site/node/hostname")
+        # If there is only one network in the rspec, make it the default
+        self.network = None
+        networks = self.get_network_list()
+        if len(networks) == 1:
+            self.network = networks[0]
+
+    def get_node_element(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        names = self.rspec.iterfind("./network[@name='%s']/site/node/hostname" % network)
         for name in names:
             if name.text == hostname:
                 return name.getparent()
         return None
         
-    def get_node_list(self):
-        result = self.rspec.xpath("./network/site/node/hostname/text()")
+    def get_node_list(self, network=None):
+        if network == None:
+            network = self.network
+        result = self.rspec.xpath("./network[@name='%s']/site/node/hostname/text()" % network)
         return result
 
     def get_network_list(self):
         return self.rspec.xpath("./network[@name]/@name")
 
-    def get_nodes_from_network(self, network):
-        return self.rspec.xpath("./network[@name='%s']/site/node/hostname/text()" % network)
+    def get_sliver_list(self, network=None):
+        if network == None:
+            network = self.network
+        result = self.rspec.xpath("./network[@name='%s']/site/node[sliver]/hostname/text()" % network)
+        return result
 
-    def get_sliver_list(self):
-        result = self.rspec.xpath("./network/site/node[sliver]/hostname/text()")
+    def get_available_node_list(self, network=None):
+        if network == None:
+            network = self.network
+        result = self.rspec.xpath("./network[@name='%s']/site/node[not(sliver)]/hostname/text()" % network)
         return result
 
-    def add_sliver(self, hostname):
-        node = self.get_node_element(hostname)
+    def add_sliver(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         etree.SubElement(node, "sliver")
 
-    def remove_sliver(self, hostname):
-        node = self.get_node_element(hostname)
+    def remove_sliver(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         node.remove(node.find("sliver"))
 
     def attributes_list(self, elem):
@@ -109,12 +128,16 @@ class RSpec:
                 opts.append((e.tag, e.text))
         return opts
 
-    def get_default_sliver_attributes(self):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def get_default_sliver_attributes(self, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
         return self.attributes_list(defaults)
 
-    def get_sliver_attributes(self, hostname):
-        node = self.get_node_element(hostname)
+    def get_sliver_attributes(self, hostname, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         return self.attributes_list(sliver)
 
@@ -122,16 +145,20 @@ class RSpec:
         opt = etree.SubElement(elem, name)
         opt.text = value
 
-    def add_default_sliver_attribute(self, name, value):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def add_default_sliver_attribute(self, name, value, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
         if defaults is None:
             defaults = etree.Element("sliver_defaults")
-            network = self.rspec.find(".//network")
+            network = self.rspec.find("./network[@name='%s']" % network)
             network.insert(0, defaults)
         self.add_attribute(defaults, name, value)
 
-    def add_sliver_attribute(self, hostname, name, value):
-        node = self.get_node_element(hostname)
+    def add_sliver_attribute(self, hostname, name, value, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         self.add_attribute(sliver, name, value)
 
@@ -143,38 +170,48 @@ class RSpec:
                     if opt.text == value:
                         elem.remove(opt)
 
-    def remove_default_sliver_attribute(self, name, value):
-        defaults = self.rspec.find(".//sliver_defaults")
+    def remove_default_sliver_attribute(self, name, value, network=None):
+        if network == None:
+            network = self.network
+        defaults = self.rspec.find("./network[@name='%s']/sliver_defaults" % network)
         self.remove_attribute(defaults, name, value)
 
-    def remove_sliver_attribute(self, hostname, name, value):
-        node = self.get_node_element(hostname)
+    def remove_sliver_attribute(self, hostname, name, value, network=None):
+        if network == None:
+            network = self.network
+        node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
         self.remove_attribute(sliver, name, value)
 
-    def get_site_nodes(self, siteid):
-        query = './/site[@id="%s"]/node/hostname/text()' % siteid
+    def get_site_nodes(self, siteid, network=None):
+        if network == None:
+            network = self.network
+        query = './network[@name="%s"]/site[@id="%s"]/node/hostname/text()' % (network, siteid)
         result = self.rspec.xpath(query)
         return result
         
-    def get_link_list(self):
+    def get_link_list(self, network=None):
+        if network == None:
+            network = self.network
         linklist = []
-        links = self.rspec.iterfind(".//link")
+        links = self.rspec.iterfind("./network[@name='%s']/link" % network)
         for link in links:
             (end1, end2) = link.get("endpoints").split()
             name = link.find("description")
             linklist.append((name.text, 
-                             self.get_site_nodes(end1), 
-                             self.get_site_nodes(end2)))
+                             self.get_site_nodes(end1, network), 
+                             self.get_site_nodes(end2, network)))
         return linklist
 
-    def get_vlink_list(self):
+    def get_vlink_list(self, network=None):
+        if network == None:
+            network = self.network
         vlinklist = []
-        vlinks = self.rspec.iterfind(".//vlink")
+        vlinks = self.rspec.iterfind("./network[@name='%s']//vlink" % network)
         for vlink in vlinks:
             endpoints = vlink.get("endpoints")
             (end1, end2) = endpoints.split()
-            query = './/node[@id="%s"]/hostname/text()'
+            query = './network[@name="%s"]//node[@id="%s"]/hostname/text()' % network
             node1 = self.rspec.xpath(query % end1)[0]
             node2 = self.rspec.xpath(query % end2)[0]
             desc = "%s <--> %s" % (node1, node2) 
@@ -182,29 +219,35 @@ class RSpec:
             vlinklist.append((endpoints, desc, kbps.text))
         return vlinklist
 
-    def query_links(self, fromnode, tonode):
+    def query_links(self, fromnode, tonode, network=None):
+        if network == None:
+            network = self.network
         fromsite = fromnode.getparent()
         tosite = tonode.getparent()
         fromid = fromsite.get("id")
         toid = tosite.get("id")
 
-        query = ".//link[@endpoints = '%s %s']" % (fromid, toid)
+        query = "./network[@name='%s']/link[@endpoints = '%s %s']" % (network, fromid, toid)
         results = self.rspec.xpath(query)
         if results == None:
-            query = ".//link[@endpoints = '%s %s']" % (toid, fromid)
+            query = "./network[@name='%s']/link[@endpoints = '%s %s']" % (network, toid, fromid)
             results = self.rspec.xpath(query)
         return results
 
-    def query_vlinks(self, endpoints):
-        query = ".//vlink[@endpoints = '%s']" % endpoints
+    def query_vlinks(self, endpoints, network=None):
+        if network == None:
+            network = self.network
+        query = "./network[@name='%s']//vlink[@endpoints = '%s']" % (network, endpoints)
         results = self.rspec.xpath(query)
         return results
             
     
-    def add_vlink(self, fromhost, tohost, kbps):
-        fromnode = self.get_node_element(fromhost)
-        tonode = self.get_node_element(tohost)
-        links = self.query_links(fromnode, tonode)
+    def add_vlink(self, fromhost, tohost, kbps, network=None):
+        if network == None:
+            network = self.network
+        fromnode = self.get_node_element(fromhost, network)
+        tonode = self.get_node_element(tohost, network)
+        links = self.query_links(fromnode, tonode, network)
 
         for link in links:
             vlink = etree.SubElement(link, "vlink")
@@ -214,8 +257,10 @@ class RSpec:
             self.add_attribute(vlink, "kbps", kbps)
         
 
-    def remove_vlink(self, endpoints):
-        vlinks = self.query_vlinks(endpoints)
+    def remove_vlink(self, endpoints, network=None):
+        if network == None:
+            network = self.network
+        vlinks = self.query_vlinks(endpoints, network)
         for vlink in vlinks:
             vlink.getparent().remove(vlink)