added XPathFilter, RSpecElement classes. added find(), find_elements(), get_element_a...
[sfa.git] / sfa / rspecs / rspec.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4 from datetime import datetime, timedelta
5 from sfa.util.xrn import *
6 from sfa.util.plxrn import hostname_to_urn
7 from sfa.util.enumeration import Enum
8 from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement
9
10
11 class XpathFilter:
12     @staticmethod
13     def xpath(filter={}):
14         xpath = ""
15         if filter:
16             filter_list = []
17             for (key, value) in filter.items():
18                 if key == 'text':
19                     key = 'text()'
20                 else:
21                     key = '@'+key
22                 if isinstance(value, str):
23                     filter_list.append('%s="%s"' % (key, value))
24                 elif isinstance(value, list):
25                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
26             if filter_list:
27                 xpath = ' and '.join(filter_list)
28                 xpath = '[' + xpath + ']'
29         return xpath
30
31 # recognized top level rspec elements 
32 RSpecElements = Enum('NETWORK', 'NODE', 'SLIVER', 'INTERFACE', 'LINK', 'VLINK')
33
34 class RSpecElement:
35     def __init__(self, element_type, name, path):
36         if not element_type in RSpecElements:
37             raise InvalidRSpecElement(element_type)
38         self.type = element_type
39         self.name = name
40         self.path = path     
41
42 class RSpec:
43     header = '<?xml version="1.0"?>\n'
44     template = """<RSpec></RSpec>"""
45     xml = None
46     type = None
47     version = None
48     namespaces = None
49     user_options = {}
50  
51     def __init__(self, rspec="", namespaces={}, type=None, user_options={}):
52         self.type = type
53         self.user_options = user_options
54         self.elements = {}
55         if rspec:
56             self.parse_rspec(rspec, namespaces)
57         else:
58             self.create()
59
60     def create(self):
61         """
62         Create root element
63         """
64         # eg. 2011-03-23T19:53:28Z 
65         date_format = '%Y-%m-%dT%H:%M:%SZ'
66         now = datetime.utcnow()
67         generated_ts = now.strftime(date_format)
68         expires_ts = (now + timedelta(hours=1)).strftime(date_format) 
69         self.parse_rspec(self.template, self.namespaces)
70         self.xml.set('expires', expires_ts)
71         self.xml.set('generated', generated_ts)
72     
73     def parse_rspec(self, rspec, namespaces={}):
74         """
75         parse rspec into etree
76         """
77         parser = etree.XMLParser(remove_blank_text=True)
78         try:
79             tree = etree.parse(rspec, parser)
80         except IOError:
81             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
82             try:
83                 tree = etree.parse(StringIO(rspec), parser)
84             except Exception, e:
85                 raise InvalidRSpec(str(e))
86         self.xml = tree.getroot()  
87         if namespaces:
88            self.namespaces = namespaces
89
90     def validate(self, schema):
91         """
92         Validate against rng schema
93         """
94
95         relaxng_doc = etree.parse(schema)
96         relaxng = etree.RelaxNG(relaxng_doc)
97         if not relaxng(self.xml):
98             error = relaxng.error_log.last_error
99             message = "%s (line %s)" % (error.message, error.line)
100             raise InvalidRSpec(message)
101         return True
102
103     def xpath(self, xpath):
104         return self.xml.xpath(xpath, namespaces=self.namespaces)
105
106     def register_element_type(self, element_type, element_name, element_path):
107         if element_type not in RSpecElements:
108             raise InvalidRSpecElement(element_type, extra="no such element type: %s. Must specify a valid RSpecElement" % element_type)
109         self.elements[element_type] = RSpecElement(element_type, element_name, element_path)
110
111     def get_element_type(self, element_type):
112         if element_type not in self.elements:
113             msg = "ElementType %s not registerd for this rspec" % element_type
114             raise InvalidRSpecElement(element_type, extra=msg)
115         return self.elements[element_type]
116
117     def add_attribute(self, elem, name, value):
118         """
119         Add attribute to specified etree element    
120         """
121         opt = etree.SubElement(elem, name)
122         opt.text = value
123
124     def add_element(self, name, attrs={}, parent=None, text=""):
125         """
126         Generic wrapper around etree.SubElement(). Adds an element to 
127         specified parent node. Adds element to root node is parent is 
128         not specified. 
129         """
130         if parent == None:
131             parent = self.xml
132         element = etree.SubElement(parent, name)
133         if text:
134             element.text = text
135         if isinstance(attrs, dict):
136             for attr in attrs:
137                 element.set(attr, attrs[attr])  
138         return element
139
140     def remove_attribute(self, elem, name, value):
141         """
142         Removes an attribute from an element
143         """
144         if elem is not None:
145             opts = elem.iterfind(name)
146             if opts is not None:
147                 for opt in opts:
148                     if opt.text == value:
149                         elem.remove(opt)
150
151     def remove_element(self, element_name, root_node = None):
152         """
153         Removes all occurences of an element from the tree. Start at 
154         specified root_node if specified, otherwise start at tree's root.   
155         """
156         if not root_node:
157             root_node = self.xml
158
159         if not element_name.startswith('//'):
160             element_name = '//' + element_name
161
162         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
163         for element in elements:
164             parent = element.getparent()
165             parent.remove(element)
166
167     def get_element_attributes(self, elem=None, depth=0):
168         if elem == None:
169             elem = self.root_node
170         attrs = dict(elem.attrib)
171         attrs['text'] = str(elem.text).strip()
172         if isinstance(depth, int) and depth > 0:
173             for child_elem in list(elem):
174                 key = str(child_elem.tag)
175                 if key not in attrs:
176                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
177                 else:
178                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
179         else:
180             attrs['child_nodes'] = list(elem)
181         return attrs
182
183     def find(self, element_type, filter={}, depth=0):
184         elements = [self.get_element_attributes(element, depth=depth) for element in \
185                     self.find_elements(element_type, filter)]
186         return elements
187
188     def find_elements(self, element_type, filter={}):
189         """
190         search for a registered element
191         """
192         if element_type not in self.elements:
193             msg = "Unable to search for element %s in rspec, expath expression not found." % \
194                    element_type
195             raise InvalidRSpecElement(element_type, extra=msg)
196         rspec_element = self.get_element_type(element_type)
197         xpath = rspec_element.path + XpathFilter.xpath(filter)
198         return self.xpath(xpath)
199
200     def merge(self, in_rspec):
201         pass
202
203     def cleanup(self):
204         """
205         Optional method which inheriting classes can choose to implent. 
206         """
207         pass 
208
209     def _process_slivers(self, slivers):
210         """
211         Creates a dict of sliver details for each sliver host
212         
213         @param slivers a single hostname, list of hostanmes or list of dicts keys on hostname,
214         Returns a list of dicts 
215         """
216         if not isinstance(slivers, list):
217             slivers = [slivers]
218         dicts = []
219         for sliver in slivers:
220             if isinstance(sliver, dict):
221                 dicts.append(sliver)
222             elif isinstance(sliver, basestring):
223                 dicts.append({'hostname': sliver}) 
224         return dicts
225
226     def __str__(self):
227         return self.toxml()
228
229     def toxml(self, cleanup=False):
230         if cleanup:
231             self.cleanup()
232         return self.header + etree.tostring(self.xml, pretty_print=True)  
233         
234     def save(self, filename):
235         f = open(filename, 'w')
236         f.write(self.toxml())
237         f.close()
238  
239 if __name__ == '__main__':
240     rspec = RSpec('/tmp/resources.rspec')
241     print rspec
242     #rspec.register_element_type(RSpecElements.NETWORK, 'network', '//network')
243     #rspec.register_element_type(RSpecElements.NODE, 'node', '//node')
244     #print rspec.find(RSpecElements.NODE)[0]
245     #print rspec.find(RSpecElements.NODE, depth=1)[0]
246