fix GetVersion()
[sfa.git] / sfa / managers / slice_manager.py
1 import sys
2 import time
3 from StringIO import StringIO
4 from copy import copy
5 from lxml import etree
6
7 from sfa.trust.sfaticket import SfaTicket
8 from sfa.trust.credential import Credential
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.xrn import Xrn, urn_to_hrn
12 from sfa.util.version import version_core
13 from sfa.util.callids import Callids
14
15 from sfa.server.threadmanager import ThreadManager
16
17 from sfa.rspecs.rspec_converter import RSpecConverter
18 from sfa.rspecs.version_manager import VersionManager
19 from sfa.rspecs.rspec import RSpec 
20 from sfa.client.client_helper import sfa_to_pg_users_arg
21
22 class SliceManager:
23     def __init__ (self):
24     #    self.caching=False
25         self.caching=True
26         
27     
28     def _call_id_supported(self, api, server):
29         """
30         Returns true if server support the optional call_id arg, false otherwise.
31         """
32         server_version = api.get_cached_server_version(server)
33     
34         if 'sfa' in server_version:
35             code_tag = server_version['code_tag']
36             code_tag_parts = code_tag.split("-")
37     
38             version_parts = code_tag_parts[0].split(".")
39             major, minor = version_parts[0:2]
40             rev = code_tag_parts[1]
41             if int(major) > 1:
42                 if int(minor) > 0 or int(rev) > 20:
43                     return True
44         return False
45     
46     # we have specialized xmlrpclib.ServerProxy to remember the input url
47     # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
48     def get_serverproxy_url (self, server):
49         try:
50             return server.get_url()
51         except:
52             logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
53             return server._ServerProxy__host + server._ServerProxy__handler
54     
55     def GetVersion(self, api):
56         # peers explicitly in aggregates.xml
57         peers =dict ([ (peername,self.get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
58                        if peername != api.hrn])
59         version_manager = VersionManager()
60         ad_rspec_versions = []
61         request_rspec_versions = []
62         for rspec_version in version_manager.versions:
63             if rspec_version.content_type in ['*', 'ad']:
64                 ad_rspec_versions.append(rspec_version.to_dict())
65             if rspec_version.content_type in ['*', 'request']:
66                 request_rspec_versions.append(rspec_version.to_dict())
67         default_rspec_version = version_manager.get_version("sfa 1").to_dict()
68         xrn=Xrn(api.hrn, 'authority+sa')
69         version_more = {'interface':'slicemgr',
70                         'hrn' : xrn.get_hrn(),
71                         'urn' : xrn.get_urn(),
72                         'peers': peers,
73                         'request_rspec_versions': request_rspec_versions,
74                         'ad_rspec_versions': ad_rspec_versions,
75                         'default_ad_rspec': default_rspec_version
76                     }
77         sm_version=version_core(version_more)
78         # local aggregate if present needs to have localhost resolved
79         if api.hrn in api.aggregates:
80             local_am_url=self.get_serverproxy_url(api.aggregates[api.hrn])
81             sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
82         return sm_version
83     
84     def drop_slicemgr_stats(self, rspec):
85         try:
86             stats_elements = rspec.xml.xpath('//statistics')
87             for node in stats_elements:
88                 node.getparent().remove(node)
89         except Exception, e:
90             logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
91     
92     def add_slicemgr_stat(self, rspec, callname, aggname, elapsed, status):
93         try:
94             stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
95             if stats_tags:
96                 stats_tag = stats_tags[0]
97             else:
98                 stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
99     
100             etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
101         except Exception, e:
102             logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
103     
104     def ListResources(self, api, creds, options, call_id):
105         version_manager = VersionManager()
106         def _ListResources(aggregate, server, credential, opts, call_id):
107     
108             my_opts = copy(opts)
109             args = [credential, my_opts]
110             tStart = time.time()
111             try:
112                 if self._call_id_supported(api, server):
113                     args.append(call_id)
114                 version = api.get_cached_server_version(server)
115                 # force ProtoGENI aggregates to give us a v2 RSpec
116                 if 'sfa' not in version.keys():
117                     my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
118                 rspec = server.ListResources(*args)
119                 return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
120             except Exception, e:
121                 api.logger.log_exc("ListResources failed at %s" %(server.url))
122                 return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
123     
124         if Callids().already_handled(call_id): return ""
125     
126         # get slice's hrn from options
127         xrn = options.get('geni_slice_urn', '')
128         (hrn, type) = urn_to_hrn(xrn)
129         if 'geni_compressed' in options:
130             del(options['geni_compressed'])
131     
132         # get the rspec's return format from options
133         rspec_version = version_manager.get_version(options.get('rspec_version'))
134         version_string = "rspec_%s" % (rspec_version.to_string())
135     
136         # look in cache first
137         if self.caching and api.cache and not xrn:
138             rspec =  api.cache.get(version_string)
139             if rspec:
140                 return rspec
141     
142         # get the callers hrn
143         valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
144         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
145     
146         # attempt to use delegated credential first
147         cred = api.getDelegatedCredential(creds)
148         if not cred:
149             cred = api.getCredential()
150         threads = ThreadManager()
151         for aggregate in api.aggregates:
152             # prevent infinite loop. Dont send request back to caller
153             # unless the caller is the aggregate's SM
154             if caller_hrn == aggregate and aggregate != api.hrn:
155                 continue
156     
157             # get the rspec from the aggregate
158             interface = api.aggregates[aggregate]
159             server = api.server_proxy(interface, cred)
160             threads.run(_ListResources, aggregate, server, [cred], options, call_id)
161     
162     
163         results = threads.get_results()
164         rspec_version = version_manager.get_version(options.get('rspec_version'))
165         if xrn:    
166             result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
167         else: 
168             result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
169         rspec = RSpec(version=result_version)
170         for result in results:
171             self.add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
172             if result["status"]=="success":
173                 try:
174                     rspec.version.merge(result["rspec"])
175                 except:
176                     api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
177     
178         # cache the result
179         if self.caching and api.cache and not xrn:
180             api.cache.add(version_string, rspec.toxml())
181     
182         return rspec.toxml()
183     
184     
185     def CreateSliver(self, api, xrn, creds, rspec_str, users, call_id):
186     
187         version_manager = VersionManager()
188         def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
189             tStart = time.time()
190             try:
191                 # Need to call GetVersion at an aggregate to determine the supported
192                 # rspec type/format beofre calling CreateSliver at an Aggregate.
193                 server_version = api.get_cached_server_version(server)
194                 requested_users = users
195                 if 'sfa' not in server_version and 'geni_api' in server_version:
196                     # sfa aggregtes support both sfa and pg rspecs, no need to convert
197                     # if aggregate supports sfa rspecs. otherwise convert to pg rspec
198                     rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
199                     filter = {'component_manager_id': server_version['urn']}
200                     rspec.filter(filter)
201                     rspec = rspec.toxml()
202                     requested_users = sfa_to_pg_users_arg(users)
203                 args = [xrn, credential, rspec, requested_users]
204                 if self._call_id_supported(api, server):
205                     args.append(call_id)
206                 rspec = server.CreateSliver(*args)
207                 return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
208             except: 
209                 logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
210                 return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
211     
212         if Callids().already_handled(call_id): return ""
213         # Validate the RSpec against PlanetLab's schema --disabled for now
214         # The schema used here needs to aggregate the PL and VINI schemas
215         # schema = "/var/www/html/schemas/pl.rng"
216         rspec = RSpec(rspec_str)
217     #    schema = None
218     #    if schema:
219     #        rspec.validate(schema)
220     
221         # if there is a <statistics> section, the aggregates don't care about it,
222         # so delete it.
223         self.drop_slicemgr_stats(rspec)
224     
225         # attempt to use delegated credential first
226         cred = api.getDelegatedCredential(creds)
227         if not cred:
228             cred = api.getCredential()
229     
230         # get the callers hrn
231         hrn, type = urn_to_hrn(xrn)
232         valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
233         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
234         threads = ThreadManager()
235         for aggregate in api.aggregates:
236             # prevent infinite loop. Dont send request back to caller
237             # unless the caller is the aggregate's SM 
238             if caller_hrn == aggregate and aggregate != api.hrn:
239                 continue
240             interface = api.aggregates[aggregate]
241             server = api.server_proxy(interface, cred)
242             # Just send entire RSpec to each aggregate
243             threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
244                 
245         results = threads.get_results()
246         manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
247         result_rspec = RSpec(version=manifest_version)
248         for result in results:
249             self.add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
250             if result["status"]=="success":
251                 try:
252                     result_rspec.version.merge(result["rspec"])
253                 except:
254                     api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
255         return result_rspec.toxml()
256     
257     def RenewSliver(self, api, xrn, creds, expiration_time, call_id):
258         def _RenewSliver(server, xrn, creds, expiration_time, call_id):
259             server_version = api.get_cached_server_version(server)
260             args =  [xrn, creds, expiration_time, call_id]
261             if self._call_id_supported(api, server):
262                 args.append(call_id)
263             return server.RenewSliver(*args)
264     
265         if Callids().already_handled(call_id): return True
266     
267         (hrn, type) = urn_to_hrn(xrn)
268         # get the callers hrn
269         valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
270         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
271     
272         # attempt to use delegated credential first
273         cred = api.getDelegatedCredential(creds)
274         if not cred:
275             cred = api.getCredential()
276         threads = ThreadManager()
277         for aggregate in api.aggregates:
278             # prevent infinite loop. Dont send request back to caller
279             # unless the caller is the aggregate's SM
280             if caller_hrn == aggregate and aggregate != api.hrn:
281                 continue
282             interface = api.aggregates[aggregate]
283             server = api.server_proxy(interface, cred)
284             threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
285         # 'and' the results
286         return reduce (lambda x,y: x and y, threads.get_results() , True)
287     
288     def DeleteSliver(self, api, xrn, creds, call_id):
289         def _DeleteSliver(server, xrn, creds, call_id):
290             server_version = api.get_cached_server_version(server)
291             args =  [xrn, creds]
292             if self._call_id_supported(api, server):
293                 args.append(call_id)
294             return server.DeleteSliver(*args)
295     
296         if Callids().already_handled(call_id): return ""
297         (hrn, type) = urn_to_hrn(xrn)
298         # get the callers hrn
299         valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
300         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
301     
302         # attempt to use delegated credential first
303         cred = api.getDelegatedCredential(creds)
304         if not cred:
305             cred = api.getCredential()
306         threads = ThreadManager()
307         for aggregate in api.aggregates:
308             # prevent infinite loop. Dont send request back to caller
309             # unless the caller is the aggregate's SM
310             if caller_hrn == aggregate and aggregate != api.hrn:
311                 continue
312             interface = api.aggregates[aggregate]
313             server = api.server_proxy(interface, cred)
314             threads.run(_DeleteSliver, server, xrn, [cred], call_id)
315         threads.get_results()
316         return 1
317     
318     
319     # first draft at a merging SliverStatus
320     def SliverStatus(self, api, slice_xrn, creds, call_id):
321         def _SliverStatus(server, xrn, creds, call_id):
322             server_version = api.get_cached_server_version(server)
323             args =  [xrn, creds]
324             if self._call_id_supported(api, server):
325                 args.append(call_id)
326             return server.SliverStatus(*args)
327         
328         if Callids().already_handled(call_id): return {}
329         # attempt to use delegated credential first
330         cred = api.getDelegatedCredential(creds)
331         if not cred:
332             cred = api.getCredential()
333         threads = ThreadManager()
334         for aggregate in api.aggregates:
335             interface = api.aggregates[aggregate]
336             server = api.server_proxy(interface, cred)
337             threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
338         results = threads.get_results()
339     
340         # get rid of any void result - e.g. when call_id was hit where by convention we return {}
341         results = [ result for result in results if result and result['geni_resources']]
342     
343         # do not try to combine if there's no result
344         if not results : return {}
345     
346         # otherwise let's merge stuff
347         overall = {}
348     
349         # mmh, it is expected that all results carry the same urn
350         overall['geni_urn'] = results[0]['geni_urn']
351         overall['pl_login'] = results[0]['pl_login']
352         # append all geni_resources
353         overall['geni_resources'] = \
354             reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
355         overall['status'] = 'unknown'
356         if overall['geni_resources']:
357             overall['status'] = 'ready'
358     
359         return overall
360     
361     def ListSlices(self, api, creds, call_id):
362         def _ListSlices(server, creds, call_id):
363             server_version = api.get_cached_server_version(server)
364             args =  [creds]
365             if self._call_id_supported(api, server):
366                 args.append(call_id)
367             return server.ListSlices(*args)
368     
369         if Callids().already_handled(call_id): return []
370     
371         # look in cache first
372         if self.caching and api.cache:
373             slices = api.cache.get('slices')
374             if slices:
375                 return slices
376     
377         # get the callers hrn
378         valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
379         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
380     
381         # attempt to use delegated credential first
382         cred= api.getDelegatedCredential(creds)
383         if not cred:
384             cred = api.getCredential()
385         threads = ThreadManager()
386         # fetch from aggregates
387         for aggregate in api.aggregates:
388             # prevent infinite loop. Dont send request back to caller
389             # unless the caller is the aggregate's SM
390             if caller_hrn == aggregate and aggregate != api.hrn:
391                 continue
392             interface = api.aggregates[aggregate]
393             server = api.server_proxy(interface, cred)
394             threads.run(_ListSlices, server, [cred], call_id)
395     
396         # combime results
397         results = threads.get_results()
398         slices = []
399         for result in results:
400             slices.extend(result)
401     
402         # cache the result
403         if self.caching and api.cache:
404             api.cache.add('slices', slices)
405     
406         return slices
407     
408     
409     def get_ticket(self, api, xrn, creds, rspec, users):
410         slice_hrn, type = urn_to_hrn(xrn)
411         # get the netspecs contained within the clients rspec
412         aggregate_rspecs = {}
413         tree= etree.parse(StringIO(rspec))
414         elements = tree.findall('./network')
415         for element in elements:
416             aggregate_hrn = element.values()[0]
417             aggregate_rspecs[aggregate_hrn] = rspec 
418     
419         # get the callers hrn
420         valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
421         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
422     
423         # attempt to use delegated credential first
424         cred = api.getDelegatedCredential(creds)
425         if not cred:
426             cred = api.getCredential() 
427         threads = ThreadManager()
428         for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
429             # prevent infinite loop. Dont send request back to caller
430             # unless the caller is the aggregate's SM
431             if caller_hrn == aggregate and aggregate != api.hrn:
432                 continue
433             
434             interface = api.aggregates[aggregate]
435             server = api.server_proxy(interface, cred)
436             threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
437     
438         results = threads.get_results()
439         
440         # gather information from each ticket 
441         rspec = None
442         initscripts = []
443         slivers = [] 
444         object_gid = None  
445         for result in results:
446             agg_ticket = SfaTicket(string=result)
447             attrs = agg_ticket.get_attributes()
448             if not object_gid:
449                 object_gid = agg_ticket.get_gid_object()
450             if not rspec:
451                 rspec = RSpec(agg_ticket.get_rspec())
452             else:
453                 rspec.version.merge(agg_ticket.get_rspec())
454             initscripts.extend(attrs.get('initscripts', [])) 
455             slivers.extend(attrs.get('slivers', [])) 
456         
457         # merge info
458         attributes = {'initscripts': initscripts,
459                      'slivers': slivers}
460         
461         # create a new ticket
462         ticket = SfaTicket(subject = slice_hrn)
463         ticket.set_gid_caller(api.auth.client_gid)
464         ticket.set_issuer(key=api.key, subject=api.hrn)
465         ticket.set_gid_object(object_gid)
466         ticket.set_pubkey(object_gid.get_pubkey())
467         #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
468         ticket.set_attributes(attributes)
469         ticket.set_rspec(rspec.toxml())
470         ticket.encode()
471         ticket.sign()          
472         return ticket.save_to_string(save_parents=True)
473     
474     def start_slice(self, api, xrn, creds):
475         hrn, type = urn_to_hrn(xrn)
476     
477         # get the callers hrn
478         valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
479         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
480     
481         # attempt to use delegated credential first
482         cred = api.getDelegatedCredential(creds)
483         if not cred:
484             cred = api.getCredential()
485         threads = ThreadManager()
486         for aggregate in api.aggregates:
487             # prevent infinite loop. Dont send request back to caller
488             # unless the caller is the aggregate's SM
489             if caller_hrn == aggregate and aggregate != api.hrn:
490                 continue
491             interface = api.aggregates[aggregate]
492             server = api.server_proxy(interface, cred)    
493             threads.run(server.Start, xrn, cred)
494         threads.get_results()    
495         return 1
496      
497     def stop_slice(self, api, xrn, creds):
498         hrn, type = urn_to_hrn(xrn)
499     
500         # get the callers hrn
501         valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
502         caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
503     
504         # attempt to use delegated credential first
505         cred = api.getDelegatedCredential(creds)
506         if not cred:
507             cred = api.getCredential()
508         threads = ThreadManager()
509         for aggregate in api.aggregates:
510             # prevent infinite loop. Dont send request back to caller
511             # unless the caller is the aggregate's SM
512             if caller_hrn == aggregate and aggregate != api.hrn:
513                 continue
514             interface = api.aggregates[aggregate]
515             server = api.server_proxy(interface, cred)
516             threads.run(server.Stop, xrn, cred)
517         threads.get_results()    
518         return 1
519     
520     def reset_slice(self, api, xrn):
521         """
522         Not implemented
523         """
524         return 1
525     
526     def shutdown(self, api, xrn, creds):
527         """
528         Not implemented   
529         """
530         return 1
531     
532     def status(self, api, xrn, creds):
533         """
534         Not implemented 
535         """
536         return 1
537