xpath() returns sfa.util.xml.XmlNode instances. Added iterchildren() method
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from types import StringTypes
3 from lxml import etree
4 from StringIO import StringIO
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9
10     def filter_value(key, value):
11         xpath = ""    
12         if isinstance(value, str):
13             if '*' in value:
14                 value = value.replace('*', '')
15                 xpath = 'contains(%s, "%s")' % (key, value)
16             else:
17                 xpath = '%s="%s"' % (key, value)                
18         return xpath
19
20     @staticmethod
21     def xpath(filter={}):
22         xpath = ""
23         if filter:
24             filter_list = []
25             for (key, value) in filter.items():
26                 if key == 'text':
27                     key = 'text()'
28                 else:
29                     key = '@'+key
30                 if isinstance(value, str):
31                     filter_list.append(XpathFilter.filter_value(key, value))
32                 elif isinstance(value, list):
33                     stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value])
34                     filter_list.append(stmt)   
35             if filter_list:
36                 xpath = ' and '.join(filter_list)
37                 xpath = '[' + xpath + ']'
38         return xpath
39
40 class XmlNode:
41     def __init__(self, node, namespaces):
42         self.node = node
43         self.namespaces = namespaces
44         self.attrib = node.attrib
45
46     def xpath(self, xpath, namespaces=None):
47         if not namespaces:
48             namespaces = self.namespaces 
49         elems = self.node.xpath(xpath, namespaces=namespaces)
50         return [XmlNode(elem, namespaces) for elem in elems]
51     
52     def add_element(name, *args, **kwds):
53         element = etree.SubElement(name, args, kwds)
54         return XmlNode(element, self.namespaces)
55
56     def remove_elements(name):
57         """
58         Removes all occurences of an element from the tree. Start at
59         specified root_node if specified, otherwise start at tree's root.
60         """
61         
62         if not element_name.startswith('//'):
63             element_name = '//' + element_name
64         elements = self.node.xpath('%s ' % name, namespaces=self.namespaces) 
65         for element in elements:
66             parent = element.getparent()
67             parent.remove(element)
68
69     def set(self, key, value):
70         self.node.set(key, value)
71     
72     def set_text(self, text):
73         self.node.text = text
74     
75     def unset(self, key):
76         del self.node.attrib[key]
77   
78     def iterchildren(self):
79         return self.node.iterchildren()
80      
81     def toxml(self):
82         return etree.tostring(self.node, encoding='UTF-8', pretty_print=True)                    
83
84     def __str__(self):
85         return self.toxml()
86
87 class XML:
88  
89     def __init__(self, xml=None, namespaces=None):
90         self.root = None
91         self.namespaces = namespaces
92         self.default_namespace = None
93         self.schema = None
94         if isinstance(xml, basestring):
95             self.parse_xml(xml)
96         if isinstance(xml, XmlNode):
97             self.root = xml
98             self.namespces = xml.namespaces
99         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
100             self.parse_xml(etree.tostring(xml))
101
102     def parse_xml(self, xml):
103         """
104         parse rspec into etree
105         """
106         parser = etree.XMLParser(remove_blank_text=True)
107         try:
108             tree = etree.parse(xml, parser)
109         except IOError:
110             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
111             try:
112                 tree = etree.parse(StringIO(xml), parser)
113             except Exception, e:
114                 raise InvalidXML(str(e))
115         root = tree.getroot()
116         self.namespaces = dict(root.nsmap)
117         # set namespaces map
118         if 'default' not in self.namespaces and None in self.namespaces: 
119             # If the 'None' exist, then it's pointing to the default namespace. This makes 
120             # it hard for us to write xpath queries for the default naemspace because lxml 
121             # wont understand a None prefix. We will just associate the default namespeace 
122             # with a key named 'default'.     
123             self.namespaces['default'] = self.namespaces.pop(None)
124             
125         else:
126             self.namespaces['default'] = 'default' 
127
128         self.root = XmlNode(root, self.namespaces)
129         # set schema 
130         for key in self.root.attrib.keys():
131             if key.endswith('schemaLocation'):
132                 # schema location should be at the end of the list
133                 schema_parts  = self.root.attrib[key].split(' ')
134                 self.schema = schema_parts[1]    
135                 namespace, schema  = schema_parts[0], schema_parts[1]
136                 break
137
138     def parse_dict(self, d, root_tag_name='xml', element = None):
139         if element is None: 
140             if self.root is None:
141                 self.parse_xml('<%s/>' % root_tag_name)
142             element = self.root
143
144         if 'text' in d:
145             text = d.pop('text')
146             element.text = text
147
148         # handle repeating fields
149         for (key, value) in d.items():
150             if isinstance(value, list):
151                 value = d.pop(key)
152                 for val in value:
153                     if isinstance(val, dict):
154                         child_element = etree.SubElement(element, key)
155                         self.parse_dict(val, key, child_element)
156                     elif isinstance(val, basestring):
157                         child_element = etree.SubElement(element, key).text = val
158                         
159             elif isinstance(value, int):
160                 d[key] = unicode(d[key])  
161             elif value is None:
162                 d.pop(key)
163
164         # element.attrib.update will explode if DateTimes are in the
165         # dcitionary.
166         d=d.copy()
167         # looks like iteritems won't stand side-effects
168         for k in d.keys():
169             if not isinstance(d[k],StringTypes):
170                 del d[k]
171
172         element.attrib.update(d)
173
174     def validate(self, schema):
175         """
176         Validate against rng schema
177         """
178         relaxng_doc = etree.parse(schema)
179         relaxng = etree.RelaxNG(relaxng_doc)
180         if not relaxng(self.root):
181             error = relaxng.error_log.last_error
182             message = "%s (line %s)" % (error.message, error.line)
183             raise InvalidXML(message)
184         return True
185
186     def xpath(self, xpath, namespaces=None):
187         if not namespaces:
188             namespaces = self.namespaces
189         return self.root.xpath(xpath, namespaces=namespaces)
190
191     def set(self, key, value, node=None):
192         if not node:
193             node = self.root 
194         return node.set(key, value)
195
196     def remove_attribute(self, name, node=None):
197         if not node:
198             node = self.root
199         node.remove_attribute(name) 
200         
201
202     def add_element(self, name, attrs={}, parent=None, text=""):
203         """
204         Wrapper around etree.SubElement(). Adds an element to 
205         specified parent node. Adds element to root node is parent is 
206         not specified. 
207         """
208         if parent == None:
209             parent = self.root
210         element = etree.SubElement(parent, name)
211         if text:
212             element.text = text
213         if isinstance(attrs, dict):
214             for attr in attrs:
215                 element.set(attr, attrs[attr])  
216         return XmlNode(element, self.namespaces)
217
218     def remove_elements(self, name, node = None):
219         """
220         Removes all occurences of an element from the tree. Start at 
221         specified root_node if specified, otherwise start at tree's root.   
222         """
223         if not node:
224             node = self.root
225
226         node.remove_elements(name)
227
228     def attributes_list(self, elem):
229         # convert a list of attribute tags into list of tuples
230         # (tagnme, text_value)
231         opts = []
232         if elem is not None:
233             for e in elem:
234                 opts.append((e.tag, str(e.text).strip()))
235         return opts
236
237     def get_element_attributes(self, elem=None, depth=0):
238         if elem == None:
239             elem = self.root_node
240         if not hasattr(elem, 'attrib'):
241             # this is probably not an element node with attribute. could be just and an
242             # attribute, return it
243             return elem
244         attrs = dict(elem.attrib)
245         attrs['text'] = str(elem.text).strip()
246         attrs['parent'] = elem.getparent()
247         if isinstance(depth, int) and depth > 0:
248             for child_elem in list(elem):
249                 key = str(child_elem.tag)
250                 if key not in attrs:
251                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
252                 else:
253                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
254         else:
255             attrs['child_nodes'] = list(elem)
256         return attrs
257
258     def merge(self, in_xml):
259         pass
260
261     def __str__(self):
262         return self.toxml()
263
264     def toxml(self):
265         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
266     
267     # XXX smbaker, for record.load_from_string
268     def todict(self, elem=None):
269         if elem is None:
270             elem = self.root
271         d = {}
272         d.update(elem.attrib)
273         d['text'] = elem.text
274         for child in elem.iterchildren():
275             if child.tag not in d:
276                 d[child.tag] = []
277             d[child.tag].append(self.todict(child))
278
279         if len(d)==1 and ("text" in d):
280             d = d["text"]
281
282         return d
283         
284     def save(self, filename):
285         f = open(filename, 'w')
286         f.write(self.toxml())
287         f.close()
288
289 # no RSpec in scope 
290 #if __name__ == '__main__':
291 #    rspec = RSpec('/tmp/resources.rspec')
292 #    print rspec
293