f46443a8425f1de7cb51a276bad9573c987a2aa4
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from types import StringTypes
3 from lxml import etree
4 from StringIO import StringIO
5 from sfa.util.faults import InvalidXML
6 from sfa.rspecs.elements.element import Element
7
8 # helper functions to help build xpaths
9 class XpathFilter:
10     @staticmethod
11
12     def filter_value(key, value):
13         xpath = ""    
14         if isinstance(value, str):
15             if '*' in value:
16                 value = value.replace('*', '')
17                 xpath = 'contains(%s, "%s")' % (key, value)
18             else:
19                 xpath = '%s="%s"' % (key, value)                
20         return xpath
21
22     @staticmethod
23     def xpath(filter=None):
24         if filter is None: filter={}
25         xpath = ""
26         if filter:
27             filter_list = []
28             for (key, value) in filter.items():
29                 if key == 'text':
30                     key = 'text()'
31                 else:
32                     key = '@'+key
33                 if isinstance(value, str):
34                     filter_list.append(XpathFilter.filter_value(key, value))
35                 elif isinstance(value, list):
36                     stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value])
37                     filter_list.append(stmt)   
38             if filter_list:
39                 xpath = ' and '.join(filter_list)
40                 xpath = '[' + xpath + ']'
41         return xpath
42
43 # a wrapper class around lxml.etree._Element
44 # the reason why we need this one is because of the limitations
45 # we've found in xpath to address documents with multiple namespaces defined
46 # in a nutshell, we deal with xml documents that have
47 # a default namespace defined (xmlns="http://default.com/") and specific prefixes defined
48 # (xmlns:foo="http://foo.com")
49 # according to the documentation instead of writing
50 # element.xpath ( "//node/foo:subnode" ) 
51 # we'd then need to write xpaths like
52 # element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" ) 
53 # which is a real pain..
54 # So just so we can keep some reasonable programming style we need to manage the
55 # namespace map that goes with the _Element (its internal .nsmap being unmutable)
56
57 class XmlElement:
58     def __init__(self, element, namespaces):
59         self.element = element
60         self.namespaces = namespaces
61         
62     # redefine as few methods as possible
63     def xpath(self, xpath, namespaces=None):
64         if not namespaces:
65             namespaces = self.namespaces 
66         elems = self.element.xpath(xpath, namespaces=namespaces)
67         return [XmlElement(elem, namespaces) for elem in elems]
68
69     def add_element(self, tagname, **kwds):
70         element = etree.SubElement(self.element, tagname, **kwds)
71         return XmlElement(element, self.namespaces)
72
73     def append(self, elem):
74         if isinstance(elem, XmlElement):
75             self.element.append(elem.element)
76         else:
77             self.element.append(elem)
78
79     def getparent(self):
80         return XmlElement(self.element.getparent(), self.namespaces)
81
82     def get_instance(self, instance_class=None, fields=None):
83         """
84         Returns an instance (dict) of this xml element. The instance
85         holds a reference to this xml element.   
86         """
87         if fields is None: fields=[]
88         if not instance_class:
89             instance_class = Element
90         if not fields and hasattr(instance_class, 'fields'):
91             fields = instance_class.fields
92
93         if not fields:
94             instance = instance_class(self.attrib, self)
95         else:
96             instance = instance_class({}, self)
97             for field in fields:
98                 if field in self.attrib:
99                    instance[field] = self.attrib[field]  
100         return instance             
101
102     def add_instance(self, name, instance, fields=None):
103         """
104         Adds the specifed instance(s) as a child element of this xml 
105         element. 
106         """
107         if fields is None: fields=[]
108         if not fields and hasattr(instance, 'keys'):
109             fields = instance.keys()
110         elem = self.add_element(name)
111         for field in fields:
112             if field in instance and instance[field]:
113                 elem.set(field, unicode(instance[field]))
114         return elem                  
115
116     def remove_elements(self, name):
117         """
118         Removes all occurences of an element from the tree. Start at
119         specified root_node if specified, otherwise start at tree's root.
120         """
121         
122         if not element_name.startswith('//'):
123             element_name = '//' + element_name
124         elements = self.element.xpath('%s ' % name, namespaces=self.namespaces) 
125         for element in elements:
126             parent = element.getparent()
127             parent.remove(element)
128
129     def delete(self):
130         parent = self.getparent()
131         parent.remove(self)
132
133     def remove(self, element):
134         if isinstance(element, XmlElement):
135             self.element.remove(element.element)
136         else:
137             self.element.remove(element)
138
139     def set_text(self, text):
140         self.element.text = text
141     
142     # Element does not have unset ?!?
143     def unset(self, key):
144         del self.element.attrib[key]
145   
146     def toxml(self):
147         return etree.tostring(self.element, encoding='UTF-8', pretty_print=True)                    
148
149     def __str__(self):
150         return self.toxml()
151
152     ### other method calls or attribute access like .text or .tag or .get 
153     # are redirected on self.element
154     def __getattr__ (self, name):
155         if not hasattr(self.element, name):
156             raise AttributeError(name)
157         return getattr(self.element, name)
158
159 class XML:
160  
161     def __init__(self, xml=None, namespaces=None):
162         self.root = None
163         self.namespaces = namespaces
164         self.default_namespace = None
165         self.schema = None
166         if isinstance(xml, basestring):
167             self.parse_xml(xml)
168         if isinstance(xml, XmlElement):
169             self.root = xml
170             self.namespaces = xml.namespaces
171         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
172             self.parse_xml(etree.tostring(xml))
173
174     def parse_xml(self, xml):
175         """
176         parse rspec into etree
177         """
178         parser = etree.XMLParser(remove_blank_text=True)
179         try:
180             tree = etree.parse(xml, parser)
181         except IOError:
182             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
183             try:
184                 tree = etree.parse(StringIO(xml), parser)
185             except Exception as e:
186                 raise InvalidXML(str(e))
187         root = tree.getroot()
188         self.namespaces = dict(root.nsmap)
189         # set namespaces map
190         if 'default' not in self.namespaces and None in self.namespaces: 
191             # If the 'None' exist, then it's pointing to the default namespace. This makes 
192             # it hard for us to write xpath queries for the default naemspace because lxml 
193             # wont understand a None prefix. We will just associate the default namespeace 
194             # with a key named 'default'.     
195             self.namespaces['default'] = self.namespaces.pop(None)
196             
197         else:
198             self.namespaces['default'] = 'default' 
199
200         self.root = XmlElement(root, self.namespaces)
201         # set schema
202         for key in self.root.attrib.keys():
203             if key.endswith('schemaLocation'):
204                 # schemaLocation should be at the end of the list.
205                 # Use list comprehension to filter out empty strings 
206                 schema_parts  = [x for x in self.root.attrib[key].split(' ') if x]
207                 self.schema = schema_parts[1]    
208                 namespace, schema  = schema_parts[0], schema_parts[1]
209                 break
210
211     def parse_dict(self, d, root_tag_name='xml', element = None):
212         if element is None: 
213             if self.root is None:
214                 self.parse_xml('<%s/>' % root_tag_name)
215             element = self.root.element
216
217         if 'text' in d:
218             text = d.pop('text')
219             element.text = text
220
221         # handle repeating fields
222         for (key, value) in d.items():
223             if isinstance(value, list):
224                 value = d.pop(key)
225                 for val in value:
226                     if isinstance(val, dict):
227                         child_element = etree.SubElement(element, key)
228                         self.parse_dict(val, key, child_element)
229                     elif isinstance(val, basestring):
230                         child_element = etree.SubElement(element, key).text = val
231
232             elif isinstance(value, int):
233                 d[key] = unicode(d[key])
234             elif value is None:
235                 d.pop(key)
236
237         # element.attrib.update will explode if DateTimes are in the
238         # dcitionary.
239         d=d.copy()
240         # looks like iteritems won't stand side-effects
241         for k in d.keys():
242             if not isinstance(d[k],StringTypes):
243                 del d[k]
244
245         element.attrib.update(d)
246
247     def validate(self, schema):
248         """
249         Validate against rng schema
250         """
251         relaxng_doc = etree.parse(schema)
252         relaxng = etree.RelaxNG(relaxng_doc)
253         if not relaxng(self.root):
254             error = relaxng.error_log.last_error
255             message = "%s (line %s)" % (error.message, error.line)
256             raise InvalidXML(message)
257         return True
258
259     def xpath(self, xpath, namespaces=None):
260         if not namespaces:
261             namespaces = self.namespaces
262         return self.root.xpath(xpath, namespaces=namespaces)
263
264     def set(self, key, value):
265         return self.root.set(key, value)
266
267     def remove_attribute(self, name, element=None):
268         if not element:
269             element = self.root
270         element.remove_attribute(name) 
271
272     def add_element(self, *args, **kwds):
273         """
274         Wrapper around etree.SubElement(). Adds an element to 
275         specified parent node. Adds element to root node is parent is 
276         not specified. 
277         """
278         return self.root.add_element(*args, **kwds)
279
280     def remove_elements(self, name, element = None):
281         """
282         Removes all occurences of an element from the tree. Start at 
283         specified root_node if specified, otherwise start at tree's root.   
284         """
285         if not element:
286             element = self.root
287
288         element.remove_elements(name)
289
290     def add_instance(self, *args, **kwds):
291         return self.root.add_instance(*args, **kwds)
292
293     def get_instance(self, *args, **kwds):
294         return self.root.get_instnace(*args, **kwds)
295
296     def get_element_attributes(self, elem=None, depth=0):
297         if elem == None:
298             elem = self.root
299         if not hasattr(elem, 'attrib'):
300             # this is probably not an element node with attribute. could be just and an
301             # attribute, return it
302             return elem
303         attrs = dict(elem.attrib)
304         attrs['text'] = str(elem.text).strip()
305         attrs['parent'] = elem.getparent()
306         if isinstance(depth, int) and depth > 0:
307             for child_elem in list(elem):
308                 key = str(child_elem.tag)
309                 if key not in attrs:
310                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
311                 else:
312                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
313         else:
314             attrs['child_nodes'] = list(elem)
315         return attrs
316
317     def append(self, elem):
318         return self.root.append(elem)
319
320     def iterchildren(self):
321         return self.root.iterchildren()    
322
323     def merge(self, in_xml):
324         pass
325
326     def __str__(self):
327         return self.toxml()
328
329     def toxml(self):
330         return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True)  
331     
332     # XXX smbaker, for record.load_from_string
333     def todict(self, elem=None):
334         if elem is None:
335             elem = self.root
336         d = {}
337         d.update(elem.attrib)
338         d['text'] = elem.text
339         for child in elem.iterchildren():
340             if child.tag not in d:
341                 d[child.tag] = []
342             d[child.tag].append(self.todict(child))
343
344         if len(d)==1 and ("text" in d):
345             d = d["text"]
346
347         return d
348         
349     def save(self, filename):
350         f = open(filename, 'w')
351         f.write(self.toxml())
352         f.close()
353
354 # no RSpec in scope 
355 #if __name__ == '__main__':
356 #    rspec = RSpec('/tmp/resources.rspec')
357 #    print rspec
358