From: Tony Mack Date: Wed, 9 Nov 2011 17:08:23 +0000 (-0500) Subject: fix XpathFilter.xpath() X-Git-Tag: sfa-2.1-24~35^2 X-Git-Url: http://git.onelab.eu/?p=sfa.git;a=commitdiff_plain;h=2424addc9fb0bd817d5b5a811b4c9c49e27e86a0 fix XpathFilter.xpath() --- diff --git a/sfa/util/xml.py b/sfa/util/xml.py index ddb06e4f..b2aea13b 100755 --- a/sfa/util/xml.py +++ b/sfa/util/xml.py @@ -2,11 +2,22 @@ from types import StringTypes from lxml import etree from StringIO import StringIO - from sfa.util.faults import InvalidXML class XpathFilter: @staticmethod + + def filter_value(key, value): + xpath = "" + if isinstance(value, str): + if '*' in value: + value = value.replace('*', '') + xpath = 'contains(%s, "%s")' % (key, value) + else: + xpath = '%s="%s"' % (key, value) + return xpath + + @staticmethod def xpath(filter={}): xpath = "" if filter: @@ -17,27 +28,72 @@ class XpathFilter: else: key = '@'+key if isinstance(value, str): - filter_list.append('%s="%s"' % (key, value)) + filter_list.append(XpathFilter.filter_value(key, value)) elif isinstance(value, list): - filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key)) + stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value]) + filter_list.append(stmt) if filter_list: xpath = ' and '.join(filter_list) xpath = '[' + xpath + ']' return xpath +class XmlNode: + def __init__(self, node, namespaces): + self.node = node + self.namespaces = namespaces + self.attrib = node.attrib + + def xpath(self, xpath, namespaces=None): + if not namespaces: + namespaces = self.namespaces + return self.node.xpath(xpath, namespaces=namespaces) + + def add_element(name, *args, **kwds): + element = etree.SubElement(name, args, kwds) + return XmlNode(element, self.namespaces) + + def remove_elements(name): + """ + Removes all occurences of an element from the tree. Start at + specified root_node if specified, otherwise start at tree's root. + """ + + if not element_name.startswith('//'): + element_name = '//' + element_name + elements = self.node.xpath('%s ' % name, namespaces=self.namespaces) + for element in elements: + parent = element.getparent() + parent.remove(element) + + def set(self, key, value): + self.node.set(key, value) + + def set_text(self, text): + self.node.text = text + + def unset(self, key): + del self.node.attrib[key] + + def toxml(self): + return etree.tostring(self.node, encoding='UTF-8', pretty_print=True) + + def __str__(self): + return self.toxml() + class XML: - def __init__(self, xml=None): + def __init__(self, xml=None, namespaces=None): self.root = None - self.namespaces = None + self.namespaces = namespaces self.default_namespace = None self.schema = None if isinstance(xml, basestring): self.parse_xml(xml) - elif isinstance(xml, etree._ElementTree): - self.root = xml.getroot() - elif isinstance(xml, etree._Element): - self.root = xml + if isinstance(xml, XmlNode): + self.root = xml + self.namespces = xml.namespaces + elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element): + self.parse_xml(etree.tostring(xml)) def parse_xml(self, xml): """ @@ -52,9 +108,9 @@ class XML: tree = etree.parse(StringIO(xml), parser) except Exception, e: raise InvalidXML(str(e)) - self.root = tree.getroot() + root = tree.getroot() + self.namespaces = dict(root.nsmap) # set namespaces map - self.namespaces = dict(self.root.nsmap) if 'default' not in self.namespaces and None in self.namespaces: # If the 'None' exist, then it's pointing to the default namespace. This makes # it hard for us to write xpath queries for the default naemspace because lxml @@ -64,6 +120,8 @@ class XML: else: self.namespaces['default'] = 'default' + + self.root = XmlNode(root, self.namespaces) # set schema for key in self.root.attrib.keys(): if key.endswith('schemaLocation'): @@ -126,15 +184,16 @@ class XML: namespaces = self.namespaces return self.root.xpath(xpath, namespaces=namespaces) - def set(self, key, value): - return self.root.set(key, value) + def set(self, key, value, node=None): + if not node: + node = self.root + return node.set(key, value) - def add_attribute(self, elem, name, value): - """ - Add attribute to specified etree element - """ - opt = etree.SubElement(elem, name) - opt.text = value + def remove_attribute(self, name, node=None): + if not node: + node = self.root + node.remove_attribute(name) + def add_element(self, name, attrs={}, parent=None, text=""): """ @@ -150,34 +209,17 @@ class XML: if isinstance(attrs, dict): for attr in attrs: element.set(attr, attrs[attr]) - return element + return XmlNode(element, self.namespaces) - def remove_attribute(self, elem, name, value): - """ - Removes an attribute from an element - """ - if elem is not None: - opts = elem.iterfind(name) - if opts is not None: - for opt in opts: - if opt.text == value: - elem.remove(opt) - - def remove_element(self, element_name, root_node = None): + def remove_elements(self, name, node = None): """ Removes all occurences of an element from the tree. Start at specified root_node if specified, otherwise start at tree's root. """ - if not root_node: - root_node = self.root - - if not element_name.startswith('//'): - element_name = '//' + element_name + if not node: + node = self.root - elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces) - for element in elements: - parent = element.getparent() - parent.remove(element) + node.remove_elements(name) def attributes_list(self, elem): # convert a list of attribute tags into list of tuples