Make sure ProroGENI request rspecs have the correct schema and format
[sfa.git] / sfa / managers / slice_manager_pl.py
1 #
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.rspecs.rspec_converter import RSpecConverter
19 from sfa.rspecs.version_manager import VersionManager
20 from sfa.rspecs.rspec import RSpec 
21 from sfa.util.policy import Policy
22 from sfa.util.prefixTree import prefixTree
23 from sfa.util.sfaticket import *
24 from sfa.trust.credential import Credential
25 from sfa.util.threadmanager import ThreadManager
26 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
27 import sfa.plc.peers as peers
28 from sfa.util.version import version_core
29 from sfa.util.callids import Callids
30
31
32 def _call_id_supported(api, server):
33     """
34     Returns true if server support the optional call_id arg, false otherwise.
35     """
36     server_version = api.get_cached_server_version(server)
37
38     if 'sfa' in server_version:
39         code_tag = server_version['code_tag']
40         code_tag_parts = code_tag.split("-")
41
42         version_parts = code_tag_parts[0].split(".")
43         major, minor = version_parts[0:2]
44         rev = code_tag_parts[1]
45         if int(major) > 1:
46             if int(minor) > 0 or int(rev) > 20:
47                 return True
48     return False
49
50 # we have specialized xmlrpclib.ServerProxy to remember the input url
51 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
52 def get_serverproxy_url (server):
53     try:
54         return server.url
55     except:
56         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
57         return server._ServerProxy__host + server._ServerProxy__handler
58
59 def GetVersion(api):
60     # peers explicitly in aggregates.xml
61     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
62                    if peername != api.hrn])
63     version_manager = VersionManager()
64     ad_rspec_versions = []
65     request_rspec_versions = []
66     for rspec_version in version_manager.versions:
67         if rspec_version in ['*', 'ad']:
68             request_rspec_versions.append(rspec_version.to_dict())
69         if rspec_version in ['*', 'request']:
70             request_rspec_version.append(rspec_version.to_dict())
71     default_rspec_version = version_manager.get_version("sfa 1").to_dict()
72     xrn=Xrn(api.hrn)
73     version_more = {'interface':'slicemgr',
74                     'hrn' : xrn.get_hrn(),
75                     'urn' : xrn.get_urn(),
76                     'peers': peers,
77                     'request_rspec_versions': request_rspec_versions,
78                     'ad_rspec_versions': ad_rspec_versions,
79                     'default_ad_rspec': default_rspec_version
80                     }
81     sm_version=version_core(version_more)
82     # local aggregate if present needs to have localhost resolved
83     if api.hrn in api.aggregates:
84         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
85         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
86     return sm_version
87
88 def drop_slicemgr_stats(rspec):
89     try:
90         stats_elements = rspec.xml.xpath('//statistics')
91         for node in stats_elements:
92             node.getparent().remove(node)
93     except Exception, e:
94         api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
95
96 def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
97     try:
98         stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
99         if stats_tags:
100             stats_tag = stats_tags[0]
101         else:
102             stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
103
104         etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
105     except Exception, e:
106         api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
107
108 def ListResources(api, creds, options, call_id):
109     version_manager = VersionManager()
110     def _ListResources(aggregate, server, credential, opts, call_id):
111
112         my_opts = copy(opts)
113         args = [credential, my_opts]
114         tStart = time.time()
115         try:
116             if _call_id_supported(api, server):
117                 args.append(call_id)
118             version = api.get_cached_server_version(server)
119             # force ProtoGENI aggregates to give us a v2 RSpec
120             if 'sfa' not in version.keys():
121                 my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
122             rspec = server.ListResources(*args)
123             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
124         except Exception, e:
125             api.logger.warn("ListResources failed at %s: %s" %(server.url, str(e)))
126             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
127
128     if Callids().already_handled(call_id): return ""
129
130     # get slice's hrn from options
131     xrn = options.get('geni_slice_urn', '')
132     (hrn, type) = urn_to_hrn(xrn)
133     if 'geni_compressed' in options:
134         del(options['geni_compressed'])
135
136     # get the rspec's return format from options
137     rspec_version = version_manager.get_version(options.get('rspec_version'))
138     version_string = "rspec_%s" % (rspec_version.to_string())
139
140     # look in cache first
141     if caching and api.cache and not xrn:
142         rspec =  api.cache.get(version_string)
143         if rspec:
144             return rspec
145
146     # get the callers hrn
147     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
148     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
149
150     # attempt to use delegated credential first
151     credential = api.getDelegatedCredential(creds)
152     if not credential:
153         credential = api.getCredential()
154     credentials = [credential]
155     threads = ThreadManager()
156     for aggregate in api.aggregates:
157         # prevent infinite loop. Dont send request back to caller
158         # unless the caller is the aggregate's SM
159         if caller_hrn == aggregate and aggregate != api.hrn:
160             continue
161
162         # get the rspec from the aggregate
163         server = api.aggregates[aggregate]
164         threads.run(_ListResources, aggregate, server, credentials, options, call_id)
165
166     results = threads.get_results()
167     rspec_version = version_manager.get_version(options.get('rspec_version'))
168     if xrn:    
169         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
170     else: 
171         result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
172     rspec = RSpec(version=result_version)
173     for result in results:
174         add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
175         if result["status"]=="success":
176             try:
177                 rspec.version.merge(result["rspec"])
178             except:
179                 api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
180
181     # cache the result
182     if caching and api.cache and not xrn:
183         api.cache.add(version_string, rspec.toxml())
184
185     return rspec.toxml()
186
187
188 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
189
190     version_manager = VersionManager()
191     def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
192         tStart = time.time()
193         try:
194             # Need to call GetVersion at an aggregate to determine the supported
195             # rspec type/format beofre calling CreateSliver at an Aggregate.
196             server_version = api.get_cached_server_version(server)
197             if 'sfa' not in server_version and 'geni_api' in server_version:
198                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
199                 # if aggregate supports sfa rspecs. otherwise convert to pg rspec
200                 rspec = RSpecConverter.to_pg_rspec(rspec, 'request')
201             args = [xrn, credential, rspec, users]
202             if _call_id_supported(api, server):
203                 args.append(call_id)
204             rspec = server.CreateSliver(*args)
205             return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
206         except: 
207             logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
208             return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
209
210     if Callids().already_handled(call_id): return ""
211     # Validate the RSpec against PlanetLab's schema --disabled for now
212     # The schema used here needs to aggregate the PL and VINI schemas
213     # schema = "/var/www/html/schemas/pl.rng"
214     rspec = RSpec(rspec_str)
215     schema = None
216     if schema:
217         rspec.validate(schema)
218
219     # if there is a <statistics> section, the aggregates don't care about it,
220     # so delete it.
221     drop_slicemgr_stats(rspec)
222
223     # attempt to use delegated credential first
224     credential = api.getDelegatedCredential(creds)
225     if not credential:
226         credential = api.getCredential()
227
228     # get the callers hrn
229     hrn, type = urn_to_hrn(xrn)
230     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
231     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
232     threads = ThreadManager()
233     for aggregate in api.aggregates:
234         # prevent infinite loop. Dont send request back to caller
235         # unless the caller is the aggregate's SM 
236         if caller_hrn == aggregate and aggregate != api.hrn:
237             continue
238         server = api.aggregates[aggregate]
239         # Just send entire RSpec to each aggregate
240         threads.run(_CreateSliver, aggregate, server, xrn, credential, rspec.toxml(), users, call_id)
241             
242     results = threads.get_results()
243     manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
244     result_rspec = RSpec(version=manifest_version)
245     for result in results:
246         add_slicemgr_stat(rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
247         if result["status"]=="success":
248             try:
249                 result_rspec.version.merge(result["rspec"])
250             except:
251                 api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
252     return rspec.toxml()
253
254 def RenewSliver(api, xrn, creds, expiration_time, call_id):
255     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
256         server_version = api.get_cached_server_version(server)
257         args =  [xrn, creds, expiration_time, call_id]
258         if _call_id_supported(api, server):
259             args.append(call_id)
260         return server.RenewSliver(*args)
261
262     if Callids().already_handled(call_id): return True
263
264     (hrn, type) = urn_to_hrn(xrn)
265     # get the callers hrn
266     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
267     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
268
269     # attempt to use delegated credential first
270     credential = api.getDelegatedCredential(creds)
271     if not credential:
272         credential = api.getCredential()
273     threads = ThreadManager()
274     for aggregate in api.aggregates:
275         # prevent infinite loop. Dont send request back to caller
276         # unless the caller is the aggregate's SM
277         if caller_hrn == aggregate and aggregate != api.hrn:
278             continue
279         server = api.aggregates[aggregate]
280         threads.run(_RenewSliver, server, xrn, [credential], expiration_time, call_id)
281     # 'and' the results
282     return reduce (lambda x,y: x and y, threads.get_results() , True)
283
284 def DeleteSliver(api, xrn, creds, call_id):
285     def _DeleteSliver(server, xrn, creds, call_id):
286         server_version = api.get_cached_server_version(server)
287         args =  [xrn, creds]
288         if _call_id_supported(api, server):
289             args.append(call_id)
290         return server.DeleteSliver(*args)
291
292     if Callids().already_handled(call_id): return ""
293     (hrn, type) = urn_to_hrn(xrn)
294     # get the callers hrn
295     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
296     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
297
298     # attempt to use delegated credential first
299     credential = api.getDelegatedCredential(creds)
300     if not credential:
301         credential = api.getCredential()
302     threads = ThreadManager()
303     for aggregate in api.aggregates:
304         # prevent infinite loop. Dont send request back to caller
305         # unless the caller is the aggregate's SM
306         if caller_hrn == aggregate and aggregate != api.hrn:
307             continue
308         server = api.aggregates[aggregate]
309         threads.run(_DeleteSliver, server, xrn, credential, call_id)
310     threads.get_results()
311     return 1
312
313
314 # first draft at a merging SliverStatus
315 def SliverStatus(api, slice_xrn, creds, call_id):
316     def _SliverStatus(server, xrn, creds, call_id):
317         server_version = api.get_cached_server_version(server)
318         args =  [xrn, creds]
319         if _call_id_supported(api, server):
320             args.append(call_id)
321         return server.SliverStatus(*args)
322     
323     if Callids().already_handled(call_id): return {}
324     # attempt to use delegated credential first
325     credential = api.getDelegatedCredential(creds)
326     if not credential:
327         credential = api.getCredential()
328     threads = ThreadManager()
329     for aggregate in api.aggregates:
330         server = api.aggregates[aggregate]
331         threads.run (_SliverStatus, server, slice_xrn, credential, call_id)
332     results = threads.get_results()
333
334     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
335     results = [ result for result in results if result and result['geni_resources']]
336
337     # do not try to combine if there's no result
338     if not results : return {}
339
340     # otherwise let's merge stuff
341     overall = {}
342
343     # mmh, it is expected that all results carry the same urn
344     overall['geni_urn'] = results[0]['geni_urn']
345     overall['pl_login'] = results[0]['pl_login']
346     # append all geni_resources
347     overall['geni_resources'] = \
348         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
349     overall['status'] = 'unknown'
350     if overall['geni_resources']:
351         overall['status'] = 'ready'
352
353     return overall
354
355 caching=True
356 #caching=False
357 def ListSlices(api, creds, call_id):
358     def _ListSlices(server, creds, call_id):
359         server_version = api.get_cached_server_version(server)
360         args =  [creds]
361         if _call_id_supported(api, server):
362             args.append(call_id)
363         return server.ListSlices(*args)
364
365     if Callids().already_handled(call_id): return []
366
367     # look in cache first
368     if caching and api.cache:
369         slices = api.cache.get('slices')
370         if slices:
371             return slices
372
373     # get the callers hrn
374     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
375     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
376
377     # attempt to use delegated credential first
378     credential = api.getDelegatedCredential(creds)
379     if not credential:
380         credential = api.getCredential()
381     threads = ThreadManager()
382     # fetch from aggregates
383     for aggregate in api.aggregates:
384         # prevent infinite loop. Dont send request back to caller
385         # unless the caller is the aggregate's SM
386         if caller_hrn == aggregate and aggregate != api.hrn:
387             continue
388         server = api.aggregates[aggregate]
389         threads.run(_ListSlices, server, credential, call_id)
390
391     # combime results
392     results = threads.get_results()
393     slices = []
394     for result in results:
395         slices.extend(result)
396
397     # cache the result
398     if caching and api.cache:
399         api.cache.add('slices', slices)
400
401     return slices
402
403
404 def get_ticket(api, xrn, creds, rspec, users):
405     slice_hrn, type = urn_to_hrn(xrn)
406     # get the netspecs contained within the clients rspec
407     aggregate_rspecs = {}
408     tree= etree.parse(StringIO(rspec))
409     elements = tree.findall('./network')
410     for element in elements:
411         aggregate_hrn = element.values()[0]
412         aggregate_rspecs[aggregate_hrn] = rspec 
413
414     # get the callers hrn
415     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
416     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
417
418     # attempt to use delegated credential first
419     credential = api.getDelegatedCredential(creds)
420     if not credential:
421         credential = api.getCredential() 
422     threads = ThreadManager()
423     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
424         # prevent infinite loop. Dont send request back to caller
425         # unless the caller is the aggregate's SM
426         if caller_hrn == aggregate and aggregate != api.hrn:
427             continue
428         server = None
429         if aggregate in api.aggregates:
430             server = api.aggregates[aggregate]
431         else:
432             net_urn = hrn_to_urn(aggregate, 'authority')     
433             # we may have a peer that knows about this aggregate
434             for agg in api.aggregates:
435                 target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
436                 if not target_aggs or not 'hrn' in target_aggs[0]:
437                     continue
438                 # send the request to this address 
439                 url = target_aggs[0]['url']
440                 server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file, timeout=30)
441                 # aggregate found, no need to keep looping
442                 break   
443         if server is None:
444             continue 
445         threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
446
447     results = threads.get_results()
448     
449     # gather information from each ticket 
450     rspecs = []
451     initscripts = []
452     slivers = [] 
453     object_gid = None  
454     for result in results:
455         agg_ticket = SfaTicket(string=result)
456         attrs = agg_ticket.get_attributes()
457         if not object_gid:
458             object_gid = agg_ticket.get_gid_object()
459         rspecs.append(agg_ticket.get_rspec())
460         initscripts.extend(attrs.get('initscripts', [])) 
461         slivers.extend(attrs.get('slivers', [])) 
462     
463     # merge info
464     attributes = {'initscripts': initscripts,
465                  'slivers': slivers}
466     merged_rspec = merge_rspecs(rspecs) 
467
468     # create a new ticket
469     ticket = SfaTicket(subject = slice_hrn)
470     ticket.set_gid_caller(api.auth.client_gid)
471     ticket.set_issuer(key=api.key, subject=api.hrn)
472     ticket.set_gid_object(object_gid)
473     ticket.set_pubkey(object_gid.get_pubkey())
474     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
475     ticket.set_attributes(attributes)
476     ticket.set_rspec(merged_rspec)
477     ticket.encode()
478     ticket.sign()          
479     return ticket.save_to_string(save_parents=True)
480
481 def start_slice(api, xrn, creds):
482     hrn, type = urn_to_hrn(xrn)
483
484     # get the callers hrn
485     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
486     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
487
488     # attempt to use delegated credential first
489     credential = api.getDelegatedCredential(creds)
490     if not credential:
491         credential = api.getCredential()
492     threads = ThreadManager()
493     for aggregate in api.aggregates:
494         # prevent infinite loop. Dont send request back to caller
495         # unless the caller is the aggregate's SM
496         if caller_hrn == aggregate and aggregate != api.hrn:
497             continue
498         server = api.aggregates[aggregate]
499         threads.run(server.Start, xrn, credential)
500     threads.get_results()    
501     return 1
502  
503 def stop_slice(api, xrn, creds):
504     hrn, type = urn_to_hrn(xrn)
505
506     # get the callers hrn
507     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
508     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
509
510     # attempt to use delegated credential first
511     credential = api.getDelegatedCredential(creds)
512     if not credential:
513         credential = api.getCredential()
514     threads = ThreadManager()
515     for aggregate in api.aggregates:
516         # prevent infinite loop. Dont send request back to caller
517         # unless the caller is the aggregate's SM
518         if caller_hrn == aggregate and aggregate != api.hrn:
519             continue
520         server = api.aggregates[aggregate]
521         threads.run(server.Stop, xrn, credential)
522     threads.get_results()    
523     return 1
524
525 def reset_slice(api, xrn):
526     """
527     Not implemented
528     """
529     return 1
530
531 def shutdown(api, xrn, creds):
532     """
533     Not implemented   
534     """
535     return 1
536
537 def status(api, xrn, creds):
538     """
539     Not implemented 
540     """
541     return 1
542
543 def main():
544     r = RSpec()
545     r.parseFile(sys.argv[1])
546     rspec = r.toDict()
547     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
548
549 if __name__ == "__main__":
550     main()
551