Merge branch 'master' of ssh://git.onelab.eu/git/sfa
[sfa.git] / sfa / managers / slice_manager.py
1 import sys
2 import time
3 import traceback
4 from StringIO import StringIO
5 from copy import copy
6 from lxml import etree
7
8 from sfa.trust.sfaticket import SfaTicket
9 from sfa.trust.credential import Credential
10
11 from sfa.util.sfalogging import logger
12 from sfa.util.xrn import Xrn, urn_to_hrn
13 from sfa.util.version import version_core
14 from sfa.util.callids import Callids
15 from sfa.server.threadmanager import ThreadManager
16 from sfa.rspecs.rspec_converter import RSpecConverter
17 from sfa.rspecs.version_manager import VersionManager
18 from sfa.rspecs.rspec import RSpec 
19 from sfa.client.client_helper import sfa_to_pg_users_arg
20 from sfa.client.return_value import ReturnValue
21
22 class SliceManager:
23     def __init__ (self):
24     #    self.caching=False
25         self.caching=True
26         
27     def _options_supported(self, api, server):
28         """
29         Returns true if server support the optional call_id arg, false otherwise.
30         """
31         server_version = api.get_cached_server_version(server)
32     
33         if 'sfa' in server_version:
34             code_tag = server_version['code_tag']
35             code_tag_parts = code_tag.split("-")
36     
37             version_parts = code_tag_parts[0].split(".")
38             major, minor = version_parts[0:2]
39             rev = code_tag_parts[1]
40             if int(major) >= 1:
41                 if int(minor) >= 2:
42                     return True
43         return False
44     
45     def GetVersion(self, api, options={}):
46         # peers explicitly in aggregates.xml
47         peers =dict ([ (peername,interface.get_url()) for (peername,interface) in api.aggregates.iteritems()
48                        if peername != api.hrn])
49         version_manager = VersionManager()
50         ad_rspec_versions = []
51         request_rspec_versions = []
52         for rspec_version in version_manager.versions:
53             if rspec_version.content_type in ['*', 'ad']:
54                 ad_rspec_versions.append(rspec_version.to_dict())
55             if rspec_version.content_type in ['*', 'request']:
56                 request_rspec_versions.append(rspec_version.to_dict())
57         xrn=Xrn(api.hrn, 'authority+sa')
58         version_more = {'interface':'slicemgr',
59                         'sfa': 1,
60                         'geni_api': api.config.SFA_AGGREGATE_API_VERSION,
61                         'hrn' : xrn.get_hrn(),
62                         'urn' : xrn.get_urn(),
63                         'peers': peers,
64                         'geni_request_rspec_versions': request_rspec_versions,
65                         'geni_ad_rspec_versions': ad_rspec_versions,
66                     }
67         sm_version=version_core(version_more)
68         # local aggregate if present needs to have localhost resolved
69         if api.hrn in api.aggregates:
70             local_am_url=api.aggregates[api.hrn].get_url()
71             sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
72         return sm_version
73     
74     def drop_slicemgr_stats(self, rspec):
75         try:
76             stats_elements = rspec.xml.xpath('//statistics')
77             for node in stats_elements:
78                 node.getparent().remove(node)
79         except Exception, e:
80             logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
81     
82     def add_slicemgr_stat(self, rspec, callname, aggname, elapsed, status, exc_info=None):
83         try:
84             stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
85             if stats_tags:
86                 stats_tag = stats_tags[0]
87             else:
88                 stats_tag = rspec.xml.root.add_element("statistics", call=callname)
89
90             stat_tag = stats_tag.add_element("aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
91
92             if exc_info:
93                 exc_tag = stat_tag.add_element("exc_info", name=str(exc_info[1]))
94
95                 # formats the traceback as one big text blob
96                 #exc_tag.text = "\n".join(traceback.format_exception(exc_info[0], exc_info[1], exc_info[2]))
97
98                 # formats the traceback as a set of xml elements
99                 tb = traceback.extract_tb(exc_info[2])
100                 for item in tb:
101                     exc_frame = exc_tag.add_element("tb_frame", filename=str(item[0]), line=str(item[1]), func=str(item[2]), code=str(item[3]))
102
103         except Exception, e:
104             logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
105     
106     def ListResources(self, api, creds, options={}):
107         version_manager = VersionManager()
108         def _ListResources(aggregate, server, credential, opts={}):
109     
110             my_opts = copy(opts)
111             args = [credential, my_opts]
112             tStart = time.time()
113             try:
114                 version = api.get_cached_server_version(server)
115                 # force ProtoGENI aggregates to give us a v2 RSpec
116                 if 'sfa' in version.keys():
117                     my_opts['rspec_version'] = version_manager.get_version('SFA 1').to_dict()
118                 else:
119                     my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
120                 rspec = server.ListResources(*args)
121                 return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
122             except Exception, e:
123                 api.logger.log_exc("ListResources failed at %s" %(server.url))
124                 return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()}
125         call_id = options.get('call_id') 
126         if Callids().already_handled(call_id): return ""
127     
128         # get slice's hrn from options
129         xrn = options.get('geni_slice_urn', '')
130         (hrn, type) = urn_to_hrn(xrn)
131         if 'geni_compressed' in options:
132             del(options['geni_compressed'])
133     
134         # get the rspec's return format from options
135         rspec_version = version_manager.get_version(options.get('geni_rspec_version'))
136         version_string = "rspec_%s" % (rspec_version)
137     
138         # look in cache first
139         if self.caching and api.cache and not xrn:
140             rspec =  api.cache.get(version_string)
141             if rspec:
142                 return rspec
143     
144         # get the callers hrn
145         valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
146         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
147     
148         # attempt to use delegated credential first
149         cred = api.getDelegatedCredential(creds)
150         if not cred:
151             cred = api.getCredential()
152         threads = ThreadManager()
153         for aggregate in api.aggregates:
154             # prevent infinite loop. Dont send request back to caller
155             # unless the caller is the aggregate's SM
156             if caller_hrn == aggregate and aggregate != api.hrn:
157                 continue
158     
159             # get the rspec from the aggregate
160             interface = api.aggregates[aggregate]
161             server = api.server_proxy(interface, cred)
162             threads.run(_ListResources, aggregate, server, [cred], options)
163     
164     
165         results = threads.get_results()
166         rspec_version = version_manager.get_version(options.get('geni_rspec_version'))
167         if xrn:    
168             result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
169         else: 
170             result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
171         rspec = RSpec(version=result_version)
172         for result in results:
173             self.add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"], result.get("exc_info",None))
174             if result["status"]=="success":
175                 try:
176                     rspec.version.merge(ReturnValue.get_value(result["rspec"]))
177                 except:
178                     api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
179     
180         # cache the result
181         if self.caching and api.cache and not xrn:
182             api.cache.add(version_string, rspec.toxml())
183     
184         return rspec.toxml()
185     
186     
187     def CreateSliver(self, api, xrn, creds, rspec_str, users, options={}):
188     
189         version_manager = VersionManager()
190         def _CreateSliver(aggregate, server, xrn, credential, rspec, users, options={}):
191             tStart = time.time()
192             try:
193                 # Need to call GetVersion at an aggregate to determine the supported
194                 # rspec type/format beofre calling CreateSliver at an Aggregate.
195                 server_version = api.get_cached_server_version(server)
196                 requested_users = users
197                 if 'sfa' not in server_version and 'geni_api' in server_version:
198                     # sfa aggregtes support both sfa and pg rspecs, no need to convert
199                     # if aggregate supports sfa rspecs. otherwise convert to pg rspec
200                     rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
201                     filter = {'component_manager_id': server_version['urn']}
202                     rspec.filter(filter)
203                     rspec = rspec.toxml()
204                     requested_users = sfa_to_pg_users_arg(users)
205                 args = [xrn, credential, rspec, requested_users]
206                 if self._options_supported(api, server):
207                     args.append(options)
208                 rspec = server.CreateSliver(*args)
209                 return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
210             except:
211                 logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
212                 return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()}
213
214         call_id = options.get('call_id')
215         if Callids().already_handled(call_id): return ""
216         # Validate the RSpec against PlanetLab's schema --disabled for now
217         # The schema used here needs to aggregate the PL and VINI schemas
218         # schema = "/var/www/html/schemas/pl.rng"
219         rspec = RSpec(rspec_str)
220     #    schema = None
221     #    if schema:
222     #        rspec.validate(schema)
223     
224         # if there is a <statistics> section, the aggregates don't care about it,
225         # so delete it.
226         self.drop_slicemgr_stats(rspec)
227     
228         # attempt to use delegated credential first
229         cred = api.getDelegatedCredential(creds)
230         if not cred:
231             cred = api.getCredential()
232     
233         # get the callers hrn
234         hrn, type = urn_to_hrn(xrn)
235         valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
236         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
237         threads = ThreadManager()
238         for aggregate in api.aggregates:
239             # prevent infinite loop. Dont send request back to caller
240             # unless the caller is the aggregate's SM 
241             if caller_hrn == aggregate and aggregate != api.hrn:
242                 continue
243             interface = api.aggregates[aggregate]
244             server = api.server_proxy(interface, cred)
245             # Just send entire RSpec to each aggregate
246             threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
247                 
248         results = threads.get_results()
249         manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
250         result_rspec = RSpec(version=manifest_version)
251         for result in results:
252             self.add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"], result.get("exc_info",None))
253             if result["status"]=="success":
254                 try:
255                     result_rspec.version.merge(ReturnValue.get_value(result["rspec"]))
256                 except:
257                     api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
258         return result_rspec.toxml()
259     
260     def RenewSliver(self, api, xrn, creds, expiration_time, options={}):
261         def _RenewSliver(server, xrn, creds, expiration_time, options={}):
262             server_version = api.get_cached_server_version(server)
263             args =  [xrn, creds, expiration_time]
264             if self._options_supported(api, server):
265                 args.append(options)
266             return server.RenewSliver(*args)
267     
268         call_id = options.get('call_id')
269         if Callids().already_handled(call_id): return True
270     
271         (hrn, type) = urn_to_hrn(xrn)
272         # get the callers hrn
273         valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
274         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
275     
276         # attempt to use delegated credential first
277         cred = api.getDelegatedCredential(creds)
278         if not cred:
279             cred = api.getCredential()
280         threads = ThreadManager()
281         for aggregate in api.aggregates:
282             # prevent infinite loop. Dont send request back to caller
283             # unless the caller is the aggregate's SM
284             if caller_hrn == aggregate and aggregate != api.hrn:
285                 continue
286             interface = api.aggregates[aggregate]
287             server = api.server_proxy(interface, cred)
288             threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
289         # 'and' the results
290         results = [ReturnValue.get_value(result) for result in threads.get_results()]
291         return reduce (lambda x,y: x and y, results , True)
292     
293     def DeleteSliver(self, api, xrn, creds, options={}):
294         def _DeleteSliver(server, xrn, creds, options={}):
295             server_version = api.get_cached_server_version(server)
296             args =  [xrn, creds]
297             if self._options_supported(api, server):
298                 args.append(options)
299             return server.DeleteSliver(*args)
300
301         call_id = options.get('call_id') 
302         if Callids().already_handled(call_id): return ""
303         (hrn, type) = urn_to_hrn(xrn)
304         # get the callers hrn
305         valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
306         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
307     
308         # attempt to use delegated credential first
309         cred = api.getDelegatedCredential(creds)
310         if not cred:
311             cred = api.getCredential()
312         threads = ThreadManager()
313         for aggregate in api.aggregates:
314             # prevent infinite loop. Dont send request back to caller
315             # unless the caller is the aggregate's SM
316             if caller_hrn == aggregate and aggregate != api.hrn:
317                 continue
318             interface = api.aggregates[aggregate]
319             server = api.server_proxy(interface, cred)
320             threads.run(_DeleteSliver, server, xrn, [cred], call_id)
321         threads.get_results()
322         return 1
323     
324     
325     # first draft at a merging SliverStatus
326     def SliverStatus(self, api, slice_xrn, creds, options={}):
327         def _SliverStatus(server, xrn, creds, options={}):
328             server_version = api.get_cached_server_version(server)
329             args =  [xrn, creds]
330             if self._options_supported(api, server):
331                 args.append(options)
332             return server.SliverStatus(*args)
333
334         call_id = options.get('call_id') 
335         if Callids().already_handled(call_id): return {}
336         # attempt to use delegated credential first
337         cred = api.getDelegatedCredential(creds)
338         if not cred:
339             cred = api.getCredential()
340         threads = ThreadManager()
341         for aggregate in api.aggregates:
342             interface = api.aggregates[aggregate]
343             server = api.server_proxy(interface, cred)
344             threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
345         results = [ReturnValue.get_value(result) for result in threads.get_results()]
346     
347         # get rid of any void result - e.g. when call_id was hit where by convention we return {}
348         results = [ result for result in results if result and result['geni_resources']]
349     
350         # do not try to combine if there's no result
351         if not results : return {}
352     
353         # otherwise let's merge stuff
354         overall = {}
355     
356         # mmh, it is expected that all results carry the same urn
357         overall['geni_urn'] = results[0]['geni_urn']
358         overall['pl_login'] = results[0]['pl_login']
359         # append all geni_resources
360         overall['geni_resources'] = \
361             reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
362         overall['status'] = 'unknown'
363         if overall['geni_resources']:
364             overall['status'] = 'ready'
365     
366         return overall
367     
368     def ListSlices(self, api, creds, options={}):
369         def _ListSlices(server, creds, options={}):
370             server_version = api.get_cached_server_version(server)
371             args =  [creds]
372             if self._options_supported(api, server):
373                 args.append(options)
374             return server.ListSlices(*args)
375
376         call_id = options.get('call_id') 
377         if Callids().already_handled(call_id): return []
378     
379         # look in cache first
380         if self.caching and api.cache:
381             slices = api.cache.get('slices')
382             if slices:
383                 return slices
384     
385         # get the callers hrn
386         valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
387         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
388     
389         # attempt to use delegated credential first
390         cred= api.getDelegatedCredential(creds)
391         if not cred:
392             cred = api.getCredential()
393         threads = ThreadManager()
394         # fetch from aggregates
395         for aggregate in api.aggregates:
396             # prevent infinite loop. Dont send request back to caller
397             # unless the caller is the aggregate's SM
398             if caller_hrn == aggregate and aggregate != api.hrn:
399                 continue
400             interface = api.aggregates[aggregate]
401             server = api.server_proxy(interface, cred)
402             threads.run(_ListSlices, server, [cred], options)
403     
404         # combime results
405         results = [ReturnValue.get_value(result) for result in threads.get_results()]
406         slices = []
407         for result in results:
408             slices.extend(result)
409     
410         # cache the result
411         if self.caching and api.cache:
412             api.cache.add('slices', slices)
413     
414         return slices
415     
416     
417     def GetTicket(self, api, xrn, creds, rspec, users, options={}):
418         slice_hrn, type = urn_to_hrn(xrn)
419         # get the netspecs contained within the clients rspec
420         aggregate_rspecs = {}
421         tree= etree.parse(StringIO(rspec))
422         elements = tree.findall('./network')
423         for element in elements:
424             aggregate_hrn = element.values()[0]
425             aggregate_rspecs[aggregate_hrn] = rspec 
426     
427         # get the callers hrn
428         valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
429         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
430     
431         # attempt to use delegated credential first
432         cred = api.getDelegatedCredential(creds)
433         if not cred:
434             cred = api.getCredential() 
435         threads = ThreadManager()
436         for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
437             # prevent infinite loop. Dont send request back to caller
438             # unless the caller is the aggregate's SM
439             if caller_hrn == aggregate and aggregate != api.hrn:
440                 continue
441             
442             interface = api.aggregates[aggregate]
443             server = api.server_proxy(interface, cred)
444             threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users, options)
445     
446         results = threads.get_results()
447         
448         # gather information from each ticket 
449         rspec = None
450         initscripts = []
451         slivers = [] 
452         object_gid = None  
453         for result in results:
454             agg_ticket = SfaTicket(string=result)
455             attrs = agg_ticket.get_attributes()
456             if not object_gid:
457                 object_gid = agg_ticket.get_gid_object()
458             if not rspec:
459                 rspec = RSpec(agg_ticket.get_rspec())
460             else:
461                 rspec.version.merge(agg_ticket.get_rspec())
462             initscripts.extend(attrs.get('initscripts', [])) 
463             slivers.extend(attrs.get('slivers', [])) 
464         
465         # merge info
466         attributes = {'initscripts': initscripts,
467                      'slivers': slivers}
468         
469         # create a new ticket
470         ticket = SfaTicket(subject = slice_hrn)
471         ticket.set_gid_caller(api.auth.client_gid)
472         ticket.set_issuer(key=api.key, subject=api.hrn)
473         ticket.set_gid_object(object_gid)
474         ticket.set_pubkey(object_gid.get_pubkey())
475         #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
476         ticket.set_attributes(attributes)
477         ticket.set_rspec(rspec.toxml())
478         ticket.encode()
479         ticket.sign()          
480         return ticket.save_to_string(save_parents=True)
481     
482     def start_slice(self, api, xrn, creds):
483         hrn, type = urn_to_hrn(xrn)
484     
485         # get the callers hrn
486         valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
487         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
488     
489         # attempt to use delegated credential first
490         cred = api.getDelegatedCredential(creds)
491         if not cred:
492             cred = api.getCredential()
493         threads = ThreadManager()
494         for aggregate in api.aggregates:
495             # prevent infinite loop. Dont send request back to caller
496             # unless the caller is the aggregate's SM
497             if caller_hrn == aggregate and aggregate != api.hrn:
498                 continue
499             interface = api.aggregates[aggregate]
500             server = api.server_proxy(interface, cred)    
501             threads.run(server.Start, xrn, cred)
502         threads.get_results()    
503         return 1
504      
505     def stop_slice(self, api, xrn, creds):
506         hrn, type = urn_to_hrn(xrn)
507     
508         # get the callers hrn
509         valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
510         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
511     
512         # attempt to use delegated credential first
513         cred = api.getDelegatedCredential(creds)
514         if not cred:
515             cred = api.getCredential()
516         threads = ThreadManager()
517         for aggregate in api.aggregates:
518             # prevent infinite loop. Dont send request back to caller
519             # unless the caller is the aggregate's SM
520             if caller_hrn == aggregate and aggregate != api.hrn:
521                 continue
522             interface = api.aggregates[aggregate]
523             server = api.server_proxy(interface, cred)
524             threads.run(server.Stop, xrn, cred)
525         threads.get_results()    
526         return 1
527     
528     def reset_slice(self, api, xrn):
529         """
530         Not implemented
531         """
532         return 1
533     
534     def shutdown(self, api, xrn, creds):
535         """
536         Not implemented   
537         """
538         return 1
539     
540     def status(self, api, xrn, creds):
541         """
542         Not implemented 
543         """
544         return 1
545