solve conflicts
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9     def xpath(filter={}):
10         xpath = ""
11         if filter:
12             filter_list = []
13             for (key, value) in filter.items():
14                 if key == 'text':
15                     key = 'text()'
16                 else:
17                     key = '@'+key
18                 if isinstance(value, str):
19                     filter_list.append('%s="%s"' % (key, value))
20                 elif isinstance(value, list):
21                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
22             if filter_list:
23                 xpath = ' and '.join(filter_list)
24                 xpath = '[' + xpath + ']'
25         return xpath
26
27 class XML:
28  
29     def __init__(self, xml=None):
30         self.root = None
31         self.namespaces = None
32         self.default_namespace = None
33         self.schema = None
34         if isinstance(xml, basestring):
35             self.parse_xml(xml)
36         elif isinstance(xml, etree._ElementTree):
37             self.root = xml.getroot()
38         elif isinstance(xml, etree._Element):
39             self.root = xml 
40
41     def parse_xml(self, xml):
42         """
43         parse rspec into etree
44         """
45         parser = etree.XMLParser(remove_blank_text=True)
46         try:
47             tree = etree.parse(xml, parser)
48         except IOError:
49             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
50             try:
51                 tree = etree.parse(StringIO(xml), parser)
52             except Exception, e:
53                 raise InvalidXML(str(e))
54         self.root = tree.getroot()
55         # set namespaces map
56         self.namespaces = dict(self.root.nsmap)
57         if 'default' not in self.namespaces and None in self.namespaces: 
58             # If the 'None' exist, then it's pointing to the default namespace. This makes 
59             # it hard for us to write xpath queries for the default naemspace because lxml 
60             # wont understand a None prefix. We will just associate the default namespeace 
61             # with a key named 'default'.     
62             self.namespaces['default'] = self.namespaces[None]
63         else:
64             self.namespaces['default'] = 'default' 
65
66         # set schema 
67         for key in self.root.attrib.keys():
68             if key.endswith('schemaLocation'):
69                 # schema location should be at the end of the list
70                 schema_parts  = self.root.attrib[key].split(' ')
71                 self.schema = schema_parts[1]    
72                 namespace, schema  = schema_parts[0], schema_parts[1]
73                 break
74
75     def parse_dict(self, d, root_tag_name='xml', element = None):
76         if element is None: 
77             if self.root is None:
78                 self.parse_xml('<%s/>' % root_tag_name)
79             element = self.root
80
81         if 'text' in d:
82             text = d.pop('text')
83             element.text = text
84
85         # handle repeating fields
86         for (key, value) in d.items():
87             if isinstance(value, list):
88                 value = d.pop(key)
89                 for val in value:
90                     if isinstance(val, dict):
91                         child_element = etree.SubElement(element, key)
92                         self.parse_dict(val, key, child_element)
93                     elif isinstance(val, basestring):
94                         child_element = etree.SubElement(element, key).text = val
95                         
96             elif isinstance(value, int):
97                 d[key] = unicode(d[key])  
98             elif value is None:
99                 d.pop(key)          
100              
101         element.attrib.update(d)
102
103     def validate(self, schema):
104         """
105         Validate against rng schema
106         """
107         relaxng_doc = etree.parse(schema)
108         relaxng = etree.RelaxNG(relaxng_doc)
109         if not relaxng(self.root):
110             error = relaxng.error_log.last_error
111             message = "%s (line %s)" % (error.message, error.line)
112             raise InvalidXML(message)
113         return True
114
115     def xpath(self, xpath, namespaces=None):
116         if not namespaces:
117             namespaces = self.namespaces
118         return self.root.xpath(xpath, namespaces=namespaces)
119
120     def set(self, key, value):
121         return self.root.set(key, value)
122
123     def add_attribute(self, elem, name, value):
124         """
125         Add attribute to specified etree element    
126         """
127         opt = etree.SubElement(elem, name)
128         opt.text = value
129
130     def add_element(self, name, attrs={}, parent=None, text=""):
131         """
132         Generic wrapper around etree.SubElement(). Adds an element to 
133         specified parent node. Adds element to root node is parent is 
134         not specified. 
135         """
136         if parent == None:
137             parent = self.root
138         element = etree.SubElement(parent, name)
139         if text:
140             element.text = text
141         if isinstance(attrs, dict):
142             for attr in attrs:
143                 element.set(attr, attrs[attr])  
144         return element
145
146     def remove_attribute(self, elem, name, value):
147         """
148         Removes an attribute from an element
149         """
150         if elem is not None:
151             opts = elem.iterfind(name)
152             if opts is not None:
153                 for opt in opts:
154                     if opt.text == value:
155                         elem.remove(opt)
156
157     def remove_element(self, element_name, root_node = None):
158         """
159         Removes all occurences of an element from the tree. Start at 
160         specified root_node if specified, otherwise start at tree's root.   
161         """
162         if not root_node:
163             root_node = self.root
164
165         if not element_name.startswith('//'):
166             element_name = '//' + element_name
167
168         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
169         for element in elements:
170             parent = element.getparent()
171             parent.remove(element)
172
173     def attributes_list(self, elem):
174         # convert a list of attribute tags into list of tuples
175         # (tagnme, text_value)
176         opts = []
177         if elem is not None:
178             for e in elem:
179                 opts.append((e.tag, str(e.text).strip()))
180         return opts
181
182     def get_element_attributes(self, elem=None, depth=0):
183         if elem == None:
184             elem = self.root_node
185         if not hasattr(elem, 'attrib'):
186             # this is probably not an element node with attribute. could be just and an
187             # attribute, return it
188             return elem
189         attrs = dict(elem.attrib)
190         attrs['text'] = str(elem.text).strip()
191         attrs['parent'] = elem.getparent()
192         if isinstance(depth, int) and depth > 0:
193             for child_elem in list(elem):
194                 key = str(child_elem.tag)
195                 if key not in attrs:
196                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
197                 else:
198                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
199         else:
200             attrs['child_nodes'] = list(elem)
201         return attrs
202
203     def merge(self, in_xml):
204         pass
205
206     def __str__(self):
207         return self.toxml()
208
209     def toxml(self):
210         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
211     
212     def todict(self, elem=None):
213         if elem is None:
214             elem = self.root
215         d = {}
216         d.update(elem.attrib)
217         d['text'] = elem.text
218         for child in elem.iterchildren():
219             if child.tag not in d:
220                 d[child.tag] = []
221             d[child.tag].append(self.todict(child))
222         return d            
223         
224     def save(self, filename):
225         f = open(filename, 'w')
226         f.write(self.toxml())
227         f.close()
228
229 # no RSpec in scope 
230 #if __name__ == '__main__':
231 #    rspec = RSpec('/tmp/resources.rspec')
232 #    print rspec
233