0ef71d96b06ab740185639d98193f34cfc26cb4b
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9     def xpath(filter={}):
10         xpath = ""
11         if filter:
12             filter_list = []
13             for (key, value) in filter.items():
14                 if key == 'text':
15                     key = 'text()'
16                 else:
17                     key = '@'+key
18                 if isinstance(value, str):
19                     filter_list.append('%s="%s"' % (key, value))
20                 elif isinstance(value, list):
21                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
22             if filter_list:
23                 xpath = ' and '.join(filter_list)
24                 xpath = '[' + xpath + ']'
25         return xpath
26
27 class XML:
28  
29     def __init__(self, xml=None):
30         self.root = None
31         self.namespaces = None
32         self.default_namespace = None
33         self.schema = None
34         if isinstance(xml, basestring):
35             self.parse_xml(xml)
36         elif isinstance(xml, etree._ElementTree):
37             self.root = xml.getroot()
38         elif isinstance(xml, etree._Element):
39             self.root = xml 
40
41     def parse_xml(self, xml):
42         """
43         parse rspec into etree
44         """
45         parser = etree.XMLParser(remove_blank_text=True)
46         try:
47             tree = etree.parse(xml, parser)
48         except IOError:
49             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
50             try:
51                 tree = etree.parse(StringIO(xml), parser)
52             except Exception, e:
53                 raise InvalidXML(str(e))
54         self.root = tree.getroot()
55         # set namespaces map
56         self.namespaces = dict(self.root.nsmap)
57         if 'default' not in self.namespaces and None in self.namespaces: 
58             # If the 'None' exist, then it's pointing to the default namespace. This makes 
59             # it hard for us to write xpath queries for the default naemspace because lxml 
60             # wont understand a None prefix. We will just associate the default namespeace 
61             # with a key named 'default'.     
62             self.namespaces['default'] = self.namespaces[None]
63         else:
64             self.namespaces['default'] = 'default' 
65
66         # set schema 
67         for key in self.root.attrib.keys():
68             if key.endswith('schemaLocation'):
69                 # schema location should be at the end of the list
70                 schema_parts  = self.root.attrib[key].split(' ')
71                 self.schema = schema_parts[1]    
72                 namespace, schema  = schema_parts[0], schema_parts[1]
73                 break
74
75     def parse_dict(self, d, root_tag_name='xml', element = None):
76         if element is None: 
77             if self.root is None:
78                 self.parse_xml('<%s/>' % root_tag_name)
79             element = self.root
80
81         if 'text' in d:
82             text = d.pop('text')
83             element.text = text
84
85         # handle repeating fields
86         for (key, value) in d.items():
87             if isinstance(value, list):
88                 value = d.pop(key)
89                 for val in value:
90                     if isinstance(val, dict):
91                         child_element = etree.SubElement(element, key)
92                         self.parse_dict(val, key, child_element)
93                     elif isinstance(val, basestring):
94                         child_element = etree.SubElement(element, key).text = val
95                         
96             elif isinstance(value, int):
97                 d[key] = unicode(d[key])  
98             elif value is None:
99                 d.pop(key)
100
101         # element.attrib.update will explode if DateTimes are in the
102         # dcitionary.
103         d=d.copy()
104         for k in d.keys():
105             if (type(d[k]) != str) and (type(d[k]) != unicode):
106                 del d[k]
107
108         element.attrib.update(d)
109
110     def validate(self, schema):
111         """
112         Validate against rng schema
113         """
114         relaxng_doc = etree.parse(schema)
115         relaxng = etree.RelaxNG(relaxng_doc)
116         if not relaxng(self.root):
117             error = relaxng.error_log.last_error
118             message = "%s (line %s)" % (error.message, error.line)
119             raise InvalidXML(message)
120         return True
121
122     def xpath(self, xpath, namespaces=None):
123         if not namespaces:
124             namespaces = self.namespaces
125         return self.root.xpath(xpath, namespaces=namespaces)
126
127     def set(self, key, value):
128         return self.root.set(key, value)
129
130     def add_attribute(self, elem, name, value):
131         """
132         Add attribute to specified etree element    
133         """
134         opt = etree.SubElement(elem, name)
135         opt.text = value
136
137     def add_element(self, name, attrs={}, parent=None, text=""):
138         """
139         Generic wrapper around etree.SubElement(). Adds an element to 
140         specified parent node. Adds element to root node is parent is 
141         not specified. 
142         """
143         if parent == None:
144             parent = self.root
145         element = etree.SubElement(parent, name)
146         if text:
147             element.text = text
148         if isinstance(attrs, dict):
149             for attr in attrs:
150                 element.set(attr, attrs[attr])  
151         return element
152
153     def remove_attribute(self, elem, name, value):
154         """
155         Removes an attribute from an element
156         """
157         if elem is not None:
158             opts = elem.iterfind(name)
159             if opts is not None:
160                 for opt in opts:
161                     if opt.text == value:
162                         elem.remove(opt)
163
164     def remove_element(self, element_name, root_node = None):
165         """
166         Removes all occurences of an element from the tree. Start at 
167         specified root_node if specified, otherwise start at tree's root.   
168         """
169         if not root_node:
170             root_node = self.root
171
172         if not element_name.startswith('//'):
173             element_name = '//' + element_name
174
175         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
176         for element in elements:
177             parent = element.getparent()
178             parent.remove(element)
179
180     def attributes_list(self, elem):
181         # convert a list of attribute tags into list of tuples
182         # (tagnme, text_value)
183         opts = []
184         if elem is not None:
185             for e in elem:
186                 opts.append((e.tag, str(e.text).strip()))
187         return opts
188
189     def get_element_attributes(self, elem=None, depth=0):
190         if elem == None:
191             elem = self.root_node
192         if not hasattr(elem, 'attrib'):
193             # this is probably not an element node with attribute. could be just and an
194             # attribute, return it
195             return elem
196         attrs = dict(elem.attrib)
197         attrs['text'] = str(elem.text).strip()
198         attrs['parent'] = elem.getparent()
199         if isinstance(depth, int) and depth > 0:
200             for child_elem in list(elem):
201                 key = str(child_elem.tag)
202                 if key not in attrs:
203                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
204                 else:
205                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
206         else:
207             attrs['child_nodes'] = list(elem)
208         return attrs
209
210     def merge(self, in_xml):
211         pass
212
213     def __str__(self):
214         return self.toxml()
215
216     def toxml(self):
217         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
218     
219     # XXX smbaker, for record.load_from_string
220     def todict(self, elem=None):
221         if elem is None:
222             elem = self.root
223         d = {}
224         d.update(elem.attrib)
225         d['text'] = elem.text
226         for child in elem.iterchildren():
227             if child.tag not in d:
228                 d[child.tag] = []
229             d[child.tag].append(self.todict2(child))
230
231         if len(d)==1 and ("text" in d):
232             d = d["text"]
233
234         return d
235         
236     def save(self, filename):
237         f = open(filename, 'w')
238         f.write(self.toxml())
239         f.close()
240
241 # no RSpec in scope 
242 #if __name__ == '__main__':
243 #    rspec = RSpec('/tmp/resources.rspec')
244 #    print rspec
245