use sfa.util.threadmanager in get_rspec() to contact aggregates in parallel instead...
[sfa.git] / sfa / managers / slice_manager_pl.py
1 ### $Id: slices.py 15842 2009-11-22 09:56:13Z anil $
2 ### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/plc/slices.py $
3
4 import datetime
5 import time
6 import traceback
7 import sys
8 from copy import deepcopy
9 from lxml import etree
10 from StringIO import StringIO
11 from types import StringTypes
12
13 from sfa.util.namespace import *
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.util.policy import Policy
19 from sfa.util.prefixTree import prefixTree
20 from sfa.util.sfaticket import *
21 from sfa.util.threadmanager import ThreadManager
22 from sfa.util.debug import log
23 import sfa.plc.peers as peers
24
25 def delete_slice(api, xrn, origin_hrn=None):
26     credential = api.getCredential()
27     aggregates = api.aggregates
28     for aggregate in aggregates:
29         success = False
30         # request hash is optional so lets try the call without it
31         try:
32             aggregates[aggregate].delete_slice(credential, xrn, origin_hrn)
33             success = True
34         except:
35             print >> log, "%s" % (traceback.format_exc())
36             print >> log, "Error calling delete slice at aggregate %s" % aggregate
37     return 1
38
39 def create_slice(api, xrn, rspec, origin_hrn=None):
40     hrn, type = urn_to_hrn(xrn)
41
42     # Validate the RSpec against PlanetLab's schema --disabled for now
43     # The schema used here needs to aggregate the PL and VINI schemas
44     # schema = "/var/www/html/schemas/pl.rng"
45     schema = None
46     if schema:
47         try:
48             tree = etree.parse(StringIO(rspec))
49         except etree.XMLSyntaxError:
50             message = str(sys.exc_info()[1])
51             raise InvalidRSpec(message)
52
53         relaxng_doc = etree.parse(schema)
54         relaxng = etree.RelaxNG(relaxng_doc)
55         
56         if not relaxng(tree):
57             error = relaxng.error_log.last_error
58             message = "%s (line %s)" % (error.message, error.line)
59             raise InvalidRSpec(message)
60
61     aggs = api.aggregates
62     cred = api.getCredential()                                                 
63     for agg in aggs:
64         if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]:      
65             try:
66                 # Just send entire RSpec to each aggregate
67                 aggs[agg].create_slice(cred, xrn, rspec, origin_hrn)
68             except:
69                 print >> log, "Error creating slice %s at %s" % (hrn, agg)
70                 traceback.print_exc()
71
72     return True
73
74 def get_ticket(api, xrn, rspec, origin_hrn=None):
75     slice_hrn, type = urn_to_hrn(xrn)
76     # get the netspecs contained within the clients rspec
77     client_rspec = RSpec(xml=rspec)
78     netspecs = client_rspec.getDictsByTagName('NetSpec')
79     
80     # create an rspec for each individual rspec 
81     rspecs = {}
82     temp_rspec = RSpec()
83     for netspec in netspecs:
84         net_hrn = netspec['name']
85         resources = {'start_time': 0, 'end_time': 0 , 
86                      'network': {'NetSpec' : netspec}}
87         resourceDict = {'RSpec': resources}
88         temp_rspec.parseDict(resourceDict)
89         rspecs[net_hrn] = temp_rspec.toxml() 
90     
91     # send the rspec to the appropiate aggregate/sm
92     aggregates = api.aggregates
93     credential = api.getCredential()
94     tickets = {}
95     for net_hrn in rspecs:
96         net_urn = urn_to_hrn(net_hrn)     
97         try:
98             # if we are directly connected to the aggregate then we can just
99             # send them the request. if not, then we may be connected to an sm
100             # thats connected to the aggregate
101             if net_hrn in aggregates:
102                 ticket = aggregates[net_hrn].get_ticket(credential, xrn, \
103                             rspecs[net_hrn], origin_hrn)
104                 tickets[net_hrn] = ticket
105             else:
106                 # lets forward this rspec to a sm that knows about the network
107                 for agg in aggregates:
108                     network_found = aggregates[agg].get_aggregates(credential, net_urn)
109                     if network_found:
110                         ticket = aggregates[aggregate].get_ticket(credential, \
111                                         slice_hrn, rspecs[net_hrn], origin_hrn)
112                         tickets[aggregate] = ticket
113         except:
114             print >> log, "Error getting ticket for %(slice_hrn)s at aggregate %(net_hrn)s" % \
115                            locals()
116             
117     # create a new ticket
118     new_ticket = SfaTicket(subject = slice_hrn)
119     new_ticket.set_gid_caller(api.auth.client_gid)
120     new_ticket.set_issuer(key=api.key, subject=api.hrn)
121    
122     tmp_rspec = RSpec()
123     networks = []
124     valid_data = {
125         'timestamp': int(time.time()),
126         'initscripts': [],
127         'slivers': [] 
128     } 
129     # merge data from aggregate ticket into new ticket 
130     for agg_ticket in tickets.values():
131         # get data from this ticket
132         agg_ticket = SfaTicket(string=agg_ticket)
133         attributes = agg_ticket.get_attributes()
134         if attributes.get('initscripts', []) != None:
135             valid_data['initscripts'].extend(attributes.get('initscripts', []))
136         if attributes.get('slivers', []) != None:
137             valid_data['slivers'].extend(attributes.get('slivers', []))
138  
139         # set the object gid
140         object_gid = agg_ticket.get_gid_object()
141         new_ticket.set_gid_object(object_gid)
142         new_ticket.set_pubkey(object_gid.get_pubkey())
143
144         # build the rspec
145         tmp_rspec.parseString(agg_ticket.get_rspec())
146         networks.extend([{'NetSpec': tmp_rspec.getDictsByTagName('NetSpec')}])
147     
148     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
149     new_ticket.set_attributes(valid_data)
150     resources = {'networks': networks, 'start_time': 0, 'duration': 0}
151     resourceDict = {'RSpec': resources}
152     tmp_rspec.parseDict(resourceDict)
153     new_ticket.set_rspec(tmp_rspec.toxml())
154     new_ticket.encode()
155     new_ticket.sign()          
156     return new_ticket.save_to_string(save_parents=True)
157
158 def start_slice(api, xrn):
159     hrn, type = urn_to_hrn(xrn)
160     slicename = hrn_to_pl_slicename(hrn)
161     slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id'])
162     if not slices:
163         raise RecordNotFound(hrn)
164     slice_id = slices[0]
165     attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
166     attribute_id = attreibutes[0]['slice_attribute_id']
167     api.plshell.UpdateSliceTag(api.plauth, attribute_id, "1" )
168
169     return 1
170  
171 def stop_slice(api, xrn):
172     hrn, type = urn_to_hrn(xrn)
173     slicename = hrn_to_pl_slicename(hrn)
174     slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id'])
175     if not slices:
176         raise RecordNotFound(hrn)
177     slice_id = slices[0]['slice_id']
178     attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
179     attribute_id = attributes[0]['slice_attribute_id']
180     api.plshell.UpdateSliceTag(api.plauth, attribute_id, "0")
181     return 1
182
183 def reset_slice(api, xrn):
184     # XX not implemented at this interface
185     return 1
186
187 def get_slices(api):
188     # look in cache first
189     if api.cache:
190         slices = api.cache.get('slices')
191         if slices:
192             return slices    
193
194     # fetch from aggregates
195     slices = []
196     credential = api.getCredential()
197     for aggregate in api.aggregates:
198         try:
199             tmp_slices = api.aggregates[aggregate].get_slices(credential)
200             slices.extend(tmp_slices)
201         except:
202             print >> log, "%s" % (traceback.format_exc())
203             print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
204
205     # cache the result
206     if api.cache:
207         api.cache.add('slices', slices)
208
209     return slices
210  
211 def get_rspec(api, xrn=None, origin_hrn=None):
212     # look in cache first 
213     if api.cache and not xrn:
214         rspec =  api.cache.get('nodes')
215         if rspec:
216             return rspec
217
218     hrn, type = urn_to_hrn(xrn)
219     rspec = None
220     aggs = api.aggregates
221     cred = api.getCredential()
222     threads = ThreadManager()
223     for agg in aggs:
224         if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]:      
225                 # get the rspec from the aggregate
226                 #agg_rspec = aggs[agg].get_resources(cred, xrn, origin_hrn)
227                 threads.run(aggs[agg].get_resources, cred, xrn, origin_hrn)
228
229     results = threads.get_results()
230     # combine the rspecs into a single rspec 
231     for agg_rspec in results:
232         try:
233             tree = etree.parse(StringIO(agg_rspec))
234         except etree.XMLSyntaxError:
235             message = agg + ": " + str(sys.exc_info()[1])
236             raise InvalidRSpec(message)
237
238         root = tree.getroot()
239         if root.get("type") in ["SFA"]:
240             if rspec == None:
241                 rspec = root
242             else:
243                 for network in root.iterfind("./network"):
244                     rspec.append(deepcopy(network))
245                 for request in root.iterfind("./request"):
246                     rspec.append(deepcopy(request))
247
248     rspec =  etree.tostring(rspec, xml_declaration=True, pretty_print=True)
249     # cache the result
250     if api.cache and not xrn:
251         api.cache.add('nodes', rspec)
252  
253     return rspec
254
255 """
256 Returns the request context required by sfatables. At some point, this
257 mechanism should be changed to refer to "contexts", which is the
258 information that sfatables is requesting. But for now, we just return
259 the basic information needed in a dict.
260 """
261 def fetch_context(slice_hrn, user_hrn, contexts):
262     #slice_hrn = urn_to_hrn(slice_xrn)[0]
263     #user_hrn = urn_to_hrn(user_xrn)[0]
264     base_context = {'sfa':{'user':{'hrn':user_hrn}, 'slice':{'hrn':slice_hrn}}}
265     return base_context
266
267 def main():
268     r = RSpec()
269     r.parseFile(sys.argv[1])
270     rspec = r.toDict()
271     create_slice(None,'plc.princeton.tmacktestslice',rspec)
272
273 if __name__ == "__main__":
274     main()
275