moved RSpecElement and RSpecElements to rspec_element.py
[sfa.git] / sfa / rspecs / rspec.py
1 #!/usr/bin/python 
2 from lxml import etree
3 from StringIO import StringIO
4 from datetime import datetime, timedelta
5 from sfa.util.xrn import *
6 from sfa.util.plxrn import hostname_to_urn
7 from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements 
8 from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement
9
10
11 class XpathFilter:
12     @staticmethod
13     def xpath(filter={}):
14         xpath = ""
15         if filter:
16             filter_list = []
17             for (key, value) in filter.items():
18                 if key == 'text':
19                     key = 'text()'
20                 else:
21                     key = '@'+key
22                 if isinstance(value, str):
23                     filter_list.append('%s="%s"' % (key, value))
24                 elif isinstance(value, list):
25                     filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
26             if filter_list:
27                 xpath = ' and '.join(filter_list)
28                 xpath = '[' + xpath + ']'
29         return xpath
30
31 class RSpec:
32     header = '<?xml version="1.0"?>\n'
33     template = """<RSpec></RSpec>"""
34     xml = None
35     type = None
36     version = None
37     namespaces = None
38     user_options = {}
39  
40     def __init__(self, rspec="", namespaces={}, type=None, user_options={}):
41         self.type = type
42         self.user_options = user_options
43         self.elements = {}
44         if rspec:
45             self.parse_rspec(rspec, namespaces)
46         else:
47             self.create()
48
49     def create(self):
50         """
51         Create root element
52         """
53         # eg. 2011-03-23T19:53:28Z 
54         date_format = '%Y-%m-%dT%H:%M:%SZ'
55         now = datetime.utcnow()
56         generated_ts = now.strftime(date_format)
57         expires_ts = (now + timedelta(hours=1)).strftime(date_format) 
58         self.parse_rspec(self.template, self.namespaces)
59         self.xml.set('expires', expires_ts)
60         self.xml.set('generated', generated_ts)
61     
62     def parse_rspec(self, rspec, namespaces={}):
63         """
64         parse rspec into etree
65         """
66         parser = etree.XMLParser(remove_blank_text=True)
67         try:
68             tree = etree.parse(rspec, parser)
69         except IOError:
70             # 'rspec' file doesnt exist. 'rspec' is proably an xml string
71             try:
72                 tree = etree.parse(StringIO(rspec), parser)
73             except Exception, e:
74                 raise InvalidRSpec(str(e))
75         self.xml = tree.getroot()  
76         if namespaces:
77            self.namespaces = namespaces
78
79     def validate(self, schema):
80         """
81         Validate against rng schema
82         """
83
84         relaxng_doc = etree.parse(schema)
85         relaxng = etree.RelaxNG(relaxng_doc)
86         if not relaxng(self.xml):
87             error = relaxng.error_log.last_error
88             message = "%s (line %s)" % (error.message, error.line)
89             raise InvalidRSpec(message)
90         return True
91
92     def xpath(self, xpath):
93         return self.xml.xpath(xpath, namespaces=self.namespaces)
94
95     def register_element_type(self, element_type, element_name, element_path):
96         if element_type not in RSpecElements:
97             raise InvalidRSpecElement(element_type, extra="no such element type: %s. Must specify a valid RSpecElement" % element_type)
98         self.elements[element_type] = RSpecElement(element_type, element_name, element_path)
99
100     def get_element_type(self, element_type):
101         if element_type not in self.elements:
102             msg = "ElementType %s not registerd for this rspec" % element_type
103             raise InvalidRSpecElement(element_type, extra=msg)
104         return self.elements[element_type]
105
106     def add_attribute(self, elem, name, value):
107         """
108         Add attribute to specified etree element    
109         """
110         opt = etree.SubElement(elem, name)
111         opt.text = value
112
113     def add_element(self, name, attrs={}, parent=None, text=""):
114         """
115         Generic wrapper around etree.SubElement(). Adds an element to 
116         specified parent node. Adds element to root node is parent is 
117         not specified. 
118         """
119         if parent == None:
120             parent = self.xml
121         element = etree.SubElement(parent, name)
122         if text:
123             element.text = text
124         if isinstance(attrs, dict):
125             for attr in attrs:
126                 element.set(attr, attrs[attr])  
127         return element
128
129     def remove_attribute(self, elem, name, value):
130         """
131         Removes an attribute from an element
132         """
133         if elem is not None:
134             opts = elem.iterfind(name)
135             if opts is not None:
136                 for opt in opts:
137                     if opt.text == value:
138                         elem.remove(opt)
139
140     def remove_element(self, element_name, root_node = None):
141         """
142         Removes all occurences of an element from the tree. Start at 
143         specified root_node if specified, otherwise start at tree's root.   
144         """
145         if not root_node:
146             root_node = self.xml
147
148         if not element_name.startswith('//'):
149             element_name = '//' + element_name
150
151         elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
152         for element in elements:
153             parent = element.getparent()
154             parent.remove(element)
155
156     def attributes_list(self, elem):
157         # convert a list of attribute tags into list of tuples
158         # (tagnme, text_value)
159         opts = []
160         if elem is not None:
161             for e in elem:
162                 opts.append((e.tag, str(e.text).strip()))
163         return opts
164
165     def get_element_attributes(self, elem=None, depth=0):
166         if elem == None:
167             elem = self.root_node
168         attrs = dict(elem.attrib)
169         attrs['text'] = str(elem.text).strip()
170         if isinstance(depth, int) and depth > 0:
171             for child_elem in list(elem):
172                 key = str(child_elem.tag)
173                 if key not in attrs:
174                     attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
175                 else:
176                     attrs[key].append(self.get_element_attributes(child_elem, depth-1))
177         else:
178             attrs['child_nodes'] = list(elem)
179         return attrs
180
181     def get(self, element_type, filter={}, depth=0):
182         elements = [self.get_element_attributes(element, depth=depth) for element in \
183                     self.get_elements(element_type, filter)]
184         return elements
185
186     def get_elements(self, element_type, filter={}):
187         """
188         search for a registered element
189         """
190         if element_type not in self.elements:
191             msg = "Unable to search for element %s in rspec, expath expression not found." % \
192                    element_type
193             raise InvalidRSpecElement(element_type, extra=msg)
194         rspec_element = self.get_element_type(element_type)
195         xpath = rspec_element.path + XpathFilter.xpath(filter)
196         return self.xpath(xpath)
197
198     def merge(self, in_rspec):
199         pass
200
201     def cleanup(self):
202         """
203         Optional method which inheriting classes can choose to implent. 
204         """
205         pass 
206
207     def _process_slivers(self, slivers):
208         """
209         Creates a dict of sliver details for each sliver host
210         
211         @param slivers a single hostname, list of hostanmes or list of dicts keys on hostname,
212         Returns a list of dicts 
213         """
214         if not isinstance(slivers, list):
215             slivers = [slivers]
216         dicts = []
217         for sliver in slivers:
218             if isinstance(sliver, dict):
219                 dicts.append(sliver)
220             elif isinstance(sliver, basestring):
221                 dicts.append({'hostname': sliver}) 
222         return dicts
223
224     def __str__(self):
225         return self.toxml()
226
227     def toxml(self, cleanup=False):
228         if cleanup:
229             self.cleanup()
230         return self.header + etree.tostring(self.xml, pretty_print=True)  
231         
232     def save(self, filename):
233         f = open(filename, 'w')
234         f.write(self.toxml())
235         f.close()
236  
237 if __name__ == '__main__':
238     rspec = RSpec('/tmp/resources.rspec')
239     print rspec
240     #rspec.register_element_type(RSpecElements.NETWORK, 'network', '//network')
241     #rspec.register_element_type(RSpecElements.NODE, 'node', '//node')
242     #print rspec.find(RSpecElements.NODE)[0]
243     #print rspec.find(RSpecElements.NODE, depth=1)[0]
244