trimmed useless imports, unstarred all imports
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4
5 from sfa.util.faults import InvalidXML
6
7 class XpathFilter:
8     @staticmethod
9     def xpath(filter={}):
10         xpath = ""
11         if filter:
12             filter_list = []
13             for (key, value) in filter.items():
14                 if key == 'text':
15                     key = 'text()'
16                 else:
17                     key = '@'+key
18                 if isinstance(value, str):
19                     filter_list.append('%s="%s"' % (key, value))
20                 elif isinstance(value, list):
21                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
22             if filter_list:
23                 xpath = ' and '.join(filter_list)
24                 xpath = '[' + xpath + ']'
25         return xpath
26
27 class XML:
28  
29     def __init__(self, xml=None):
30         self.root = None
31         self.namespaces = None
32         self.default_namespace = None
33         self.schema = None
34         if isinstance(xml, basestring):
35             self.parse_xml(xml)
36         elif isinstance(xml, etree._ElementTree):
37             self.root = xml.getroot()
38         elif isinstance(xml, etree._Element):
39             self.root = xml 
40
41     def parse_xml(self, xml):
42         """
43         parse rspec into etree
44         """
45         parser = etree.XMLParser(remove_blank_text=True)
46         try:
47             tree = etree.parse(xml, parser)
48         except IOError:
49             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
50             try:
51                 tree = etree.parse(StringIO(xml), parser)
52             except Exception, e:
53                 raise InvalidXML(str(e))
54         self.root = tree.getroot()
55         # set namespaces map
56         self.namespaces = dict(self.root.nsmap)
57         if 'default' not in self.namespaces and None in self.namespaces: 
58             self.namespaces['default'] = self.namespaces[None]
59         # If the 'None' exist, then it's pointing to the default namespace. This makes 
60         # it hard for us to write xpath queries for the default naemspace because lxml 
61         # wont understand a None prefix. We will just associate the default namespeace 
62         # with a key named 'default'.     
63         if None in self.namespaces:
64             default_namespace = self.namespaces.pop(None)
65             self.namespaces['default'] = default_namespace
66
67         # set schema 
68         for key in self.root.attrib.keys():
69             if key.endswith('schemaLocation'):
70                 # schema location should be at the end of the list
71                 schema_parts  = self.root.attrib[key].split(' ')
72                 self.schema = schema_parts[1]    
73                 namespace, schema  = schema_parts[0], schema_parts[1]
74                 break
75
76     def parse_dict(self, d, root_tag_name='xml', element = None):
77         if element is None: 
78             self.parse_xml('<%s/>' % root_tag_name)
79             element = self.root
80
81         if 'text' in d:
82             text = d.pop('text')
83             element.text = text
84
85         # handle repeating fields
86         for (key, value) in d.items():
87             if isinstance(value, list):
88                 value = d.pop(key)
89                 for val in value:
90                     if isinstance(val, dict):
91                         child_element = etree.SubElement(element, key)
92                         self.parse_dict(val, key, child_element) 
93         
94         element.attrib.update(d)
95
96     def validate(self, schema):
97         """
98         Validate against rng schema
99         """
100         relaxng_doc = etree.parse(schema)
101         relaxng = etree.RelaxNG(relaxng_doc)
102         if not relaxng(self.root):
103             error = relaxng.error_log.last_error
104             message = "%s (line %s)" % (error.message, error.line)
105             raise InvalidXML(message)
106         return True
107
108     def xpath(self, xpath, namespaces=None):
109         if not namespaces:
110             namespaces = self.namespaces
111         return self.root.xpath(xpath, namespaces=namespaces)
112
113     def set(self, key, value):
114         return self.root.set(key, value)
115
116     def add_attribute(self, elem, name, value):
117         """
118         Add attribute to specified etree element    
119         """
120         opt = etree.SubElement(elem, name)
121         opt.text = value
122
123     def add_element(self, name, attrs={}, parent=None, text=""):
124         """
125         Generic wrapper around etree.SubElement(). Adds an element to 
126         specified parent node. Adds element to root node is parent is 
127         not specified. 
128         """
129         if parent == None:
130             parent = self.root
131         element = etree.SubElement(parent, name)
132         if text:
133             element.text = text
134         if isinstance(attrs, dict):
135             for attr in attrs:
136                 element.set(attr, attrs[attr])  
137         return element
138
139     def remove_attribute(self, elem, name, value):
140         """
141         Removes an attribute from an element
142         """
143         if elem is not None:
144             opts = elem.iterfind(name)
145             if opts is not None:
146                 for opt in opts:
147                     if opt.text == value:
148                         elem.remove(opt)
149
150     def remove_element(self, element_name, root_node = None):
151         """
152         Removes all occurences of an element from the tree. Start at 
153         specified root_node if specified, otherwise start at tree's root.   
154         """
155         if not root_node:
156             root_node = self.root
157
158         if not element_name.startswith('//'):
159             element_name = '//' + element_name
160
161         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
162         for element in elements:
163             parent = element.getparent()
164             parent.remove(element)
165
166     def attributes_list(self, elem):
167         # convert a list of attribute tags into list of tuples
168         # (tagnme, text_value)
169         opts = []
170         if elem is not None:
171             for e in elem:
172                 opts.append((e.tag, str(e.text).strip()))
173         return opts
174
175     def get_element_attributes(self, elem=None, depth=0):
176         if elem == None:
177             elem = self.root_node
178         if not hasattr(elem, 'attrib'):
179             # this is probably not an element node with attribute. could be just and an
180             # attribute, return it
181             return elem
182         attrs = dict(elem.attrib)
183         attrs['text'] = str(elem.text).strip()
184         attrs['parent'] = elem.getparent()
185         if isinstance(depth, int) and depth > 0:
186             for child_elem in list(elem):
187                 key = str(child_elem.tag)
188                 if key not in attrs:
189                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
190                 else:
191                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
192         else:
193             attrs['child_nodes'] = list(elem)
194         return attrs
195
196     def merge(self, in_xml):
197         pass
198
199     def __str__(self):
200         return self.toxml()
201
202     def toxml(self):
203         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
204     
205     def todict(self, elem=None):
206         if elem is None:
207             elem = self.root
208         d = {}
209         d.update(elem.attrib)
210         d['text'] = elem.text
211         for child in elem.iterchildren():
212             if child.tag not in d:
213                 d[child.tag] = []
214             d[child.tag].append(self.todict(child))
215         return d            
216         
217     def save(self, filename):
218         f = open(filename, 'w')
219         f.write(self.toxml())
220         f.close()
221
222 # no RSpec in scope 
223 #if __name__ == '__main__':
224 #    rspec = RSpec('/tmp/resources.rspec')
225 #    print rspec
226