replaced Element.get_elements() with XmlNode.get_instance(). replaced Element.add_ele...
[sfa.git] / sfa / util / xml.py
index bb298a3..d6fd196 100755 (executable)
@@ -3,6 +3,7 @@ from types import StringTypes
 from lxml import etree
 from StringIO import StringIO
 from sfa.util.faults import InvalidXML
+from sfa.rspecs.elements.element import Element
 
 class XpathFilter:
     @staticmethod
@@ -37,32 +38,64 @@ class XpathFilter:
                 xpath = '[' + xpath + ']'
         return xpath
 
-class XmlNode:
-    def __init__(self, node, namespaces):
-        self.node = node
-        self.text = node.text
+class XmlElement:
+    def __init__(self, element, namespaces):
+        self.element = element
+        self.text = element.text
         self.namespaces = namespaces
-        self.attrib = node.attrib
+        self.attrib = element.attrib
         
 
     def xpath(self, xpath, namespaces=None):
         if not namespaces:
             namespaces = self.namespaces 
-        elems = self.node.xpath(xpath, namespaces=namespaces)
-        return [XmlNode(elem, namespaces) for elem in elems]
+        elems = self.element.xpath(xpath, namespaces=namespaces)
+        return [XmlElement(elem, namespaces) for elem in elems]
     
     def add_element(self, tagname, **kwds):
-        element = etree.SubElement(self.node, tagname, **kwds)
-        return XmlNode(element, self.namespaces)
+        element = etree.SubElement(self.element, tagname, **kwds)
+        return XmlElement(element, self.namespaces)
 
     def append(self, elem):
-        if isinstance(elem, XmlNode):
-            self.node.append(elem.node)
+        if isinstance(elem, XmlElement):
+            self.element.append(elem.element)
         else:
-            self.node.append(elem)
+            self.element.append(elem)
 
     def getparent(self):
-        return XmlNode(self.node.getparent(), self.namespaces)
+        return XmlElement(self.element.getparent(), self.namespaces)
+
+    def get_instance(self, instance_class=None, fields=[]):
+        """
+        Returns an instance (dict) of this xml element. The instance
+        holds a reference this xml element.   
+        """
+        if not instance_class:
+            instance_class = Element
+        if not fields and hasattr(instance_class, 'fields'):
+            fields = instance_class.fields
+
+        if not fields:
+            instance = instance_class(self.attrib, self)
+        else:
+            instance = instance_class({}, self)
+            for field in fields:
+                if field in self.attrib:
+                   instance[field] = self.attrib[field]  
+        return instance             
+
+    def add_instance(self, name, instance, fields=[]):
+        """
+        Adds the specifed instance(s) as a child element of this xml 
+        element. 
+        """
+        if not fields and hasattr(instance, 'keys'):
+            fields = instance.keys()
+        elem = self.add_element(name)
+        for field in fields:
+            if field in instance and instance[field]:
+                elem.set(field, unicode(instance[field]))
+        return elem                  
 
     def remove_elements(self, name):
         """
@@ -72,36 +105,36 @@ class XmlNode:
         
         if not element_name.startswith('//'):
             element_name = '//' + element_name
-        elements = self.node.xpath('%s ' % name, namespaces=self.namespaces) 
+        elements = self.element.xpath('%s ' % name, namespaces=self.namespaces) 
         for element in elements:
             parent = element.getparent()
             parent.remove(element)
 
     def remove(self, element):
-        if isinstance(element, XmlNode):
-            self.node.remove(element.node)
+        if isinstance(element, XmlElement):
+            self.element.remove(element.element)
         else:
-            self.node.remove(element)
+            self.element.remove(element)
 
     def get(self, key, *args):
-        return self.node.get(key, *args)
+        return self.element.get(key, *args)
 
-    def items(self): return self.node.items()
+    def items(self): return self.element.items()
 
     def set(self, key, value):
-        self.node.set(key, value)
+        self.element.set(key, value)
     
     def set_text(self, text):
-        self.node.text = text
+        self.element.text = text
     
     def unset(self, key):
-        del self.node.attrib[key]
+        del self.element.attrib[key]
   
     def iterchildren(self):
-        return self.node.iterchildren()
+        return self.element.iterchildren()
      
     def toxml(self):
-        return etree.tostring(self.node, encoding='UTF-8', pretty_print=True)                    
+        return etree.tostring(self.element, encoding='UTF-8', pretty_print=True)                    
 
     def __str__(self):
         return self.toxml()
@@ -115,7 +148,7 @@ class XML:
         self.schema = None
         if isinstance(xml, basestring):
             self.parse_xml(xml)
-        if isinstance(xml, XmlNode):
+        if isinstance(xml, XmlElement):
             self.root = xml
             self.namespaces = xml.namespaces
         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
@@ -147,7 +180,7 @@ class XML:
         else:
             self.namespaces['default'] = 'default' 
 
-        self.root = XmlNode(root, self.namespaces)
+        self.root = XmlElement(root, self.namespaces)
         # set schema 
         for key in self.root.attrib.keys():
             if key.endswith('schemaLocation'):
@@ -210,48 +243,43 @@ class XML:
             namespaces = self.namespaces
         return self.root.xpath(xpath, namespaces=namespaces)
 
-    def set(self, key, value, node=None):
-        if not node:
-            node = self.root 
-        return node.set(key, value)
+    def set(self, key, value, element=None):
+        if not element:
+            element = self.root 
+        return element.set(key, value)
 
-    def remove_attribute(self, name, node=None):
-        if not node:
-            node = self.root
-        node.remove_attribute(name) 
+    def remove_attribute(self, name, element=None):
+        if not element:
+            element = self.root
+        element.remove_attribute(name) 
         
-    def add_element(self, name, **kwds):
+    def add_element(self, *args, **kwds):
         """
         Wrapper around etree.SubElement(). Adds an element to 
         specified parent node. Adds element to root node is parent is 
         not specified. 
         """
-        parent = self.root
-        xmlnode = parent.add_element(name, *kwds)
-        return xmlnode
+        return self.root.add_element(*args, **kwds)
 
-    def remove_elements(self, name, node = None):
+    def remove_elements(self, name, element = None):
         """
         Removes all occurences of an element from the tree. Start at 
         specified root_node if specified, otherwise start at tree's root.   
         """
-        if not node:
-            node = self.root
+        if not element:
+            element = self.root
+
+        element.remove_elements(name)
 
-        node.remove_elements(name)
+    def add_instance(self, *args, **kwds):
+        return self.root.add_instance(*args, **kwds)
 
-    def attributes_list(self, elem):
-        # convert a list of attribute tags into list of tuples
-        # (tagnme, text_value)
-        opts = []
-        if elem is not None:
-            for e in elem:
-                opts.append((e.tag, str(e.text).strip()))
-        return opts
+    def get_instance(self, *args, **kwds):
+        return self.root.get_instnace(*args, **kwds)
 
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
-            elem = self.root_node
+            elem = self.root
         if not hasattr(elem, 'attrib'):
             # this is probably not an element node with attribute. could be just and an
             # attribute, return it
@@ -283,7 +311,7 @@ class XML:
         return self.toxml()
 
     def toxml(self):
-        return etree.tostring(self.root.node, encoding='UTF-8', pretty_print=True)  
+        return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True)  
     
     # XXX smbaker, for record.load_from_string
     def todict(self, elem=None):