parse_dict now supports ints and Non values
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9     def xpath(filter={}):
10         xpath = ""
11         if filter:
12             filter_list = []
13             for (key, value) in filter.items():
14                 if key == 'text':
15                     key = 'text()'
16                 else:
17                     key = '@'+key
18                 if isinstance(value, str):
19                     filter_list.append('%s="%s"' % (key, value))
20                 elif isinstance(value, list):
21                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
22             if filter_list:
23                 xpath = ' and '.join(filter_list)
24                 xpath = '[' + xpath + ']'
25         return xpath
26
27 class XML:
28  
29     def __init__(self, xml=None):
30         self.root = None
31         self.namespaces = None
32         self.default_namespace = None
33         self.schema = None
34         if isinstance(xml, basestring):
35             self.parse_xml(xml)
36         elif isinstance(xml, etree._ElementTree):
37             self.root = xml.getroot()
38         elif isinstance(xml, etree._Element):
39             self.root = xml 
40
41     def parse_xml(self, xml):
42         """
43         parse rspec into etree
44         """
45         parser = etree.XMLParser(remove_blank_text=True)
46         try:
47             tree = etree.parse(xml, parser)
48         except IOError:
49             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
50             try:
51                 tree = etree.parse(StringIO(xml), parser)
52             except Exception, e:
53                 raise InvalidXML(str(e))
54         self.root = tree.getroot()
55         # set namespaces map
56         self.namespaces = dict(self.root.nsmap)
57         if 'default' not in self.namespaces and None in self.namespaces: 
58             self.namespaces['default'] = self.namespaces[None]
59         # If the 'None' exist, then it's pointing to the default namespace. This makes 
60         # it hard for us to write xpath queries for the default naemspace because lxml 
61         # wont understand a None prefix. We will just associate the default namespeace 
62         # with a key named 'default'.     
63         if None in self.namespaces:
64             default_namespace = self.namespaces.pop(None)
65             self.namespaces['default'] = default_namespace
66
67         # set schema 
68         for key in self.root.attrib.keys():
69             if key.endswith('schemaLocation'):
70                 # schema location should be at the end of the list
71                 schema_parts  = self.root.attrib[key].split(' ')
72                 self.schema = schema_parts[1]    
73                 namespace, schema  = schema_parts[0], schema_parts[1]
74                 break
75
76     def parse_dict(self, d, root_tag_name='xml', element = None):
77         if element is None: 
78             if self.root is None:
79                 self.parse_xml('<%s/>' % root_tag_name)
80             element = self.root
81
82         if 'text' in d:
83             text = d.pop('text')
84             element.text = text
85
86         # handle repeating fields
87         for (key, value) in d.items():
88             if isinstance(value, list):
89                 value = d.pop(key)
90                 for val in value:
91                     if isinstance(val, dict):
92                         child_element = etree.SubElement(element, key)
93                         self.parse_dict(val, key, child_element)
94             elif isinstance(value, int):
95                 d[key] = unicode(d[key])  
96             elif value is None:
97                 d.pop(key)          
98              
99         element.attrib.update(d)
100
101     def validate(self, schema):
102         """
103         Validate against rng schema
104         """
105         relaxng_doc = etree.parse(schema)
106         relaxng = etree.RelaxNG(relaxng_doc)
107         if not relaxng(self.root):
108             error = relaxng.error_log.last_error
109             message = "%s (line %s)" % (error.message, error.line)
110             raise InvalidXML(message)
111         return True
112
113     def xpath(self, xpath, namespaces=None):
114         if not namespaces:
115             namespaces = self.namespaces
116         return self.root.xpath(xpath, namespaces=namespaces)
117
118     def set(self, key, value):
119         return self.root.set(key, value)
120
121     def add_attribute(self, elem, name, value):
122         """
123         Add attribute to specified etree element    
124         """
125         opt = etree.SubElement(elem, name)
126         opt.text = value
127
128     def add_element(self, name, attrs={}, parent=None, text=""):
129         """
130         Generic wrapper around etree.SubElement(). Adds an element to 
131         specified parent node. Adds element to root node is parent is 
132         not specified. 
133         """
134         if parent == None:
135             parent = self.root
136         element = etree.SubElement(parent, name)
137         if text:
138             element.text = text
139         if isinstance(attrs, dict):
140             for attr in attrs:
141                 element.set(attr, attrs[attr])  
142         return element
143
144     def remove_attribute(self, elem, name, value):
145         """
146         Removes an attribute from an element
147         """
148         if elem is not None:
149             opts = elem.iterfind(name)
150             if opts is not None:
151                 for opt in opts:
152                     if opt.text == value:
153                         elem.remove(opt)
154
155     def remove_element(self, element_name, root_node = None):
156         """
157         Removes all occurences of an element from the tree. Start at 
158         specified root_node if specified, otherwise start at tree's root.   
159         """
160         if not root_node:
161             root_node = self.root
162
163         if not element_name.startswith('//'):
164             element_name = '//' + element_name
165
166         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
167         for element in elements:
168             parent = element.getparent()
169             parent.remove(element)
170
171     def attributes_list(self, elem):
172         # convert a list of attribute tags into list of tuples
173         # (tagnme, text_value)
174         opts = []
175         if elem is not None:
176             for e in elem:
177                 opts.append((e.tag, str(e.text).strip()))
178         return opts
179
180     def get_element_attributes(self, elem=None, depth=0):
181         if elem == None:
182             elem = self.root_node
183         if not hasattr(elem, 'attrib'):
184             # this is probably not an element node with attribute. could be just and an
185             # attribute, return it
186             return elem
187         attrs = dict(elem.attrib)
188         attrs['text'] = str(elem.text).strip()
189         attrs['parent'] = elem.getparent()
190         if isinstance(depth, int) and depth > 0:
191             for child_elem in list(elem):
192                 key = str(child_elem.tag)
193                 if key not in attrs:
194                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
195                 else:
196                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
197         else:
198             attrs['child_nodes'] = list(elem)
199         return attrs
200
201     def merge(self, in_xml):
202         pass
203
204     def __str__(self):
205         return self.toxml()
206
207     def toxml(self):
208         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
209     
210     def todict(self, elem=None):
211         if elem is None:
212             elem = self.root
213         d = {}
214         d.update(elem.attrib)
215         d['text'] = elem.text
216         for child in elem.iterchildren():
217             if child.tag not in d:
218                 d[child.tag] = []
219             d[child.tag].append(self.todict(child))
220         return d            
221         
222     def save(self, filename):
223         f = open(filename, 'w')
224         f.write(self.toxml())
225         f.close()
226
227 # no RSpec in scope 
228 #if __name__ == '__main__':
229 #    rspec = RSpec('/tmp/resources.rspec')
230 #    print rspec
231