python3 - 2to3 + miscell obvious tweaks
[sfa.git] / sfa / util / xml.py
index d2bb6a7..d9e23d5 100755 (executable)
-#!/usr/bin/python 
-from types import StringTypes
+#!/usr/bin/env python3
 from lxml import etree
 from lxml import etree
-from StringIO import StringIO
 from sfa.util.faults import InvalidXML
 from sfa.util.faults import InvalidXML
+from sfa.rspecs.elements.element import Element
+
+from sfa.util.py23 import StringType
+from sfa.util.py23 import StringIO
+
+# helper functions to help build xpaths
+
 
 class XpathFilter:
 
 class XpathFilter:
-    @staticmethod
 
 
+    @staticmethod
     def filter_value(key, value):
     def filter_value(key, value):
-        xpath = ""    
+        xpath = ""
         if isinstance(value, str):
             if '*' in value:
                 value = value.replace('*', '')
                 xpath = 'contains(%s, "%s")' % (key, value)
             else:
         if isinstance(value, str):
             if '*' in value:
                 value = value.replace('*', '')
                 xpath = 'contains(%s, "%s")' % (key, value)
             else:
-                xpath = '%s="%s"' % (key, value)                
+                xpath = '%s="%s"' % (key, value)
         return xpath
 
     @staticmethod
         return xpath
 
     @staticmethod
-    def xpath(filter={}):
+    def xpath(filter=None):
+        if filter is None:
+            filter = {}
         xpath = ""
         if filter:
             filter_list = []
         xpath = ""
         if filter:
             filter_list = []
-            for (key, value) in filter.items():
+            for (key, value) in list(filter.items()):
                 if key == 'text':
                     key = 'text()'
                 else:
                 if key == 'text':
                     key = 'text()'
                 else:
-                    key = '@'+key
+                    key = '@' + key
                 if isinstance(value, str):
                     filter_list.append(XpathFilter.filter_value(key, value))
                 elif isinstance(value, list):
                 if isinstance(value, str):
                     filter_list.append(XpathFilter.filter_value(key, value))
                 elif isinstance(value, list):
-                    stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value])
-                    filter_list.append(stmt)   
+                    stmt = ' or '.join(
+                        [XpathFilter.filter_value(key, str(val)) for val in value])
+                    filter_list.append(stmt)
             if filter_list:
                 xpath = ' and '.join(filter_list)
                 xpath = '[' + xpath + ']'
         return xpath
 
             if filter_list:
                 xpath = ' and '.join(filter_list)
                 xpath = '[' + xpath + ']'
         return xpath
 
-class XmlNode:
-    def __init__(self, node, namespaces):
-        self.node = node
-        self.text = node.text
+# a wrapper class around lxml.etree._Element
+# the reason why we need this one is because of the limitations
+# we've found in xpath to address documents with multiple namespaces defined
+# in a nutshell, we deal with xml documents that have
+# a default namespace defined (xmlns="http://default.com/") and specific prefixes defined
+# (xmlns:foo="http://foo.com")
+# according to the documentation instead of writing
+# element.xpath ( "//node/foo:subnode" )
+# we'd then need to write xpaths like
+# element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" )
+# which is a real pain..
+# So just so we can keep some reasonable programming style we need to manage the
+# namespace map that goes with the _Element (its internal .nsmap being
+# unmutable)
+
+
+class XmlElement:
+
+    def __init__(self, element, namespaces):
+        self.element = element
         self.namespaces = namespaces
         self.namespaces = namespaces
-        self.attrib = node.attrib
-        
 
 
+    # redefine as few methods as possible
     def xpath(self, xpath, namespaces=None):
         if not namespaces:
     def xpath(self, xpath, namespaces=None):
         if not namespaces:
-            namespaces = self.namespaces 
-        elems = self.node.xpath(xpath, namespaces=namespaces)
-        return [XmlNode(elem, namespaces) for elem in elems]
-    
-    def add_element(self, name, **kwds):
-        element = etree.SubElement(self.node, name, **kwds)
-        return XmlNode(element, self.namespaces)
+            namespaces = self.namespaces
+        elems = self.element.xpath(xpath, namespaces=namespaces)
+        return [XmlElement(elem, namespaces) for elem in elems]
+
+    def add_element(self, tagname, **kwds):
+        element = etree.SubElement(self.element, tagname, **kwds)
+        return XmlElement(element, self.namespaces)
 
     def append(self, elem):
 
     def append(self, elem):
-        self.node.append(elem)
+        if isinstance(elem, XmlElement):
+            self.element.append(elem.element)
+        else:
+            self.element.append(elem)
 
 
-    def remove_elements(name):
+    def getparent(self):
+        return XmlElement(self.element.getparent(), self.namespaces)
+
+    def get_instance(self, instance_class=None, fields=None):
+        """
+        Returns an instance (dict) of this xml element. The instance
+        holds a reference to this xml element.   
+        """
+        if fields is None:
+            fields = []
+        if not instance_class:
+            instance_class = Element
+        if not fields and hasattr(instance_class, 'fields'):
+            fields = instance_class.fields
+
+        if not fields:
+            instance = instance_class(self.attrib, self)
+        else:
+            instance = instance_class({}, self)
+            for field in fields:
+                if field in self.attrib:
+                    instance[field] = self.attrib[field]
+        return instance
+
+    def add_instance(self, name, instance, fields=None):
+        """
+        Adds the specifed instance(s) as a child element of this xml 
+        element. 
+        """
+        if fields is None:
+            fields = []
+        if not fields and hasattr(instance, 'keys'):
+            fields = list(instance.keys())
+        elem = self.add_element(name)
+        for field in fields:
+            if field in instance and instance[field]:
+                elem.set(field, str(instance[field]))
+        return elem
+
+    def remove_elements(self, name):
         """
         Removes all occurences of an element from the tree. Start at
         specified root_node if specified, otherwise start at tree's root.
         """
         """
         Removes all occurences of an element from the tree. Start at
         specified root_node if specified, otherwise start at tree's root.
         """
-        
+
         if not element_name.startswith('//'):
             element_name = '//' + element_name
         if not element_name.startswith('//'):
             element_name = '//' + element_name
-        elements = self.node.xpath('%s ' % name, namespaces=self.namespaces) 
+        elements = self.element.xpath('%s ' % name, namespaces=self.namespaces)
         for element in elements:
             parent = element.getparent()
             parent.remove(element)
 
         for element in elements:
             parent = element.getparent()
             parent.remove(element)
 
-    def remove(element):
-        self.node.remove(element)
+    def delete(self):
+        parent = self.getparent()
+        parent.remove(self)
+
+    def remove(self, element):
+        if isinstance(element, XmlElement):
+            self.element.remove(element.element)
+        else:
+            self.element.remove(element)
 
 
-    def set(self, key, value):
-        self.node.set(key, value)
-    
     def set_text(self, text):
     def set_text(self, text):
-        self.node.text = text
-    
+        self.element.text = text
+
+    # Element does not have unset ?!?
     def unset(self, key):
     def unset(self, key):
-        del self.node.attrib[key]
-  
-    def iterchildren(self):
-        return self.node.iterchildren()
-     
+        del self.element.attrib[key]
+
     def toxml(self):
     def toxml(self):
-        return etree.tostring(self.node, encoding='UTF-8', pretty_print=True)                    
+        return etree.tostring(self.element, encoding='UTF-8', pretty_print=True)
 
     def __str__(self):
         return self.toxml()
 
 
     def __str__(self):
         return self.toxml()
 
+    # other method calls or attribute access like .text or .tag or .get
+    # are redirected on self.element
+    def __getattr__(self, name):
+        if not hasattr(self.element, name):
+            raise AttributeError(name)
+        return getattr(self.element, name)
+
+
 class XML:
 class XML:
+
     def __init__(self, xml=None, namespaces=None):
         self.root = None
         self.namespaces = namespaces
         self.default_namespace = None
         self.schema = None
     def __init__(self, xml=None, namespaces=None):
         self.root = None
         self.namespaces = namespaces
         self.default_namespace = None
         self.schema = None
-        if isinstance(xml, basestring):
+        if isinstance(xml, StringType):
             self.parse_xml(xml)
             self.parse_xml(xml)
-        if isinstance(xml, XmlNode):
+        if isinstance(xml, XmlElement):
             self.root = xml
             self.namespaces = xml.namespaces
         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
             self.root = xml
             self.namespaces = xml.namespaces
         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
@@ -118,63 +193,66 @@ class XML:
             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
             try:
                 tree = etree.parse(StringIO(xml), parser)
             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
             try:
                 tree = etree.parse(StringIO(xml), parser)
-            except Exception, e:
+            except Exception as e:
                 raise InvalidXML(str(e))
         root = tree.getroot()
         self.namespaces = dict(root.nsmap)
         # set namespaces map
                 raise InvalidXML(str(e))
         root = tree.getroot()
         self.namespaces = dict(root.nsmap)
         # set namespaces map
-        if 'default' not in self.namespaces and None in self.namespaces: 
-            # If the 'None' exist, then it's pointing to the default namespace. This makes 
-            # it hard for us to write xpath queries for the default naemspace because lxml 
-            # wont understand a None prefix. We will just associate the default namespeace 
-            # with a key named 'default'.     
+        if 'default' not in self.namespaces and None in self.namespaces:
+            # If the 'None' exist, then it's pointing to the default namespace. This makes
+            # it hard for us to write xpath queries for the default naemspace because lxml
+            # wont understand a None prefix. We will just associate the default namespeace
+            # with a key named 'default'.
             self.namespaces['default'] = self.namespaces.pop(None)
             self.namespaces['default'] = self.namespaces.pop(None)
-            
+
         else:
         else:
-            self.namespaces['default'] = 'default' 
+            self.namespaces['default'] = 'default'
 
 
-        self.root = XmlNode(root, self.namespaces)
-        # set schema 
-        for key in self.root.attrib.keys():
+        self.root = XmlElement(root, self.namespaces)
+        # set schema
+        for key in list(self.root.attrib.keys()):
             if key.endswith('schemaLocation'):
             if key.endswith('schemaLocation'):
-                # schema location should be at the end of the list
-                schema_parts  = self.root.attrib[key].split(' ')
-                self.schema = schema_parts[1]    
-                namespace, schema  = schema_parts[0], schema_parts[1]
+                # schemaLocation should be at the end of the list.
+                # Use list comprehension to filter out empty strings
+                schema_parts = [
+                    x for x in self.root.attrib[key].split(' ') if x]
+                self.schema = schema_parts[1]
+                namespace, schema = schema_parts[0], schema_parts[1]
                 break
 
                 break
 
-    def parse_dict(self, d, root_tag_name='xml', element = None):
-        if element is None: 
+    def parse_dict(self, d, root_tag_name='xml', element=None):
+        if element is None:
             if self.root is None:
                 self.parse_xml('<%s/>' % root_tag_name)
             if self.root is None:
                 self.parse_xml('<%s/>' % root_tag_name)
-            element = self.root
+            element = self.root.element
 
         if 'text' in d:
             text = d.pop('text')
             element.text = text
 
         # handle repeating fields
 
         if 'text' in d:
             text = d.pop('text')
             element.text = text
 
         # handle repeating fields
-        for (key, value) in d.items():
+        for (key, value) in list(d.items()):
             if isinstance(value, list):
                 value = d.pop(key)
                 for val in value:
                     if isinstance(val, dict):
                         child_element = etree.SubElement(element, key)
                         self.parse_dict(val, key, child_element)
             if isinstance(value, list):
                 value = d.pop(key)
                 for val in value:
                     if isinstance(val, dict):
                         child_element = etree.SubElement(element, key)
                         self.parse_dict(val, key, child_element)
-                    elif isinstance(val, basestring):
-                        child_element = etree.SubElement(element, key).text = val
-                        
+                    elif isinstance(val, StringType):
+                        child_element = etree.SubElement(
+                            element, key).text = val
+
             elif isinstance(value, int):
             elif isinstance(value, int):
-                d[key] = unicode(d[key])  
+                d[key] = str(d[key])
             elif value is None:
                 d.pop(key)
 
         # element.attrib.update will explode if DateTimes are in the
         # dcitionary.
             elif value is None:
                 d.pop(key)
 
         # element.attrib.update will explode if DateTimes are in the
         # dcitionary.
-        d=d.copy()
+        d = d.copy()
         # looks like iteritems won't stand side-effects
         # looks like iteritems won't stand side-effects
-        for k in d.keys():
-            if not isinstance(d[k],StringTypes):
+        for k in list(d.keys()):
+            if not isinstance(d[k], StringType):
                 del d[k]
 
         element.attrib.update(d)
                 del d[k]
 
         element.attrib.update(d)
@@ -196,49 +274,41 @@ class XML:
             namespaces = self.namespaces
         return self.root.xpath(xpath, namespaces=namespaces)
 
             namespaces = self.namespaces
         return self.root.xpath(xpath, namespaces=namespaces)
 
-    def set(self, key, value, node=None):
-        if not node:
-            node = self.root 
-        return node.set(key, value)
+    def set(self, key, value):
+        return self.root.set(key, value)
 
 
-    def remove_attribute(self, name, node=None):
-        if not node:
-            node = self.root
-        node.remove_attribute(name) 
-        
+    def remove_attribute(self, name, element=None):
+        if not element:
+            element = self.root
+        element.remove_attribute(name)
 
 
-    def add_element(self, name, **kwds):
+    def add_element(self, *args, **kwds):
         """
         Wrapper around etree.SubElement(). Adds an element to 
         specified parent node. Adds element to root node is parent is 
         not specified. 
         """
         """
         Wrapper around etree.SubElement(). Adds an element to 
         specified parent node. Adds element to root node is parent is 
         not specified. 
         """
-        parent = self.root
-        xmlnode = parent.add_element(name, *kwds)
-        return xmlnode
+        return self.root.add_element(*args, **kwds)
 
 
-    def remove_elements(self, name, node = None):
+    def remove_elements(self, name, element=None):
         """
         Removes all occurences of an element from the tree. Start at 
         specified root_node if specified, otherwise start at tree's root.   
         """
         """
         Removes all occurences of an element from the tree. Start at 
         specified root_node if specified, otherwise start at tree's root.   
         """
-        if not node:
-            node = self.root
+        if not element:
+            element = self.root
+
+        element.remove_elements(name)
 
 
-        node.remove_elements(name)
+    def add_instance(self, *args, **kwds):
+        return self.root.add_instance(*args, **kwds)
 
 
-    def attributes_list(self, elem):
-        # convert a list of attribute tags into list of tuples
-        # (tagnme, text_value)
-        opts = []
-        if elem is not None:
-            for e in elem:
-                opts.append((e.tag, str(e.text).strip()))
-        return opts
+    def get_instance(self, *args, **kwds):
+        return self.root.get_instnace(*args, **kwds)
 
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
 
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
-            elem = self.root_node
+            elem = self.root
         if not hasattr(elem, 'attrib'):
             # this is probably not an element node with attribute. could be just and an
             # attribute, return it
         if not hasattr(elem, 'attrib'):
             # this is probably not an element node with attribute. could be just and an
             # attribute, return it
@@ -250,13 +320,21 @@ class XML:
             for child_elem in list(elem):
                 key = str(child_elem.tag)
                 if key not in attrs:
             for child_elem in list(elem):
                 key = str(child_elem.tag)
                 if key not in attrs:
-                    attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
+                    attrs[key] = [self.get_element_attributes(
+                        child_elem, depth - 1)]
                 else:
                 else:
-                    attrs[key].append(self.get_element_attributes(child_elem, depth-1))
+                    attrs[key].append(
+                        self.get_element_attributes(child_elem, depth - 1))
         else:
             attrs['child_nodes'] = list(elem)
         return attrs
 
         else:
             attrs['child_nodes'] = list(elem)
         return attrs
 
+    def append(self, elem):
+        return self.root.append(elem)
+
+    def iterchildren(self):
+        return self.root.iterchildren()
+
     def merge(self, in_xml):
         pass
 
     def merge(self, in_xml):
         pass
 
@@ -264,8 +342,8 @@ class XML:
         return self.toxml()
 
     def toxml(self):
         return self.toxml()
 
     def toxml(self):
-        return etree.tostring(self.root.node, encoding='UTF-8', pretty_print=True)  
-    
+        return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True)
+
     # XXX smbaker, for record.load_from_string
     def todict(self, elem=None):
         if elem is None:
     # XXX smbaker, for record.load_from_string
     def todict(self, elem=None):
         if elem is None:
@@ -278,18 +356,17 @@ class XML:
                 d[child.tag] = []
             d[child.tag].append(self.todict(child))
 
                 d[child.tag] = []
             d[child.tag].append(self.todict(child))
 
-        if len(d)==1 and ("text" in d):
+        if len(d) == 1 and ("text" in d):
             d = d["text"]
 
         return d
             d = d["text"]
 
         return d
-        
+
     def save(self, filename):
         f = open(filename, 'w')
         f.write(self.toxml())
         f.close()
 
     def save(self, filename):
         f = open(filename, 'w')
         f.write(self.toxml())
         f.close()
 
-# no RSpec in scope 
-#if __name__ == '__main__':
+# no RSpec in scope
+# if __name__ == '__main__':
 #    rspec = RSpec('/tmp/resources.rspec')
 #    print rspec
 #    rspec = RSpec('/tmp/resources.rspec')
 #    print rspec
-