X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=sfa%2Futil%2Fxml.py;h=bb298a3f6bf9fb799bea880ae9edb8929e165e70;hb=69fb221c274eb0b6e9f6ff6f895e5e6f90b17230;hp=89eef9e00aeb577711c64212ff48c2ca5a36ec43;hpb=becde8484a5149f51d0ab3801fa953568393464d;p=sfa.git diff --git a/sfa/util/xml.py b/sfa/util/xml.py index 89eef9e0..bb298a3f 100755 --- a/sfa/util/xml.py +++ b/sfa/util/xml.py @@ -1,11 +1,23 @@ #!/usr/bin/python +from types import StringTypes from lxml import etree from StringIO import StringIO - from sfa.util.faults import InvalidXML class XpathFilter: @staticmethod + + def filter_value(key, value): + xpath = "" + if isinstance(value, str): + if '*' in value: + value = value.replace('*', '') + xpath = 'contains(%s, "%s")' % (key, value) + else: + xpath = '%s="%s"' % (key, value) + return xpath + + @staticmethod def xpath(filter={}): xpath = "" if filter: @@ -16,27 +28,98 @@ class XpathFilter: else: key = '@'+key if isinstance(value, str): - filter_list.append('%s="%s"' % (key, value)) + filter_list.append(XpathFilter.filter_value(key, value)) elif isinstance(value, list): - filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key)) + stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value]) + filter_list.append(stmt) if filter_list: xpath = ' and '.join(filter_list) xpath = '[' + xpath + ']' return xpath +class XmlNode: + def __init__(self, node, namespaces): + self.node = node + self.text = node.text + self.namespaces = namespaces + self.attrib = node.attrib + + + def xpath(self, xpath, namespaces=None): + if not namespaces: + namespaces = self.namespaces + elems = self.node.xpath(xpath, namespaces=namespaces) + return [XmlNode(elem, namespaces) for elem in elems] + + def add_element(self, tagname, **kwds): + element = etree.SubElement(self.node, tagname, **kwds) + return XmlNode(element, self.namespaces) + + def append(self, elem): + if isinstance(elem, XmlNode): + self.node.append(elem.node) + else: + self.node.append(elem) + + def getparent(self): + return XmlNode(self.node.getparent(), self.namespaces) + + def remove_elements(self, name): + """ + Removes all occurences of an element from the tree. Start at + specified root_node if specified, otherwise start at tree's root. + """ + + if not element_name.startswith('//'): + element_name = '//' + element_name + elements = self.node.xpath('%s ' % name, namespaces=self.namespaces) + for element in elements: + parent = element.getparent() + parent.remove(element) + + def remove(self, element): + if isinstance(element, XmlNode): + self.node.remove(element.node) + else: + self.node.remove(element) + + def get(self, key, *args): + return self.node.get(key, *args) + + def items(self): return self.node.items() + + def set(self, key, value): + self.node.set(key, value) + + def set_text(self, text): + self.node.text = text + + def unset(self, key): + del self.node.attrib[key] + + def iterchildren(self): + return self.node.iterchildren() + + def toxml(self): + return etree.tostring(self.node, encoding='UTF-8', pretty_print=True) + + def __str__(self): + return self.toxml() + class XML: - def __init__(self, xml=None): + def __init__(self, xml=None, namespaces=None): self.root = None - self.namespaces = None + self.namespaces = namespaces self.default_namespace = None self.schema = None if isinstance(xml, basestring): self.parse_xml(xml) - elif isinstance(xml, etree._ElementTree): - self.root = xml.getroot() - elif isinstance(xml, etree._Element): - self.root = xml + if isinstance(xml, XmlNode): + self.root = xml + self.namespaces = xml.namespaces + elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element): + self.parse_xml(etree.tostring(xml)) def parse_xml(self, xml): """ @@ -51,18 +134,20 @@ class XML: tree = etree.parse(StringIO(xml), parser) except Exception, e: raise InvalidXML(str(e)) - self.root = tree.getroot() + root = tree.getroot() + self.namespaces = dict(root.nsmap) # set namespaces map - self.namespaces = dict(self.root.nsmap) if 'default' not in self.namespaces and None in self.namespaces: # If the 'None' exist, then it's pointing to the default namespace. This makes # it hard for us to write xpath queries for the default naemspace because lxml # wont understand a None prefix. We will just associate the default namespeace # with a key named 'default'. - self.namespaces['default'] = self.namespaces[None] + self.namespaces['default'] = self.namespaces.pop(None) + else: self.namespaces['default'] = 'default' + self.root = XmlNode(root, self.namespaces) # set schema for key in self.root.attrib.keys(): if key.endswith('schemaLocation'): @@ -101,8 +186,9 @@ class XML: # element.attrib.update will explode if DateTimes are in the # dcitionary. d=d.copy() + # looks like iteritems won't stand side-effects for k in d.keys(): - if (type(d[k]) != str) and (type(d[k]) != unicode): + if not isinstance(d[k],StringTypes): del d[k] element.attrib.update(d) @@ -124,58 +210,35 @@ class XML: namespaces = self.namespaces return self.root.xpath(xpath, namespaces=namespaces) - def set(self, key, value): - return self.root.set(key, value) - - def add_attribute(self, elem, name, value): - """ - Add attribute to specified etree element - """ - opt = etree.SubElement(elem, name) - opt.text = value + def set(self, key, value, node=None): + if not node: + node = self.root + return node.set(key, value) - def add_element(self, name, attrs={}, parent=None, text=""): + def remove_attribute(self, name, node=None): + if not node: + node = self.root + node.remove_attribute(name) + + def add_element(self, name, **kwds): """ - Generic wrapper around etree.SubElement(). Adds an element to + Wrapper around etree.SubElement(). Adds an element to specified parent node. Adds element to root node is parent is not specified. """ - if parent == None: - parent = self.root - element = etree.SubElement(parent, name) - if text: - element.text = text - if isinstance(attrs, dict): - for attr in attrs: - element.set(attr, attrs[attr]) - return element + parent = self.root + xmlnode = parent.add_element(name, *kwds) + return xmlnode - def remove_attribute(self, elem, name, value): - """ - Removes an attribute from an element - """ - if elem is not None: - opts = elem.iterfind(name) - if opts is not None: - for opt in opts: - if opt.text == value: - elem.remove(opt) - - def remove_element(self, element_name, root_node = None): + def remove_elements(self, name, node = None): """ Removes all occurences of an element from the tree. Start at specified root_node if specified, otherwise start at tree's root. """ - if not root_node: - root_node = self.root + if not node: + node = self.root - if not element_name.startswith('//'): - element_name = '//' + element_name - - elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces) - for element in elements: - parent = element.getparent() - parent.remove(element) + node.remove_elements(name) def attributes_list(self, elem): # convert a list of attribute tags into list of tuples @@ -207,6 +270,12 @@ class XML: attrs['child_nodes'] = list(elem) return attrs + def append(self, elem): + return self.root.append(elem) + + def iterchildren(self): + return self.root.iterchildren() + def merge(self, in_xml): pass @@ -214,8 +283,9 @@ class XML: return self.toxml() def toxml(self): - return etree.tostring(self.root, encoding='UTF-8', pretty_print=True) + return etree.tostring(self.root.node, encoding='UTF-8', pretty_print=True) + # XXX smbaker, for record.load_from_string def todict(self, elem=None): if elem is None: elem = self.root @@ -226,7 +296,11 @@ class XML: if child.tag not in d: d[child.tag] = [] d[child.tag].append(self.todict(child)) - return d + + if len(d)==1 and ("text" in d): + d = d["text"] + + return d def save(self, filename): f = open(filename, 'w')