065687d8b2b5a809d43099ba53f36818cc699470
[sfa.git] / sfa / managers / slice_manager_pl.py
1
2 import sys
3 import time,datetime
4 from StringIO import StringIO
5 from types import StringTypes
6 from copy import deepcopy
7 from copy import copy
8 from lxml import etree
9
10 from sfa.util.sfalogging import sfa_logger
11 from sfa.util.rspecHelper import merge_rspecs
12 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
13 from sfa.util.plxrn import hrn_to_pl_slicename
14 from sfa.util.rspec import *
15 from sfa.util.specdict import *
16 from sfa.util.faults import *
17 from sfa.util.record import SfaRecord
18 from sfa.rspecs.pg_rspec import PGRSpec
19 from sfa.rspecs.sfa_rspec import SfaRSpec
20 from sfa.rspecs.rspec_converter import RSpecConverter
21 from sfa.rspecs.rspec_parser import parse_rspec    
22 from sfa.rspecs.rspec_version import RSpecVersion
23 from sfa.rspecs.sfa_rspec import sfa_rspec_version
24 from sfa.rspecs.pg_rspec import pg_rspec_version    
25 from sfa.util.policy import Policy
26 from sfa.util.prefixTree import prefixTree
27 from sfa.util.sfaticket import *
28 from sfa.trust.credential import Credential
29 from sfa.util.threadmanager import ThreadManager
30 import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
31 import sfa.plc.peers as peers
32 from sfa.util.version import version_core
33 from sfa.util.callids import Callids
34
35 # we have specialized xmlrpclib.ServerProxy to remember the input url
36 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
37 def get_serverproxy_url (server):
38     try:
39         return server.url
40     except:
41         sfa_logger().warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
42         return server._ServerProxy__host + server._ServerProxy__handler 
43
44 def GetVersion(api):
45     # peers explicitly in aggregates.xml
46     peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems() 
47                    if peername != api.hrn])
48     xrn=Xrn (api.hrn)
49     supported_rspecs = [dict(pg_rspec_version), dict(sfa_rspec_version)]
50     version_more = {'interface':'slicemgr',
51                     'hrn' : xrn.get_hrn(),
52                     'urn' : xrn.get_urn(),
53                     'peers': peers,
54                     'request_rspec_versions': supported_rspecs,
55                     'ad_rspec_versions': supported_rspecs,
56                     'default_ad_rspec': dict(sfa_rspec_version)
57                     }
58     sm_version=version_core(version_more)
59     # local aggregate if present needs to have localhost resolved
60     if api.hrn in api.aggregates:
61         local_am_url=get_serverproxy_url(api.aggregates[api.hrn])
62         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
63     return sm_version
64
65 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
66
67     def _CreateSliver(aggregate, xrn, credential, rspec, users, call_id):
68             # Need to call GetVersion at an aggregate to determine the supported 
69             # rspec type/format beofre calling CreateSliver at an Aggregate. 
70             # The Aggregate's verion info is cached 
71             server = api.aggregates[aggregate]
72             # get cached aggregate version
73             aggregate_version_key = 'version_'+ aggregate
74             aggregate_version = api.cache.get(aggregate_version_key)
75             if not aggregate_version:
76                 # get current aggregate version anc cache it for 24 hours
77                 aggregate_version = server.GetVersion()
78                 api.cache.add(aggregate_version_key, aggregate_version, 60 * 60 * 24)
79                 
80             if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
81                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
82                 # if aggregate supports sfa rspecs. othewise convert to pg rspec
83                 rspec = RSpecConverter.to_pg_rspec(rspec)
84
85             return server.CreateSliver(xrn, credential, rspec, users, call_id)
86                 
87
88     if Callids().already_handled(call_id): return ""
89
90     # Validate the RSpec against PlanetLab's schema --disabled for now
91     # The schema used here needs to aggregate the PL and VINI schemas
92     # schema = "/var/www/html/schemas/pl.rng"
93     rspec = parse_rspec(rspec_str)
94     schema = None
95     if schema:
96         rspec.validate(schema)
97
98     # attempt to use delegated credential first
99     credential = api.getDelegatedCredential(creds)
100     if not credential:
101         credential = api.getCredential()
102
103     # get the callers hrn
104     hrn, type = urn_to_hrn(xrn)
105     valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0]
106     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
107     threads = ThreadManager()
108     for aggregate in api.aggregates:
109         # prevent infinite loop. Dont send request back to caller
110         # unless the caller is the aggregate's SM 
111         if caller_hrn == aggregate and aggregate != api.hrn:
112             continue
113             
114         # Just send entire RSpec to each aggregate
115         threads.run(_CreateSliver, aggregate, xrn, credential, rspec.toxml(), users, call_id)
116             
117     results = threads.get_results()
118     rspec = SfaRSpec()
119     for result in results:
120         rspec.merge(result)     
121     return rspec.toxml()
122
123 def RenewSliver(api, xrn, creds, expiration_time, call_id):
124     if Callids().already_handled(call_id): return True
125
126     (hrn, type) = urn_to_hrn(xrn)
127     # get the callers hrn
128     valid_cred = api.auth.checkCredentials(creds, 'renewsliver', hrn)[0]
129     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
130
131     # attempt to use delegated credential first
132     credential = api.getDelegatedCredential(creds)
133     if not credential:
134         credential = api.getCredential()
135     threads = ThreadManager()
136     for aggregate in api.aggregates:
137         # prevent infinite loop. Dont send request back to caller
138         # unless the caller is the aggregate's SM
139         if caller_hrn == aggregate and aggregate != api.hrn:
140             continue
141
142         server = api.aggregates[aggregate]
143         threads.run(server.RenewSliver, xrn, [credential], expiration_time, call_id)
144     # 'and' the results
145     return reduce (lambda x,y: x and y, threads.get_results() , True)
146
147 def get_ticket(api, xrn, creds, rspec, users):
148     slice_hrn, type = urn_to_hrn(xrn)
149     # get the netspecs contained within the clients rspec
150     aggregate_rspecs = {}
151     tree= etree.parse(StringIO(rspec))
152     elements = tree.findall('./network')
153     for element in elements:
154         aggregate_hrn = element.values()[0]
155         aggregate_rspecs[aggregate_hrn] = rspec 
156
157     # get the callers hrn
158     valid_cred = api.auth.checkCredentials(creds, 'getticket', slice_hrn)[0]
159     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
160
161     # attempt to use delegated credential first
162     credential = api.getDelegatedCredential(creds)
163     if not credential:
164         credential = api.getCredential() 
165     threads = ThreadManager()
166     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
167         # prevent infinite loop. Dont send request back to caller
168         # unless the caller is the aggregate's SM
169         if caller_hrn == aggregate and aggregate != api.hrn:
170             continue
171         server = None
172         if aggregate in api.aggregates:
173             server = api.aggregates[aggregate]
174         else:
175             net_urn = hrn_to_urn(aggregate, 'authority')     
176             # we may have a peer that knows about this aggregate
177             for agg in api.aggregates:
178                 target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
179                 if not target_aggs or not 'hrn' in target_aggs[0]:
180                     continue
181                 # send the request to this address 
182                 url = target_aggs[0]['url']
183                 server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
184                 # aggregate found, no need to keep looping
185                 break   
186         if server is None:
187             continue 
188         threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
189
190     results = threads.get_results()
191     
192     # gather information from each ticket 
193     rspecs = []
194     initscripts = []
195     slivers = [] 
196     object_gid = None  
197     for result in results:
198         agg_ticket = SfaTicket(string=result)
199         attrs = agg_ticket.get_attributes()
200         if not object_gid:
201             object_gid = agg_ticket.get_gid_object()
202         rspecs.append(agg_ticket.get_rspec())
203         initscripts.extend(attrs.get('initscripts', [])) 
204         slivers.extend(attrs.get('slivers', [])) 
205     
206     # merge info
207     attributes = {'initscripts': initscripts,
208                  'slivers': slivers}
209     merged_rspec = merge_rspecs(rspecs) 
210
211     # create a new ticket
212     ticket = SfaTicket(subject = slice_hrn)
213     ticket.set_gid_caller(api.auth.client_gid)
214     ticket.set_issuer(key=api.key, subject=api.hrn)
215     ticket.set_gid_object(object_gid)
216     ticket.set_pubkey(object_gid.get_pubkey())
217     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
218     ticket.set_attributes(attributes)
219     ticket.set_rspec(merged_rspec)
220     ticket.encode()
221     ticket.sign()          
222     return ticket.save_to_string(save_parents=True)
223
224
225 def DeleteSliver(api, xrn, creds, call_id):
226     if Callids().already_handled(call_id): return ""
227     (hrn, type) = urn_to_hrn(xrn)
228     # get the callers hrn
229     valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
230     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
231
232     # attempt to use delegated credential first
233     credential = api.getDelegatedCredential(creds)
234     if not credential:
235         credential = api.getCredential()
236     threads = ThreadManager()
237     for aggregate in api.aggregates:
238         # prevent infinite loop. Dont send request back to caller
239         # unless the caller is the aggregate's SM
240         if caller_hrn == aggregate and aggregate != api.hrn:
241             continue
242         server = api.aggregates[aggregate]
243         threads.run(server.DeleteSliver, xrn, credential, call_id)
244     threads.get_results()
245     return 1
246
247 def start_slice(api, xrn, creds):
248     hrn, type = urn_to_hrn(xrn)
249
250     # get the callers hrn
251     valid_cred = api.auth.checkCredentials(creds, 'startslice', hrn)[0]
252     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
253
254     # attempt to use delegated credential first
255     credential = api.getDelegatedCredential(creds)
256     if not credential:
257         credential = api.getCredential()
258     threads = ThreadManager()
259     for aggregate in api.aggregates:
260         # prevent infinite loop. Dont send request back to caller
261         # unless the caller is the aggregate's SM
262         if caller_hrn == aggregate and aggregate != api.hrn:
263             continue
264         server = api.aggregates[aggregate]
265         threads.run(server.Start, xrn, credential)
266     threads.get_results()    
267     return 1
268  
269 def stop_slice(api, xrn, creds):
270     hrn, type = urn_to_hrn(xrn)
271
272     # get the callers hrn
273     valid_cred = api.auth.checkCredentials(creds, 'stopslice', hrn)[0]
274     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
275
276     # attempt to use delegated credential first
277     credential = api.getDelegatedCredential(creds)
278     if not credential:
279         credential = api.getCredential()
280     threads = ThreadManager()
281     for aggregate in api.aggregates:
282         # prevent infinite loop. Dont send request back to caller
283         # unless the caller is the aggregate's SM
284         if caller_hrn == aggregate and aggregate != api.hrn:
285             continue
286         server = api.aggregates[aggregate]
287         threads.run(server.Stop, xrn, credential)
288     threads.get_results()    
289     return 1
290
291 def reset_slice(api, xrn):
292     """
293     Not implemented
294     """
295     return 1
296
297 def shutdown(api, xrn, creds):
298     """
299     Not implemented   
300     """
301     return 1
302
303 def status(api, xrn, creds):
304     """
305     Not implemented 
306     """
307     return 1
308
309 # Thierry : caching at the slicemgr level makes sense to some extent
310 caching=True
311 #caching=False
312 def ListSlices(api, creds, call_id):
313
314     if Callids().already_handled(call_id): return []
315
316     # look in cache first
317     if caching and api.cache:
318         slices = api.cache.get('slices')
319         if slices:
320             return slices    
321
322     # get the callers hrn
323     valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
324     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
325
326     # attempt to use delegated credential first
327     credential = api.getDelegatedCredential(creds)
328     if not credential:
329         credential = api.getCredential()
330     threads = ThreadManager()
331     # fetch from aggregates
332     for aggregate in api.aggregates:
333         # prevent infinite loop. Dont send request back to caller
334         # unless the caller is the aggregate's SM
335         if caller_hrn == aggregate and aggregate != api.hrn:
336             continue
337         server = api.aggregates[aggregate]
338         threads.run(server.ListSlices, credential, call_id)
339
340     # combime results
341     results = threads.get_results()
342     slices = []
343     for result in results:
344         slices.extend(result)
345     
346     # cache the result
347     if caching and api.cache:
348         api.cache.add('slices', slices)
349
350     return slices
351
352
353 def ListResources(api, creds, options, call_id):
354
355     if Callids().already_handled(call_id): return ""
356
357     # get slice's hrn from options
358     xrn = options.get('geni_slice_urn', '')
359     (hrn, type) = urn_to_hrn(xrn)
360
361     # get the rspec's return format from options
362     rspec_version = RSpecVersion(options.get('rspec_version'))
363     version_string = "rspec_%s" % (rspec_version.get_version_name())
364
365     # look in cache first
366     if caching and api.cache and not xrn:
367         rspec =  api.cache.get(version_string)
368         if rspec:
369             return rspec
370
371     # get the callers hrn
372     valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
373     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
374
375     # attempt to use delegated credential first
376     credential = api.getDelegatedCredential(creds)
377     if not credential:
378         credential = api.getCredential()
379     threads = ThreadManager()
380     for aggregate in api.aggregates:
381         # prevent infinite loop. Dont send request back to caller
382         # unless the caller is the aggregate's SM
383         if caller_hrn == aggregate and aggregate != api.hrn:
384             continue
385         # get the rspec from the aggregate
386         server = api.aggregates[aggregate]
387         my_opts = copy(options)
388         my_opts['geni_compressed'] = False
389         threads.run(server.ListResources, credential, my_opts, call_id)
390                     
391     results = threads.get_results()
392     #results.append(open('/root/protogeni.rspec', 'r').read())
393     rspec_version = RSpecVersion(my_opts.get('rspec_version'))
394     if rspec_version['type'].lower() == 'protogeni':
395         rspec = PGRSpec()
396     else:
397         rspec = SfaRSpec()
398
399     for result in results:
400         print "RESULT"
401         try:
402             rspec.merge(result)
403         except:
404             raise
405             api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
406
407     # cache the result
408     if caching and api.cache and not xrn:
409         api.cache.add(version_string, rspec.toxml())
410  
411     return rspec.toxml()
412
413 # first draft at a merging SliverStatus
414 def SliverStatus(api, slice_xrn, creds, call_id):
415     if Callids().already_handled(call_id): return {}
416     # attempt to use delegated credential first
417     credential = api.getDelegatedCredential(creds)
418     if not credential:
419         credential = api.getCredential()
420     threads = ThreadManager()
421     for aggregate in api.aggregates:
422         server = api.aggregates[aggregate]
423         threads.run (server.SliverStatus, slice_xrn, credential, call_id)
424     results = threads.get_results()
425
426     # get rid of any void result - e.g. when call_id was hit where by convention we return {}
427     results = [ result for result in results if result and result['geni_resources']]
428
429     # do not try to combine if there's no result
430     if not results : return {}
431
432     # otherwise let's merge stuff
433     overall = {}
434
435     # mmh, it is expected that all results carry the same urn
436     overall['geni_urn'] = results[0]['geni_urn']
437
438     # consolidate geni_status - simple model using max on a total order
439     states = [ 'ready', 'configuring', 'failed', 'unknown' ]
440     # hash name to index
441     shash = dict ( zip ( states, range(len(states)) ) )
442     def combine_status (x,y):
443         return shash [ max (shash(x),shash(y)) ]
444     overall['geni_status'] = reduce (combine_status, [ result['geni_status'] for result in results], 'ready' )
445
446     # {'ready':0,'configuring':1,'failed':2,'unknown':3}
447     # append all geni_resources
448     overall['geni_resources'] = \
449         reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
450
451     return overall
452
453 def main():
454     r = RSpec()
455     r.parseFile(sys.argv[1])
456     rspec = r.toDict()
457     CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice')
458
459 if __name__ == "__main__":
460     main()
461