added merge_rspecs() method
[sfa.git] / sfa / util / rspec.py
1 ### $Id$
2 ### $URL$
3
4 import sys
5 import pprint
6 import os
7 import httplib
8 from xml.dom import minidom
9 from types import StringTypes, ListType
10 from lxml import etree
11 from StringIO import StringIO
12
13 def merge_rspecs(rspecs):
14     """
15     Merge merge a set of RSpecs into 1 RSpec, and return the result.
16     rspecs must be a valid RSpec string or list of rspec strings. 
17     """
18     if not rspecs or not isinstance(rspecs, list):
19         return rspecs
20     
21     rspec = None
22     for tmp_rspec in rspecs:
23         try:
24             tree = etree.parse(StringIO(tmp_rspec))
25         except etree.XMLSyntaxError:
26             # consider failing silently here
27             message = str(agg_rspec) + ": " + str(sys.exc_info()[1])
28             raise InvalidRSpec(message)
29
30         root = tree.getroot()
31         if root.get("type") in ["SFA"]:
32             if rspec == None:
33                 rspec = root
34             else:
35                 for network in root.iterfind("./network"):
36                     rspec.append(deepcopy(network))
37                 for request in root.iterfind("./request"):
38                     rspec.append(deepcopy(request))    
39     return etree.tostring(rspec, xml_declaration=True, pretty_print=True)
40         
41
42
43 class RSpec:
44
45     def __init__(self, xml = None, xsd = None, NSURL = None):
46         '''
47         Class to manipulate RSpecs.  Reads and parses rspec xml into python dicts
48         and reads python dicts and writes rspec xml
49
50         self.xsd = # Schema.  Can be local or remote file.
51         self.NSURL = # If schema is remote, Name Space URL to query (full path minus filename)
52         self.rootNode = # root of the DOM
53         self.dict = # dict of the RSpec.
54         self.schemaDict = {} # dict of the Schema
55         '''
56  
57         self.xsd = xsd
58         self.rootNode = None
59         self.dict = {}
60         self.schemaDict = {}
61         self.NSURL = NSURL 
62         if xml:
63             if type(xml) == file:
64                 self.parseFile(xml)
65             if type(xml) in StringTypes:
66                 self.parseString(xml)
67             self.dict = self.toDict() 
68         if xsd:
69             self._parseXSD(self.NSURL + self.xsd)
70
71
72     def _getText(self, nodelist):
73         rc = ""
74         for node in nodelist:
75             if node.nodeType == node.TEXT_NODE:
76                 rc = rc + node.data
77         return rc
78   
79     # The rspec is comprised of 2 parts, and 1 reference:
80     # attributes/elements describe individual resources
81     # complexTypes are used to describe a set of attributes/elements
82     # complexTypes can include a reference to other complexTypes.
83   
84   
85     def _getName(self, node):
86         '''Gets name of node. If tag has no name, then return tag's localName'''
87         name = None
88         if not node.nodeName.startswith("#"):
89             if node.localName:
90                 name = node.localName
91             elif node.attributes.has_key("name"):
92                 name = node.attributes.get("name").value
93         return name     
94  
95  
96     # Attribute.  {name : nameofattribute, {items: values})
97     def _attributeDict(self, attributeDom):
98         '''Traverse single attribute node.  Create a dict {attributename : {name: value,}]}'''
99         node = {} # parsed dict
100         for attr in attributeDom.attributes.keys():
101             node[attr] = attributeDom.attributes.get(attr).value
102         return node
103   
104  
105     def appendToDictOrCreate(self, dict, key, value):
106         if (dict.has_key(key)):
107             dict[key].append(value)
108         else:
109             dict[key]=[value]
110         return dict
111
112     def toGenDict(self, nodeDom=None, parentdict=None, siblingdict={}, parent=None):
113         """
114         convert an XML to a nested dict:
115           * Non-terminal nodes (elements with string children and attributes) are simple dictionaries
116           * Terminal nodes (the rest) are nested dictionaries
117         """
118
119         if (not nodeDom):
120             nodeDom=self.rootNode
121
122         curNodeName = nodeDom.localName
123
124         if (nodeDom.hasChildNodes()):
125             childdict={}
126             for attribute in nodeDom.attributes.keys():
127                 childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
128             for child in nodeDom.childNodes[:-1]:
129                 if (child.nodeValue):
130                     siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
131                 else:
132                     childdict = self.toGenDict(child, None, childdict, curNodeName)
133
134             child = nodeDom.childNodes[-1]
135             if (child.nodeValue):
136                 siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
137                 if (childdict):
138                     siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
139             else:
140                 siblingdict = self.toGenDict(child, siblingdict, childdict, curNodeName)
141         else:
142             childdict={}
143             for attribute in nodeDom.attributes.keys():
144                 childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
145
146             self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
147             
148         if (parentdict is not None):
149             parentdict = self.appendToDictOrCreate(parentdict, parent, siblingdict)
150             return parentdict
151         else:
152             return siblingdict
153
154
155
156     def toDict(self, nodeDom = None):
157         """
158         convert this rspec to a dict and return it.
159         """
160         node = {}
161         if not nodeDom:
162              nodeDom = self.rootNode
163   
164         elementName = nodeDom.nodeName
165         if elementName and not elementName.startswith("#"):
166             # attributes have tags and values.  get {tag: value}, else {type: value}
167             node[elementName] = self._attributeDict(nodeDom)
168             # resolve the child nodes.
169             if nodeDom.hasChildNodes():
170                 for child in nodeDom.childNodes:
171                     childName = self._getName(child)
172                     
173                     # skip null children
174                     if not childName: continue
175
176                     # initialize the possible array of children
177                     if not node[elementName].has_key(childName): node[elementName][childName] = []
178
179                     if isinstance(child, minidom.Text):
180                         # add if data is not empty
181                         if child.data.strip():
182                             node[elementName][childName].append(nextchild.data)
183                     elif child.hasChildNodes() and isinstance(child.childNodes[0], minidom.Text):
184                         for nextchild in child.childNodes:  
185                             node[elementName][childName].append(nextchild.data)
186                     else:
187                         childdict = self.toDict(child)
188                         for value in childdict.values():
189                             node[elementName][childName].append(value)
190
191         return node
192
193   
194     def toxml(self):
195         """
196         convert this rspec to an xml string and return it.
197         """
198         return self.rootNode.toxml()
199
200   
201     def toprettyxml(self):
202         """
203         print this rspec in xml in a pretty format.
204         """
205         return self.rootNode.toprettyxml()
206
207   
208     def __removeWhitespaceNodes(self, parent):
209         for child in list(parent.childNodes):
210             if child.nodeType == minidom.Node.TEXT_NODE and child.data.strip() == '':
211                 parent.removeChild(child)
212             else:
213                 self.__removeWhitespaceNodes(child)
214
215     def parseFile(self, filename):
216         """
217         read a local xml file and store it as a dom object.
218         """
219         dom = minidom.parse(filename)
220         self.__removeWhitespaceNodes(dom)
221         self.rootNode = dom.childNodes[0]
222
223
224     def parseString(self, xml):
225         """
226         read an xml string and store it as a dom object.
227         """
228         dom = minidom.parseString(xml)
229         self.__removeWhitespaceNodes(dom)
230         self.rootNode = dom.childNodes[0]
231
232  
233     def _httpGetXSD(self, xsdURI):
234         # split the URI into relevant parts
235         host = xsdURI.split("/")[2]
236         if xsdURI.startswith("https"):
237             conn = httplib.HTTPSConnection(host,
238                 httplib.HTTPSConnection.default_port)
239         elif xsdURI.startswith("http"):
240             conn = httplib.HTTPConnection(host,
241                 httplib.HTTPConnection.default_port)
242         conn.request("GET", xsdURI)
243         # If we can't download the schema, raise an exception
244         r1 = conn.getresponse()
245         if r1.status != 200: 
246             raise Exception
247         return r1.read().replace('\n', '').replace('\t', '').strip() 
248
249
250     def _parseXSD(self, xsdURI):
251         """
252         Download XSD from URL, or if file, read local xsd file and set
253         schemaDict.
254         
255         Since the schema definiton is a global namespace shared by and
256         agreed upon by others, this should probably be a URL.  Check
257         for URL, download xsd, parse, or if local file, use that.
258         """
259         schemaDom = None
260         if xsdURI.startswith("http"):
261             try: 
262                 schemaDom = minidom.parseString(self._httpGetXSD(xsdURI))
263             except Exception, e:
264                 # logging.debug("%s: web file not found" % xsdURI)
265                 # logging.debug("Using local file %s" % self.xsd")
266                 print e
267                 print "Can't find %s on the web. Continuing." % xsdURI
268         if not schemaDom:
269             if os.path.exists(xsdURI):
270                 # logging.debug("using local copy.")
271                 print "Using local %s" % xsdURI
272                 schemaDom = minidom.parse(xsdURI)
273             else:
274                 raise Exception("Can't find xsd locally")
275         self.schemaDict = self.toDict(schemaDom.childNodes[0])
276
277
278     def dict2dom(self, rdict, include_doc = False):
279         """
280         convert a dict object into a dom object.
281         """
282      
283         def elementNode(tagname, rd):
284             element = minidom.Element(tagname)
285             for key in rd.keys():
286                 if isinstance(rd[key], StringTypes) or isinstance(rd[key], int):
287                     element.setAttribute(key, str(rd[key]))
288                 elif isinstance(rd[key], dict):
289                     child = elementNode(key, rd[key])
290                     element.appendChild(child)
291                 elif isinstance(rd[key], list):
292                     for item in rd[key]:
293                         if isinstance(item, dict):
294                             child = elementNode(key, item)
295                             element.appendChild(child)
296                         elif isinstance(item, StringTypes) or isinstance(item, int):
297                             child = minidom.Element(key)
298                             text = minidom.Text()
299                             text.data = item
300                             child.appendChild(text)
301                             element.appendChild(child) 
302             return element
303         
304         # Minidom does not allow documents to have more then one
305         # child, but elements may have many children. Because of
306         # this, the document's root node will be the first key/value
307         # pair in the dictionary.  
308         node = elementNode(rdict.keys()[0], rdict.values()[0])
309         if include_doc:
310             rootNode = minidom.Document()
311             rootNode.appendChild(node)
312         else:
313             rootNode = node
314         return rootNode
315
316  
317     def parseDict(self, rdict, include_doc = True):
318         """
319         Convert a dictionary into a dom object and store it.
320         """
321         self.rootNode = self.dict2dom(rdict, include_doc).childNodes[0]
322  
323  
324     def getDictsByTagName(self, tagname, dom = None):
325         """
326         Search the dom for all elements with the specified tagname
327         and return them as a list of dicts
328         """
329         if not dom:
330             dom = self.rootNode
331         dicts = []
332         doms = dom.getElementsByTagName(tagname)
333         dictlist = [self.toDict(d) for d in doms]
334         for item in dictlist:
335             for value in item.values():
336                 dicts.append(value)
337         return dicts
338
339     def getDictByTagNameValue(self, tagname, value, dom = None):
340         """
341         Search the dom for the first element with the specified tagname
342         and value and return it as a dict.
343         """
344         tempdict = {}
345         if not dom:
346             dom = self.rootNode
347         dicts = self.getDictsByTagName(tagname, dom)
348         
349         for rdict in dicts:
350             if rdict.has_key('name') and rdict['name'] in [value]:
351                 return rdict
352               
353         return tempdict
354
355
356     def filter(self, tagname, attribute, blacklist = [], whitelist = [], dom = None):
357         """
358         Removes all elements where:
359         1. tagname matches the element tag
360         2. attribute matches the element attribte
361         3. attribute value is in valuelist  
362         """
363
364         tempdict = {}
365         if not dom:
366             dom = self.rootNode
367        
368         if dom.localName in [tagname] and dom.attributes.has_key(attribute):
369             if whitelist and dom.attributes.get(attribute).value not in whitelist:
370                 dom.parentNode.removeChild(dom)
371             if blacklist and dom.attributes.get(attribute).value in blacklist:
372                 dom.parentNode.removeChild(dom)
373            
374         if dom.hasChildNodes():
375             for child in dom.childNodes:
376                 self.filter(tagname, attribute, blacklist, whitelist, child) 
377
378
379     def merge(self, rspecs, tagname, dom=None):
380         """
381         Merge this rspec with the requested rspec based on the specified 
382         starting tag name. The start tag (and all of its children) will be merged  
383         """
384         tempdict = {}
385         if not dom:
386             dom = self.rootNode
387
388         whitelist = []
389         blacklist = []
390             
391         if dom.localName in [tagname] and dom.attributes.has_key(attribute):
392             if whitelist and dom.attributes.get(attribute).value not in whitelist:
393                 dom.parentNode.removeChild(dom)
394             if blacklist and dom.attributes.get(attribute).value in blacklist:
395                 dom.parentNode.removeChild(dom)
396
397         if dom.hasChildNodes():
398             for child in dom.childNodes:
399                 self.filter(tagname, attribute, blacklist, whitelist, child) 
400
401     def validateDicts(self):
402         types = {
403             'EInt' : int,
404             'EString' : str,
405             'EByteArray' : list,
406             'EBoolean' : bool,
407             'EFloat' : float,
408             'EDate' : date}
409
410
411     def pprint(self, r = None, depth = 0):
412         """
413         Pretty print the dict
414         """
415         line = ""
416         if r == None: r = self.dict
417         # Set the dept
418         for tab in range(0,depth): line += "    "
419         # check if it's nested
420         if type(r) == dict:
421             for i in r.keys():
422                 print line + "%s:" % i
423                 self.pprint(r[i], depth + 1)
424         elif type(r) in (tuple, list):
425             for j in r: self.pprint(j, depth + 1)
426         # not nested so just print.
427         else:
428             print line + "%s" %  r
429     
430
431
432 class RecordSpec(RSpec):
433
434     root_tag = 'record'
435     def parseDict(self, rdict, include_doc = False):
436         """
437         Convert a dictionary into a dom object and store it.
438         """
439         self.rootNode = self.dict2dom(rdict, include_doc)
440
441     def dict2dom(self, rdict, include_doc = False):
442         record_dict = rdict
443         if not len(rdict.keys()) == 1:
444             record_dict = {self.root_tag : rdict}
445         return RSpec.dict2dom(self, record_dict, include_doc)
446
447         
448 # vim:ts=4:expandtab
449