if the remote agg. does not have person with peer_id, try to use the local account...
[sfa.git] / sfa / plc / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6 import traceback
7 import sys
8
9 from types import StringTypes
10 from sfa.util.misc import *
11 from sfa.util.rspec import *
12 from sfa.util.specdict import *
13 from sfa.util.faults import *
14 from sfa.util.storage import *
15 from sfa.util.record import GeniRecord
16 from sfa.util.policy import Policy
17 from sfa.util.prefixTree import prefixTree
18 from sfa.util.debug import log
19 from sfa.server.aggregate import Aggregates
20 from sfa.server.registry import Registries
21
22 class Slices(SimpleStorage):
23
24     def __init__(self, api, ttl = .5, caller_cred=None):
25         self.api = api
26         self.ttl = ttl
27         self.threshold = None
28         path = self.api.config.SFA_BASE_DIR
29         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
30         filepath = path + os.sep + filename
31         self.slices_file = filepath
32         SimpleStorage.__init__(self, self.slices_file)
33         self.policy = Policy(self.api)    
34         self.load()
35         self.caller_cred=caller_cred
36
37
38     def get_peer(self, hrn):
39         # Becaues of myplc federation,  we first need to determine if this
40         # slice belongs to out local plc or a myplc peer. We will assume it 
41         # is a local site, unless we find out otherwise  
42         peer = None
43
44         # get this slice's authority (site)
45         slice_authority = get_authority(hrn)
46
47         # get this site's authority (sfa root authority or sub authority)
48         site_authority = get_authority(slice_authority).lower()
49
50         # check if we are already peered with this site_authority, if so
51         peers = self.api.plshell.GetPeers(self.api.plauth, {}, ['peer_id', 'peername', 'shortname', 'hrn_root'])
52         for peer_record in peers:
53             names = [name.lower() for name in peer_record.values() if isinstance(name, StringTypes)]
54             if site_authority in names:
55                 peer = peer_record['shortname']
56
57         return peer
58
59     def get_sfa_peer(self, hrn):
60         # return the authority for this hrn or None if we are the authority
61         sfa_peer = None
62         slice_authority = get_authority(hrn)
63         site_authority = get_authority(slice_authority)
64
65         if site_authority != self.api.hrn:
66             sfa_peer = site_authority
67
68         return sfa_peer 
69
70     def refresh(self):
71         """
72         Update the cached list of slices
73         """
74         # Reload components list
75         now = datetime.datetime.now()
76         if not self.has_key('threshold') or not self.has_key('timestamp') or \
77            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
78             if self.api.interface in ['aggregate']:
79                 self.refresh_slices_aggregate()
80             elif self.api.interface in ['slicemgr']:
81                 self.refresh_slices_smgr()
82
83     def refresh_slices_aggregate(self):
84         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
85         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
86
87          # update timestamp and threshold
88         timestamp = datetime.datetime.now()
89         hr_timestamp = timestamp.strftime(self.api.time_format)
90         delta = datetime.timedelta(hours=self.ttl)
91         threshold = timestamp + delta
92         hr_threshold = threshold.strftime(self.api.time_format)
93         
94         slice_details = {'hrn': slice_hrns,
95                          'timestamp': hr_timestamp,
96                          'threshold': hr_threshold
97                         }
98         self.update(slice_details)
99         self.write()     
100         
101
102     def refresh_slices_smgr(self):
103         slice_hrns = []
104         aggregates = Aggregates(self.api)
105         credential = self.api.getCredential()
106         arg_list = [credential]
107         request_hash = self.api.key.compute_hash(arg_list)
108         for aggregate in aggregates:
109             try:
110                 slices = aggregates[aggregate].get_slices(credential, request_hash)
111                 slice_hrns.extend(slices)
112             except:
113                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
114          # update timestamp and threshold
115         timestamp = datetime.datetime.now()
116         hr_timestamp = timestamp.strftime(self.api.time_format)
117         delta = datetime.timedelta(hours=self.ttl)
118         threshold = timestamp + delta
119         hr_threshold = threshold.strftime(self.api.time_format)
120
121         slice_details = {'hrn': slice_hrns,
122                          'timestamp': hr_timestamp,
123                          'threshold': hr_threshold
124                         }
125         self.update(slice_details)
126         self.write()
127
128
129     def delete_slice(self, hrn):
130         if self.api.interface in ['aggregate']:
131             self.delete_slice_aggregate(hrn)
132         elif self.api.interface in ['slicemgr']:
133             self.delete_slice_smgr(hrn)
134         
135     def delete_slice_aggregate(self, hrn):
136
137         slicename = hrn_to_pl_slicename(hrn)
138         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename})
139         if not slices:
140             return 1        
141         slice = slices[0]
142
143         # determine if this is a peer slice
144         peer = self.get_peer(hrn)
145         if peer:
146             self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer)
147         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
148         if peer:
149             self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice['slice_id'], peer, slice['peer_slice_id'])
150         return 1
151
152     def delete_slice_smgr(self, hrn):
153         credential = self.api.getCredential()
154         caller_cred = self.caller_cred
155         aggregates = Aggregates(self.api)
156         arg_list = [credential, hrn]
157         request_hash = self.api.key.compute_hash(arg_list)
158         for aggregate in aggregates:
159             try:
160                 aggregates[aggregate].delete_slice(credential, hrn, request_hash, caller_cred)
161             except:
162                 print >> log, "Error calling list nodes at aggregate %s" % aggregate
163                 traceback.print_exc(log)
164                 exc_type, exc_value, exc_traceback = sys.exc_info()
165                 print exc_type, exc_value, exc_traceback
166
167     def create_slice(self, hrn, rspec):
168         
169         # check our slice policy before we procede
170         whitelist = self.policy['slice_whitelist']     
171         blacklist = self.policy['slice_blacklist']
172        
173         if whitelist and hrn not in whitelist or \
174            blacklist and hrn in blacklist:
175             policy_file = self.policy.policy_file
176             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
177             return 1
178
179         if self.api.interface in ['aggregate']:     
180             self.create_slice_aggregate(hrn, rspec)
181         elif self.api.interface in ['slicemgr']:
182             self.create_slice_smgr(hrn, rspec)
183
184     def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer):
185         authority = get_authority(slice_hrn)
186         arg_list = [credential, authority]
187         request_hash = self.api.key.compute_hash(arg_list)
188         site_records = registry.resolve(credential, authority, request_hash)
189         site = {}
190         for site_record in site_records:
191             if site_record['type'] == 'authority':
192                 site = site_record
193         if not site:
194             raise RecordNotFound(authority)
195         remote_site_id = site.pop('site_id')    
196                 
197         login_base = get_leaf(authority)
198         sites = self.api.plshell.GetSites(self.api.plauth, login_base)
199         if not sites:
200             site_id = self.api.plshell.AddSite(self.api.plauth, site)
201             if peer:
202                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)   
203             # mark this site as an sfa peer record
204             if sfa_peer:
205                 peer_dict = {'type': 'authority', 'hrn': authority, 'peer_authority': sfa_peer, 'pointer': site_id}
206                 arg_list = [credential]
207                 request_hash = self.api.key.compute_hash(arg_list) 
208                 registry.register_peer_object(credential, peer_dict, request_hash)
209                 pass
210         else:
211             site_id = sites[0]['site_id']
212             remote_site_id = sites[0]['peer_site_id']
213
214
215         return (site_id, remote_site_id) 
216
217     def verify_slice(self, registry, credential, slice_hrn, site_id, remote_site_id, peer, sfa_peer):
218         slice = {}
219         slice_record = None
220         authority = get_authority(slice_hrn)
221         arg_list = [credential, slice_hrn]
222         request_hash = self.api.key.compute_hash(arg_list)
223         slice_records = registry.resolve(credential, slice_hrn, request_hash)
224         for record in slice_records:
225             if record['type'] in ['slice']:
226                 slice_record = record
227         if not slice_record:
228             raise RecordNotFound(hrn)
229         slicename = hrn_to_pl_slicename(slice_hrn)
230         parts = slicename.split("_")
231         login_base = parts[0]
232         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['slice_id', 'node_ids', 'site_id']) 
233         if not slices:
234             slice_fields = {}
235             slice_keys = ['name', 'url', 'description']
236             for key in slice_keys:
237                 if key in slice_record and slice_record[key]:
238                     slice_fields[key] = slice_record[key]
239
240             # add the slice  
241             slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
242             slice = slice_fields
243             slice['slice_id'] = slice_id
244
245             # mark this slice as an sfa peer record
246             if sfa_peer:
247                 peer_dict = {'type': 'slice', 'hrn': slice_hrn, 'peer_authority': sfa_peer, 'pointer': slice_id}
248                 arg_list = [credential]
249                 request_hash = self.api.key.compute_hash(arg_list) 
250                 registry.register_peer_object(credential, peer_dict, request_hash)
251                 pass
252
253             #this belongs to a peer
254             if peer:
255                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
256             slice['node_ids'] = []
257         else:
258             slice = slices[0]
259             slice_id = slice['slice_id']
260             site_id = slice['site_id']
261
262         slice['peer_slice_id'] = slice_record['pointer']
263         self.verify_persons(registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer)
264     
265         return slice        
266
267     def verify_persons(self, registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer):
268         # get the list of valid slice users from the registry and make 
269         # sure they are added to the slice 
270         slicename = hrn_to_pl_slicename(slice_record['hrn'])
271         researchers = slice_record.get('researcher', [])
272         for researcher in researchers:
273             arg_list = [credential, researcher]
274             request_hash = self.api.key.compute_hash(arg_list) 
275             person_record = {}
276             person_records = registry.resolve(credential, researcher, request_hash)
277             for record in person_records:
278                 if record['type'] in ['user']:
279                     person_record = record
280             if not person_record:
281                 pass
282             person_dict = person_record
283             local_person=False
284             if peer:
285                 peer_id = self.api.plshell.GetPeers(self.api.plauth, {'shortname': peer}, ['peer_id'])[0]['peer_id']
286                 persons = self.api.plshell.GetPersons(self.api.plauth, {'email': [person_dict['email']], 'peer_id': peer_id}, ['person_id', 'key_ids'])
287                 if not persons:
288                     persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
289                     if persons:
290                        local_person=True
291
292             else:
293                 persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])   
294         
295             if not persons:
296                 person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
297                 self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
298                 
299                 # mark this person as an sfa peer record
300                 if sfa_peer:
301                     peer_dict = {'type': 'user', 'hrn': researcher, 'peer_authority': sfa_peer, 'pointer': person_id}
302                     arg_list = [credential]
303                     request_hash = self.api.key.compute_hash(arg_list) 
304                     registry.register_peer_object(credential, peer_dict, request_hash)
305                     pass
306
307                 if peer:
308                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
309                 key_ids = []
310             else:
311                 person_id = persons[0]['person_id']
312                 key_ids = persons[0]['key_ids']
313
314
315             # if this is a peer person, we must unbind them from the peer or PLCAPI will throw
316             # an error
317             if peer:
318                 self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
319                 self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site_id,  peer)
320
321             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)
322             self.api.plshell.AddPersonToSite(self.api.plauth, person_dict['email'], site_id)
323             if peer and not local_person:
324                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
325             if peer:
326                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
327             
328             self.verify_keys(registry, credential, person_dict, key_ids, person_id, peer, local_person)
329
330     def verify_keys(self, registry, credential, person_dict, key_ids, person_id,  peer, local_person):
331         keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
332         keys = [key['key'] for key in keylist]
333         
334         #add keys that arent already there
335         key_ids = person_dict['key_ids']
336         for personkey in person_dict['keys']:
337             if personkey not in keys:
338                 key = {'key_type': 'ssh', 'key': personkey}
339                 if peer:
340                     self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
341                 key_id = self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
342                 if peer and not local_person:
343                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
344                 if peer:
345                     try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key_id, peer, key_ids.pop(0))
346
347                     except: pass   
348
349     def create_slice_aggregate(self, hrn, rspec):
350
351         # Determine if this is a peer slice
352         peer = self.get_peer(hrn)
353         sfa_peer = self.get_sfa_peer(hrn)
354
355         spec = RSpec(rspec)
356         # Get the slice record from sfa
357         slicename = hrn_to_pl_slicename(hrn) 
358         slice = {}
359         slice_record = None
360         registries = Registries(self.api)
361         registry = registries[self.api.hrn]
362         credential = self.api.getCredential()
363
364         site_id, remote_site_id = self.verify_site(registry, credential, hrn, peer, sfa_peer)
365         slice = self.verify_slice(registry, credential, hrn, site_id, remote_site_id, peer, sfa_peer)
366
367         # find out where this slice is currently running
368         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
369         hostnames = [node['hostname'] for node in nodelist]
370
371         # get netspec details
372         nodespecs = spec.getDictsByTagName('NodeSpec')
373         nodes = []
374         for nodespec in nodespecs:
375             if isinstance(nodespec['name'], list):
376                 nodes.extend(nodespec['name'])
377             elif isinstance(nodespec['name'], StringTypes):
378                 nodes.append(nodespec['name'])
379
380         # remove nodes not in rspec
381         deleted_nodes = list(set(hostnames).difference(nodes))
382         # add nodes from rspec
383         added_nodes = list(set(nodes).difference(hostnames))
384
385         if peer:
386             self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer)
387         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
388         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
389         if peer:
390             self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice['slice_id'], peer, slice['peer_slice_id'])
391
392         return 1
393
394     def create_slice_smgr(self, hrn, rspec):
395         spec = RSpec()
396         tempspec = RSpec()
397         spec.parseString(rspec)
398         slicename = hrn_to_pl_slicename(hrn)
399         specDict = spec.toDict()
400         if specDict.has_key('RSpec'): specDict = specDict['RSpec']
401         if specDict.has_key('start_time'): start_time = specDict['start_time']
402         else: start_time = 0
403         if specDict.has_key('end_time'): end_time = specDict['end_time']
404         else: end_time = 0
405
406         rspecs = {}
407         aggregates = Aggregates(self.api)
408         credential = self.api.getCredential()
409
410         # split the netspecs into individual rspecs
411         netspecs = spec.getDictsByTagName('NetSpec')
412         for netspec in netspecs:
413             net_hrn = netspec['name']
414             resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
415             resourceDict = {'RSpec': resources}
416             tempspec.parseDict(resourceDict)
417             rspecs[net_hrn] = tempspec.toxml()
418
419         # send each rspec to the appropriate aggregate/sm
420         caller_cred = self.caller_cred 
421         for net_hrn in rspecs:
422             try:
423                 # if we are directly connected to the aggregate then we can just send them the rspec
424                 # if not, then we may be connected to an sm thats connected to the aggregate
425                 if net_hrn in aggregates:
426                     # send the whloe rspec to the local aggregate
427                     if net_hrn in [self.api.hrn]:
428                         arg_list = [credential,hrn,rspec]
429                         request_hash = self.api.key.compute_hash(arg_list)
430                         aggregates[net_hrn].create_slice(credential, hrn, rspec, request_hash, caller_cred)
431                     else:
432                         arg_list = [credential,hrn,rspecs[net_hrn]]
433                         request_hash = self.api.key.compute_hash(arg_list)
434                         aggregates[net_hrn].create_slice(credential, hrn, rspecs[net_hrn], request_hash, caller_cred)
435                 else:
436                     # lets forward this rspec to a sm that knows about the network
437                     arg_list = [credential, net_hrn]
438                     request_hash = self.api.compute_hash(arg_list)    
439                     for aggregate in aggregates:
440                         network_found = aggregates[aggregate].get_aggregates(credential, net_hrn, request_hash)
441                         if network_networks:
442                             arg_list = [credential, hrn, rspecs[net_hrn]]
443                             request_hash = self.api.key.compute_hash(arg_list) 
444                             aggregates[aggregate].create_slice(credential, hrn, rspecs[net_hrn], request_hash, caller_cred)
445                      
446             except:
447                 print >> log, "Error creating slice %(hrn)s at aggregate %(net_hrn)s" % locals()
448                 traceback.print_exc()
449         return 1
450
451
452     def start_slice(self, hrn):
453         if self.api.interface in ['aggregate']:
454             self.start_slice_aggregate(hrn)
455         elif self.api.interface in ['slicemgr']:
456             self.start_slice_smgr(hrn)
457
458     def start_slice_aggregate(self, hrn):
459         slicename = hrn_to_pl_slicename(hrn)
460         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
461         if not slices:
462             raise RecordNotFound(hrn)
463         slice_id = slices[0]
464         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
465         attribute_id = attreibutes[0]['slice_attribute_id']
466         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
467         return 1
468
469     def start_slice_smgr(self, hrn):
470         credential = self.api.getCredential()
471         aggregates = Aggregates(self.api)
472         for aggregate in aggregates:
473             aggregates[aggregate].start_slice(credential, hrn)
474         return 1
475
476
477     def stop_slice(self, hrn):
478         if self.api.interface in ['aggregate']:
479             self.stop_slice_aggregate(hrn)
480         elif self.api.interface in ['slicemgr']:
481             self.stop_slice_smgr(hrn)
482
483     def stop_slice_aggregate(self, hrn):
484         slicename = hrn_to_pl_slicename(hrn)
485         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
486         if not slices:
487             raise RecordNotFound(hrn)
488         slice_id = slices[0]['slice_id']
489         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
490         attribute_id = attributes[0]['slice_attribute_id']
491         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
492         return 1
493
494     def stop_slice_smgr(self, hrn):
495         credential = self.api.getCredential()
496         aggregates = Aggregates(self.api)
497         arg_list = [credential, hrn]
498         request_hash = self.api.key.compute_hash(arg_list)
499         for aggregate in aggregates:
500             aggregates[aggregate].stop_slice(credential, hrn, request_hash)  
501