add exception info to statistics section
[sfa.git] / sfa / managers / slice_manager.py
1 import sys
2 import traceback
3 import time
4 from StringIO import StringIO
5 from copy import copy
6 from lxml import etree
7
8 from sfa.trust.sfaticket import SfaTicket
9 from sfa.trust.credential import Credential
10
11 from sfa.util.sfalogging import logger
12 from sfa.util.xrn import Xrn, urn_to_hrn
13 from sfa.util.threadmanager import ThreadManager
14 from sfa.util.version import version_core
15 from sfa.util.callids import Callids
16
17 from sfa.rspecs.rspec_converter import RSpecConverter
18 from sfa.rspecs.version_manager import VersionManager
19 from sfa.rspecs.rspec import RSpec 
20 from sfa.client.client_helper import sfa_to_pg_users_arg
21
22 def _call_id_supported(api, server):
23     """
24     Returns true if server support the optional call_id arg, false otherwise.
25     """
26     server_version = api.get_cached_server_version(server)
27
28     if 'sfa' in server_version:
29         code_tag = server_version['code_tag']
30         code_tag_parts = code_tag.split("-")
31
32         version_parts = code_tag_parts[0].split(".")
33         major, minor = version_parts[0:2]
34         rev = code_tag_parts[1]
35         if int(major) > 1:
36             if int(minor) > 0 or int(rev) > 20:
37                 return True
38     return False
39
40 # we have specialized xmlrpclib.ServerProxy to remember the input url
41 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
42 def get_serverproxy_url (server):
43     try:
44         return server.get_url()
45     except:
46         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
47         return server._ServerProxy__host + server._ServerProxy__handler
48
49 def GetVersion(api):
50     # peers explicitly in aggregates.xml
51     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
52                    if peername != api.hrn])
53     version_manager = VersionManager()
54     ad_rspec_versions = []
55     request_rspec_versions = []
56     for rspec_version in version_manager.versions:
57         if rspec_version.content_type in ['*', 'ad']:
58             ad_rspec_versions.append(rspec_version.to_dict())
59         if rspec_version.content_type in ['*', 'request']:
60             request_rspec_versions.append(rspec_version.to_dict())
61     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
62     xrn=Xrn(api.hrn, 'authority+sa')
63     version_more = {'interface':'slicemgr',
64                     'hrn' : xrn.get_hrn(),
65                     'urn' : xrn.get_urn(),
66                     'peers': peers,
67                     'request_rspec_versions': request_rspec_versions,
68                     'ad_rspec_versions': ad_rspec_versions,
69                     'default_ad_rspec': default_rspec_version
70                     }
71     sm_version=version_core(version_more)
72     # local aggregate if present needs to have localhost resolved
73     if api.hrn in api.aggregates:
74         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
75         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
76     return sm_version
77
78 def drop_slicemgr_stats(rspec):
79     try:
80         stats_elements = rspec.xml.xpath('//statistics')
81         for node in stats_elements:
82             node.getparent().remove(node)
83     except Exception, e:
84         logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
85
86 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status, exc_info=None):
87     try:
88         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
89         if stats_tags:
90             stats_tag = stats_tags[0]
91         else:
92             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
93
94         stat_tag = etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
95
96         if exc_info:
97             exc_tag = etree.SubElement(stat_tag, "exc_info", name=str(exc_info[1]))
98
99             # this would encode it as a text block
100             #exc_tag.text = "\n".join(traceback.format_exception(exc_info[0], exc_info[1], exc_info[2]))
101
102             # this encodes the traceback as a set of xml tags
103             tb = traceback.extract_tb(exc_info[2])
104             for item in tb:
105                 exc_frame = etree.SubElement(exc_tag, "tb_frame", filename=str(item[0]), line=str(item[1]), func=str(item[2]), code=str(item[3]))
106
107     except Exception, e:
108         logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
109
110 def ListResources(api, creds, options, call_id):
111     version_manager = VersionManager()
112     def _ListResources(aggregate, server, credential, opts, call_id):
113
114         my_opts = copy(opts)
115         args = [credential, my_opts]
116         tStart = time.time()
117         try:
118             if _call_id_supported(api, server):
119                 args.append(call_id)
120             version = api.get_cached_server_version(server)
121             # force ProtoGENI aggregates to give us a v2 RSpec
122             if 'sfa' not in version.keys():
123                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
124             rspec = server.ListResources(*args)
125             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
126         except Exception, e:
127             api.logger.log_exc("ListResources failed at %s" %(server.url))
128             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()}
129
130     if Callids().already_handled(call_id): return ""
131
132     # get slice's hrn from options
133     xrn = options.get('geni_slice_urn', '')
134     (hrn, type) = urn_to_hrn(xrn)
135     if 'geni_compressed' in options:
136         del(options['geni_compressed'])
137
138     # get the rspec's return format from options
139     rspec_version = version_manager.get_version(options.get('rspec_version'))
140     version_string = "rspec_%s" % (rspec_version.to_string())
141
142     # look in cache first
143     if caching and api.cache and not xrn:
144         rspec =  api.cache.get(version_string)
145         if rspec:
146             return rspec
147
148     # get the callers hrn
149     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
150     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
151
152     # attempt to use delegated credential first
153     cred = api.getDelegatedCredential(creds)
154     if not cred:
155         cred = api.getCredential()
156     threads = ThreadManager()
157     for aggregate in api.aggregates:
158         # prevent infinite loop. Dont send request back to caller
159         # unless the caller is the aggregate's SM
160         if caller_hrn == aggregate and aggregate != api.hrn:
161             continue
162
163         # get the rspec from the aggregate
164         interface = api.aggregates[aggregate]
165         server = api.get_server(interface, cred)
166         threads.run(_ListResources, aggregate, server, [cred], options, call_id)
167
168
169     results = threads.get_results()
170     rspec_version = version_manager.get_version(options.get('rspec_version'))
171     if xrn:    
172         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
173     else:
174         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
175     rspec = RSpec(version=result_version)
176     for result in results:
177         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"], result.get("exc_info",None))
178         if result["status"]=="success":
179             try:
180                 rspec.version.merge(result["rspec"])
181             except:
182                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
183
184     # cache the result
185     if caching and api.cache and not xrn:
186         api.cache.add(version_string, rspec.toxml())
187
188     return rspec.toxml()
189
190
191 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
192
193     version_manager = VersionManager()
194     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
195         tStart = time.time()
196         try:
197             # Need to call GetVersion at an aggregate to determine the supported
198             # rspec type/format beofre calling CreateSliver at an Aggregate.
199             server_version = api.get_cached_server_version(server)
200             requested_users = users
201             if 'sfa' not in server_version and 'geni_api' in server_version:
202                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
203                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
204                 rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
205                 filter = {'component_manager_id': server_version['urn']}
206                 rspec.filter(filter)
207                 rspec = rspec.toxml()
208                 requested_users = sfa_to_pg_users_arg(users)
209             args = [xrn, credential, rspec, requested_users]
210             if _call_id_supported(api, server):
211                 args.append(call_id)
212             rspec = server.CreateSliver(*args)
213             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
214         except: 
215             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
216             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()}
217
218     if Callids().already_handled(call_id): return ""
219     # Validate the RSpec against PlanetLab's schema --disabled for now
220     # The schema used here needs to aggregate the PL and VINI schemas
221     # schema = "/var/www/html/schemas/pl.rng"
222     rspec = RSpec(rspec_str)
223 #    schema = None
224 #    if schema:
225 #        rspec.validate(schema)
226
227     # if there is a <statistics> section, the aggregates don't care about it,
228     # so delete it.
229     drop_slicemgr_stats(rspec)
230
231     # attempt to use delegated credential first
232     cred = api.getDelegatedCredential(creds)
233     if not cred:
234         cred = api.getCredential()
235
236     # get the callers hrn
237     hrn, type = urn_to_hrn(xrn)
238     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
239     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
240     threads = ThreadManager()
241     for aggregate in api.aggregates:
242         # prevent infinite loop. Dont send request back to caller
243         # unless the caller is the aggregate's SM 
244         if caller_hrn == aggregate and aggregate != api.hrn:
245             continue
246         interface = api.aggregates[aggregate]
247         server = api.get_server(interface, cred)
248         # Just send entire RSpec to each aggregate
249         threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
250             
251     results = threads.get_results()
252     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
253     result_rspec = RSpec(version=manifest_version)
254     for result in results:
255         add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"], result.get("exc_info",None))
256         if result["status"]=="success":
257             try:
258                 result_rspec.version.merge(result["rspec"])
259             except:
260                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
261     return result_rspec.toxml()
262
263 def RenewSliver(api, xrn, creds, expiration_time, call_id):
264     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
265         server_version = api.get_cached_server_version(server)
266         args =  [xrn, creds, expiration_time, call_id]
267         if _call_id_supported(api, server):
268             args.append(call_id)
269         return server.RenewSliver(*args)
270
271     if Callids().already_handled(call_id): return True
272
273     (hrn, type) = urn_to_hrn(xrn)
274     # get the callers hrn
275     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
276     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
277
278     # attempt to use delegated credential first
279     cred = api.getDelegatedCredential(creds)
280     if not cred:
281         cred = api.getCredential()
282     threads = ThreadManager()
283     for aggregate in api.aggregates:
284         # prevent infinite loop. Dont send request back to caller
285         # unless the caller is the aggregate's SM
286         if caller_hrn == aggregate and aggregate != api.hrn:
287             continue
288         interface = api.aggregates[aggregate]
289         server = api.get_server(interface, cred)
290         threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
291     # 'and' the results
292     return reduce (lambda x,y: x and y, threads.get_results() , True)
293
294 def DeleteSliver(api, xrn, creds, call_id):
295     def _DeleteSliver(server, xrn, creds, call_id):
296         server_version = api.get_cached_server_version(server)
297         args =  [xrn, creds]
298         if _call_id_supported(api, server):
299             args.append(call_id)
300         return server.DeleteSliver(*args)
301
302     if Callids().already_handled(call_id): return ""
303     (hrn, type) = urn_to_hrn(xrn)
304     # get the callers hrn
305     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
306     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
307
308     # attempt to use delegated credential first
309     cred = api.getDelegatedCredential(creds)
310     if not cred:
311         cred = api.getCredential()
312     threads = ThreadManager()
313     for aggregate in api.aggregates:
314         # prevent infinite loop. Dont send request back to caller
315         # unless the caller is the aggregate's SM
316         if caller_hrn == aggregate and aggregate != api.hrn:
317             continue
318         interface = api.aggregates[aggregate]
319         server = api.get_server(interface, cred)
320         threads.run(_DeleteSliver, server, xrn, [cred], call_id)
321     threads.get_results()
322     return 1
323
324
325 # first draft at a merging SliverStatus
326 def SliverStatus(api, slice_xrn, creds, call_id):
327     def _SliverStatus(server, xrn, creds, call_id):
328         server_version = api.get_cached_server_version(server)
329         args =  [xrn, creds]
330         if _call_id_supported(api, server):
331             args.append(call_id)
332         return server.SliverStatus(*args)
333     
334     if Callids().already_handled(call_id): return {}
335     # attempt to use delegated credential first
336     cred = api.getDelegatedCredential(creds)
337     if not cred:
338         cred = api.getCredential()
339     threads = ThreadManager()
340     for aggregate in api.aggregates:
341         interface = api.aggregates[aggregate]
342         server = api.get_server(interface, cred)
343         threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
344     results = threads.get_results()
345
346     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
347     results = [ result for result in results if result and result['geni_resources']]
348
349     # do not try to combine if there's no result
350     if not results : return {}
351
352     # otherwise let's merge stuff
353     overall = {}
354
355     # mmh, it is expected that all results carry the same urn
356     overall['geni_urn'] = results[0]['geni_urn']
357     overall['pl_login'] = results[0]['pl_login']
358     # append all geni_resources
359     overall['geni_resources'] = \
360         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
361     overall['status'] = 'unknown'
362     if overall['geni_resources']:
363         overall['status'] = 'ready'
364
365     return overall
366
367 caching=True
368 #caching=False
369 def ListSlices(api, creds, call_id):
370     def _ListSlices(server, creds, call_id):
371         server_version = api.get_cached_server_version(server)
372         args =  [creds]
373         if _call_id_supported(api, server):
374             args.append(call_id)
375         return server.ListSlices(*args)
376
377     if Callids().already_handled(call_id): return []
378
379     # look in cache first
380     if caching and api.cache:
381         slices = api.cache.get('slices')
382         if slices:
383             return slices
384
385     # get the callers hrn
386     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
387     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
388
389     # attempt to use delegated credential first
390     cred= api.getDelegatedCredential(creds)
391     if not cred:
392         cred = api.getCredential()
393     threads = ThreadManager()
394     # fetch from aggregates
395     for aggregate in api.aggregates:
396         # prevent infinite loop. Dont send request back to caller
397         # unless the caller is the aggregate's SM
398         if caller_hrn == aggregate and aggregate != api.hrn:
399             continue
400         interface = api.aggregates[aggregate]
401         server = api.get_server(interface, cred)
402         threads.run(_ListSlices, server, [cred], call_id)
403
404     # combime results
405     results = threads.get_results()
406     slices = []
407     for result in results:
408         slices.extend(result)
409
410     # cache the result
411     if caching and api.cache:
412         api.cache.add('slices', slices)
413
414     return slices
415
416
417 def get_ticket(api, xrn, creds, rspec, users):
418     slice_hrn, type = urn_to_hrn(xrn)
419     # get the netspecs contained within the clients rspec
420     aggregate_rspecs = {}
421     tree= etree.parse(StringIO(rspec))
422     elements = tree.findall('./network')
423     for element in elements:
424         aggregate_hrn = element.values()[0]
425         aggregate_rspecs[aggregate_hrn] = rspec 
426
427     # get the callers hrn
428     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
429     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
430
431     # attempt to use delegated credential first
432     cred = api.getDelegatedCredential(creds)
433     if not cred:
434         cred = api.getCredential() 
435     threads = ThreadManager()
436     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
437         # prevent infinite loop. Dont send request back to caller
438         # unless the caller is the aggregate's SM
439         if caller_hrn == aggregate and aggregate != api.hrn:
440             continue
441         
442         interface = api.aggregates[aggregate]
443         server = api.get_server(interface, cred)
444         threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
445
446     results = threads.get_results()
447     
448     # gather information from each ticket 
449     rspec = None
450     initscripts = []
451     slivers = [] 
452     object_gid = None  
453     for result in results:
454         agg_ticket = SfaTicket(string=result)
455         attrs = agg_ticket.get_attributes()
456         if not object_gid:
457             object_gid = agg_ticket.get_gid_object()
458         if not rspec:
459             rspec = RSpec(agg_ticket.get_rspec())
460         else:
461             rspec.version.merge(agg_ticket.get_rspec())
462         initscripts.extend(attrs.get('initscripts', [])) 
463         slivers.extend(attrs.get('slivers', [])) 
464     
465     # merge info
466     attributes = {'initscripts': initscripts,
467                  'slivers': slivers}
468     
469     # create a new ticket
470     ticket = SfaTicket(subject = slice_hrn)
471     ticket.set_gid_caller(api.auth.client_gid)
472     ticket.set_issuer(key=api.key, subject=api.hrn)
473     ticket.set_gid_object(object_gid)
474     ticket.set_pubkey(object_gid.get_pubkey())
475     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
476     ticket.set_attributes(attributes)
477     ticket.set_rspec(rspec.toxml())
478     ticket.encode()
479     ticket.sign()          
480     return ticket.save_to_string(save_parents=True)
481
482 def start_slice(api, xrn, creds):
483     hrn, type = urn_to_hrn(xrn)
484
485     # get the callers hrn
486     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
487     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
488
489     # attempt to use delegated credential first
490     cred = api.getDelegatedCredential(creds)
491     if not cred:
492         cred = api.getCredential()
493     threads = ThreadManager()
494     for aggregate in api.aggregates:
495         # prevent infinite loop. Dont send request back to caller
496         # unless the caller is the aggregate's SM
497         if caller_hrn == aggregate and aggregate != api.hrn:
498             continue
499         interface = api.aggregates[aggregate]
500         server = api.get_server(interface, cred)    
501         threads.run(server.Start, xrn, cred)
502     threads.get_results()    
503     return 1
504  
505 def stop_slice(api, xrn, creds):
506     hrn, type = urn_to_hrn(xrn)
507
508     # get the callers hrn
509     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
510     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
511
512     # attempt to use delegated credential first
513     cred = api.getDelegatedCredential(creds)
514     if not cred:
515         cred = api.getCredential()
516     threads = ThreadManager()
517     for aggregate in api.aggregates:
518         # prevent infinite loop. Dont send request back to caller
519         # unless the caller is the aggregate's SM
520         if caller_hrn == aggregate and aggregate != api.hrn:
521             continue
522         interface = api.aggregates[aggregate]
523         server = api.get_server(interface, cred)
524         threads.run(server.Stop, xrn, cred)
525     threads.get_results()    
526     return 1
527
528 def reset_slice(api, xrn):
529     """
530     Not implemented
531     """
532     return 1
533
534 def shutdown(api, xrn, creds):
535     """
536     Not implemented   
537     """
538     return 1
539
540 def status(api, xrn, creds):
541     """
542     Not implemented 
543     """
544     return 1
545
546 # this is plain broken
547 #def main():
548 #    r = RSpec()
549 #    r.parseFile(sys.argv[1])
550 #    rspec = r.toDict()
551 #    CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
552
553 if __name__ == "__main__":
554     main()
555