do not depend on types.StringTypes anymore
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4 from sfa.util.faults import InvalidXML
5 from sfa.rspecs.elements.element import Element
6
7 from sfa.util.py23 import StringType
8
9 # helper functions to help build xpaths
10 class XpathFilter:
11     @staticmethod
12
13     def filter_value(key, value):
14         xpath = ""    
15         if isinstance(value, str):
16             if '*' in value:
17                 value = value.replace('*', '')
18                 xpath = 'contains(%s, "%s")' % (key, value)
19             else:
20                 xpath = '%s="%s"' % (key, value)                
21         return xpath
22
23     @staticmethod
24     def xpath(filter=None):
25         if filter is None: filter={}
26         xpath = ""
27         if filter:
28             filter_list = []
29             for (key, value) in filter.items():
30                 if key == 'text':
31                     key = 'text()'
32                 else:
33                     key = '@'+key
34                 if isinstance(value, str):
35                     filter_list.append(XpathFilter.filter_value(key, value))
36                 elif isinstance(value, list):
37                     stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value])
38                     filter_list.append(stmt)   
39             if filter_list:
40                 xpath = ' and '.join(filter_list)
41                 xpath = '[' + xpath + ']'
42         return xpath
43
44 # a wrapper class around lxml.etree._Element
45 # the reason why we need this one is because of the limitations
46 # we've found in xpath to address documents with multiple namespaces defined
47 # in a nutshell, we deal with xml documents that have
48 # a default namespace defined (xmlns="http://default.com/") and specific prefixes defined
49 # (xmlns:foo="http://foo.com")
50 # according to the documentation instead of writing
51 # element.xpath ( "//node/foo:subnode" ) 
52 # we'd then need to write xpaths like
53 # element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" ) 
54 # which is a real pain..
55 # So just so we can keep some reasonable programming style we need to manage the
56 # namespace map that goes with the _Element (its internal .nsmap being unmutable)
57
58 class XmlElement:
59     def __init__(self, element, namespaces):
60         self.element = element
61         self.namespaces = namespaces
62         
63     # redefine as few methods as possible
64     def xpath(self, xpath, namespaces=None):
65         if not namespaces:
66             namespaces = self.namespaces 
67         elems = self.element.xpath(xpath, namespaces=namespaces)
68         return [XmlElement(elem, namespaces) for elem in elems]
69
70     def add_element(self, tagname, **kwds):
71         element = etree.SubElement(self.element, tagname, **kwds)
72         return XmlElement(element, self.namespaces)
73
74     def append(self, elem):
75         if isinstance(elem, XmlElement):
76             self.element.append(elem.element)
77         else:
78             self.element.append(elem)
79
80     def getparent(self):
81         return XmlElement(self.element.getparent(), self.namespaces)
82
83     def get_instance(self, instance_class=None, fields=None):
84         """
85         Returns an instance (dict) of this xml element. The instance
86         holds a reference to this xml element.   
87         """
88         if fields is None: fields=[]
89         if not instance_class:
90             instance_class = Element
91         if not fields and hasattr(instance_class, 'fields'):
92             fields = instance_class.fields
93
94         if not fields:
95             instance = instance_class(self.attrib, self)
96         else:
97             instance = instance_class({}, self)
98             for field in fields:
99                 if field in self.attrib:
100                    instance[field] = self.attrib[field]  
101         return instance             
102
103     def add_instance(self, name, instance, fields=None):
104         """
105         Adds the specifed instance(s) as a child element of this xml 
106         element. 
107         """
108         if fields is None: fields=[]
109         if not fields and hasattr(instance, 'keys'):
110             fields = instance.keys()
111         elem = self.add_element(name)
112         for field in fields:
113             if field in instance and instance[field]:
114                 elem.set(field, unicode(instance[field]))
115         return elem                  
116
117     def remove_elements(self, name):
118         """
119         Removes all occurences of an element from the tree. Start at
120         specified root_node if specified, otherwise start at tree's root.
121         """
122         
123         if not element_name.startswith('//'):
124             element_name = '//' + element_name
125         elements = self.element.xpath('%s ' % name, namespaces=self.namespaces) 
126         for element in elements:
127             parent = element.getparent()
128             parent.remove(element)
129
130     def delete(self):
131         parent = self.getparent()
132         parent.remove(self)
133
134     def remove(self, element):
135         if isinstance(element, XmlElement):
136             self.element.remove(element.element)
137         else:
138             self.element.remove(element)
139
140     def set_text(self, text):
141         self.element.text = text
142     
143     # Element does not have unset ?!?
144     def unset(self, key):
145         del self.element.attrib[key]
146   
147     def toxml(self):
148         return etree.tostring(self.element, encoding='UTF-8', pretty_print=True)                    
149
150     def __str__(self):
151         return self.toxml()
152
153     ### other method calls or attribute access like .text or .tag or .get 
154     # are redirected on self.element
155     def __getattr__ (self, name):
156         if not hasattr(self.element, name):
157             raise AttributeError(name)
158         return getattr(self.element, name)
159
160 class XML:
161  
162     def __init__(self, xml=None, namespaces=None):
163         self.root = None
164         self.namespaces = namespaces
165         self.default_namespace = None
166         self.schema = None
167         if isinstance(xml, basestring):
168             self.parse_xml(xml)
169         if isinstance(xml, XmlElement):
170             self.root = xml
171             self.namespaces = xml.namespaces
172         elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element):
173             self.parse_xml(etree.tostring(xml))
174
175     def parse_xml(self, xml):
176         """
177         parse rspec into etree
178         """
179         parser = etree.XMLParser(remove_blank_text=True)
180         try:
181             tree = etree.parse(xml, parser)
182         except IOError:
183             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
184             try:
185                 tree = etree.parse(StringIO(xml), parser)
186             except Exception as e:
187                 raise InvalidXML(str(e))
188         root = tree.getroot()
189         self.namespaces = dict(root.nsmap)
190         # set namespaces map
191         if 'default' not in self.namespaces and None in self.namespaces: 
192             # If the 'None' exist, then it's pointing to the default namespace. This makes 
193             # it hard for us to write xpath queries for the default naemspace because lxml 
194             # wont understand a None prefix. We will just associate the default namespeace 
195             # with a key named 'default'.     
196             self.namespaces['default'] = self.namespaces.pop(None)
197             
198         else:
199             self.namespaces['default'] = 'default' 
200
201         self.root = XmlElement(root, self.namespaces)
202         # set schema
203         for key in self.root.attrib.keys():
204             if key.endswith('schemaLocation'):
205                 # schemaLocation should be at the end of the list.
206                 # Use list comprehension to filter out empty strings 
207                 schema_parts  = [x for x in self.root.attrib[key].split(' ') if x]
208                 self.schema = schema_parts[1]    
209                 namespace, schema  = schema_parts[0], schema_parts[1]
210                 break
211
212     def parse_dict(self, d, root_tag_name='xml', element = None):
213         if element is None: 
214             if self.root is None:
215                 self.parse_xml('<%s/>' % root_tag_name)
216             element = self.root.element
217
218         if 'text' in d:
219             text = d.pop('text')
220             element.text = text
221
222         # handle repeating fields
223         for (key, value) in d.items():
224             if isinstance(value, list):
225                 value = d.pop(key)
226                 for val in value:
227                     if isinstance(val, dict):
228                         child_element = etree.SubElement(element, key)
229                         self.parse_dict(val, key, child_element)
230                     elif isinstance(val, basestring):
231                         child_element = etree.SubElement(element, key).text = val
232
233             elif isinstance(value, int):
234                 d[key] = unicode(d[key])
235             elif value is None:
236                 d.pop(key)
237
238         # element.attrib.update will explode if DateTimes are in the
239         # dcitionary.
240         d=d.copy()
241         # looks like iteritems won't stand side-effects
242         for k in d.keys():
243             if not isinstance(d[k], StringType):
244                 del d[k]
245
246         element.attrib.update(d)
247
248     def validate(self, schema):
249         """
250         Validate against rng schema
251         """
252         relaxng_doc = etree.parse(schema)
253         relaxng = etree.RelaxNG(relaxng_doc)
254         if not relaxng(self.root):
255             error = relaxng.error_log.last_error
256             message = "%s (line %s)" % (error.message, error.line)
257             raise InvalidXML(message)
258         return True
259
260     def xpath(self, xpath, namespaces=None):
261         if not namespaces:
262             namespaces = self.namespaces
263         return self.root.xpath(xpath, namespaces=namespaces)
264
265     def set(self, key, value):
266         return self.root.set(key, value)
267
268     def remove_attribute(self, name, element=None):
269         if not element:
270             element = self.root
271         element.remove_attribute(name) 
272
273     def add_element(self, *args, **kwds):
274         """
275         Wrapper around etree.SubElement(). Adds an element to 
276         specified parent node. Adds element to root node is parent is 
277         not specified. 
278         """
279         return self.root.add_element(*args, **kwds)
280
281     def remove_elements(self, name, element = None):
282         """
283         Removes all occurences of an element from the tree. Start at 
284         specified root_node if specified, otherwise start at tree's root.   
285         """
286         if not element:
287             element = self.root
288
289         element.remove_elements(name)
290
291     def add_instance(self, *args, **kwds):
292         return self.root.add_instance(*args, **kwds)
293
294     def get_instance(self, *args, **kwds):
295         return self.root.get_instnace(*args, **kwds)
296
297     def get_element_attributes(self, elem=None, depth=0):
298         if elem == None:
299             elem = self.root
300         if not hasattr(elem, 'attrib'):
301             # this is probably not an element node with attribute. could be just and an
302             # attribute, return it
303             return elem
304         attrs = dict(elem.attrib)
305         attrs['text'] = str(elem.text).strip()
306         attrs['parent'] = elem.getparent()
307         if isinstance(depth, int) and depth > 0:
308             for child_elem in list(elem):
309                 key = str(child_elem.tag)
310                 if key not in attrs:
311                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
312                 else:
313                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
314         else:
315             attrs['child_nodes'] = list(elem)
316         return attrs
317
318     def append(self, elem):
319         return self.root.append(elem)
320
321     def iterchildren(self):
322         return self.root.iterchildren()    
323
324     def merge(self, in_xml):
325         pass
326
327     def __str__(self):
328         return self.toxml()
329
330     def toxml(self):
331         return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True)  
332     
333     # XXX smbaker, for record.load_from_string
334     def todict(self, elem=None):
335         if elem is None:
336             elem = self.root
337         d = {}
338         d.update(elem.attrib)
339         d['text'] = elem.text
340         for child in elem.iterchildren():
341             if child.tag not in d:
342                 d[child.tag] = []
343             d[child.tag].append(self.todict(child))
344
345         if len(d)==1 and ("text" in d):
346             d = d["text"]
347
348         return d
349         
350     def save(self, filename):
351         f = open(filename, 'w')
352         f.write(self.toxml())
353         f.close()
354
355 # no RSpec in scope 
356 #if __name__ == '__main__':
357 #    rspec = RSpec('/tmp/resources.rspec')
358 #    print rspec
359