renamed sfaticket from util/ to trust/
[sfa.git] / sfa / managers / slice_manager.py
1 #
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.specdict import *
15 from sfa.util.faults import *
16 from sfa.util.record import SfaRecord
17 from sfa.rspecs.rspec_converter import RSpecConverter
18 from sfa.client.client_helper import sfa_to_pg_users_arg
19 from sfa.rspecs.version_manager import VersionManager
20 from sfa.rspecs.rspec import RSpec 
21 from sfa.util.policy import Policy
22 from sfa.util.prefixTree import prefixTree
23 from sfa.trust.sfaticket import SfaTicket
24 from sfa.trust.credential import Credential
25 from sfa.util.threadmanager import ThreadManager
26 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
27 import sfa.plc.peers as peers
28 from sfa.util.version import version_core
29 from sfa.util.callids import Callids
30
31
32 def _call_id_supported(api, server):
33     """
34     Returns true if server support the optional call_id arg, false otherwise.
35     """
36     server_version = api.get_cached_server_version(server)
37
38     if 'sfa' in server_version:
39         code_tag = server_version['code_tag']
40         code_tag_parts = code_tag.split("-")
41
42         version_parts = code_tag_parts[0].split(".")
43         major, minor = version_parts[0:2]
44         rev = code_tag_parts[1]
45         if int(major) > 1:
46             if int(minor) > 0 or int(rev) > 20:
47                 return True
48     return False
49
50 # we have specialized xmlrpclib.ServerProxy to remember the input url
51 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
52 def get_serverproxy_url (server):
53     try:
54         return server.get_url()
55     except:
56         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
57         return server._ServerProxy__host + server._ServerProxy__handler
58
59 def GetVersion(api):
60     # peers explicitly in aggregates.xml
61     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
62                    if peername != api.hrn])
63     version_manager = VersionManager()
64     ad_rspec_versions = []
65     request_rspec_versions = []
66     for rspec_version in version_manager.versions:
67         if rspec_version.content_type in ['*', 'ad']:
68             ad_rspec_versions.append(rspec_version.to_dict())
69         if rspec_version.content_type in ['*', 'request']:
70             request_rspec_versions.append(rspec_version.to_dict())
71     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
72     xrn=Xrn(api.hrn, 'authority+sa')
73     version_more = {'interface':'slicemgr',
74                     'hrn' : xrn.get_hrn(),
75                     'urn' : xrn.get_urn(),
76                     'peers': peers,
77                     'request_rspec_versions': request_rspec_versions,
78                     'ad_rspec_versions': ad_rspec_versions,
79                     'default_ad_rspec': default_rspec_version
80                     }
81     sm_version=version_core(version_more)
82     # local aggregate if present needs to have localhost resolved
83     if api.hrn in api.aggregates:
84         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
85         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
86     return sm_version
87
88 def drop_slicemgr_stats(rspec):
89     try:
90         stats_elements = rspec.xml.xpath('//statistics')
91         for node in stats_elements:
92             node.getparent().remove(node)
93     except Exception, e:
94         api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
95
96 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
97     try:
98         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
99         if stats_tags:
100             stats_tag = stats_tags[0]
101         else:
102             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
103
104         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
105     except Exception, e:
106         api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
107
108 def ListResources(api, creds, options, call_id):
109     version_manager = VersionManager()
110     def _ListResources(aggregate, server, credential, opts, call_id):
111
112         my_opts = copy(opts)
113         args = [credential, my_opts]
114         tStart = time.time()
115         try:
116             if _call_id_supported(api, server):
117                 args.append(call_id)
118             version = api.get_cached_server_version(server)
119             # force ProtoGENI aggregates to give us a v2 RSpec
120             if 'sfa' not in version.keys():
121                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
122             rspec = server.ListResources(*args)
123             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
124         except Exception, e:
125             api.logger.log_exc("ListResources failed at %s" %(server.url))
126             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
127
128     if Callids().already_handled(call_id): return ""
129
130     # get slice's hrn from options
131     xrn = options.get('geni_slice_urn', '')
132     (hrn, type) = urn_to_hrn(xrn)
133     if 'geni_compressed' in options:
134         del(options['geni_compressed'])
135
136     # get the rspec's return format from options
137     rspec_version = version_manager.get_version(options.get('rspec_version'))
138     version_string = "rspec_%s" % (rspec_version.to_string())
139
140     # look in cache first
141     if caching and api.cache and not xrn:
142         rspec =  api.cache.get(version_string)
143         if rspec:
144             return rspec
145
146     # get the callers hrn
147     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
148     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
149
150     # attempt to use delegated credential first
151     cred = api.getDelegatedCredential(creds)
152     if not cred:
153         cred = api.getCredential()
154     threads = ThreadManager()
155     for aggregate in api.aggregates:
156         # prevent infinite loop. Dont send request back to caller
157         # unless the caller is the aggregate's SM
158         if caller_hrn == aggregate and aggregate != api.hrn:
159             continue
160
161         # get the rspec from the aggregate
162         interface = api.aggregates[aggregate]
163         server = api.get_server(interface, cred)
164         threads.run(_ListResources, aggregate, server, [cred], options, call_id)
165
166
167     results = threads.get_results()
168     rspec_version = version_manager.get_version(options.get('rspec_version'))
169     if xrn:    
170         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
171     else: 
172         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
173     rspec = RSpec(version=result_version)
174     for result in results:
175         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
176         if result["status"]=="success":
177             try:
178                 rspec.version.merge(result["rspec"])
179             except:
180                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
181
182     # cache the result
183     if caching and api.cache and not xrn:
184         api.cache.add(version_string, rspec.toxml())
185
186     return rspec.toxml()
187
188
189 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
190
191     version_manager = VersionManager()
192     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
193         tStart = time.time()
194         try:
195             # Need to call GetVersion at an aggregate to determine the supported
196             # rspec type/format beofre calling CreateSliver at an Aggregate.
197             server_version = api.get_cached_server_version(server)
198             requested_users = users
199             if 'sfa' not in server_version and 'geni_api' in server_version:
200                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
201                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
202                 rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
203                 filter = {'component_manager_id': server_version['urn']}
204                 rspec.filter(filter)
205                 rspec = rspec.toxml()
206                 requested_users = sfa_to_pg_users_arg(users)
207             args = [xrn, credential, rspec, requested_users]
208             if _call_id_supported(api, server):
209                 args.append(call_id)
210             rspec = server.CreateSliver(*args)
211             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
212         except: 
213             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
214             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
215
216     if Callids().already_handled(call_id): return ""
217     # Validate the RSpec against PlanetLab's schema --disabled for now
218     # The schema used here needs to aggregate the PL and VINI schemas
219     # schema = "/var/www/html/schemas/pl.rng"
220     rspec = RSpec(rspec_str)
221     schema = None
222     if schema:
223         rspec.validate(schema)
224
225     # if there is a <statistics> section, the aggregates don't care about it,
226     # so delete it.
227     drop_slicemgr_stats(rspec)
228
229     # attempt to use delegated credential first
230     cred = api.getDelegatedCredential(creds)
231     if not cred:
232         cred = api.getCredential()
233
234     # get the callers hrn
235     hrn, type = urn_to_hrn(xrn)
236     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
237     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
238     threads = ThreadManager()
239     for aggregate in api.aggregates:
240         # prevent infinite loop. Dont send request back to caller
241         # unless the caller is the aggregate's SM 
242         if caller_hrn == aggregate and aggregate != api.hrn:
243             continue
244         interface = api.aggregates[aggregate]
245         server = api.get_server(interface, cred)
246         # Just send entire RSpec to each aggregate
247         threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
248             
249     results = threads.get_results()
250     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
251     result_rspec = RSpec(version=manifest_version)
252     for result in results:
253         add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
254         if result["status"]=="success":
255             try:
256                 result_rspec.version.merge(result["rspec"])
257             except:
258                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
259     return result_rspec.toxml()
260
261 def RenewSliver(api, xrn, creds, expiration_time, call_id):
262     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
263         server_version = api.get_cached_server_version(server)
264         args =  [xrn, creds, expiration_time, call_id]
265         if _call_id_supported(api, server):
266             args.append(call_id)
267         return server.RenewSliver(*args)
268
269     if Callids().already_handled(call_id): return True
270
271     (hrn, type) = urn_to_hrn(xrn)
272     # get the callers hrn
273     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
274     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
275
276     # attempt to use delegated credential first
277     cred = api.getDelegatedCredential(creds)
278     if not cred:
279         cred = api.getCredential()
280     threads = ThreadManager()
281     for aggregate in api.aggregates:
282         # prevent infinite loop. Dont send request back to caller
283         # unless the caller is the aggregate's SM
284         if caller_hrn == aggregate and aggregate != api.hrn:
285             continue
286         interface = api.aggregates[aggregate]
287         server = api.get_server(interface, cred)
288         threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
289     # 'and' the results
290     return reduce (lambda x,y: x and y, threads.get_results() , True)
291
292 def DeleteSliver(api, xrn, creds, call_id):
293     def _DeleteSliver(server, xrn, creds, call_id):
294         server_version = api.get_cached_server_version(server)
295         args =  [xrn, creds]
296         if _call_id_supported(api, server):
297             args.append(call_id)
298         return server.DeleteSliver(*args)
299
300     if Callids().already_handled(call_id): return ""
301     (hrn, type) = urn_to_hrn(xrn)
302     # get the callers hrn
303     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
304     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
305
306     # attempt to use delegated credential first
307     cred = api.getDelegatedCredential(creds)
308     if not cred:
309         cred = api.getCredential()
310     threads = ThreadManager()
311     for aggregate in api.aggregates:
312         # prevent infinite loop. Dont send request back to caller
313         # unless the caller is the aggregate's SM
314         if caller_hrn == aggregate and aggregate != api.hrn:
315             continue
316         interface = api.aggregates[aggregate]
317         server = api.get_server(interface, cred)
318         threads.run(_DeleteSliver, server, xrn, [cred], call_id)
319     threads.get_results()
320     return 1
321
322
323 # first draft at a merging SliverStatus
324 def SliverStatus(api, slice_xrn, creds, call_id):
325     def _SliverStatus(server, xrn, creds, call_id):
326         server_version = api.get_cached_server_version(server)
327         args =  [xrn, creds]
328         if _call_id_supported(api, server):
329             args.append(call_id)
330         return server.SliverStatus(*args)
331     
332     if Callids().already_handled(call_id): return {}
333     # attempt to use delegated credential first
334     cred = api.getDelegatedCredential(creds)
335     if not cred:
336         cred = api.getCredential()
337     threads = ThreadManager()
338     for aggregate in api.aggregates:
339         interface = api.aggregates[aggregate]
340         server = api.get_server(interface, cred)
341         threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
342     results = threads.get_results()
343
344     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
345     results = [ result for result in results if result and result['geni_resources']]
346
347     # do not try to combine if there's no result
348     if not results : return {}
349
350     # otherwise let's merge stuff
351     overall = {}
352
353     # mmh, it is expected that all results carry the same urn
354     overall['geni_urn'] = results[0]['geni_urn']
355     overall['pl_login'] = results[0]['pl_login']
356     # append all geni_resources
357     overall['geni_resources'] = \
358         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
359     overall['status'] = 'unknown'
360     if overall['geni_resources']:
361         overall['status'] = 'ready'
362
363     return overall
364
365 caching=True
366 #caching=False
367 def ListSlices(api, creds, call_id):
368     def _ListSlices(server, creds, call_id):
369         server_version = api.get_cached_server_version(server)
370         args =  [creds]
371         if _call_id_supported(api, server):
372             args.append(call_id)
373         return server.ListSlices(*args)
374
375     if Callids().already_handled(call_id): return []
376
377     # look in cache first
378     if caching and api.cache:
379         slices = api.cache.get('slices')
380         if slices:
381             return slices
382
383     # get the callers hrn
384     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
385     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
386
387     # attempt to use delegated credential first
388     cred= api.getDelegatedCredential(creds)
389     if not cred:
390         cred = api.getCredential()
391     threads = ThreadManager()
392     # fetch from aggregates
393     for aggregate in api.aggregates:
394         # prevent infinite loop. Dont send request back to caller
395         # unless the caller is the aggregate's SM
396         if caller_hrn == aggregate and aggregate != api.hrn:
397             continue
398         interface = api.aggregates[aggregate]
399         server = api.get_server(interface, cred)
400         threads.run(_ListSlices, server, [cred], call_id)
401
402     # combime results
403     results = threads.get_results()
404     slices = []
405     for result in results:
406         slices.extend(result)
407
408     # cache the result
409     if caching and api.cache:
410         api.cache.add('slices', slices)
411
412     return slices
413
414
415 def get_ticket(api, xrn, creds, rspec, users):
416     slice_hrn, type = urn_to_hrn(xrn)
417     # get the netspecs contained within the clients rspec
418     aggregate_rspecs = {}
419     tree= etree.parse(StringIO(rspec))
420     elements = tree.findall('./network')
421     for element in elements:
422         aggregate_hrn = element.values()[0]
423         aggregate_rspecs[aggregate_hrn] = rspec 
424
425     # get the callers hrn
426     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
427     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
428
429     # attempt to use delegated credential first
430     cred = api.getDelegatedCredential(creds)
431     if not cred:
432         cred = api.getCredential() 
433     threads = ThreadManager()
434     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
435         # prevent infinite loop. Dont send request back to caller
436         # unless the caller is the aggregate's SM
437         if caller_hrn == aggregate and aggregate != api.hrn:
438             continue
439         
440         interface = api.aggregates[aggregate]
441         server = api.get_server(interface, cred)
442         threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
443
444     results = threads.get_results()
445     
446     # gather information from each ticket 
447     rspecs = []
448     initscripts = []
449     slivers = [] 
450     object_gid = None  
451     for result in results:
452         agg_ticket = SfaTicket(string=result)
453         attrs = agg_ticket.get_attributes()
454         if not object_gid:
455             object_gid = agg_ticket.get_gid_object()
456         rspecs.append(agg_ticket.get_rspec())
457         initscripts.extend(attrs.get('initscripts', [])) 
458         slivers.extend(attrs.get('slivers', [])) 
459     
460     # merge info
461     attributes = {'initscripts': initscripts,
462                  'slivers': slivers}
463     merged_rspec = merge_rspecs(rspecs) 
464
465     # create a new ticket
466     ticket = SfaTicket(subject = slice_hrn)
467     ticket.set_gid_caller(api.auth.client_gid)
468     ticket.set_issuer(key=api.key, subject=api.hrn)
469     ticket.set_gid_object(object_gid)
470     ticket.set_pubkey(object_gid.get_pubkey())
471     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
472     ticket.set_attributes(attributes)
473     ticket.set_rspec(merged_rspec)
474     ticket.encode()
475     ticket.sign()          
476     return ticket.save_to_string(save_parents=True)
477
478 def start_slice(api, xrn, creds):
479     hrn, type = urn_to_hrn(xrn)
480
481     # get the callers hrn
482     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
483     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
484
485     # attempt to use delegated credential first
486     cred = api.getDelegatedCredential(creds)
487     if not cred:
488         cred = api.getCredential()
489     threads = ThreadManager()
490     for aggregate in api.aggregates:
491         # prevent infinite loop. Dont send request back to caller
492         # unless the caller is the aggregate's SM
493         if caller_hrn == aggregate and aggregate != api.hrn:
494             continue
495         interface = api.aggregates[aggregate]
496         server = api.get_server(interface, cred)    
497         threads.run(server.Start, xrn, cred)
498     threads.get_results()    
499     return 1
500  
501 def stop_slice(api, xrn, creds):
502     hrn, type = urn_to_hrn(xrn)
503
504     # get the callers hrn
505     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
506     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
507
508     # attempt to use delegated credential first
509     cred = api.getDelegatedCredential(creds)
510     if not cred:
511         cred = api.getCredential()
512     threads = ThreadManager()
513     for aggregate in api.aggregates:
514         # prevent infinite loop. Dont send request back to caller
515         # unless the caller is the aggregate's SM
516         if caller_hrn == aggregate and aggregate != api.hrn:
517             continue
518         interface = api.aggregates[aggregate]
519         server = api.get_server(interface, cred)
520         threads.run(server.Stop, xrn, cred)
521     threads.get_results()    
522     return 1
523
524 def reset_slice(api, xrn):
525     """
526     Not implemented
527     """
528     return 1
529
530 def shutdown(api, xrn, creds):
531     """
532     Not implemented   
533     """
534     return 1
535
536 def status(api, xrn, creds):
537     """
538     Not implemented 
539     """
540     return 1
541
542 def main():
543     r = RSpec()
544     r.parseFile(sys.argv[1])
545     rspec = r.toDict()
546     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
547
548 if __name__ == "__main__":
549     main()
550