Merge branch 'upstreammaster'
[sfa.git] / sfa / util / xml.py
1 #!/usr/bin/python 
2 from types import StringTypes
3 from lxml import etree
4 from StringIO import StringIO
5
6 from sfa.util.faults import InvalidXML
7
8 class XpathFilter:
9     @staticmethod
10     def xpath(filter={}):
11         xpath = ""
12         if filter:
13             filter_list = []
14             for (key, value) in filter.items():
15                 if key == 'text':
16                     key = 'text()'
17                 else:
18                     key = '@'+key
19                 if isinstance(value, str):
20                     filter_list.append('%s="%s"' % (key, value))
21                 elif isinstance(value, list):
22                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
23             if filter_list:
24                 xpath = ' and '.join(filter_list)
25                 xpath = '[' + xpath + ']'
26         return xpath
27
28 class XML:
29  
30     def __init__(self, xml=None):
31         self.root = None
32         self.namespaces = None
33         self.default_namespace = None
34         self.schema = None
35         if isinstance(xml, basestring):
36             self.parse_xml(xml)
37         elif isinstance(xml, etree._ElementTree):
38             self.root = xml.getroot()
39         elif isinstance(xml, etree._Element):
40             self.root = xml 
41
42     def parse_xml(self, xml):
43         """
44         parse rspec into etree
45         """
46         parser = etree.XMLParser(remove_blank_text=True)
47         try:
48             tree = etree.parse(xml, parser)
49         except IOError:
50             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
51             try:
52                 tree = etree.parse(StringIO(xml), parser)
53             except Exception, e:
54                 raise InvalidXML(str(e))
55         self.root = tree.getroot()
56         # set namespaces map
57         self.namespaces = dict(self.root.nsmap)
58         if 'default' not in self.namespaces and None in self.namespaces: 
59             # If the 'None' exist, then it's pointing to the default namespace. This makes 
60             # it hard for us to write xpath queries for the default naemspace because lxml 
61             # wont understand a None prefix. We will just associate the default namespeace 
62             # with a key named 'default'.     
63             self.namespaces['default'] = self.namespaces[None]
64         else:
65             self.namespaces['default'] = 'default' 
66
67         # set schema 
68         for key in self.root.attrib.keys():
69             if key.endswith('schemaLocation'):
70                 # schema location should be at the end of the list
71                 schema_parts  = self.root.attrib[key].split(' ')
72                 self.schema = schema_parts[1]    
73                 namespace, schema  = schema_parts[0], schema_parts[1]
74                 break
75
76     def parse_dict(self, d, root_tag_name='xml', element = None):
77         if element is None: 
78             if self.root is None:
79                 self.parse_xml('<%s/>' % root_tag_name)
80             element = self.root
81
82         if 'text' in d:
83             text = d.pop('text')
84             element.text = text
85
86         # handle repeating fields
87         for (key, value) in d.items():
88             if isinstance(value, list):
89                 value = d.pop(key)
90                 for val in value:
91                     if isinstance(val, dict):
92                         child_element = etree.SubElement(element, key)
93                         self.parse_dict(val, key, child_element)
94                     elif isinstance(val, basestring):
95                         child_element = etree.SubElement(element, key).text = val
96                         
97             elif isinstance(value, int):
98                 d[key] = unicode(d[key])  
99             elif value is None:
100                 d.pop(key)
101
102         # element.attrib.update will explode if DateTimes are in the
103         # dcitionary.
104         d=d.copy()
105         for (k,v) in d.iteritems():
106             if not isinstance(v,StringTypes): del d[k]
107         for k in d.keys():
108             if (type(d[k]) != str) and (type(d[k]) != unicode):
109                 del d[k]
110
111         element.attrib.update(d)
112
113     def validate(self, schema):
114         """
115         Validate against rng schema
116         """
117         relaxng_doc = etree.parse(schema)
118         relaxng = etree.RelaxNG(relaxng_doc)
119         if not relaxng(self.root):
120             error = relaxng.error_log.last_error
121             message = "%s (line %s)" % (error.message, error.line)
122             raise InvalidXML(message)
123         return True
124
125     def xpath(self, xpath, namespaces=None):
126         if not namespaces:
127             namespaces = self.namespaces
128         return self.root.xpath(xpath, namespaces=namespaces)
129
130     def set(self, key, value):
131         return self.root.set(key, value)
132
133     def add_attribute(self, elem, name, value):
134         """
135         Add attribute to specified etree element    
136         """
137         opt = etree.SubElement(elem, name)
138         opt.text = value
139
140     def add_element(self, name, attrs={}, parent=None, text=""):
141         """
142         Generic wrapper around etree.SubElement(). Adds an element to 
143         specified parent node. Adds element to root node is parent is 
144         not specified. 
145         """
146         if parent == None:
147             parent = self.root
148         element = etree.SubElement(parent, name)
149         if text:
150             element.text = text
151         if isinstance(attrs, dict):
152             for attr in attrs:
153                 element.set(attr, attrs[attr])  
154         return element
155
156     def remove_attribute(self, elem, name, value):
157         """
158         Removes an attribute from an element
159         """
160         if elem is not None:
161             opts = elem.iterfind(name)
162             if opts is not None:
163                 for opt in opts:
164                     if opt.text == value:
165                         elem.remove(opt)
166
167     def remove_element(self, element_name, root_node = None):
168         """
169         Removes all occurences of an element from the tree. Start at 
170         specified root_node if specified, otherwise start at tree's root.   
171         """
172         if not root_node:
173             root_node = self.root
174
175         if not element_name.startswith('//'):
176             element_name = '//' + element_name
177
178         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
179         for element in elements:
180             parent = element.getparent()
181             parent.remove(element)
182
183     def attributes_list(self, elem):
184         # convert a list of attribute tags into list of tuples
185         # (tagnme, text_value)
186         opts = []
187         if elem is not None:
188             for e in elem:
189                 opts.append((e.tag, str(e.text).strip()))
190         return opts
191
192     def get_element_attributes(self, elem=None, depth=0):
193         if elem == None:
194             elem = self.root_node
195         if not hasattr(elem, 'attrib'):
196             # this is probably not an element node with attribute. could be just and an
197             # attribute, return it
198             return elem
199         attrs = dict(elem.attrib)
200         attrs['text'] = str(elem.text).strip()
201         attrs['parent'] = elem.getparent()
202         if isinstance(depth, int) and depth > 0:
203             for child_elem in list(elem):
204                 key = str(child_elem.tag)
205                 if key not in attrs:
206                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
207                 else:
208                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
209         else:
210             attrs['child_nodes'] = list(elem)
211         return attrs
212
213     def merge(self, in_xml):
214         pass
215
216     def __str__(self):
217         return self.toxml()
218
219     def toxml(self):
220         return etree.tostring(self.root, encoding='UTF-8', pretty_print=True)  
221     
222     # XXX smbaker, for record.load_from_string
223     def todict(self, elem=None):
224         if elem is None:
225             elem = self.root
226         d = {}
227         d.update(elem.attrib)
228         d['text'] = elem.text
229         for child in elem.iterchildren():
230             if child.tag not in d:
231                 d[child.tag] = []
232             d[child.tag].append(self.todict(child))
233
234         if len(d)==1 and ("text" in d):
235             d = d["text"]
236
237         return d
238         
239     def save(self, filename):
240         f = open(filename, 'w')
241         f.write(self.toxml())
242         f.close()
243
244 # no RSpec in scope 
245 #if __name__ == '__main__':
246 #    rspec = RSpec('/tmp/resources.rspec')
247 #    print rspec
248