python3 - 2to3 + miscell obvious tweaks
[sfa.git] / sfa / util / xml.py
index ddb06e4..d9e23d5 100755 (executable)
-#!/usr/bin/python 
-from types import StringTypes
+#!/usr/bin/env python3
 from lxml import etree
 from lxml import etree
-from StringIO import StringIO
-
 from sfa.util.faults import InvalidXML
 from sfa.util.faults import InvalidXML
+from sfa.rspecs.elements.element import Element
+
+from sfa.util.py23 import StringType
+from sfa.util.py23 import StringIO
+
+# helper functions to help build xpaths
+
 
 class XpathFilter:
 
 class XpathFilter:
+
     @staticmethod
     @staticmethod
-    def xpath(filter={}):
+    def filter_value(key, value):
+        xpath = ""
+        if isinstance(value, str):
+            if '*' in value:
+                value = value.replace('*', '')
+                xpath = 'contains(%s, "%s")' % (key, value)
+            else:
+                xpath = '%s="%s"' % (key, value)
+        return xpath
+
+    @staticmethod
+    def xpath(filter=None):
+        if filter is None:
+            filter = {}
         xpath = ""
         if filter:
             filter_list = []
         xpath = ""
         if filter:
             filter_list = []
-            for (key, value) in filter.items():
+            for (key, value) in list(filter.items()):
                 if key == 'text':
                     key = 'text()'
                 else:
                 if key == 'text':
                     key = 'text()'
                 else:
-                    key = '@'+key
+                    key = '@' + key
                 if isinstance(value, str):
                 if isinstance(value, str):
-                    filter_list.append('%s="%s"' % (key, value))
+                    filter_list.append(XpathFilter.filter_value(key, value))
                 elif isinstance(value, list):
                 elif isinstance(value, list):
-                    filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
+                    stmt = ' or '.join(
+                        [XpathFilter.filter_value(key, str(val)) for val in value])
+                    filter_list.append(stmt)
             if filter_list:
                 xpath = ' and '.join(filter_list)
                 xpath = '[' + xpath + ']'
         return xpath
 
             if filter_list:
                 xpath = ' and '.join(filter_list)
                 xpath = '[' + xpath + ']'
         return xpath
 
+# a wrapper class around lxml.etree._Element
+# the reason why we need this one is because of the limitations
+# we've found in xpath to address documents with multiple namespaces defined
+# in a nutshell, we deal with xml documents that have
+# a default namespace defined (xmlns="http://default.com/") and specific prefixes defined
+# (xmlns:foo="http://foo.com")
+# according to the documentation instead of writing
+# element.xpath ( "//node/foo:subnode" )
+# we'd then need to write xpaths like
+# element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" )
+# which is a real pain..
+# So just so we can keep some reasonable programming style we need to manage the
+# namespace map that goes with the _Element (its internal .nsmap being
+# unmutable)
+
+
+class XmlElement:
+
+    def __init__(self, element, namespaces):
+        self.element = element
+        self.namespaces = namespaces
+
+    # redefine as few methods as possible
+    def xpath(self, xpath, namespaces=None):
+        if not namespaces:
+            namespaces = self.namespaces
+        elems = self.element.xpath(xpath, namespaces=namespaces)
+        return [XmlElement(elem, namespaces) for elem in elems]
+
+    def add_element(self, tagname, **kwds):
+        element = etree.SubElement(self.element, tagname, **kwds)
+        return XmlElement(element, self.namespaces)
+
+    def append(self, elem):
+        if isinstance(elem, XmlElement):
+            self.element.append(elem.element)
+        else:
+            self.element.append(elem)
+
+    def getparent(self):
+        return XmlElement(self.element.getparent(), self.namespaces)
+
+    def get_instance(self, instance_class=None, fields=None):
+        """
+        Returns an instance (dict) of this xml element. The instance
+        holds a reference to this xml element.   
+        """
+        if fields is None:
+            fields = []
+        if not instance_class:
+            instance_class = Element
+        if not fields and hasattr(instance_class, 'fields'):
+            fields = instance_class.fields
+
+        if not fields:
+            instance = instance_class(self.attrib, self)
+        else:
+            instance = instance_class({}, self)
+            for field in fields:
+                if field in self.attrib:
+                    instance[field] = self.attrib[field]
+        return instance
+
+    def add_instance(self, name, instance, fields=None):
+        """
+        Adds the specifed instance(s) as a child element of this xml 
+        element. 
+        """
+        if fields is None:
+            fields = []
+        if not fields and hasattr(instance, 'keys'):
+            fields = list(instance.keys())
+        elem = self.add_element(name)
+        for field in fields:
+            if field in instance and instance[field]:
+                elem.set(field, str(instance[field]))
+        return elem
+
+    def remove_elements(self, name):
+        """
+        Removes all occurences of an element from the tree. Start at
+        specified root_node if specified, otherwise start at tree's root.
+        """
+
+        if not element_name.startswith('//'):
+            element_name = '//' + element_name
+        elements = self.element.xpath('%s ' % name, namespaces=self.namespaces)
+        for element in elements:
+            parent = element.getparent()
+            parent.remove(element)
+
+    def delete(self):
+        parent = self.getparent()
+        parent.remove(self)
+
+    def remove(self, element):
+        if isinstance(element, XmlElement):
+            self.element.remove(element.element)
+        else:
+            self.element.remove(element)
+
+    def set_text(self, text):
+        self.element.text = text
+
+    # Element does not have unset ?!?
+    def unset(self, key):
+        del self.element.attrib[key]
+
+    def toxml(self):
+        return etree.tostring(self.element, encoding='UTF-8', pretty_print=True)
+
+    def __str__(self):
+        return self.toxml()
+
+    # other method calls or attribute access like .text or .tag or .get
+    # are redirected on self.element
+    def __getattr__(self, name):
+        if not hasattr(self.element, name):
+            raise AttributeError(name)
+        return getattr(self.element, name)
+
+
 class XML:
 class XML:
-    def __init__(self, xml=None):
+
+    def __init__(self, xml=None, namespaces=None):
         self.root = None
         self.root = None
-        self.namespaces = None
+        self.namespaces = namespaces
         self.default_namespace = None
         self.schema = None
         self.default_namespace = None
         self.schema = None
-        if isinstance(xml, basestring):
+        if isinstance(xml, StringType):
             self.parse_xml(xml)
             self.parse_xml(xml)
-        elif isinstance(xml, etree._ElementTree):
-            self.root = xml.getroot()
-        elif isinstance(xml, etree._Element):
-            self.root = xml 
+        if isinstance(xml, XmlElement):
+            self.root = xml
+            self.namespaces = xml.namespaces
+        elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
+            self.parse_xml(etree.tostring(xml))
 
     def parse_xml(self, xml):
         """
 
     def parse_xml(self, xml):
         """
@@ -50,61 +193,66 @@ class XML:
             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
             try:
                 tree = etree.parse(StringIO(xml), parser)
             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
             try:
                 tree = etree.parse(StringIO(xml), parser)
-            except Exception, e:
+            except Exception as e:
                 raise InvalidXML(str(e))
                 raise InvalidXML(str(e))
-        self.root = tree.getroot()
+        root = tree.getroot()
+        self.namespaces = dict(root.nsmap)
         # set namespaces map
         # set namespaces map
-        self.namespaces = dict(self.root.nsmap)
-        if 'default' not in self.namespaces and None in self.namespaces: 
-            # If the 'None' exist, then it's pointing to the default namespace. This makes 
-            # it hard for us to write xpath queries for the default naemspace because lxml 
-            # wont understand a None prefix. We will just associate the default namespeace 
-            # with a key named 'default'.     
+        if 'default' not in self.namespaces and None in self.namespaces:
+            # If the 'None' exist, then it's pointing to the default namespace. This makes
+            # it hard for us to write xpath queries for the default naemspace because lxml
+            # wont understand a None prefix. We will just associate the default namespeace
+            # with a key named 'default'.
             self.namespaces['default'] = self.namespaces.pop(None)
             self.namespaces['default'] = self.namespaces.pop(None)
-            
+
         else:
         else:
-            self.namespaces['default'] = 'default' 
-        # set schema 
-        for key in self.root.attrib.keys():
+            self.namespaces['default'] = 'default'
+
+        self.root = XmlElement(root, self.namespaces)
+        # set schema
+        for key in list(self.root.attrib.keys()):
             if key.endswith('schemaLocation'):
             if key.endswith('schemaLocation'):
-                # schema location should be at the end of the list
-                schema_parts  = self.root.attrib[key].split(' ')
-                self.schema = schema_parts[1]    
-                namespace, schema  = schema_parts[0], schema_parts[1]
+                # schemaLocation should be at the end of the list.
+                # Use list comprehension to filter out empty strings
+                schema_parts = [
+                    x for x in self.root.attrib[key].split(' ') if x]
+                self.schema = schema_parts[1]
+                namespace, schema = schema_parts[0], schema_parts[1]
                 break
 
                 break
 
-    def parse_dict(self, d, root_tag_name='xml', element = None):
-        if element is None: 
+    def parse_dict(self, d, root_tag_name='xml', element=None):
+        if element is None:
             if self.root is None:
                 self.parse_xml('<%s/>' % root_tag_name)
             if self.root is None:
                 self.parse_xml('<%s/>' % root_tag_name)
-            element = self.root
+            element = self.root.element
 
         if 'text' in d:
             text = d.pop('text')
             element.text = text
 
         # handle repeating fields
 
         if 'text' in d:
             text = d.pop('text')
             element.text = text
 
         # handle repeating fields
-        for (key, value) in d.items():
+        for (key, value) in list(d.items()):
             if isinstance(value, list):
                 value = d.pop(key)
                 for val in value:
                     if isinstance(val, dict):
                         child_element = etree.SubElement(element, key)
                         self.parse_dict(val, key, child_element)
             if isinstance(value, list):
                 value = d.pop(key)
                 for val in value:
                     if isinstance(val, dict):
                         child_element = etree.SubElement(element, key)
                         self.parse_dict(val, key, child_element)
-                    elif isinstance(val, basestring):
-                        child_element = etree.SubElement(element, key).text = val
-                        
+                    elif isinstance(val, StringType):
+                        child_element = etree.SubElement(
+                            element, key).text = val
+
             elif isinstance(value, int):
             elif isinstance(value, int):
-                d[key] = unicode(d[key])  
+                d[key] = str(d[key])
             elif value is None:
                 d.pop(key)
 
         # element.attrib.update will explode if DateTimes are in the
         # dcitionary.
             elif value is None:
                 d.pop(key)
 
         # element.attrib.update will explode if DateTimes are in the
         # dcitionary.
-        d=d.copy()
+        d = d.copy()
         # looks like iteritems won't stand side-effects
         # looks like iteritems won't stand side-effects
-        for k in d.keys():
-            if not isinstance(d[k],StringTypes):
+        for k in list(d.keys()):
+            if not isinstance(d[k], StringType):
                 del d[k]
 
         element.attrib.update(d)
                 del d[k]
 
         element.attrib.update(d)
@@ -129,68 +277,38 @@ class XML:
     def set(self, key, value):
         return self.root.set(key, value)
 
     def set(self, key, value):
         return self.root.set(key, value)
 
-    def add_attribute(self, elem, name, value):
-        """
-        Add attribute to specified etree element    
-        """
-        opt = etree.SubElement(elem, name)
-        opt.text = value
+    def remove_attribute(self, name, element=None):
+        if not element:
+            element = self.root
+        element.remove_attribute(name)
 
 
-    def add_element(self, name, attrs={}, parent=None, text=""):
+    def add_element(self, *args, **kwds):
         """
         Wrapper around etree.SubElement(). Adds an element to 
         specified parent node. Adds element to root node is parent is 
         not specified. 
         """
         """
         Wrapper around etree.SubElement(). Adds an element to 
         specified parent node. Adds element to root node is parent is 
         not specified. 
         """
-        if parent == None:
-            parent = self.root
-        element = etree.SubElement(parent, name)
-        if text:
-            element.text = text
-        if isinstance(attrs, dict):
-            for attr in attrs:
-                element.set(attr, attrs[attr])  
-        return element
+        return self.root.add_element(*args, **kwds)
 
 
-    def remove_attribute(self, elem, name, value):
-        """
-        Removes an attribute from an element
-        """
-        if elem is not None:
-            opts = elem.iterfind(name)
-            if opts is not None:
-                for opt in opts:
-                    if opt.text == value:
-                        elem.remove(opt)
-
-    def remove_element(self, element_name, root_node = None):
+    def remove_elements(self, name, element=None):
         """
         Removes all occurences of an element from the tree. Start at 
         specified root_node if specified, otherwise start at tree's root.   
         """
         """
         Removes all occurences of an element from the tree. Start at 
         specified root_node if specified, otherwise start at tree's root.   
         """
-        if not root_node:
-            root_node = self.root
+        if not element:
+            element = self.root
 
 
-        if not element_name.startswith('//'):
-            element_name = '//' + element_name
+        element.remove_elements(name)
 
 
-        elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
-        for element in elements:
-            parent = element.getparent()
-            parent.remove(element)
+    def add_instance(self, *args, **kwds):
+        return self.root.add_instance(*args, **kwds)
 
 
-    def attributes_list(self, elem):
-        # convert a list of attribute tags into list of tuples
-        # (tagnme, text_value)
-        opts = []
-        if elem is not None:
-            for e in elem:
-                opts.append((e.tag, str(e.text).strip()))
-        return opts
+    def get_instance(self, *args, **kwds):
+        return self.root.get_instnace(*args, **kwds)
 
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
 
     def get_element_attributes(self, elem=None, depth=0):
         if elem == None:
-            elem = self.root_node
+            elem = self.root
         if not hasattr(elem, 'attrib'):
             # this is probably not an element node with attribute. could be just and an
             # attribute, return it
         if not hasattr(elem, 'attrib'):
             # this is probably not an element node with attribute. could be just and an
             # attribute, return it
@@ -202,13 +320,21 @@ class XML:
             for child_elem in list(elem):
                 key = str(child_elem.tag)
                 if key not in attrs:
             for child_elem in list(elem):
                 key = str(child_elem.tag)
                 if key not in attrs:
-                    attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
+                    attrs[key] = [self.get_element_attributes(
+                        child_elem, depth - 1)]
                 else:
                 else:
-                    attrs[key].append(self.get_element_attributes(child_elem, depth-1))
+                    attrs[key].append(
+                        self.get_element_attributes(child_elem, depth - 1))
         else:
             attrs['child_nodes'] = list(elem)
         return attrs
 
         else:
             attrs['child_nodes'] = list(elem)
         return attrs
 
+    def append(self, elem):
+        return self.root.append(elem)
+
+    def iterchildren(self):
+        return self.root.iterchildren()
+
     def merge(self, in_xml):
         pass
 
     def merge(self, in_xml):
         pass
 
@@ -216,8 +342,8 @@ class XML:
         return self.toxml()
 
     def toxml(self):
         return self.toxml()
 
     def toxml(self):
-        return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
-    
+        return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True)
+
     # XXX smbaker, for record.load_from_string
     def todict(self, elem=None):
         if elem is None:
     # XXX smbaker, for record.load_from_string
     def todict(self, elem=None):
         if elem is None:
@@ -230,18 +356,17 @@ class XML:
                 d[child.tag] = []
             d[child.tag].append(self.todict(child))
 
                 d[child.tag] = []
             d[child.tag].append(self.todict(child))
 
-        if len(d)==1 and ("text" in d):
+        if len(d) == 1 and ("text" in d):
             d = d["text"]
 
         return d
             d = d["text"]
 
         return d
-        
+
     def save(self, filename):
         f = open(filename, 'w')
         f.write(self.toxml())
         f.close()
 
     def save(self, filename):
         f = open(filename, 'w')
         f.write(self.toxml())
         f.close()
 
-# no RSpec in scope 
-#if __name__ == '__main__':
+# no RSpec in scope
+# if __name__ == '__main__':
 #    rspec = RSpec('/tmp/resources.rspec')
 #    print rspec
 #    rspec = RSpec('/tmp/resources.rspec')
 #    print rspec
-