Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / managers / slice_manager.py
1 import sys
2 import time,datetime
3 from StringIO import StringIO
4 from types import StringTypes
5 from copy import deepcopy
6 from copy import copy
7 from lxml import etree
8
9 from sfa.util.sfalogging import logger
10 from sfa.util.rspecHelper import merge_rspecs
11 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
12 from sfa.util.plxrn import hrn_to_pl_slicename
13 from sfa.util.faults import *
14 from sfa.util.record import SfaRecord
15 from sfa.rspecs.rspec_converter import RSpecConverter
16 from sfa.client.client_helper import sfa_to_pg_users_arg
17 from sfa.rspecs.version_manager import VersionManager
18 from sfa.rspecs.rspec import RSpec 
19 from sfa.util.policy import Policy
20 from sfa.util.prefixTree import prefixTree
21 from sfa.trust.sfaticket import SfaTicket
22 from sfa.trust.credential import Credential
23 from sfa.util.threadmanager import ThreadManager
24 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
25 import sfa.plc.peers as peers
26 from sfa.util.version import version_core
27 from sfa.util.callids import Callids
28
29
30 def _call_id_supported(api, server):
31     """
32     Returns true if server support the optional call_id arg, false otherwise.
33     """
34     server_version = api.get_cached_server_version(server)
35
36     if 'sfa' in server_version:
37         code_tag = server_version['code_tag']
38         code_tag_parts = code_tag.split("-")
39
40         version_parts = code_tag_parts[0].split(".")
41         major, minor = version_parts[0:2]
42         rev = code_tag_parts[1]
43         if int(major) > 1:
44             if int(minor) > 0 or int(rev) > 20:
45                 return True
46     return False
47
48 # we have specialized xmlrpclib.ServerProxy to remember the input url
49 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
50 def get_serverproxy_url (server):
51     try:
52         return server.get_url()
53     except:
54         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
55         return server._ServerProxy__host + server._ServerProxy__handler
56
57 def GetVersion(api):
58     # peers explicitly in aggregates.xml
59     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
60                    if peername != api.hrn])
61     version_manager = VersionManager()
62     ad_rspec_versions = []
63     request_rspec_versions = []
64     for rspec_version in version_manager.versions:
65         if rspec_version.content_type in ['*', 'ad']:
66             ad_rspec_versions.append(rspec_version.to_dict())
67         if rspec_version.content_type in ['*', 'request']:
68             request_rspec_versions.append(rspec_version.to_dict())
69     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
70     xrn=Xrn(api.hrn, 'authority+sa')
71     version_more = {'interface':'slicemgr',
72                     'hrn' : xrn.get_hrn(),
73                     'urn' : xrn.get_urn(),
74                     'peers': peers,
75                     'request_rspec_versions': request_rspec_versions,
76                     'ad_rspec_versions': ad_rspec_versions,
77                     'default_ad_rspec': default_rspec_version
78                     }
79     sm_version=version_core(version_more)
80     # local aggregate if present needs to have localhost resolved
81     if api.hrn in api.aggregates:
82         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
83         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
84     return sm_version
85
86 def drop_slicemgr_stats(rspec):
87     try:
88         stats_elements = rspec.xml.xpath('//statistics')
89         for node in stats_elements:
90             node.getparent().remove(node)
91     except Exception, e:
92         api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
93
94 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
95     try:
96         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
97         if stats_tags:
98             stats_tag = stats_tags[0]
99         else:
100             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
101
102         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
103     except Exception, e:
104         api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
105
106 def ListResources(api, creds, options, call_id):
107     version_manager = VersionManager()
108     def _ListResources(aggregate, server, credential, opts, call_id):
109
110         my_opts = copy(opts)
111         args = [credential, my_opts]
112         tStart = time.time()
113         try:
114             if _call_id_supported(api, server):
115                 args.append(call_id)
116             version = api.get_cached_server_version(server)
117             # force ProtoGENI aggregates to give us a v2 RSpec
118             if 'sfa' not in version.keys():
119                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
120             rspec = server.ListResources(*args)
121             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
122         except Exception, e:
123             api.logger.log_exc("ListResources failed at %s" %(server.url))
124             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
125
126     if Callids().already_handled(call_id): return ""
127
128     # get slice's hrn from options
129     xrn = options.get('geni_slice_urn', '')
130     (hrn, type) = urn_to_hrn(xrn)
131     if 'geni_compressed' in options:
132         del(options['geni_compressed'])
133
134     # get the rspec's return format from options
135     rspec_version = version_manager.get_version(options.get('rspec_version'))
136     version_string = "rspec_%s" % (rspec_version.to_string())
137
138     # look in cache first
139     if caching and api.cache and not xrn:
140         rspec =  api.cache.get(version_string)
141         if rspec:
142             return rspec
143
144     # get the callers hrn
145     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
146     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
147
148     # attempt to use delegated credential first
149     cred = api.getDelegatedCredential(creds)
150     if not cred:
151         cred = api.getCredential()
152     threads = ThreadManager()
153     for aggregate in api.aggregates:
154         # prevent infinite loop. Dont send request back to caller
155         # unless the caller is the aggregate's SM
156         if caller_hrn == aggregate and aggregate != api.hrn:
157             continue
158
159         # get the rspec from the aggregate
160         interface = api.aggregates[aggregate]
161         server = api.get_server(interface, cred)
162         threads.run(_ListResources, aggregate, server, [cred], options, call_id)
163
164
165     results = threads.get_results()
166     rspec_version = version_manager.get_version(options.get('rspec_version'))
167     if xrn:    
168         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
169     else: 
170         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
171     rspec = RSpec(version=result_version)
172     for result in results:
173         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
174         if result["status"]=="success":
175             try:
176                 rspec.version.merge(result["rspec"])
177             except:
178                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
179
180     # cache the result
181     if caching and api.cache and not xrn:
182         api.cache.add(version_string, rspec.toxml())
183
184     return rspec.toxml()
185
186
187 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
188
189     version_manager = VersionManager()
190     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
191         tStart = time.time()
192         try:
193             # Need to call GetVersion at an aggregate to determine the supported
194             # rspec type/format beofre calling CreateSliver at an Aggregate.
195             server_version = api.get_cached_server_version(server)
196             requested_users = users
197             if 'sfa' not in server_version and 'geni_api' in server_version:
198                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
199                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
200                 rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
201                 filter = {'component_manager_id': server_version['urn']}
202                 rspec.filter(filter)
203                 rspec = rspec.toxml()
204                 requested_users = sfa_to_pg_users_arg(users)
205             args = [xrn, credential, rspec, requested_users]
206             if _call_id_supported(api, server):
207                 args.append(call_id)
208             rspec = server.CreateSliver(*args)
209             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
210         except: 
211             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
212             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
213
214     if Callids().already_handled(call_id): return ""
215     # Validate the RSpec against PlanetLab's schema --disabled for now
216     # The schema used here needs to aggregate the PL and VINI schemas
217     # schema = "/var/www/html/schemas/pl.rng"
218     rspec = RSpec(rspec_str)
219     schema = None
220     if schema:
221         rspec.validate(schema)
222
223     # if there is a <statistics> section, the aggregates don't care about it,
224     # so delete it.
225     drop_slicemgr_stats(rspec)
226
227     # attempt to use delegated credential first
228     cred = api.getDelegatedCredential(creds)
229     if not cred:
230         cred = api.getCredential()
231
232     # get the callers hrn
233     hrn, type = urn_to_hrn(xrn)
234     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
235     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
236     threads = ThreadManager()
237     for aggregate in api.aggregates:
238         # prevent infinite loop. Dont send request back to caller
239         # unless the caller is the aggregate's SM 
240         if caller_hrn == aggregate and aggregate != api.hrn:
241             continue
242         interface = api.aggregates[aggregate]
243         server = api.get_server(interface, cred)
244         # Just send entire RSpec to each aggregate
245         threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
246             
247     results = threads.get_results()
248     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
249     result_rspec = RSpec(version=manifest_version)
250     for result in results:
251         add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
252         if result["status"]=="success":
253             try:
254                 result_rspec.version.merge(result["rspec"])
255             except:
256                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
257     return result_rspec.toxml()
258
259 def RenewSliver(api, xrn, creds, expiration_time, call_id):
260     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
261         server_version = api.get_cached_server_version(server)
262         args =  [xrn, creds, expiration_time, call_id]
263         if _call_id_supported(api, server):
264             args.append(call_id)
265         return server.RenewSliver(*args)
266
267     if Callids().already_handled(call_id): return True
268
269     (hrn, type) = urn_to_hrn(xrn)
270     # get the callers hrn
271     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
272     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
273
274     # attempt to use delegated credential first
275     cred = api.getDelegatedCredential(creds)
276     if not cred:
277         cred = api.getCredential()
278     threads = ThreadManager()
279     for aggregate in api.aggregates:
280         # prevent infinite loop. Dont send request back to caller
281         # unless the caller is the aggregate's SM
282         if caller_hrn == aggregate and aggregate != api.hrn:
283             continue
284         interface = api.aggregates[aggregate]
285         server = api.get_server(interface, cred)
286         threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
287     # 'and' the results
288     return reduce (lambda x,y: x and y, threads.get_results() , True)
289
290 def DeleteSliver(api, xrn, creds, call_id):
291     def _DeleteSliver(server, xrn, creds, call_id):
292         server_version = api.get_cached_server_version(server)
293         args =  [xrn, creds]
294         if _call_id_supported(api, server):
295             args.append(call_id)
296         return server.DeleteSliver(*args)
297
298     if Callids().already_handled(call_id): return ""
299     (hrn, type) = urn_to_hrn(xrn)
300     # get the callers hrn
301     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
302     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
303
304     # attempt to use delegated credential first
305     cred = api.getDelegatedCredential(creds)
306     if not cred:
307         cred = api.getCredential()
308     threads = ThreadManager()
309     for aggregate in api.aggregates:
310         # prevent infinite loop. Dont send request back to caller
311         # unless the caller is the aggregate's SM
312         if caller_hrn == aggregate and aggregate != api.hrn:
313             continue
314         interface = api.aggregates[aggregate]
315         server = api.get_server(interface, cred)
316         threads.run(_DeleteSliver, server, xrn, [cred], call_id)
317     threads.get_results()
318     return 1
319
320
321 # first draft at a merging SliverStatus
322 def SliverStatus(api, slice_xrn, creds, call_id):
323     def _SliverStatus(server, xrn, creds, call_id):
324         server_version = api.get_cached_server_version(server)
325         args =  [xrn, creds]
326         if _call_id_supported(api, server):
327             args.append(call_id)
328         return server.SliverStatus(*args)
329     
330     if Callids().already_handled(call_id): return {}
331     # attempt to use delegated credential first
332     cred = api.getDelegatedCredential(creds)
333     if not cred:
334         cred = api.getCredential()
335     threads = ThreadManager()
336     for aggregate in api.aggregates:
337         interface = api.aggregates[aggregate]
338         server = api.get_server(interface, cred)
339         threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
340     results = threads.get_results()
341
342     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
343     results = [ result for result in results if result and result['geni_resources']]
344
345     # do not try to combine if there's no result
346     if not results : return {}
347
348     # otherwise let's merge stuff
349     overall = {}
350
351     # mmh, it is expected that all results carry the same urn
352     overall['geni_urn'] = results[0]['geni_urn']
353     overall['pl_login'] = results[0]['pl_login']
354     # append all geni_resources
355     overall['geni_resources'] = \
356         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
357     overall['status'] = 'unknown'
358     if overall['geni_resources']:
359         overall['status'] = 'ready'
360
361     return overall
362
363 caching=True
364 #caching=False
365 def ListSlices(api, creds, call_id):
366     def _ListSlices(server, creds, call_id):
367         server_version = api.get_cached_server_version(server)
368         args =  [creds]
369         if _call_id_supported(api, server):
370             args.append(call_id)
371         return server.ListSlices(*args)
372
373     if Callids().already_handled(call_id): return []
374
375     # look in cache first
376     if caching and api.cache:
377         slices = api.cache.get('slices')
378         if slices:
379             return slices
380
381     # get the callers hrn
382     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
383     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
384
385     # attempt to use delegated credential first
386     cred= api.getDelegatedCredential(creds)
387     if not cred:
388         cred = api.getCredential()
389     threads = ThreadManager()
390     # fetch from aggregates
391     for aggregate in api.aggregates:
392         # prevent infinite loop. Dont send request back to caller
393         # unless the caller is the aggregate's SM
394         if caller_hrn == aggregate and aggregate != api.hrn:
395             continue
396         interface = api.aggregates[aggregate]
397         server = api.get_server(interface, cred)
398         threads.run(_ListSlices, server, [cred], call_id)
399
400     # combime results
401     results = threads.get_results()
402     slices = []
403     for result in results:
404         slices.extend(result)
405
406     # cache the result
407     if caching and api.cache:
408         api.cache.add('slices', slices)
409
410     return slices
411
412
413 def get_ticket(api, xrn, creds, rspec, users):
414     slice_hrn, type = urn_to_hrn(xrn)
415     # get the netspecs contained within the clients rspec
416     aggregate_rspecs = {}
417     tree= etree.parse(StringIO(rspec))
418     elements = tree.findall('./network')
419     for element in elements:
420         aggregate_hrn = element.values()[0]
421         aggregate_rspecs[aggregate_hrn] = rspec 
422
423     # get the callers hrn
424     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
425     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
426
427     # attempt to use delegated credential first
428     cred = api.getDelegatedCredential(creds)
429     if not cred:
430         cred = api.getCredential() 
431     threads = ThreadManager()
432     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
433         # prevent infinite loop. Dont send request back to caller
434         # unless the caller is the aggregate's SM
435         if caller_hrn == aggregate and aggregate != api.hrn:
436             continue
437         
438         interface = api.aggregates[aggregate]
439         server = api.get_server(interface, cred)
440         threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
441
442     results = threads.get_results()
443     
444     # gather information from each ticket 
445     rspecs = []
446     initscripts = []
447     slivers = [] 
448     object_gid = None  
449     for result in results:
450         agg_ticket = SfaTicket(string=result)
451         attrs = agg_ticket.get_attributes()
452         if not object_gid:
453             object_gid = agg_ticket.get_gid_object()
454         rspecs.append(agg_ticket.get_rspec())
455         initscripts.extend(attrs.get('initscripts', [])) 
456         slivers.extend(attrs.get('slivers', [])) 
457     
458     # merge info
459     attributes = {'initscripts': initscripts,
460                  'slivers': slivers}
461     merged_rspec = merge_rspecs(rspecs) 
462
463     # create a new ticket
464     ticket = SfaTicket(subject = slice_hrn)
465     ticket.set_gid_caller(api.auth.client_gid)
466     ticket.set_issuer(key=api.key, subject=api.hrn)
467     ticket.set_gid_object(object_gid)
468     ticket.set_pubkey(object_gid.get_pubkey())
469     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
470     ticket.set_attributes(attributes)
471     ticket.set_rspec(merged_rspec)
472     ticket.encode()
473     ticket.sign()          
474     return ticket.save_to_string(save_parents=True)
475
476 def start_slice(api, xrn, creds):
477     hrn, type = urn_to_hrn(xrn)
478
479     # get the callers hrn
480     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
481     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
482
483     # attempt to use delegated credential first
484     cred = api.getDelegatedCredential(creds)
485     if not cred:
486         cred = api.getCredential()
487     threads = ThreadManager()
488     for aggregate in api.aggregates:
489         # prevent infinite loop. Dont send request back to caller
490         # unless the caller is the aggregate's SM
491         if caller_hrn == aggregate and aggregate != api.hrn:
492             continue
493         interface = api.aggregates[aggregate]
494         server = api.get_server(interface, cred)    
495         threads.run(server.Start, xrn, cred)
496     threads.get_results()    
497     return 1
498  
499 def stop_slice(api, xrn, creds):
500     hrn, type = urn_to_hrn(xrn)
501
502     # get the callers hrn
503     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
504     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
505
506     # attempt to use delegated credential first
507     cred = api.getDelegatedCredential(creds)
508     if not cred:
509         cred = api.getCredential()
510     threads = ThreadManager()
511     for aggregate in api.aggregates:
512         # prevent infinite loop. Dont send request back to caller
513         # unless the caller is the aggregate's SM
514         if caller_hrn == aggregate and aggregate != api.hrn:
515             continue
516         interface = api.aggregates[aggregate]
517         server = api.get_server(interface, cred)
518         threads.run(server.Stop, xrn, cred)
519     threads.get_results()    
520     return 1
521
522 def reset_slice(api, xrn):
523     """
524     Not implemented
525     """
526     return 1
527
528 def shutdown(api, xrn, creds):
529     """
530     Not implemented   
531     """
532     return 1
533
534 def status(api, xrn, creds):
535     """
536     Not implemented 
537     """
538     return 1
539
540 def main():
541     r = RSpec()
542     r.parseFile(sys.argv[1])
543     rspec = r.toDict()
544     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
545
546 if __name__ == "__main__":
547     main()
548