renamged get_element_type() to get_rspec_element(). Added load_rspec_elements() method
authorTony Mack <tmack@paris.CS.Princeton.EDU>
Thu, 1 Sep 2011 20:45:18 +0000 (16:45 -0400)
committerTony Mack <tmack@paris.CS.Princeton.EDU>
Thu, 1 Sep 2011 20:45:18 +0000 (16:45 -0400)
sfa/rspecs/rspec.py

index 764114a..5872e33 100755 (executable)
@@ -7,7 +7,6 @@ from sfa.util.plxrn import hostname_to_urn
 from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements 
 from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement
 
-
 class XpathFilter:
     @staticmethod
     def xpath(filter={}):
@@ -92,12 +91,18 @@ class RSpec:
     def xpath(self, xpath):
         return self.xml.xpath(xpath, namespaces=self.namespaces)
 
-    def register_element_type(self, element_type, element_name, element_path):
+    def load_rspec_elements(self, rspec_elements):
+        self.elements = {}
+        for rspec_element in rspec_elements:
+            if isinstance(rspec_element, RSpecElement):
+                self.elements[rspec_element.type] = rspec_element
+
+    def register_rspec_element(self, element_type, element_name, element_path):
         if element_type not in RSpecElements:
             raise InvalidRSpecElement(element_type, extra="no such element type: %s. Must specify a valid RSpecElement" % element_type)
         self.elements[element_type] = RSpecElement(element_type, element_name, element_path)
 
-    def get_element_type(self, element_type):
+    def get_rspec_element(self, element_type):
         if element_type not in self.elements:
             msg = "ElementType %s not registerd for this rspec" % element_type
             raise InvalidRSpecElement(element_type, extra=msg)
@@ -165,8 +170,13 @@ class RSpec:
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
             elem = self.root_node
+        if not hasattr(elem, 'attrib'):
+            # this is probably not an element node with attribute. could be just and an
+            # attribute, return it
+            return elem
         attrs = dict(elem.attrib)
         attrs['text'] = str(elem.text).strip()
+        attrs['parent'] = elem.getparent()
         if isinstance(depth, int) and depth > 0:
             for child_elem in list(elem):
                 key = str(child_elem.tag)
@@ -179,8 +189,8 @@ class RSpec:
         return attrs
 
     def get(self, element_type, filter={}, depth=0):
-        elements = [self.get_element_attributes(element, depth=depth) for element in \
-                    self.get_elements(element_type, filter)]
+        elements = self.get_elements(element_type, filter)
+        elements = [self.get_element_attributes(element, depth=depth) for element in elements]
         return elements
 
     def get_elements(self, element_type, filter={}):
@@ -191,7 +201,7 @@ class RSpec:
             msg = "Unable to search for element %s in rspec, expath expression not found." % \
                    element_type
             raise InvalidRSpecElement(element_type, extra=msg)
-        rspec_element = self.get_element_type(element_type)
+        rspec_element = self.get_rspec_element(element_type)
         xpath = rspec_element.path + XpathFilter.xpath(filter)
         return self.xpath(xpath)
 
@@ -237,8 +247,8 @@ class RSpec:
 if __name__ == '__main__':
     rspec = RSpec('/tmp/resources.rspec')
     print rspec
-    #rspec.register_element_type(RSpecElements.NETWORK, 'network', '//network')
-    #rspec.register_element_type(RSpecElements.NODE, 'node', '//node')
+    #rspec.register_rspec_element(RSpecElements.NETWORK, 'network', '//network')
+    #rspec.register_rspec_element(RSpecElements.NODE, 'node', '//node')
     #print rspec.find(RSpecElements.NODE)[0]
     #print rspec.find(RSpecElements.NODE, depth=1)[0]