add peer records to registry db
[sfa.git] / sfa / plc / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6 import traceback
7 import sys
8
9 from types import StringTypes
10 from sfa.util.misc import *
11 from sfa.util.rspec import *
12 from sfa.util.specdict import *
13 from sfa.util.faults import *
14 from sfa.util.storage import *
15 from sfa.util.record import GeniRecord
16 from sfa.util.policy import Policy
17 from sfa.util.prefixTree import prefixTree
18 from sfa.util.debug import log
19 from sfa.server.aggregate import Aggregates
20 from sfa.server.registry import Registries
21
22 class Slices(SimpleStorage):
23
24     def __init__(self, api, ttl = .5, caller_cred=None):
25         self.api = api
26         self.ttl = ttl
27         self.threshold = None
28         path = self.api.config.SFA_BASE_DIR
29         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
30         filepath = path + os.sep + filename
31         self.slices_file = filepath
32         SimpleStorage.__init__(self, self.slices_file)
33         self.policy = Policy(self.api)    
34         self.load()
35         self.caller_cred=caller_cred
36
37
38     def get_peer(self, hrn):
39         # Becaues of myplc federation,  we first need to determine if this
40         # slice belongs to out local plc or a myplc peer. We will assume it 
41         # is a local site, unless we find out otherwise  
42         peer = None
43
44         # get this slice's authority (site)
45         slice_authority = get_authority(hrn)
46
47         # get this site's authority (sfa root authority or sub authority)
48         site_authority = get_authority(slice_authority).lower()
49
50         # check if we are already peered with this site_authority, if so
51         peers = self.api.plshell.GetPeers(self.api.plauth, {}, ['peer_id', 'peername', 'shortname', 'hrn_root'])
52         for peer_record in peers:
53             names = [name.lower() for name in peer_record.values() if isinstance(name, StringTypes)]
54             if site_authority in names:
55                 peer = peer_record['shortname']
56
57         return peer
58
59     def get_sfa_peer(self, hrn):
60         # return the authority for this hrn or None if we are the authority
61         sfa_peer = None
62         slice_authority = get_authority(hrn)
63         site_authority = get_authority(slice_authority)
64
65         if site_authority != self.api.hrn:
66             sfa_peer = site_authority
67
68         return sfa_peer 
69
70     def refresh(self):
71         """
72         Update the cached list of slices
73         """
74         # Reload components list
75         now = datetime.datetime.now()
76         if not self.has_key('threshold') or not self.has_key('timestamp') or \
77            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
78             if self.api.interface in ['aggregate']:
79                 self.refresh_slices_aggregate()
80             elif self.api.interface in ['slicemgr']:
81                 self.refresh_slices_smgr()
82
83     def refresh_slices_aggregate(self):
84         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
85         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
86
87          # update timestamp and threshold
88         timestamp = datetime.datetime.now()
89         hr_timestamp = timestamp.strftime(self.api.time_format)
90         delta = datetime.timedelta(hours=self.ttl)
91         threshold = timestamp + delta
92         hr_threshold = threshold.strftime(self.api.time_format)
93         
94         slice_details = {'hrn': slice_hrns,
95                          'timestamp': hr_timestamp,
96                          'threshold': hr_threshold
97                         }
98         self.update(slice_details)
99         self.write()     
100         
101
102     def refresh_slices_smgr(self):
103         slice_hrns = []
104         aggregates = Aggregates(self.api)
105         credential = self.api.getCredential()
106         for aggregate in aggregates:
107             try:
108                 slices = aggregates[aggregate].get_slices(credential)
109                 slice_hrns.extend(slices)
110             except:
111                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
112          # update timestamp and threshold
113         timestamp = datetime.datetime.now()
114         hr_timestamp = timestamp.strftime(self.api.time_format)
115         delta = datetime.timedelta(hours=self.ttl)
116         threshold = timestamp + delta
117         hr_threshold = threshold.strftime(self.api.time_format)
118
119         slice_details = {'hrn': slice_hrns,
120                          'timestamp': hr_timestamp,
121                          'threshold': hr_threshold
122                         }
123         self.update(slice_details)
124         self.write()
125
126
127     def delete_slice(self, hrn):
128         if self.api.interface in ['aggregate']:
129             self.delete_slice_aggregate(hrn)
130         elif self.api.interface in ['slicemgr']:
131             self.delete_slice_smgr(hrn)
132         
133     def delete_slice_aggregate(self, hrn):
134
135         slicename = hrn_to_pl_slicename(hrn)
136         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename})
137         if not slices:
138             return 1        
139         slice = slices[0]
140
141         # determine if this is a peer slice
142         peer = self.get_peer(hrn)
143         if peer:
144             self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer)
145         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
146         if peer:
147             self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice['slice_id'], peer, slice['peer_slice_id'])
148         return 1
149
150     def delete_slice_smgr(self, hrn):
151         credential = self.api.getCredential()
152         aggregates = Aggregates(self.api)
153         for aggregate in aggregates:
154             try:
155                 aggregates[aggregate].delete_slice(credential, hrn, caller_cred=self.caller_cred)
156             except:
157                 print >> log, "Error calling list nodes at aggregate %s" % aggregate
158                 traceback.print_exc(log)
159                 exc_type, exc_value, exc_traceback = sys.exc_info()
160                 print exc_type, exc_value, exc_traceback
161
162     def create_slice(self, hrn, rspec):
163         
164         # check our slice policy before we procede
165         whitelist = self.policy['slice_whitelist']     
166         blacklist = self.policy['slice_blacklist']
167        
168         if whitelist and hrn not in whitelist or \
169            blacklist and hrn in blacklist:
170             policy_file = self.policy.policy_file
171             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
172             return 1
173
174         if self.api.interface in ['aggregate']:     
175             self.create_slice_aggregate(hrn, rspec)
176         elif self.api.interface in ['slicemgr']:
177             self.create_slice_smgr(hrn, rspec)
178
179     def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer):
180         authority = get_authority(slice_hrn)
181         site_records = registry.resolve(credential, authority)
182         site = {}
183         for site_record in site_records:
184             if site_record['type'] == 'authority':
185                 site = site_record.as_dict()
186         if not site:
187             raise RecordNotFound(authority)
188         remote_site_id = site.pop('site_id')    
189                 
190         login_base = get_leaf(authority)
191         sites = self.api.plshell.GetSites(self.api.plauth, login_base)
192         if not sites:
193             site_id = self.api.plshell.AddSite(self.api.plauth, site)
194             if peer:
195                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)   
196             # mark this site as an sfa peer record
197             if sfa_peer:
198                 peer_dict = {'type': 'authority', 'hrn': authority, 'peer_authority': sfa_peer, 'pointer': site_id} 
199                 registry.register_peer_object(credential, peer_dict)
200                 pass
201         else:
202             site_id = sites[0]['site_id']
203             remote_site_id = sites[0]['peer_site_id']
204
205
206         return (site_id, remote_site_id) 
207
208     def verify_slice(self, registry, credential, slice_hrn, site_id, remote_site_id, peer, sfa_peer):
209         slice = {}
210         slice_record = None
211         authority = get_authority(slice_hrn)
212         slice_records = registry.resolve(credential, slice_hrn)
213         for record in slice_records:
214             if record['type'] in ['slice']:
215                 slice_record = record
216         if not slice_record:
217             raise RecordNotFound(hrn)
218         slicename = hrn_to_pl_slicename(slice_hrn)
219         parts = slicename.split("_")
220         login_base = parts[0]
221         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['slice_id', 'node_ids', 'site_id']) 
222         if not slices:
223             slice_fields = {}
224             slice_keys = ['name', 'url', 'description']
225             for key in slice_keys:
226                 if key in slice_record and slice_record[key]:
227                     slice_fields[key] = slice_record[key]
228
229             # add the slice  
230             slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
231             slice = slice_fields
232             slice['slice_id'] = slice_id
233
234             # mark this slice as an sfa peer record
235             if sfa_peer:
236                 peer_dict = {'type': 'slice', 'hrn': slice_hrn, 'peer_authority': sfa_peer, 'pointer': slice_id} 
237                 registry.register_peer_object(credential, peer_dict)
238                 pass
239
240             #this belongs to a peer
241             if peer:
242                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
243             slice['node_ids'] = []
244         else:
245             slice = slices[0]
246             slice_id = slice['slice_id']
247             site_id = slice['site_id']
248
249         slice['peer_slice_id'] = slice_record['pointer']
250         self.verify_persons(registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer)
251     
252         return slice        
253
254     def verify_persons(self, registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer):
255         # get the list of valid slice users from the registry and make 
256         # sure they are added to the slice 
257         slicename = hrn_to_pl_slicename(slice_record['hrn'])
258         researchers = slice_record.get('researcher', [])
259         for researcher in researchers:
260             person_record = {}
261             person_records = registry.resolve(credential, researcher)
262             for record in person_records:
263                 if record['type'] in ['user']:
264                     person_record = record
265             if not person_record:
266                 pass
267             person_dict = person_record.as_dict()
268             if peer:
269                 peer_id = self.api.plshell.GetPeers(self.api.plauth, {'shortname': peer}, ['peer_id'])[0]['peer_id']
270                 persons = self.api.plshell.GetPersons(self.api.plauth, {'email': [person_dict['email']], 'peer_id': peer_id}, ['person_id', 'key_ids'])
271
272             else:
273                 persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])   
274         
275             if not persons:
276                 person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
277                 self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
278                 
279                 # mark this person as an sfa peer record
280                 if sfa_peer:
281                     peer_dict = {'type': 'user', 'hrn': researcher, 'peer_authority': sfa_peer, 'pointer': person_id} 
282                     registry.register_peer_object(credential, peer_dict)
283                     pass
284
285                 if peer:
286                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
287                 key_ids = []
288             else:
289                 person_id = persons[0]['person_id']
290                 key_ids = persons[0]['key_ids']
291
292
293             # if this is a peer person, we must unbind them from the peer or PLCAPI will throw
294             # an error
295             if peer:
296                 self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
297                 self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site_id,  peer)
298
299             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)
300             self.api.plshell.AddPersonToSite(self.api.plauth, person_dict['email'], site_id)
301             if peer:
302                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
303                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
304             
305             self.verify_keys(registry, credential, person_dict, key_ids, person_id, peer)
306
307     def verify_keys(self, registry, credential, person_dict, key_ids, person_id,  peer):
308         keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
309         keys = [key['key'] for key in keylist]
310         
311         #add keys that arent already there
312         key_ids = person_dict['key_ids']
313         for personkey in person_dict['keys']:
314             if personkey not in keys:
315                 key = {'key_type': 'ssh', 'key': personkey}
316                 if peer:
317                     self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
318                 key_id = self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
319                 if peer:
320                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
321                     try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key_id, peer, key_ids.pop(0))
322
323                     except: pass   
324
325     def create_slice_aggregate(self, hrn, rspec):
326
327         # Determine if this is a peer slice
328         peer = self.get_peer(hrn)
329         sfa_peer = self.get_sfa_peer(hrn)
330
331         spec = Rspec(rspec)
332         # Get the slice record from sfa
333         slicename = hrn_to_pl_slicename(hrn) 
334         slice = {}
335         slice_record = None
336         registries = Registries(self.api)
337         registry = registries[self.api.hrn]
338         credential = self.api.getCredential()
339
340         site_id, remote_site_id = self.verify_site(registry, credential, hrn, peer, sfa_peer)
341         slice = self.verify_slice(registry, credential, hrn, site_id, remote_site_id, peer, sfa_peer)
342
343         # find out where this slice is currently running
344         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
345         hostnames = [node['hostname'] for node in nodelist]
346
347         # get netspec details
348         nodespecs = spec.getDictsByTagName('NodeSpec')
349         nodes = []
350         for nodespec in nodespecs:
351             if isinstance(nodespec['name'], list):
352                 nodes.extend(nodespec['name'])
353             elif isinstance(nodespec['name'], StringTypes):
354                 nodes.append(nodespec['name'])
355
356         # remove nodes not in rspec
357         deleted_nodes = list(set(hostnames).difference(nodes))
358         # add nodes from rspec
359         added_nodes = list(set(nodes).difference(hostnames))
360
361         if peer:
362             self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer)
363         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
364         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
365         if peer:
366             self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice['slice_id'], peer, slice['peer_slice_id'])
367
368         return 1
369
370     def create_slice_smgr(self, hrn, rspec):
371         spec = Rspec()
372         tempspec = Rspec()
373         spec.parseString(rspec)
374         slicename = hrn_to_pl_slicename(hrn)
375         specDict = spec.toDict()
376         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
377         if specDict.has_key('start_time'): start_time = specDict['start_time']
378         else: start_time = 0
379         if specDict.has_key('end_time'): end_time = specDict['end_time']
380         else: end_time = 0
381
382         rspecs = {}
383         aggregates = Aggregates(self.api)
384         credential = self.api.getCredential()
385
386         # split the netspecs into individual rspecs
387         netspecs = spec.getDictsByTagName('NetSpec')
388         for netspec in netspecs:
389             net_hrn = netspec['name']
390             resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
391             resourceDict = {'Rspec': resources}
392             tempspec.parseDict(resourceDict)
393             rspecs[net_hrn] = tempspec.toxml()
394
395         # send each rspec to the appropriate aggregate/sm 
396         for net_hrn in rspecs:
397             try:
398                 # if we are directly connected to the aggregate then we can just send them the rspec
399                 # if not, then we may be connected to an sm thats connected to the aggregate  
400                 if net_hrn in aggregates:
401                     # send the whloe rspec to the local aggregate
402                     if net_hrn in [self.api.hrn]:
403                         aggregates[net_hrn].create_slice(credential, hrn, rspec, caller_cred=self.caller_cred)
404                     else:
405                         aggregates[net_hrn].create_slice(credential, hrn, rspecs[net_hrn], caller_cred=self.caller_cred)
406                 else:
407                     # lets forward this rspec to a sm that knows about the network    
408                     for aggregate in aggregates:
409                         network_found = aggregates[aggregate].get_aggregates(credential, net_hrn)
410                         if network_networks:
411                             aggregates[aggregate].create_slice(credential, hrn, rspecs[net_hrn], caller_cred=self.caller_cred)
412                      
413             except:
414                 print >> log, "Error creating slice %(hrn)s at aggregate %(net_hrn)s" % locals()
415                 traceback.print_exc()
416         return 1
417
418
419     def start_slice(self, hrn):
420         if self.api.interface in ['aggregate']:
421             self.start_slice_aggregate(hrn)
422         elif self.api.interface in ['slicemgr']:
423             self.start_slice_smgr(hrn)
424
425     def start_slice_aggregate(self, hrn):
426         slicename = hrn_to_pl_slicename(hrn)
427         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
428         if not slices:
429             raise RecordNotFound(hrn)
430         slice_id = slices[0]
431         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
432         attribute_id = attreibutes[0]['slice_attribute_id']
433         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
434         return 1
435
436     def start_slice_smgr(self, hrn):
437         credential = self.api.getCredential()
438         aggregates = Aggregates(self.api)
439         for aggregate in aggregates:
440             aggregates[aggregate].start_slice(credential, hrn)
441         return 1
442
443
444     def stop_slice(self, hrn):
445         if self.api.interface in ['aggregate']:
446             self.stop_slice_aggregate(hrn)
447         elif self.api.interface in ['slicemgr']:
448             self.stop_slice_smgr(hrn)
449
450     def stop_slice_aggregate(self, hrn):
451         slicename = hrn_to_pl_slicename(hrn)
452         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
453         if not slices:
454             raise RecordNotFound(hrn)
455         slice_id = slices[0]['slice_id']
456         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
457         attribute_id = attributes[0]['slice_attribute_id']
458         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
459         return 1
460
461     def stop_slice_smgr(self, hrn):
462         credential = self.api.getCredential()
463         aggregates = Aggregates(self.api)
464         for aggregate in aggregates:
465             aggregates[aggregate].stop_slice(credential, hrn)  
466