acf7864c54f26567cd4dd545961c9dd42b08eb60
[sfa.git] / sfa / managers / slice_manager_pl.py
1
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import sfa_logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.rspecs.pg_rspec import PGRSpec
19 from sfa.rspecs.sfa_rspec import SfaRSpec
20 from sfa.rspecs.rspec_converter import RSpecConverter
21 from sfa.rspecs.rspec_parser import parse_rspec    
22 from sfa.util.policy import Policy
23 from sfa.util.prefixTree import prefixTree
24 from sfa.util.sfaticket import *
25 from sfa.trust.credential import Credential
26 from sfa.util.threadmanager import ThreadManager
27 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
28 import sfa.plc.peers as peers
29 from sfa.util.version import version_core
30 from sfa.rspecs.rspec_version import RSpecVersion
31 from sfa.rspecs.pl_rspec_version import supported_rspecs
32 from sfa.util.callids import Callids
33
34 # we have specialized xmlrpclib.ServerProxy to remember the input url
35 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
36 def get_serverproxy_url (server):
37     try:
38         return server.url
39     except:
40         sfa_logger().warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
41         return server._ServerProxy__host + server._ServerProxy__handler 
42
43 def GetVersion(api):
44     # peers explicitly in aggregates.xml
45     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems() 
46                    if peername != api.hrn])
47     xrn=Xrn (api.hrn)
48     version_more = {'interface':'slicemgr',
49                     'hrn' : xrn.get_hrn(),
50                     'urn' : xrn.get_urn(),
51                     'peers': peers,}
52     version_more.update(supported_rspecs)     
53     sm_version=version_core(version_more)
54     # local aggregate if present needs to have localhost resolved
55     if api.hrn in api.aggregates:
56         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
57         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
58     return sm_version
59
60 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
61
62     def _CreateSliver(server, xrn, credentail, rspec, users, call_id):
63             # get aggregate version
64             version = server.GetVersion()
65             if 'sfa' in version:
66                 # just send the whole rspec to SFA AM/SM 
67                 server.CreateSliver(xrn, credential, rspec, users, call_id)
68             elif 'geni_api' in version:
69                 pass  
70                 # convert to pg rspec
71                 
72
73     if Callids().already_handled(call_id): return ""
74
75     # Validate the RSpec against PlanetLab's schema --disabled for now
76     # The schema used here needs to aggregate the PL and VINI schemas
77     # schema = "/var/www/html/schemas/pl.rng"
78     rspec = parse_rspec(rspec_str)
79     schema = None
80     if schema:
81         rspec.validate(schema)
82
83     # attempt to use delegated credential first
84     credential = api.getDelegatedCredential(creds)
85     if not credential:
86         credential = api.getCredential()
87
88     # get the callers hrn
89     hrn, type = urn_to_hrn(xrn)
90     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
91     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
92     threads = ThreadManager()
93     for aggregate in api.aggregates:
94         # prevent infinite loop. Dont send request back to caller
95         # unless the caller is the aggregate's SM 
96         if caller_hrn == aggregate and aggregate != api.hrn:
97             continue
98             
99         # Just send entire RSpec to each aggregate
100         server = api.aggregates[aggregate]
101         threads.run(_CreateSliver, server, xrn, credential, rspec.toxml(), users, call_id)
102             
103     results = threads.get_results()
104     rspec = SfaRSpec()
105     for result in results:
106         rspec.merge(result)     
107     return rspec.toxml()
108
109 def RenewSliver(api, xrn, creds, expiration_time, call_id):
110     if Callids().already_handled(call_id): return True
111
112     (hrn, type) = urn_to_hrn(xrn)
113     # get the callers hrn
114     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
115     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
116
117     # attempt to use delegated credential first
118     credential = api.getDelegatedCredential(creds)
119     if not credential:
120         credential = api.getCredential()
121     threads = ThreadManager()
122     for aggregate in api.aggregates:
123         # prevent infinite loop. Dont send request back to caller
124         # unless the caller is the aggregate's SM
125         if caller_hrn == aggregate and aggregate != api.hrn:
126             continue
127
128         server = api.aggregates[aggregate]
129         threads.run(server.RenewSliver, xrn, [credential], expiration_time, call_id)
130     # 'and' the results
131     return reduce (lambda x,y: x and y, threads.get_results() , True)
132
133 def get_ticket(api, xrn, creds, rspec, users):
134     slice_hrn, type = urn_to_hrn(xrn)
135     # get the netspecs contained within the clients rspec
136     aggregate_rspecs = {}
137     tree= etree.parse(StringIO(rspec))
138     elements = tree.findall('./network')
139     for element in elements:
140         aggregate_hrn = element.values()[0]
141         aggregate_rspecs[aggregate_hrn] = rspec 
142
143     # get the callers hrn
144     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
145     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
146
147     # attempt to use delegated credential first
148     credential = api.getDelegatedCredential(creds)
149     if not credential:
150         credential = api.getCredential() 
151     threads = ThreadManager()
152     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
153         # prevent infinite loop. Dont send request back to caller
154         # unless the caller is the aggregate's SM
155         if caller_hrn == aggregate and aggregate != api.hrn:
156             continue
157         server = None
158         if aggregate in api.aggregates:
159             server = api.aggregates[aggregate]
160         else:
161             net_urn = hrn_to_urn(aggregate, 'authority')     
162             # we may have a peer that knows about this aggregate
163             for agg in api.aggregates:
164                 target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
165                 if not target_aggs or not 'hrn' in target_aggs[0]:
166                     continue
167                 # send the request to this address 
168                 url = target_aggs[0]['url']
169                 server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
170                 # aggregate found, no need to keep looping
171                 break   
172         if server is None:
173             continue 
174         threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
175
176     results = threads.get_results()
177     
178     # gather information from each ticket 
179     rspecs = []
180     initscripts = []
181     slivers = [] 
182     object_gid = None  
183     for result in results:
184         agg_ticket = SfaTicket(string=result)
185         attrs = agg_ticket.get_attributes()
186         if not object_gid:
187             object_gid = agg_ticket.get_gid_object()
188         rspecs.append(agg_ticket.get_rspec())
189         initscripts.extend(attrs.get('initscripts', [])) 
190         slivers.extend(attrs.get('slivers', [])) 
191     
192     # merge info
193     attributes = {'initscripts': initscripts,
194                  'slivers': slivers}
195     merged_rspec = merge_rspecs(rspecs) 
196
197     # create a new ticket
198     ticket = SfaTicket(subject = slice_hrn)
199     ticket.set_gid_caller(api.auth.client_gid)
200     ticket.set_issuer(key=api.key, subject=api.hrn)
201     ticket.set_gid_object(object_gid)
202     ticket.set_pubkey(object_gid.get_pubkey())
203     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
204     ticket.set_attributes(attributes)
205     ticket.set_rspec(merged_rspec)
206     ticket.encode()
207     ticket.sign()          
208     return ticket.save_to_string(save_parents=True)
209
210
211 def DeleteSliver(api, xrn, creds, call_id):
212     if Callids().already_handled(call_id): return ""
213     (hrn, type) = urn_to_hrn(xrn)
214     # get the callers hrn
215     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
216     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
217
218     # attempt to use delegated credential first
219     credential = api.getDelegatedCredential(creds)
220     if not credential:
221         credential = api.getCredential()
222     threads = ThreadManager()
223     for aggregate in api.aggregates:
224         # prevent infinite loop. Dont send request back to caller
225         # unless the caller is the aggregate's SM
226         if caller_hrn == aggregate and aggregate != api.hrn:
227             continue
228         server = api.aggregates[aggregate]
229         threads.run(server.DeleteSliver, xrn, credential, call_id)
230     threads.get_results()
231     return 1
232
233 def start_slice(api, xrn, creds):
234     hrn, type = urn_to_hrn(xrn)
235
236     # get the callers hrn
237     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
238     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
239
240     # attempt to use delegated credential first
241     credential = api.getDelegatedCredential(creds)
242     if not credential:
243         credential = api.getCredential()
244     threads = ThreadManager()
245     for aggregate in api.aggregates:
246         # prevent infinite loop. Dont send request back to caller
247         # unless the caller is the aggregate's SM
248         if caller_hrn == aggregate and aggregate != api.hrn:
249             continue
250         server = api.aggregates[aggregate]
251         threads.run(server.Start, xrn, credential)
252     threads.get_results()    
253     return 1
254  
255 def stop_slice(api, xrn, creds):
256     hrn, type = urn_to_hrn(xrn)
257
258     # get the callers hrn
259     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
260     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
261
262     # attempt to use delegated credential first
263     credential = api.getDelegatedCredential(creds)
264     if not credential:
265         credential = api.getCredential()
266     threads = ThreadManager()
267     for aggregate in api.aggregates:
268         # prevent infinite loop. Dont send request back to caller
269         # unless the caller is the aggregate's SM
270         if caller_hrn == aggregate and aggregate != api.hrn:
271             continue
272         server = api.aggregates[aggregate]
273         threads.run(server.Stop, xrn, credential)
274     threads.get_results()    
275     return 1
276
277 def reset_slice(api, xrn):
278     """
279     Not implemented
280     """
281     return 1
282
283 def shutdown(api, xrn, creds):
284     """
285     Not implemented   
286     """
287     return 1
288
289 def status(api, xrn, creds):
290     """
291     Not implemented 
292     """
293     return 1
294
295 # Thierry : caching at the slicemgr level makes sense to some extent
296 caching=True
297 #caching=False
298 def ListSlices(api, creds, call_id):
299
300     if Callids().already_handled(call_id): return []
301
302     # look in cache first
303     if caching and api.cache:
304         slices = api.cache.get('slices')
305         if slices:
306             return slices    
307
308     # get the callers hrn
309     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
310     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
311
312     # attempt to use delegated credential first
313     credential = api.getDelegatedCredential(creds)
314     if not credential:
315         credential = api.getCredential()
316     threads = ThreadManager()
317     # fetch from aggregates
318     for aggregate in api.aggregates:
319         # prevent infinite loop. Dont send request back to caller
320         # unless the caller is the aggregate's SM
321         if caller_hrn == aggregate and aggregate != api.hrn:
322             continue
323         server = api.aggregates[aggregate]
324         threads.run(server.ListSlices, credential, call_id)
325
326     # combime results
327     results = threads.get_results()
328     slices = []
329     for result in results:
330         slices.extend(result)
331     
332     # cache the result
333     if caching and api.cache:
334         api.cache.add('slices', slices)
335
336     return slices
337
338
339 def ListResources(api, creds, options, call_id):
340
341     if Callids().already_handled(call_id): return ""
342
343     # get slice's hrn from options
344     xrn = options.get('geni_slice_urn', '')
345     (hrn, type) = urn_to_hrn(xrn)
346
347     # get the rspec's return format from options
348     rspec_version = RSpecVersion(options.get('rspec_version', 'SFA 1'))
349     version_string = "rspec_%s" % (rspec_version.get_version_name())
350
351     # look in cache first
352     if caching and api.cache and not xrn:
353         rspec =  api.cache.get(version_string)
354         if rspec:
355             return rspec
356
357     # get the callers hrn
358     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
359     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
360
361     # attempt to use delegated credential first
362     credential = api.getDelegatedCredential(creds)
363     if not credential:
364         credential = api.getCredential()
365     threads = ThreadManager()
366     for aggregate in api.aggregates:
367         # prevent infinite loop. Dont send request back to caller
368         # unless the caller is the aggregate's SM
369         if caller_hrn == aggregate and aggregate != api.hrn:
370             continue
371         # get the rspec from the aggregate
372         server = api.aggregates[aggregate]
373         my_opts = copy(options)
374         my_opts['geni_compressed'] = False
375         threads.run(server.ListResources, credential, my_opts, call_id)
376                     
377     results = threads.get_results()
378     #results.append(open('/root/protogeni.rspec', 'r').read())
379     rspec = SfaRSpec()
380     for result in results:
381         try:
382             tmp_rspec = parse_rspec(result)
383             if isinstance(tmp_rspec, SfaRSpec):
384                 rspec.merge(result)
385             elif isinstance(tmp_rspec, PGRSpec):
386                 rspec.merge(RSpecConverter.to_sfa_rspec(result))
387             else:
388                 api.logger.info("SM.ListResources: invalid aggregate rspec")                        
389         except:
390             api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
391
392     # cache the result
393     if caching and api.cache and not xrn:
394         api.cache.add(version_string, rspec.toxml())
395  
396     return rspec.toxml()
397
398 # first draft at a merging SliverStatus
399 def SliverStatus(api, slice_xrn, creds, call_id):
400     if Callids().already_handled(call_id): return {}
401     # attempt to use delegated credential first
402     credential = api.getDelegatedCredential(creds)
403     if not credential:
404         credential = api.getCredential()
405     threads = ThreadManager()
406     for aggregate in api.aggregates:
407         server = api.aggregates[aggregate]
408         threads.run (server.SliverStatus, slice_xrn, credential, call_id)
409     results = threads.get_results()
410
411     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
412     results = [ result for result in results if result and result['geni_resources']]
413
414     # do not try to combine if there's no result
415     if not results : return {}
416
417     # otherwise let's merge stuff
418     overall = {}
419
420     # mmh, it is expected that all results carry the same urn
421     overall['geni_urn'] = results[0]['geni_urn']
422
423     # consolidate geni_status - simple model using max on a total order
424     states = [ 'ready', 'configuring', 'failed', 'unknown' ]
425     # hash name to index
426     shash = dict ( zip ( states, range(len(states)) ) )
427     def combine_status (x,y):
428         return shash [ max (shash(x),shash(y)) ]
429     overall['geni_status'] = reduce (combine_status, [ result['geni_status'] for result in results], 'ready' )
430
431     # {'ready':0,'configuring':1,'failed':2,'unknown':3}
432     # append all geni_resources
433     overall['geni_resources'] = \
434         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
435
436     return overall
437
438 def main():
439     r = RSpec()
440     r.parseFile(sys.argv[1])
441     rspec = r.toDict()
442     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
443
444 if __name__ == "__main__":
445     main()
446