Merge branch 'master' of ssh://git.onelab.eu/git/sfa
[sfa.git] / sfa / managers / slice_manager.py
1 /
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.faults import *
15 from sfa.util.record import SfaRecord
16 from sfa.rspecs.rspec_converter import RSpecConverter
17 from sfa.client.client_helper import sfa_to_pg_users_arg
18 from sfa.rspecs.version_manager import VersionManager
19 from sfa.rspecs.rspec import RSpec 
20 from sfa.util.policy import Policy
21 from sfa.util.prefixTree import prefixTree
22 from sfa.trust.sfaticket import SfaTicket
23 from sfa.trust.credential import Credential
24 from sfa.util.threadmanager import ThreadManager
25 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
26 import sfa.plc.peers as peers
27 from sfa.util.version import version_core
28 from sfa.util.callids import Callids
29
30
31 def _call_id_supported(api, server):
32     """
33     Returns true if server support the optional call_id arg, false otherwise.
34     """
35     server_version = api.get_cached_server_version(server)
36
37     if 'sfa' in server_version:
38         code_tag = server_version['code_tag']
39         code_tag_parts = code_tag.split("-")
40
41         version_parts = code_tag_parts[0].split(".")
42         major, minor = version_parts[0:2]
43         rev = code_tag_parts[1]
44         if int(major) > 1:
45             if int(minor) > 0 or int(rev) > 20:
46                 return True
47     return False
48
49 # we have specialized xmlrpclib.ServerProxy to remember the input url
50 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
51 def get_serverproxy_url (server):
52     try:
53         return server.get_url()
54     except:
55         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
56         return server._ServerProxy__host + server._ServerProxy__handler
57
58 def GetVersion(api):
59     # peers explicitly in aggregates.xml
60     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
61                    if peername != api.hrn])
62     version_manager = VersionManager()
63     ad_rspec_versions = []
64     request_rspec_versions = []
65     for rspec_version in version_manager.versions:
66         if rspec_version.content_type in ['*', 'ad']:
67             ad_rspec_versions.append(rspec_version.to_dict())
68         if rspec_version.content_type in ['*', 'request']:
69             request_rspec_versions.append(rspec_version.to_dict())
70     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
71     xrn=Xrn(api.hrn, 'authority+sa')
72     version_more = {'interface':'slicemgr',
73                     'hrn' : xrn.get_hrn(),
74                     'urn' : xrn.get_urn(),
75                     'peers': peers,
76                     'request_rspec_versions': request_rspec_versions,
77                     'ad_rspec_versions': ad_rspec_versions,
78                     'default_ad_rspec': default_rspec_version
79                     }
80     sm_version=version_core(version_more)
81     # local aggregate if present needs to have localhost resolved
82     if api.hrn in api.aggregates:
83         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
84         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
85     return sm_version
86
87 def drop_slicemgr_stats(rspec):
88     try:
89         stats_elements = rspec.xml.xpath('//statistics')
90         for node in stats_elements:
91             node.getparent().remove(node)
92     except Exception, e:
93         api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
94
95 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
96     try:
97         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
98         if stats_tags:
99             stats_tag = stats_tags[0]
100         else:
101             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
102
103         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
104     except Exception, e:
105         api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
106
107 def ListResources(api, creds, options, call_id):
108     version_manager = VersionManager()
109     def _ListResources(aggregate, server, credential, opts, call_id):
110
111         my_opts = copy(opts)
112         args = [credential, my_opts]
113         tStart = time.time()
114         try:
115             if _call_id_supported(api, server):
116                 args.append(call_id)
117             version = api.get_cached_server_version(server)
118             # force ProtoGENI aggregates to give us a v2 RSpec
119             if 'sfa' not in version.keys():
120                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
121             rspec = server.ListResources(*args)
122             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
123         except Exception, e:
124             api.logger.log_exc("ListResources failed at %s" %(server.url))
125             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
126
127     if Callids().already_handled(call_id): return ""
128
129     # get slice's hrn from options
130     xrn = options.get('geni_slice_urn', '')
131     (hrn, type) = urn_to_hrn(xrn)
132     if 'geni_compressed' in options:
133         del(options['geni_compressed'])
134
135     # get the rspec's return format from options
136     rspec_version = version_manager.get_version(options.get('rspec_version'))
137     version_string = "rspec_%s" % (rspec_version.to_string())
138
139     # look in cache first
140     if caching and api.cache and not xrn:
141         rspec =  api.cache.get(version_string)
142         if rspec:
143             return rspec
144
145     # get the callers hrn
146     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
147     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
148
149     # attempt to use delegated credential first
150     cred = api.getDelegatedCredential(creds)
151     if not cred:
152         cred = api.getCredential()
153     threads = ThreadManager()
154     for aggregate in api.aggregates:
155         # prevent infinite loop. Dont send request back to caller
156         # unless the caller is the aggregate's SM
157         if caller_hrn == aggregate and aggregate != api.hrn:
158             continue
159
160         # get the rspec from the aggregate
161         interface = api.aggregates[aggregate]
162         server = api.get_server(interface, cred)
163         threads.run(_ListResources, aggregate, server, [cred], options, call_id)
164
165
166     results = threads.get_results()
167     rspec_version = version_manager.get_version(options.get('rspec_version'))
168     if xrn:    
169         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
170     else: 
171         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
172     rspec = RSpec(version=result_version)
173     for result in results:
174         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
175         if result["status"]=="success":
176             try:
177                 rspec.version.merge(result["rspec"])
178             except:
179                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
180
181     # cache the result
182     if caching and api.cache and not xrn:
183         api.cache.add(version_string, rspec.toxml())
184
185     return rspec.toxml()
186
187
188 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
189
190     version_manager = VersionManager()
191     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
192         tStart = time.time()
193         try:
194             # Need to call GetVersion at an aggregate to determine the supported
195             # rspec type/format beofre calling CreateSliver at an Aggregate.
196             server_version = api.get_cached_server_version(server)
197             requested_users = users
198             if 'sfa' not in server_version and 'geni_api' in server_version:
199                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
200                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
201                 rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
202                 filter = {'component_manager_id': server_version['urn']}
203                 rspec.filter(filter)
204                 rspec = rspec.toxml()
205                 requested_users = sfa_to_pg_users_arg(users)
206             args = [xrn, credential, rspec, requested_users]
207             if _call_id_supported(api, server):
208                 args.append(call_id)
209             rspec = server.CreateSliver(*args)
210             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
211         except: 
212             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
213             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
214
215     if Callids().already_handled(call_id): return ""
216     # Validate the RSpec against PlanetLab's schema --disabled for now
217     # The schema used here needs to aggregate the PL and VINI schemas
218     # schema = "/var/www/html/schemas/pl.rng"
219     rspec = RSpec(rspec_str)
220     schema = None
221     if schema:
222         rspec.validate(schema)
223
224     # if there is a <statistics> section, the aggregates don't care about it,
225     # so delete it.
226     drop_slicemgr_stats(rspec)
227
228     # attempt to use delegated credential first
229     cred = api.getDelegatedCredential(creds)
230     if not cred:
231         cred = api.getCredential()
232
233     # get the callers hrn
234     hrn, type = urn_to_hrn(xrn)
235     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
236     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
237     threads = ThreadManager()
238     for aggregate in api.aggregates:
239         # prevent infinite loop. Dont send request back to caller
240         # unless the caller is the aggregate's SM 
241         if caller_hrn == aggregate and aggregate != api.hrn:
242             continue
243         interface = api.aggregates[aggregate]
244         server = api.get_server(interface, cred)
245         # Just send entire RSpec to each aggregate
246         threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
247             
248     results = threads.get_results()
249     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
250     result_rspec = RSpec(version=manifest_version)
251     for result in results:
252         add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
253         if result["status"]=="success":
254             try:
255                 result_rspec.version.merge(result["rspec"])
256             except:
257                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
258     return result_rspec.toxml()
259
260 def RenewSliver(api, xrn, creds, expiration_time, call_id):
261     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
262         server_version = api.get_cached_server_version(server)
263         args =  [xrn, creds, expiration_time, call_id]
264         if _call_id_supported(api, server):
265             args.append(call_id)
266         return server.RenewSliver(*args)
267
268     if Callids().already_handled(call_id): return True
269
270     (hrn, type) = urn_to_hrn(xrn)
271     # get the callers hrn
272     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
273     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
274
275     # attempt to use delegated credential first
276     cred = api.getDelegatedCredential(creds)
277     if not cred:
278         cred = api.getCredential()
279     threads = ThreadManager()
280     for aggregate in api.aggregates:
281         # prevent infinite loop. Dont send request back to caller
282         # unless the caller is the aggregate's SM
283         if caller_hrn == aggregate and aggregate != api.hrn:
284             continue
285         interface = api.aggregates[aggregate]
286         server = api.get_server(interface, cred)
287         threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
288     # 'and' the results
289     return reduce (lambda x,y: x and y, threads.get_results() , True)
290
291 def DeleteSliver(api, xrn, creds, call_id):
292     def _DeleteSliver(server, xrn, creds, call_id):
293         server_version = api.get_cached_server_version(server)
294         args =  [xrn, creds]
295         if _call_id_supported(api, server):
296             args.append(call_id)
297         return server.DeleteSliver(*args)
298
299     if Callids().already_handled(call_id): return ""
300     (hrn, type) = urn_to_hrn(xrn)
301     # get the callers hrn
302     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
303     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
304
305     # attempt to use delegated credential first
306     cred = api.getDelegatedCredential(creds)
307     if not cred:
308         cred = api.getCredential()
309     threads = ThreadManager()
310     for aggregate in api.aggregates:
311         # prevent infinite loop. Dont send request back to caller
312         # unless the caller is the aggregate's SM
313         if caller_hrn == aggregate and aggregate != api.hrn:
314             continue
315         interface = api.aggregates[aggregate]
316         server = api.get_server(interface, cred)
317         threads.run(_DeleteSliver, server, xrn, [cred], call_id)
318     threads.get_results()
319     return 1
320
321
322 # first draft at a merging SliverStatus
323 def SliverStatus(api, slice_xrn, creds, call_id):
324     def _SliverStatus(server, xrn, creds, call_id):
325         server_version = api.get_cached_server_version(server)
326         args =  [xrn, creds]
327         if _call_id_supported(api, server):
328             args.append(call_id)
329         return server.SliverStatus(*args)
330     
331     if Callids().already_handled(call_id): return {}
332     # attempt to use delegated credential first
333     cred = api.getDelegatedCredential(creds)
334     if not cred:
335         cred = api.getCredential()
336     threads = ThreadManager()
337     for aggregate in api.aggregates:
338         interface = api.aggregates[aggregate]
339         server = api.get_server(interface, cred)
340         threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
341     results = threads.get_results()
342
343     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
344     results = [ result for result in results if result and result['geni_resources']]
345
346     # do not try to combine if there's no result
347     if not results : return {}
348
349     # otherwise let's merge stuff
350     overall = {}
351
352     # mmh, it is expected that all results carry the same urn
353     overall['geni_urn'] = results[0]['geni_urn']
354     overall['pl_login'] = results[0]['pl_login']
355     # append all geni_resources
356     overall['geni_resources'] = \
357         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
358     overall['status'] = 'unknown'
359     if overall['geni_resources']:
360         overall['status'] = 'ready'
361
362     return overall
363
364 caching=True
365 #caching=False
366 def ListSlices(api, creds, call_id):
367     def _ListSlices(server, creds, call_id):
368         server_version = api.get_cached_server_version(server)
369         args =  [creds]
370         if _call_id_supported(api, server):
371             args.append(call_id)
372         return server.ListSlices(*args)
373
374     if Callids().already_handled(call_id): return []
375
376     # look in cache first
377     if caching and api.cache:
378         slices = api.cache.get('slices')
379         if slices:
380             return slices
381
382     # get the callers hrn
383     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
384     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
385
386     # attempt to use delegated credential first
387     cred= api.getDelegatedCredential(creds)
388     if not cred:
389         cred = api.getCredential()
390     threads = ThreadManager()
391     # fetch from aggregates
392     for aggregate in api.aggregates:
393         # prevent infinite loop. Dont send request back to caller
394         # unless the caller is the aggregate's SM
395         if caller_hrn == aggregate and aggregate != api.hrn:
396             continue
397         interface = api.aggregates[aggregate]
398         server = api.get_server(interface, cred)
399         threads.run(_ListSlices, server, [cred], call_id)
400
401     # combime results
402     results = threads.get_results()
403     slices = []
404     for result in results:
405         slices.extend(result)
406
407     # cache the result
408     if caching and api.cache:
409         api.cache.add('slices', slices)
410
411     return slices
412
413
414 def get_ticket(api, xrn, creds, rspec, users):
415     slice_hrn, type = urn_to_hrn(xrn)
416     # get the netspecs contained within the clients rspec
417     aggregate_rspecs = {}
418     tree= etree.parse(StringIO(rspec))
419     elements = tree.findall('./network')
420     for element in elements:
421         aggregate_hrn = element.values()[0]
422         aggregate_rspecs[aggregate_hrn] = rspec 
423
424     # get the callers hrn
425     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
426     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
427
428     # attempt to use delegated credential first
429     cred = api.getDelegatedCredential(creds)
430     if not cred:
431         cred = api.getCredential() 
432     threads = ThreadManager()
433     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
434         # prevent infinite loop. Dont send request back to caller
435         # unless the caller is the aggregate's SM
436         if caller_hrn == aggregate and aggregate != api.hrn:
437             continue
438         
439         interface = api.aggregates[aggregate]
440         server = api.get_server(interface, cred)
441         threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
442
443     results = threads.get_results()
444     
445     # gather information from each ticket 
446     rspecs = []
447     initscripts = []
448     slivers = [] 
449     object_gid = None  
450     for result in results:
451         agg_ticket = SfaTicket(string=result)
452         attrs = agg_ticket.get_attributes()
453         if not object_gid:
454             object_gid = agg_ticket.get_gid_object()
455         rspecs.append(agg_ticket.get_rspec())
456         initscripts.extend(attrs.get('initscripts', [])) 
457         slivers.extend(attrs.get('slivers', [])) 
458     
459     # merge info
460     attributes = {'initscripts': initscripts,
461                  'slivers': slivers}
462     merged_rspec = merge_rspecs(rspecs) 
463
464     # create a new ticket
465     ticket = SfaTicket(subject = slice_hrn)
466     ticket.set_gid_caller(api.auth.client_gid)
467     ticket.set_issuer(key=api.key, subject=api.hrn)
468     ticket.set_gid_object(object_gid)
469     ticket.set_pubkey(object_gid.get_pubkey())
470     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
471     ticket.set_attributes(attributes)
472     ticket.set_rspec(merged_rspec)
473     ticket.encode()
474     ticket.sign()          
475     return ticket.save_to_string(save_parents=True)
476
477 def start_slice(api, xrn, creds):
478     hrn, type = urn_to_hrn(xrn)
479
480     # get the callers hrn
481     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
482     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
483
484     # attempt to use delegated credential first
485     cred = api.getDelegatedCredential(creds)
486     if not cred:
487         cred = api.getCredential()
488     threads = ThreadManager()
489     for aggregate in api.aggregates:
490         # prevent infinite loop. Dont send request back to caller
491         # unless the caller is the aggregate's SM
492         if caller_hrn == aggregate and aggregate != api.hrn:
493             continue
494         interface = api.aggregates[aggregate]
495         server = api.get_server(interface, cred)    
496         threads.run(server.Start, xrn, cred)
497     threads.get_results()    
498     return 1
499  
500 def stop_slice(api, xrn, creds):
501     hrn, type = urn_to_hrn(xrn)
502
503     # get the callers hrn
504     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
505     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
506
507     # attempt to use delegated credential first
508     cred = api.getDelegatedCredential(creds)
509     if not cred:
510         cred = api.getCredential()
511     threads = ThreadManager()
512     for aggregate in api.aggregates:
513         # prevent infinite loop. Dont send request back to caller
514         # unless the caller is the aggregate's SM
515         if caller_hrn == aggregate and aggregate != api.hrn:
516             continue
517         interface = api.aggregates[aggregate]
518         server = api.get_server(interface, cred)
519         threads.run(server.Stop, xrn, cred)
520     threads.get_results()    
521     return 1
522
523 def reset_slice(api, xrn):
524     """
525     Not implemented
526     """
527     return 1
528
529 def shutdown(api, xrn, creds):
530     """
531     Not implemented   
532     """
533     return 1
534
535 def status(api, xrn, creds):
536     """
537     Not implemented 
538     """
539     return 1
540
541 def main():
542     r = RSpec()
543     r.parseFile(sys.argv[1])
544     rspec = r.toDict()
545     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
546
547 if __name__ == "__main__":
548     main()
549