another step of moving stuff around where it belongs
[sfa.git] / sfa / managers / slice_manager.py
1 import sys
2 import time
3 from StringIO import StringIO
4 from copy import copy
5 from lxml import etree
6
7 from sfa.trust.sfaticket import SfaTicket
8 from sfa.trust.credential import Credential
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.xrn import Xrn, urn_to_hrn
12 from sfa.util.version import version_core
13 from sfa.util.callids import Callids
14
15 from sfa.server.threadmanager import ThreadManager
16
17 from sfa.rspecs.rspec_converter import RSpecConverter
18 from sfa.rspecs.version_manager import VersionManager
19 from sfa.rspecs.rspec import RSpec 
20 from sfa.client.client_helper import sfa_to_pg_users_arg
21
22 def _call_id_supported(api, server):
23     """
24     Returns true if server support the optional call_id arg, false otherwise.
25     """
26     server_version = api.get_cached_server_version(server)
27
28     if 'sfa' in server_version:
29         code_tag = server_version['code_tag']
30         code_tag_parts = code_tag.split("-")
31
32         version_parts = code_tag_parts[0].split(".")
33         major, minor = version_parts[0:2]
34         rev = code_tag_parts[1]
35         if int(major) > 1:
36             if int(minor) > 0 or int(rev) > 20:
37                 return True
38     return False
39
40 # we have specialized xmlrpclib.ServerProxy to remember the input url
41 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
42 def get_serverproxy_url (server):
43     try:
44         return server.get_url()
45     except:
46         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
47         return server._ServerProxy__host + server._ServerProxy__handler
48
49 def GetVersion(api):
50     # peers explicitly in aggregates.xml
51     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
52                    if peername != api.hrn])
53     version_manager = VersionManager()
54     ad_rspec_versions = []
55     request_rspec_versions = []
56     for rspec_version in version_manager.versions:
57         if rspec_version.content_type in ['*', 'ad']:
58             ad_rspec_versions.append(rspec_version.to_dict())
59         if rspec_version.content_type in ['*', 'request']:
60             request_rspec_versions.append(rspec_version.to_dict())
61     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
62     xrn=Xrn(api.hrn, 'authority+sa')
63     version_more = {'interface':'slicemgr',
64                     'hrn' : xrn.get_hrn(),
65                     'urn' : xrn.get_urn(),
66                     'peers': peers,
67                     'request_rspec_versions': request_rspec_versions,
68                     'ad_rspec_versions': ad_rspec_versions,
69                     'default_ad_rspec': default_rspec_version
70                     }
71     sm_version=version_core(version_more)
72     # local aggregate if present needs to have localhost resolved
73     if api.hrn in api.aggregates:
74         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
75         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
76     return sm_version
77
78 def drop_slicemgr_stats(rspec):
79     try:
80         stats_elements = rspec.xml.xpath('//statistics')
81         for node in stats_elements:
82             node.getparent().remove(node)
83     except Exception, e:
84         logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
85
86 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
87     try:
88         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
89         if stats_tags:
90             stats_tag = stats_tags[0]
91         else:
92             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
93
94         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
95     except Exception, e:
96         logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
97
98 def ListResources(api, creds, options, call_id):
99     version_manager = VersionManager()
100     def _ListResources(aggregate, server, credential, opts, call_id):
101
102         my_opts = copy(opts)
103         args = [credential, my_opts]
104         tStart = time.time()
105         try:
106             if _call_id_supported(api, server):
107                 args.append(call_id)
108             version = api.get_cached_server_version(server)
109             # force ProtoGENI aggregates to give us a v2 RSpec
110             if 'sfa' not in version.keys():
111                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
112             rspec = server.ListResources(*args)
113             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
114         except Exception, e:
115             api.logger.log_exc("ListResources failed at %s" %(server.url))
116             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
117
118     if Callids().already_handled(call_id): return ""
119
120     # get slice's hrn from options
121     xrn = options.get('geni_slice_urn', '')
122     (hrn, type) = urn_to_hrn(xrn)
123     if 'geni_compressed' in options:
124         del(options['geni_compressed'])
125
126     # get the rspec's return format from options
127     rspec_version = version_manager.get_version(options.get('rspec_version'))
128     version_string = "rspec_%s" % (rspec_version.to_string())
129
130     # look in cache first
131     if caching and api.cache and not xrn:
132         rspec =  api.cache.get(version_string)
133         if rspec:
134             return rspec
135
136     # get the callers hrn
137     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
138     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
139
140     # attempt to use delegated credential first
141     cred = api.getDelegatedCredential(creds)
142     if not cred:
143         cred = api.getCredential()
144     threads = ThreadManager()
145     for aggregate in api.aggregates:
146         # prevent infinite loop. Dont send request back to caller
147         # unless the caller is the aggregate's SM
148         if caller_hrn == aggregate and aggregate != api.hrn:
149             continue
150
151         # get the rspec from the aggregate
152         interface = api.aggregates[aggregate]
153         server = api.server_proxy(interface, cred)
154         threads.run(_ListResources, aggregate, server, [cred], options, call_id)
155
156
157     results = threads.get_results()
158     rspec_version = version_manager.get_version(options.get('rspec_version'))
159     if xrn:    
160         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
161     else: 
162         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
163     rspec = RSpec(version=result_version)
164     for result in results:
165         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
166         if result["status"]=="success":
167             try:
168                 rspec.version.merge(result["rspec"])
169             except:
170                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
171
172     # cache the result
173     if caching and api.cache and not xrn:
174         api.cache.add(version_string, rspec.toxml())
175
176     return rspec.toxml()
177
178
179 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
180
181     version_manager = VersionManager()
182     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
183         tStart = time.time()
184         try:
185             # Need to call GetVersion at an aggregate to determine the supported
186             # rspec type/format beofre calling CreateSliver at an Aggregate.
187             server_version = api.get_cached_server_version(server)
188             requested_users = users
189             if 'sfa' not in server_version and 'geni_api' in server_version:
190                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
191                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
192                 rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
193                 filter = {'component_manager_id': server_version['urn']}
194                 rspec.filter(filter)
195                 rspec = rspec.toxml()
196                 requested_users = sfa_to_pg_users_arg(users)
197             args = [xrn, credential, rspec, requested_users]
198             if _call_id_supported(api, server):
199                 args.append(call_id)
200             rspec = server.CreateSliver(*args)
201             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
202         except: 
203             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
204             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
205
206     if Callids().already_handled(call_id): return ""
207     # Validate the RSpec against PlanetLab's schema --disabled for now
208     # The schema used here needs to aggregate the PL and VINI schemas
209     # schema = "/var/www/html/schemas/pl.rng"
210     rspec = RSpec(rspec_str)
211 #    schema = None
212 #    if schema:
213 #        rspec.validate(schema)
214
215     # if there is a <statistics> section, the aggregates don't care about it,
216     # so delete it.
217     drop_slicemgr_stats(rspec)
218
219     # attempt to use delegated credential first
220     cred = api.getDelegatedCredential(creds)
221     if not cred:
222         cred = api.getCredential()
223
224     # get the callers hrn
225     hrn, type = urn_to_hrn(xrn)
226     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
227     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
228     threads = ThreadManager()
229     for aggregate in api.aggregates:
230         # prevent infinite loop. Dont send request back to caller
231         # unless the caller is the aggregate's SM 
232         if caller_hrn == aggregate and aggregate != api.hrn:
233             continue
234         interface = api.aggregates[aggregate]
235         server = api.server_proxy(interface, cred)
236         # Just send entire RSpec to each aggregate
237         threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
238             
239     results = threads.get_results()
240     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
241     result_rspec = RSpec(version=manifest_version)
242     for result in results:
243         add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
244         if result["status"]=="success":
245             try:
246                 result_rspec.version.merge(result["rspec"])
247             except:
248                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
249     return result_rspec.toxml()
250
251 def RenewSliver(api, xrn, creds, expiration_time, call_id):
252     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
253         server_version = api.get_cached_server_version(server)
254         args =  [xrn, creds, expiration_time, call_id]
255         if _call_id_supported(api, server):
256             args.append(call_id)
257         return server.RenewSliver(*args)
258
259     if Callids().already_handled(call_id): return True
260
261     (hrn, type) = urn_to_hrn(xrn)
262     # get the callers hrn
263     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
264     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
265
266     # attempt to use delegated credential first
267     cred = api.getDelegatedCredential(creds)
268     if not cred:
269         cred = api.getCredential()
270     threads = ThreadManager()
271     for aggregate in api.aggregates:
272         # prevent infinite loop. Dont send request back to caller
273         # unless the caller is the aggregate's SM
274         if caller_hrn == aggregate and aggregate != api.hrn:
275             continue
276         interface = api.aggregates[aggregate]
277         server = api.server_proxy(interface, cred)
278         threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
279     # 'and' the results
280     return reduce (lambda x,y: x and y, threads.get_results() , True)
281
282 def DeleteSliver(api, xrn, creds, call_id):
283     def _DeleteSliver(server, xrn, creds, call_id):
284         server_version = api.get_cached_server_version(server)
285         args =  [xrn, creds]
286         if _call_id_supported(api, server):
287             args.append(call_id)
288         return server.DeleteSliver(*args)
289
290     if Callids().already_handled(call_id): return ""
291     (hrn, type) = urn_to_hrn(xrn)
292     # get the callers hrn
293     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
294     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
295
296     # attempt to use delegated credential first
297     cred = api.getDelegatedCredential(creds)
298     if not cred:
299         cred = api.getCredential()
300     threads = ThreadManager()
301     for aggregate in api.aggregates:
302         # prevent infinite loop. Dont send request back to caller
303         # unless the caller is the aggregate's SM
304         if caller_hrn == aggregate and aggregate != api.hrn:
305             continue
306         interface = api.aggregates[aggregate]
307         server = api.server_proxy(interface, cred)
308         threads.run(_DeleteSliver, server, xrn, [cred], call_id)
309     threads.get_results()
310     return 1
311
312
313 # first draft at a merging SliverStatus
314 def SliverStatus(api, slice_xrn, creds, call_id):
315     def _SliverStatus(server, xrn, creds, call_id):
316         server_version = api.get_cached_server_version(server)
317         args =  [xrn, creds]
318         if _call_id_supported(api, server):
319             args.append(call_id)
320         return server.SliverStatus(*args)
321     
322     if Callids().already_handled(call_id): return {}
323     # attempt to use delegated credential first
324     cred = api.getDelegatedCredential(creds)
325     if not cred:
326         cred = api.getCredential()
327     threads = ThreadManager()
328     for aggregate in api.aggregates:
329         interface = api.aggregates[aggregate]
330         server = api.server_proxy(interface, cred)
331         threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
332     results = threads.get_results()
333
334     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
335     results = [ result for result in results if result and result['geni_resources']]
336
337     # do not try to combine if there's no result
338     if not results : return {}
339
340     # otherwise let's merge stuff
341     overall = {}
342
343     # mmh, it is expected that all results carry the same urn
344     overall['geni_urn'] = results[0]['geni_urn']
345     overall['pl_login'] = results[0]['pl_login']
346     # append all geni_resources
347     overall['geni_resources'] = \
348         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
349     overall['status'] = 'unknown'
350     if overall['geni_resources']:
351         overall['status'] = 'ready'
352
353     return overall
354
355 caching=True
356 #caching=False
357 def ListSlices(api, creds, call_id):
358     def _ListSlices(server, creds, call_id):
359         server_version = api.get_cached_server_version(server)
360         args =  [creds]
361         if _call_id_supported(api, server):
362             args.append(call_id)
363         return server.ListSlices(*args)
364
365     if Callids().already_handled(call_id): return []
366
367     # look in cache first
368     if caching and api.cache:
369         slices = api.cache.get('slices')
370         if slices:
371             return slices
372
373     # get the callers hrn
374     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
375     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
376
377     # attempt to use delegated credential first
378     cred= api.getDelegatedCredential(creds)
379     if not cred:
380         cred = api.getCredential()
381     threads = ThreadManager()
382     # fetch from aggregates
383     for aggregate in api.aggregates:
384         # prevent infinite loop. Dont send request back to caller
385         # unless the caller is the aggregate's SM
386         if caller_hrn == aggregate and aggregate != api.hrn:
387             continue
388         interface = api.aggregates[aggregate]
389         server = api.server_proxy(interface, cred)
390         threads.run(_ListSlices, server, [cred], call_id)
391
392     # combime results
393     results = threads.get_results()
394     slices = []
395     for result in results:
396         slices.extend(result)
397
398     # cache the result
399     if caching and api.cache:
400         api.cache.add('slices', slices)
401
402     return slices
403
404
405 def get_ticket(api, xrn, creds, rspec, users):
406     slice_hrn, type = urn_to_hrn(xrn)
407     # get the netspecs contained within the clients rspec
408     aggregate_rspecs = {}
409     tree= etree.parse(StringIO(rspec))
410     elements = tree.findall('./network')
411     for element in elements:
412         aggregate_hrn = element.values()[0]
413         aggregate_rspecs[aggregate_hrn] = rspec 
414
415     # get the callers hrn
416     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
417     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
418
419     # attempt to use delegated credential first
420     cred = api.getDelegatedCredential(creds)
421     if not cred:
422         cred = api.getCredential() 
423     threads = ThreadManager()
424     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
425         # prevent infinite loop. Dont send request back to caller
426         # unless the caller is the aggregate's SM
427         if caller_hrn == aggregate and aggregate != api.hrn:
428             continue
429         
430         interface = api.aggregates[aggregate]
431         server = api.server_proxy(interface, cred)
432         threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
433
434     results = threads.get_results()
435     
436     # gather information from each ticket 
437     rspec = None
438     initscripts = []
439     slivers = [] 
440     object_gid = None  
441     for result in results:
442         agg_ticket = SfaTicket(string=result)
443         attrs = agg_ticket.get_attributes()
444         if not object_gid:
445             object_gid = agg_ticket.get_gid_object()
446         if not rspec:
447             rspec = RSpec(agg_ticket.get_rspec())
448         else:
449             rspec.version.merge(agg_ticket.get_rspec())
450         initscripts.extend(attrs.get('initscripts', [])) 
451         slivers.extend(attrs.get('slivers', [])) 
452     
453     # merge info
454     attributes = {'initscripts': initscripts,
455                  'slivers': slivers}
456     
457     # create a new ticket
458     ticket = SfaTicket(subject = slice_hrn)
459     ticket.set_gid_caller(api.auth.client_gid)
460     ticket.set_issuer(key=api.key, subject=api.hrn)
461     ticket.set_gid_object(object_gid)
462     ticket.set_pubkey(object_gid.get_pubkey())
463     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
464     ticket.set_attributes(attributes)
465     ticket.set_rspec(rspec.toxml())
466     ticket.encode()
467     ticket.sign()          
468     return ticket.save_to_string(save_parents=True)
469
470 def start_slice(api, xrn, creds):
471     hrn, type = urn_to_hrn(xrn)
472
473     # get the callers hrn
474     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
475     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
476
477     # attempt to use delegated credential first
478     cred = api.getDelegatedCredential(creds)
479     if not cred:
480         cred = api.getCredential()
481     threads = ThreadManager()
482     for aggregate in api.aggregates:
483         # prevent infinite loop. Dont send request back to caller
484         # unless the caller is the aggregate's SM
485         if caller_hrn == aggregate and aggregate != api.hrn:
486             continue
487         interface = api.aggregates[aggregate]
488         server = api.server_proxy(interface, cred)    
489         threads.run(server.Start, xrn, cred)
490     threads.get_results()    
491     return 1
492  
493 def stop_slice(api, xrn, creds):
494     hrn, type = urn_to_hrn(xrn)
495
496     # get the callers hrn
497     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
498     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
499
500     # attempt to use delegated credential first
501     cred = api.getDelegatedCredential(creds)
502     if not cred:
503         cred = api.getCredential()
504     threads = ThreadManager()
505     for aggregate in api.aggregates:
506         # prevent infinite loop. Dont send request back to caller
507         # unless the caller is the aggregate's SM
508         if caller_hrn == aggregate and aggregate != api.hrn:
509             continue
510         interface = api.aggregates[aggregate]
511         server = api.server_proxy(interface, cred)
512         threads.run(server.Stop, xrn, cred)
513     threads.get_results()    
514     return 1
515
516 def reset_slice(api, xrn):
517     """
518     Not implemented
519     """
520     return 1
521
522 def shutdown(api, xrn, creds):
523     """
524     Not implemented   
525     """
526     return 1
527
528 def status(api, xrn, creds):
529     """
530     Not implemented 
531     """
532     return 1
533
534 # this is plain broken
535 #def main():
536 #    r = RSpec()
537 #    r.parseFile(sys.argv[1])
538 #    rspec = r.toDict()
539 #    CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
540
541 if __name__ == "__main__":
542     main()
543