bbf8eda8d1a4d17399dcbd24a6c8898bbe9c8954
[sfa.git] / sfa / managers / slice_manager_pl.py
1 #
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.rspecs.rspec_converter import RSpecConverter
19 from sfa.rspecs.version_manager import VersionManager
20 from sfa.rspecs.rspec import RSpec 
21 from sfa.util.policy import Policy
22 from sfa.util.prefixTree import prefixTree
23 from sfa.util.sfaticket import *
24 from sfa.trust.credential import Credential
25 from sfa.util.threadmanager import ThreadManager
26 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
27 import sfa.plc.peers as peers
28 from sfa.util.version import version_core
29 from sfa.util.callids import Callids
30
31
32 def _call_id_supported(api, server):
33     """
34     Returns true if server support the optional call_id arg, false otherwise.
35     """
36     server_version = api.get_cached_server_version(server)
37
38     if 'sfa' in server_version:
39         code_tag = server_version['code_tag']
40         code_tag_parts = code_tag.split("-")
41
42         version_parts = code_tag_parts[0].split(".")
43         major, minor = version_parts[0:2]
44         rev = code_tag_parts[1]
45         if int(major) > 1:
46             if int(minor) > 0 or int(rev) > 20:
47                 return True
48     return False
49
50 # we have specialized xmlrpclib.ServerProxy to remember the input url
51 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
52 def get_serverproxy_url (server):
53     try:
54         return server.url
55     except:
56         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
57         return server._ServerProxy__host + server._ServerProxy__handler
58
59 def GetVersion(api):
60     # peers explicitly in aggregates.xml
61     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
62                    if peername != api.hrn])
63     version_manager = VersionManager()
64     ad_rspec_versions = []
65     request_rspec_versions = []
66     for rspec_version in version_manager.versions:
67         if rspec_version in ['*', 'ad']:
68             request_rspec_versions.append(rspec_version.to_dict())
69         if rspec_version in ['*', 'request']:
70             request_rspec_version.append(rspec_version.to_dict())
71     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
72     xrn=Xrn(api.hrn)
73     version_more = {'interface':'slicemgr',
74                     'hrn' : xrn.get_hrn(),
75                     'urn' : xrn.get_urn(),
76                     'peers': peers,
77                     'request_rspec_versions': request_rspec_versions,
78                     'ad_rspec_versions': ad_rspec_versions,
79                     'default_ad_rspec': dict(sfa_rspec_version)
80                     }
81     sm_version=version_core(version_more)
82     # local aggregate if present needs to have localhost resolved
83     if api.hrn in api.aggregates:
84         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
85         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
86     return sm_version
87
88 def drop_slicemgr_stats(rspec):
89     try:
90         stats_elements = rspec.xml.xpath('//statistics')
91         for node in stats_elements:
92             node.getparent().remove(node)
93     except Exception, e:
94         api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
95
96 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
97     try:
98         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
99         if stats_tags:
100             stats_tag = stats_tags[0]
101         else:
102             stats_tag = etree.SubElement(rspec.xml, "statistics", call=callname)
103
104         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
105     except Exception, e:
106         api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
107
108 def ListResources(api, creds, options, call_id):
109     version_manager = VersionManager()
110     def _ListResources(aggregate, server, credential, opts, call_id):
111
112         my_opts = copy(opts)
113         args = [credential, my_opts]
114         tStart = time.time()
115         try:
116             if _call_id_supported(api, server):
117                 args.append(call_id)
118             version = api.get_cached_server_version(server)
119             # force ProtoGENI aggregates to give us a v2 RSpec
120             if 'sfa' not in version.keys():
121                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 1')
122             rspec = server.ListResources(*args)
123             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
124         except Exception, e:
125             api.logger.warn("ListResources failed at %s: %s" %(server.url, str(e)))
126             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
127
128     if Callids().already_handled(call_id): return ""
129
130     # get slice's hrn from options
131     xrn = options.get('geni_slice_urn', '')
132     (hrn, type) = urn_to_hrn(xrn)
133     if 'geni_compressed' in options:
134         del(options['geni_compressed'])
135
136     # get the rspec's return format from options
137     rspec_version = version_manager.get_version(options.get('rspec_version'))
138     version_string = "rspec_%s" % (rspec_version.to_string())
139
140     # look in cache first
141     if caching and api.cache and not xrn:
142         rspec =  api.cache.get(version_string)
143         if rspec:
144             return rspec
145
146     # get the callers hrn
147     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
148     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
149
150     # attempt to use delegated credential first
151     credential = api.getDelegatedCredential(creds)
152     if not credential:
153         credential = api.getCredential()
154     credentials = [credential]
155     threads = ThreadManager()
156     for aggregate in api.aggregates:
157         # prevent infinite loop. Dont send request back to caller
158         # unless the caller is the aggregate's SM
159         if caller_hrn == aggregate and aggregate != api.hrn:
160             continue
161
162         # get the rspec from the aggregate
163         server = api.aggregates[aggregate]
164         threads.run(_ListResources, aggregate, server, credentials, options, call_id)
165
166     results = threads.get_results()
167     rspec_version = version_manager.get_version(options.get('rspec_version'))
168     rspec = RSpec(version=rspec_version)
169     for result in results:
170         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
171         if result["status"]=="success":
172             try:
173                 rspec.version.merge(result["rspec"])
174             except:
175                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
176
177     # cache the result
178     if caching and api.cache and not xrn:
179         api.cache.add(version_string, rspec.toxml())
180
181     return rspec.toxml()
182
183
184 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
185
186     version_manager = VersionManager()
187     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
188         tStart = time.time()
189         try:
190             # Need to call GetVersion at an aggregate to determine the supported
191             # rspec type/format beofre calling CreateSliver at an Aggregate.
192             server_version = api.get_cached_server_version(server)
193             if 'sfa' not in server_version and 'geni_api' in server_version:
194                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
195                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
196                 rspec = RSpecConverter.to_pg_rspec(rspec)
197             args = [xrn, credential, rspec, users]
198             if _call_id_supported(api, server):
199                 args.append(call_id)
200             rspec = server.CreateSliver(*args)
201             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
202         except: 
203             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
204             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
205
206     if Callids().already_handled(call_id): return ""
207     # Validate the RSpec against PlanetLab's schema --disabled for now
208     # The schema used here needs to aggregate the PL and VINI schemas
209     # schema = "/var/www/html/schemas/pl.rng"
210     rspec = RSpec(rspec_str)
211     schema = None
212     if schema:
213         rspec.validate(schema)
214
215     # if there is a <statistics> section, the aggregates don't care about it,
216     # so delete it.
217     drop_slicemgr_stats(rspec)
218
219     # attempt to use delegated credential first
220     credential = api.getDelegatedCredential(creds)
221     if not credential:
222         credential = api.getCredential()
223
224     # get the callers hrn
225     hrn, type = urn_to_hrn(xrn)
226     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
227     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
228     threads = ThreadManager()
229     for aggregate in api.aggregates:
230         # prevent infinite loop. Dont send request back to caller
231         # unless the caller is the aggregate's SM 
232         if caller_hrn == aggregate and aggregate != api.hrn:
233             continue
234         server = api.aggregates[aggregate]
235         # Just send entire RSpec to each aggregate
236         threads.run(_CreateSliver, aggregate, server, xrn, credential, rspec.toxml(), users, call_id)
237             
238     results = threads.get_results()
239     result_rspec = RSpec(version=rspec.version)
240     for result in results:
241         add_slicemgr_stat(rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
242         if result["status"]=="success":
243             try:
244                 result_rspec.version.merge(result["rspec"])
245             except:
246                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
247     return rspec.toxml()
248
249 def RenewSliver(api, xrn, creds, expiration_time, call_id):
250     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
251         server_version = api.get_cached_server_version(server)
252         args =  [xrn, creds, expiration_time, call_id]
253         if _call_id_supported(api, server):
254             args.append(call_id)
255         return server.RenewSliver(*args)
256
257     if Callids().already_handled(call_id): return True
258
259     (hrn, type) = urn_to_hrn(xrn)
260     # get the callers hrn
261     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
262     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
263
264     # attempt to use delegated credential first
265     credential = api.getDelegatedCredential(creds)
266     if not credential:
267         credential = api.getCredential()
268     threads = ThreadManager()
269     for aggregate in api.aggregates:
270         # prevent infinite loop. Dont send request back to caller
271         # unless the caller is the aggregate's SM
272         if caller_hrn == aggregate and aggregate != api.hrn:
273             continue
274         server = api.aggregates[aggregate]
275         threads.run(_RenewSliver, server, xrn, [credential], expiration_time, call_id)
276     # 'and' the results
277     return reduce (lambda x,y: x and y, threads.get_results() , True)
278
279 def DeleteSliver(api, xrn, creds, call_id):
280     def _DeleteSliver(server, xrn, creds, call_id):
281         server_version = api.get_cached_server_version(server)
282         args =  [xrn, creds]
283         if _call_id_supported(api, server):
284             args.append(call_id)
285         return server.DeleteSliver(*args)
286
287     if Callids().already_handled(call_id): return ""
288     (hrn, type) = urn_to_hrn(xrn)
289     # get the callers hrn
290     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
291     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
292
293     # attempt to use delegated credential first
294     credential = api.getDelegatedCredential(creds)
295     if not credential:
296         credential = api.getCredential()
297     threads = ThreadManager()
298     for aggregate in api.aggregates:
299         # prevent infinite loop. Dont send request back to caller
300         # unless the caller is the aggregate's SM
301         if caller_hrn == aggregate and aggregate != api.hrn:
302             continue
303         server = api.aggregates[aggregate]
304         threads.run(_DeleteSliver, server, xrn, credential, call_id)
305     threads.get_results()
306     return 1
307
308
309 # first draft at a merging SliverStatus
310 def SliverStatus(api, slice_xrn, creds, call_id):
311     def _SliverStatus(server, xrn, creds, call_id):
312         server_version = api.get_cached_server_version(server)
313         args =  [xrn, creds]
314         if _call_id_supported(api, server):
315             args.append(call_id)
316         return server.SliverStatus(*args)
317     
318     if Callids().already_handled(call_id): return {}
319     # attempt to use delegated credential first
320     credential = api.getDelegatedCredential(creds)
321     if not credential:
322         credential = api.getCredential()
323     threads = ThreadManager()
324     for aggregate in api.aggregates:
325         server = api.aggregates[aggregate]
326         threads.run (_SliverStatus, server, slice_xrn, credential, call_id)
327     results = threads.get_results()
328
329     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
330     results = [ result for result in results if result and result['geni_resources']]
331
332     # do not try to combine if there's no result
333     if not results : return {}
334
335     # otherwise let's merge stuff
336     overall = {}
337
338     # mmh, it is expected that all results carry the same urn
339     overall['geni_urn'] = results[0]['geni_urn']
340     overall['pl_login'] = results[0]['pl_login']
341     # append all geni_resources
342     overall['geni_resources'] = \
343         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
344     overall['status'] = 'unknown'
345     if overall['geni_resources']:
346         overall['status'] = 'ready'
347
348     return overall
349
350 caching=True
351 #caching=False
352 def ListSlices(api, creds, call_id):
353     def _ListSlices(server, creds, call_id):
354         server_version = api.get_cached_server_version(server)
355         args =  [creds]
356         if _call_id_supported(api, server):
357             args.append(call_id)
358         return server.ListSlices(*args)
359
360     if Callids().already_handled(call_id): return []
361
362     # look in cache first
363     if caching and api.cache:
364         slices = api.cache.get('slices')
365         if slices:
366             return slices
367
368     # get the callers hrn
369     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
370     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
371
372     # attempt to use delegated credential first
373     credential = api.getDelegatedCredential(creds)
374     if not credential:
375         credential = api.getCredential()
376     threads = ThreadManager()
377     # fetch from aggregates
378     for aggregate in api.aggregates:
379         # prevent infinite loop. Dont send request back to caller
380         # unless the caller is the aggregate's SM
381         if caller_hrn == aggregate and aggregate != api.hrn:
382             continue
383         server = api.aggregates[aggregate]
384         threads.run(_ListSlices, server, credential, call_id)
385
386     # combime results
387     results = threads.get_results()
388     slices = []
389     for result in results:
390         slices.extend(result)
391
392     # cache the result
393     if caching and api.cache:
394         api.cache.add('slices', slices)
395
396     return slices
397
398
399 def get_ticket(api, xrn, creds, rspec, users):
400     slice_hrn, type = urn_to_hrn(xrn)
401     # get the netspecs contained within the clients rspec
402     aggregate_rspecs = {}
403     tree= etree.parse(StringIO(rspec))
404     elements = tree.findall('./network')
405     for element in elements:
406         aggregate_hrn = element.values()[0]
407         aggregate_rspecs[aggregate_hrn] = rspec 
408
409     # get the callers hrn
410     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
411     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
412
413     # attempt to use delegated credential first
414     credential = api.getDelegatedCredential(creds)
415     if not credential:
416         credential = api.getCredential() 
417     threads = ThreadManager()
418     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
419         # prevent infinite loop. Dont send request back to caller
420         # unless the caller is the aggregate's SM
421         if caller_hrn == aggregate and aggregate != api.hrn:
422             continue
423         server = None
424         if aggregate in api.aggregates:
425             server = api.aggregates[aggregate]
426         else:
427             net_urn = hrn_to_urn(aggregate, 'authority')     
428             # we may have a peer that knows about this aggregate
429             for agg in api.aggregates:
430                 target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
431                 if not target_aggs or not 'hrn' in target_aggs[0]:
432                     continue
433                 # send the request to this address 
434                 url = target_aggs[0]['url']
435                 server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file, timeout=30)
436                 # aggregate found, no need to keep looping
437                 break   
438         if server is None:
439             continue 
440         threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
441
442     results = threads.get_results()
443     
444     # gather information from each ticket 
445     rspecs = []
446     initscripts = []
447     slivers = [] 
448     object_gid = None  
449     for result in results:
450         agg_ticket = SfaTicket(string=result)
451         attrs = agg_ticket.get_attributes()
452         if not object_gid:
453             object_gid = agg_ticket.get_gid_object()
454         rspecs.append(agg_ticket.get_rspec())
455         initscripts.extend(attrs.get('initscripts', [])) 
456         slivers.extend(attrs.get('slivers', [])) 
457     
458     # merge info
459     attributes = {'initscripts': initscripts,
460                  'slivers': slivers}
461     merged_rspec = merge_rspecs(rspecs) 
462
463     # create a new ticket
464     ticket = SfaTicket(subject = slice_hrn)
465     ticket.set_gid_caller(api.auth.client_gid)
466     ticket.set_issuer(key=api.key, subject=api.hrn)
467     ticket.set_gid_object(object_gid)
468     ticket.set_pubkey(object_gid.get_pubkey())
469     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
470     ticket.set_attributes(attributes)
471     ticket.set_rspec(merged_rspec)
472     ticket.encode()
473     ticket.sign()          
474     return ticket.save_to_string(save_parents=True)
475
476 def start_slice(api, xrn, creds):
477     hrn, type = urn_to_hrn(xrn)
478
479     # get the callers hrn
480     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
481     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
482
483     # attempt to use delegated credential first
484     credential = api.getDelegatedCredential(creds)
485     if not credential:
486         credential = api.getCredential()
487     threads = ThreadManager()
488     for aggregate in api.aggregates:
489         # prevent infinite loop. Dont send request back to caller
490         # unless the caller is the aggregate's SM
491         if caller_hrn == aggregate and aggregate != api.hrn:
492             continue
493         server = api.aggregates[aggregate]
494         threads.run(server.Start, xrn, credential)
495     threads.get_results()    
496     return 1
497  
498 def stop_slice(api, xrn, creds):
499     hrn, type = urn_to_hrn(xrn)
500
501     # get the callers hrn
502     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
503     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
504
505     # attempt to use delegated credential first
506     credential = api.getDelegatedCredential(creds)
507     if not credential:
508         credential = api.getCredential()
509     threads = ThreadManager()
510     for aggregate in api.aggregates:
511         # prevent infinite loop. Dont send request back to caller
512         # unless the caller is the aggregate's SM
513         if caller_hrn == aggregate and aggregate != api.hrn:
514             continue
515         server = api.aggregates[aggregate]
516         threads.run(server.Stop, xrn, credential)
517     threads.get_results()    
518     return 1
519
520 def reset_slice(api, xrn):
521     """
522     Not implemented
523     """
524     return 1
525
526 def shutdown(api, xrn, creds):
527     """
528     Not implemented   
529     """
530     return 1
531
532 def status(api, xrn, creds):
533     """
534     Not implemented 
535     """
536     return 1
537
538 def main():
539     r = RSpec()
540     r.parseFile(sys.argv[1])
541     rspec = r.toDict()
542     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
543
544 if __name__ == "__main__":
545     main()
546