parse_dict() improvements
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9     def xpath(filter={}):
10         xpath = ""
11         if filter:
12             filter_list = []
13             for (key, value) in filter.items():
14                 if key == 'text':
15                     key = 'text()'
16                 else:
17                     key = '@'+key
18                 if isinstance(value, str):
19                     filter_list.append('%s="%s"' % (key, value))
20                 elif isinstance(value, list):
21                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
22             if filter_list:
23                 xpath = ' and '.join(filter_list)
24                 xpath = '[' + xpath + ']'
25         return xpath
26
27 class XML:
28  
29     def __init__(self, xml=None):
30         self.root = None
31         self.namespaces = None
32         self.default_namespace = None
33         self.schema = None
34         if isinstance(xml, basestring):
35             self.parse_xml(xml)
36         elif isinstance(xml, etree._ElementTree):
37             self.root = xml.getroot()
38         elif isinstance(xml, etree._Element):
39             self.root = xml 
40
41     def parse_xml(self, xml):
42         """
43         parse rspec into etree
44         """
45         parser = etree.XMLParser(remove_blank_text=True)
46         try:
47             tree = etree.parse(xml, parser)
48         except IOError:
49             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
50             try:
51                 tree = etree.parse(StringIO(xml), parser)
52             except Exception, e:
53                 raise InvalidXML(str(e))
54         self.root = tree.getroot()
55         # set namespaces map
56         self.namespaces = dict(self.root.nsmap)
57         if 'default' not in self.namespaces and None in self.namespaces: 
58             self.namespaces['default'] = self.namespaces[None]
59         # If the 'None' exist, then it's pointing to the default namespace. This makes 
60         # it hard for us to write xpath queries for the default naemspace because lxml 
61         # wont understand a None prefix. We will just associate the default namespeace 
62         # with a key named 'default'.     
63         if None in self.namespaces:
64             default_namespace = self.namespaces.pop(None)
65             self.namespaces['default'] = default_namespace
66
67         # set schema 
68         for key in self.root.attrib.keys():
69             if key.endswith('schemaLocation'):
70                 # schema location should be at the end of the list
71                 schema_parts  = self.root.attrib[key].split(' ')
72                 self.schema = schema_parts[1]    
73                 namespace, schema  = schema_parts[0], schema_parts[1]
74                 break
75
76     def parse_dict(self, d, root_tag_name='xml', element = None):
77         if element is None: 
78             if self.root is None:
79                 self.parse_xml('<%s/>' % root_tag_name)
80             element = self.root
81
82         if 'text' in d:
83             text = d.pop('text')
84             element.text = text
85
86         # handle repeating fields
87         for (key, value) in d.items():
88             if isinstance(value, list):
89                 value = d.pop(key)
90                 for val in value:
91                     if isinstance(val, dict):
92                         child_element = etree.SubElement(element, key)
93                         self.parse_dict(val, key, child_element)
94                     elif isinstance(val, basestring):
95                         child_element = etree.SubElement(element, key).text = val
96                         
97             elif isinstance(value, int):
98                 d[key] = unicode(d[key])  
99             elif value is None:
100                 d.pop(key)          
101              
102         element.attrib.update(d)
103
104     def validate(self, schema):
105         """
106         Validate against rng schema
107         """
108         relaxng_doc = etree.parse(schema)
109         relaxng = etree.RelaxNG(relaxng_doc)
110         if not relaxng(self.root):
111             error = relaxng.error_log.last_error
112             message = "%s (line %s)" % (error.message, error.line)
113             raise InvalidXML(message)
114         return True
115
116     def xpath(self, xpath, namespaces=None):
117         if not namespaces:
118             namespaces = self.namespaces
119         return self.root.xpath(xpath, namespaces=namespaces)
120
121     def set(self, key, value):
122         return self.root.set(key, value)
123
124     def add_attribute(self, elem, name, value):
125         """
126         Add attribute to specified etree element    
127         """
128         opt = etree.SubElement(elem, name)
129         opt.text = value
130
131     def add_element(self, name, attrs={}, parent=None, text=""):
132         """
133         Generic wrapper around etree.SubElement(). Adds an element to 
134         specified parent node. Adds element to root node is parent is 
135         not specified. 
136         """
137         if parent == None:
138             parent = self.root
139         element = etree.SubElement(parent, name)
140         if text:
141             element.text = text
142         if isinstance(attrs, dict):
143             for attr in attrs:
144                 element.set(attr, attrs[attr])  
145         return element
146
147     def remove_attribute(self, elem, name, value):
148         """
149         Removes an attribute from an element
150         """
151         if elem is not None:
152             opts = elem.iterfind(name)
153             if opts is not None:
154                 for opt in opts:
155                     if opt.text == value:
156                         elem.remove(opt)
157
158     def remove_element(self, element_name, root_node = None):
159         """
160         Removes all occurences of an element from the tree. Start at 
161         specified root_node if specified, otherwise start at tree's root.   
162         """
163         if not root_node:
164             root_node = self.root
165
166         if not element_name.startswith('//'):
167             element_name = '//' + element_name
168
169         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
170         for element in elements:
171             parent = element.getparent()
172             parent.remove(element)
173
174     def attributes_list(self, elem):
175         # convert a list of attribute tags into list of tuples
176         # (tagnme, text_value)
177         opts = []
178         if elem is not None:
179             for e in elem:
180                 opts.append((e.tag, str(e.text).strip()))
181         return opts
182
183     def get_element_attributes(self, elem=None, depth=0):
184         if elem == None:
185             elem = self.root_node
186         if not hasattr(elem, 'attrib'):
187             # this is probably not an element node with attribute. could be just and an
188             # attribute, return it
189             return elem
190         attrs = dict(elem.attrib)
191         attrs['text'] = str(elem.text).strip()
192         attrs['parent'] = elem.getparent()
193         if isinstance(depth, int) and depth > 0:
194             for child_elem in list(elem):
195                 key = str(child_elem.tag)
196                 if key not in attrs:
197                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
198                 else:
199                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
200         else:
201             attrs['child_nodes'] = list(elem)
202         return attrs
203
204     def merge(self, in_xml):
205         pass
206
207     def __str__(self):
208         return self.toxml()
209
210     def toxml(self):
211         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
212     
213     def todict(self, elem=None):
214         if elem is None:
215             elem = self.root
216         d = {}
217         d.update(elem.attrib)
218         d['text'] = elem.text
219         for child in elem.iterchildren():
220             if child.tag not in d:
221                 d[child.tag] = []
222             d[child.tag].append(self.todict(child))
223         return d            
224         
225     def save(self, filename):
226         f = open(filename, 'w')
227         f.write(self.toxml())
228         f.close()
229
230 # no RSpec in scope 
231 #if __name__ == '__main__':
232 #    rspec = RSpec('/tmp/resources.rspec')
233 #    print rspec
234