57d888a488eb7a767047320db4cdc6601229c18a
[sfa.git] / sfa / managers / slice_manager_pl.py
1
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.rspecs.pg_rspec import PGRSpec
19 from sfa.rspecs.sfa_rspec import SfaRSpec
20 from sfa.rspecs.rspec_converter import RSpecConverter
21 from sfa.rspecs.rspec_parser import parse_rspec    
22 from sfa.rspecs.rspec_version import RSpecVersion
23 from sfa.rspecs.sfa_rspec import sfa_rspec_version
24 from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version   
25 from sfa.util.policy import Policy
26 from sfa.util.prefixTree import prefixTree
27 from sfa.util.sfaticket import *
28 from sfa.trust.credential import Credential
29 from sfa.util.threadmanager import ThreadManager
30 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
31 import sfa.plc.peers as peers
32 from sfa.util.version import version_core
33 from sfa.util.callids import Callids
34
35
36 def _get_cached_server_version(api, server):
37     cache_key = server.url + "-version"
38     server_version = api.cache.get(cache_key)
39     if not server_version:
40         server_version = server.GetVersion()
41         # cache version for 24 hours
42         api.cache.add(cache_key, server_version, ttl= 60*60*24)
43     return server_version
44
45 def _call_id_supported(api, server):
46     """
47     Returns true if server support the optional call_id arg, false otherwise.
48     """
49     server_version = _get_cached_server_version(api, server)
50
51     if 'sfa' in server_version:
52         code_tag = server_version['code_tag']
53         code_tag_parts = code_tag.split("-")
54
55         version_parts = code_tag_parts[0].split(".")
56         major, minor = version_parts[0], version_parts[1]
57         rev = code_tag_parts[1]
58         if int(major) > 1:
59             if int(minor) > 0 or int(rev) > 20:
60                 return True
61     return False
62
63 # we have specialized xmlrpclib.ServerProxy to remember the input url
64 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
65 def get_serverproxy_url (server):
66     try:
67         return server.url
68     except:
69         logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
70         return server._ServerProxy__host + server._ServerProxy__handler 
71
72 def GetVersion(api):
73     # peers explicitly in aggregates.xml
74     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems() 
75                    if peername != api.hrn])
76     xrn=Xrn (api.hrn)
77     request_rspec_versions = [dict(pg_rspec_request_version), dict(sfa_rspec_version)]
78     ad_rspec_versions = [dict(pg_rspec_ad_version), dict(sfa_rspec_version)]
79     version_more = {'interface':'slicemgr',
80                     'hrn' : xrn.get_hrn(),
81                     'urn' : xrn.get_urn(),
82                     'peers': peers,
83                     'request_rspec_versions': request_rspec_versions,
84                     'ad_rspec_versions': ad_rspec_versions,
85                     'default_ad_rspec': dict(sfa_rspec_version)
86                     }
87     sm_version=version_core(version_more)
88     # local aggregate if present needs to have localhost resolved
89     if api.hrn in api.aggregates:
90         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
91         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
92     return sm_version
93
94
95 def ListResources(api, creds, options, call_id):
96     def _ListResources(server, credential, my_opts, call_id):
97         args = [credential, my_opts]
98         if _call_id_supported(api, server):
99             args.append(call_id)
100         try:
101             return server.ListResources(*args)
102         except Exception, e:
103             api.logger.warn("ListResources failed at %s: %s" %(server.url, str(e)))
104
105     if Callids().already_handled(call_id): return ""
106
107     # get slice's hrn from options
108     xrn = options.get('geni_slice_urn', '')
109     (hrn, type) = urn_to_hrn(xrn)
110     my_opts = copy(options)
111     my_opts['geni_compressed'] = False
112     if 'rspec_version' in my_opts:
113         del my_opts['rspec_version']
114
115     # get the rspec's return format from options
116     rspec_version = RSpecVersion(options.get('rspec_version'))
117     version_string = "rspec_%s" % (rspec_version.get_version_name())
118
119     # look in cache first
120     if caching and api.cache and not xrn:
121         rspec =  api.cache.get(version_string)
122         if rspec:
123             return rspec
124
125     # get the callers hrn
126     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
127     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
128
129     # attempt to use delegated credential first
130     credential = api.getDelegatedCredential(creds)
131     if not credential:
132         credential = api.getCredential()
133     credentials = [credential]
134     threads = ThreadManager()
135     for aggregate in api.aggregates:
136         # prevent infinite loop. Dont send request back to caller
137         # unless the caller is the aggregate's SM
138         if caller_hrn == aggregate and aggregate != api.hrn:
139             continue
140         # get the rspec from the aggregate
141         server = api.aggregates[aggregate]
142         #threads.run(server.ListResources, credentials, my_opts, call_id)
143         threads.run(_ListResources, server, credentials, my_opts, call_id)
144
145     results = threads.get_results()
146     rspec_version = RSpecVersion(my_opts.get('rspec_version'))
147     if rspec_version['type'] == pg_rspec_ad_version['type']:
148         rspec = PGRSpec()
149     else:
150         rspec = SfaRSpec()
151     for result in results:
152         try:
153             rspec.merge(result)
154         except:
155             api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
156
157     # cache the result
158     if caching and api.cache and not xrn:
159         api.cache.add(version_string, rspec.toxml())
160
161     return rspec.toxml()
162
163
164 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
165
166     def _CreateSliver(server, xrn, credential, rspec, users, call_id):
167             # Need to call GetVersion at an aggregate to determine the supported 
168             # rspec type/format beofre calling CreateSliver at an Aggregate. 
169             server_version = _get_server_version(api, server)    
170             if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
171                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
172                 # if aggregate supports sfa rspecs. othewise convert to pg rspec
173                 rspec = RSpecConverter.to_pg_rspec(rspec)
174             args = [xrn, credential, rspec, users]
175             if _call_id_supported(api, server):
176                 args.append(call_id)
177             try:
178                 return server.CreateSliver(*args)
179             except Exception, e:
180                 api.logger.warn("CreateSliver failed at %s: %s" %(server.url, str(e)))
181
182     if Callids().already_handled(call_id): return ""
183     # Validate the RSpec against PlanetLab's schema --disabled for now
184     # The schema used here needs to aggregate the PL and VINI schemas
185     # schema = "/var/www/html/schemas/pl.rng"
186     rspec = parse_rspec(rspec_str)
187     schema = None
188     if schema:
189         rspec.validate(schema)
190
191     # attempt to use delegated credential first
192     credential = api.getDelegatedCredential(creds)
193     if not credential:
194         credential = api.getCredential()
195
196     # get the callers hrn
197     hrn, type = urn_to_hrn(xrn)
198     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
199     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
200     threads = ThreadManager()
201     for aggregate in api.aggregates:
202         # prevent infinite loop. Dont send request back to caller
203         # unless the caller is the aggregate's SM 
204         if caller_hrn == aggregate and aggregate != api.hrn:
205             continue
206         server = api.aggregates[aggregate]
207         # Just send entire RSpec to each aggregate
208         threads.run(_CreateSliver, server, xrn, credential, rspec.toxml(), users, call_id)
209             
210     results = threads.get_results()
211     rspec = SfaRSpec()
212     for result in results:
213         rspec.merge(result)     
214     return rspec.toxml()
215
216 def RenewSliver(api, xrn, creds, expiration_time, call_id):
217     def _RenewSliver(server, xrn, creds, expiration_time, call_id):
218         server_version = _get_server_version(api, server)
219         args =  [xrn, creds, expiration_time, call_id]
220         if _call_id_supported(api, server):
221             args.append(call_id)
222         return server.RenewSliver(*args)
223
224     if Callids().already_handled(call_id): return True
225
226     (hrn, type) = urn_to_hrn(xrn)
227     # get the callers hrn
228     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
229     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
230
231     # attempt to use delegated credential first
232     credential = api.getDelegatedCredential(creds)
233     if not credential:
234         credential = api.getCredential()
235     threads = ThreadManager()
236     for aggregate in api.aggregates:
237         # prevent infinite loop. Dont send request back to caller
238         # unless the caller is the aggregate's SM
239         if caller_hrn == aggregate and aggregate != api.hrn:
240             continue
241         server = api.aggregates[aggregate]
242         threads.run(_RenewSliver, server, xrn, [credential], expiration_time, call_id)
243     # 'and' the results
244     return reduce (lambda x,y: x and y, threads.get_results() , True)
245
246 def DeleteSliver(api, xrn, creds, call_id):
247     def _DeleteSliver(server, xrn, creds, call_id):
248         server_version = _get_server_version(api, server)
249         args =  [xrn, creds]
250         if _call_id_supported(api, server):
251             args.append(call_id)
252         return server.DeleteSliver(*args)
253
254     if Callids().already_handled(call_id): return ""
255     (hrn, type) = urn_to_hrn(xrn)
256     # get the callers hrn
257     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
258     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
259
260     # attempt to use delegated credential first
261     credential = api.getDelegatedCredential(creds)
262     if not credential:
263         credential = api.getCredential()
264     threads = ThreadManager()
265     for aggregate in api.aggregates:
266         # prevent infinite loop. Dont send request back to caller
267         # unless the caller is the aggregate's SM
268         if caller_hrn == aggregate and aggregate != api.hrn:
269             continue
270         server = api.aggregates[aggregate]
271         threads.run(_DeleteSliver, server, xrn, credential, call_id)
272     threads.get_results()
273     return 1
274
275
276 # first draft at a merging SliverStatus
277 def SliverStatus(api, slice_xrn, creds, call_id):
278     def _SliverStatus(server, xrn, creds, call_id):
279         server_version = _get_server_version(api, server)
280         args =  [xrn, creds]
281         if _call_id_supported(api, server):
282             args.append(call_id)
283         return server.SliverStatus(*args)
284     
285     if Callids().already_handled(call_id): return {}
286     # attempt to use delegated credential first
287     credential = api.getDelegatedCredential(creds)
288     if not credential:
289         credential = api.getCredential()
290     threads = ThreadManager()
291     for aggregate in api.aggregates:
292         server = api.aggregates[aggregate]
293         threads.run (_SliverStatus, server, slice_xrn, credential, call_id)
294     results = threads.get_results()
295
296     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
297     results = [ result for result in results if result and result['geni_resources']]
298
299     # do not try to combine if there's no result
300     if not results : return {}
301
302     # otherwise let's merge stuff
303     overall = {}
304
305     # mmh, it is expected that all results carry the same urn
306     overall['geni_urn'] = results[0]['geni_urn']
307     overall['pl_login'] = results[0]['pl_login']
308     # append all geni_resources
309     overall['geni_resources'] = \
310         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
311     overall['status'] = 'unknown'
312     if overall['geni_resources']:
313         overall['status'] = 'ready'
314
315     return overall
316
317 caching=True
318 #caching=False
319 def ListSlices(api, creds, call_id):
320     def _ListSlices(server, creds, call_id):
321         server_version = _get_server_version(api, server)
322         args =  [creds]
323         if _call_id_supported(api, server):
324             args.append(call_id)
325         return server.ListSlices(*args)
326
327     if Callids().already_handled(call_id): return []
328
329     # look in cache first
330     if caching and api.cache:
331         slices = api.cache.get('slices')
332         if slices:
333             return slices
334
335     # get the callers hrn
336     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
337     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
338
339     # attempt to use delegated credential first
340     credential = api.getDelegatedCredential(creds)
341     if not credential:
342         credential = api.getCredential()
343     threads = ThreadManager()
344     # fetch from aggregates
345     for aggregate in api.aggregates:
346         # prevent infinite loop. Dont send request back to caller
347         # unless the caller is the aggregate's SM
348         if caller_hrn == aggregate and aggregate != api.hrn:
349             continue
350         server = api.aggregates[aggregate]
351         threads.run(_ListSlices, server, credential, call_id)
352
353     # combime results
354     results = threads.get_results()
355     slices = []
356     for result in results:
357         slices.extend(result)
358
359     # cache the result
360     if caching and api.cache:
361         api.cache.add('slices', slices)
362
363     return slices
364
365
366 def get_ticket(api, xrn, creds, rspec, users):
367     slice_hrn, type = urn_to_hrn(xrn)
368     # get the netspecs contained within the clients rspec
369     aggregate_rspecs = {}
370     tree= etree.parse(StringIO(rspec))
371     elements = tree.findall('./network')
372     for element in elements:
373         aggregate_hrn = element.values()[0]
374         aggregate_rspecs[aggregate_hrn] = rspec 
375
376     # get the callers hrn
377     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
378     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
379
380     # attempt to use delegated credential first
381     credential = api.getDelegatedCredential(creds)
382     if not credential:
383         credential = api.getCredential() 
384     threads = ThreadManager()
385     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
386         # prevent infinite loop. Dont send request back to caller
387         # unless the caller is the aggregate's SM
388         if caller_hrn == aggregate and aggregate != api.hrn:
389             continue
390         server = None
391         if aggregate in api.aggregates:
392             server = api.aggregates[aggregate]
393         else:
394             net_urn = hrn_to_urn(aggregate, 'authority')     
395             # we may have a peer that knows about this aggregate
396             for agg in api.aggregates:
397                 target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
398                 if not target_aggs or not 'hrn' in target_aggs[0]:
399                     continue
400                 # send the request to this address 
401                 url = target_aggs[0]['url']
402                 server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
403                 # aggregate found, no need to keep looping
404                 break   
405         if server is None:
406             continue 
407         threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
408
409     results = threads.get_results()
410     
411     # gather information from each ticket 
412     rspecs = []
413     initscripts = []
414     slivers = [] 
415     object_gid = None  
416     for result in results:
417         agg_ticket = SfaTicket(string=result)
418         attrs = agg_ticket.get_attributes()
419         if not object_gid:
420             object_gid = agg_ticket.get_gid_object()
421         rspecs.append(agg_ticket.get_rspec())
422         initscripts.extend(attrs.get('initscripts', [])) 
423         slivers.extend(attrs.get('slivers', [])) 
424     
425     # merge info
426     attributes = {'initscripts': initscripts,
427                  'slivers': slivers}
428     merged_rspec = merge_rspecs(rspecs) 
429
430     # create a new ticket
431     ticket = SfaTicket(subject = slice_hrn)
432     ticket.set_gid_caller(api.auth.client_gid)
433     ticket.set_issuer(key=api.key, subject=api.hrn)
434     ticket.set_gid_object(object_gid)
435     ticket.set_pubkey(object_gid.get_pubkey())
436     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
437     ticket.set_attributes(attributes)
438     ticket.set_rspec(merged_rspec)
439     ticket.encode()
440     ticket.sign()          
441     return ticket.save_to_string(save_parents=True)
442
443 def start_slice(api, xrn, creds):
444     hrn, type = urn_to_hrn(xrn)
445
446     # get the callers hrn
447     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
448     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
449
450     # attempt to use delegated credential first
451     credential = api.getDelegatedCredential(creds)
452     if not credential:
453         credential = api.getCredential()
454     threads = ThreadManager()
455     for aggregate in api.aggregates:
456         # prevent infinite loop. Dont send request back to caller
457         # unless the caller is the aggregate's SM
458         if caller_hrn == aggregate and aggregate != api.hrn:
459             continue
460         server = api.aggregates[aggregate]
461         threads.run(server.Start, xrn, credential)
462     threads.get_results()    
463     return 1
464  
465 def stop_slice(api, xrn, creds):
466     hrn, type = urn_to_hrn(xrn)
467
468     # get the callers hrn
469     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
470     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
471
472     # attempt to use delegated credential first
473     credential = api.getDelegatedCredential(creds)
474     if not credential:
475         credential = api.getCredential()
476     threads = ThreadManager()
477     for aggregate in api.aggregates:
478         # prevent infinite loop. Dont send request back to caller
479         # unless the caller is the aggregate's SM
480         if caller_hrn == aggregate and aggregate != api.hrn:
481             continue
482         server = api.aggregates[aggregate]
483         threads.run(server.Stop, xrn, credential)
484     threads.get_results()    
485     return 1
486
487 def reset_slice(api, xrn):
488     """
489     Not implemented
490     """
491     return 1
492
493 def shutdown(api, xrn, creds):
494     """
495     Not implemented   
496     """
497     return 1
498
499 def status(api, xrn, creds):
500     """
501     Not implemented 
502     """
503     return 1
504
505 def main():
506     r = RSpec()
507     r.parseFile(sys.argv[1])
508     rspec = r.toDict()
509     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
510
511 if __name__ == "__main__":
512     main()
513