use BindObjectToPeer when necessary
[sfa.git] / sfa / plc / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6
7 from sfa.util.misc import *
8 from sfa.util.rspec import *
9 from sfa.util.specdict import *
10 from sfa.util.faults import *
11 from sfa.util.storage import *
12 from sfa.util.policy import Policy
13 from sfa.util.debug import log
14 from sfa.server.aggregate import Aggregates
15 from sfa.server.registry import Registries
16
17 class Slices(SimpleStorage):
18
19     def __init__(self, api, ttl = .5):
20         self.api = api
21         self.ttl = ttl
22         self.threshold = None
23         path = self.api.config.SFA_BASE_DIR
24         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
25         filepath = path + os.sep + filename
26         self.slices_file = filepath
27         SimpleStorage.__init__(self, self.slices_file)
28         self.policy = Policy(self.api)    
29         self.load()
30
31
32     def refresh(self):
33         """
34         Update the cached list of slices
35         """
36         # Reload components list
37         now = datetime.datetime.now()
38         if not self.has_key('threshold') or not self.has_key('timestamp') or \
39            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
40             if self.api.interface in ['aggregate']:
41                 self.refresh_slices_aggregate()
42             elif self.api.interface in ['slicemgr']:
43                 self.refresh_slices_smgr()
44
45     def refresh_slices_aggregate(self):
46         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
47         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
48
49          # update timestamp and threshold
50         timestamp = datetime.datetime.now()
51         hr_timestamp = timestamp.strftime(self.api.time_format)
52         delta = datetime.timedelta(hours=self.ttl)
53         threshold = timestamp + delta
54         hr_threshold = threshold.strftime(self.api.time_format)
55         
56         slice_details = {'hrn': slice_hrns,
57                          'timestamp': hr_timestamp,
58                          'threshold': hr_threshold
59                         }
60         self.update(slice_details)
61         self.write()     
62         
63
64     def refresh_slices_smgr(self):
65         slice_hrns = []
66         aggregates = Aggregates(self.api)
67         credential = self.api.getCredential()
68         for aggregate in aggregates:
69             try:
70                 slices = aggregates[aggregate].get_slices(credential)
71                 slice_hrns.extend(slices)
72             except:
73                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
74          # update timestamp and threshold
75         timestamp = datetime.datetime.now()
76         hr_timestamp = timestamp.strftime(self.api.time_format)
77         delta = datetime.timedelta(hours=self.ttl)
78         threshold = timestamp + delta
79         hr_threshold = threshold.strftime(self.api.time_format)
80
81         slice_details = {'hrn': slice_hrns,
82                          'timestamp': hr_timestamp,
83                          'threshold': hr_threshold
84                         }
85         self.update(slice_details)
86         self.write()
87
88
89     def delete_slice(self, hrn):
90         if self.api.interface in ['aggregate']:
91             self.delete_slice_aggregate(hrn)
92         elif self.api.interface in ['slicemgr']:
93             self.delete_slice_smgr(hrn)
94         
95     def delete_slice_aggregate(self, hrn):
96         slicename = hrn_to_pl_slicename(hrn)
97         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None, 'name': slicename})
98         if not slices:
99             return 1        
100         slice = slices[0]
101
102         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
103         return 1
104
105     def delete_slice_smgr(self, hrn):
106         credential = self.api.getCredential()
107         aggregates = Aggregates(self.api)
108         for aggregate in aggregates:
109             aggregates[aggregate].delete_slice(credential, hrn)
110
111     def create_slice(self, hrn, rspec):
112         # check our slice policy before we procede
113         whitelist = self.policy['slice_whitelist']     
114         blacklist = self.policy['slice_blacklist']
115         
116         if whitelist and hrn not in whitelist or \
117            blacklist and hrn in blacklist:
118             policy_file = self.policy.policy_file
119             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
120             return 1
121         if self.api.interface in ['aggregate']:     
122             self.create_slice_aggregate(hrn, rspec)
123         elif self.api.interface in ['slicemgr']:
124             self.create_slice_smgr(hrn, rspec)
125
126     def create_slice_aggregate(self, hrn, rspec, peer = None):    
127         spec = Rspec(rspec)
128         # Get the slice record from geni
129         slice = {}
130         registries = Registries(self.api)
131         registry = registries[self.api.hrn]
132         credential = self.api.getCredential()
133         records = registry.resolve(credential, hrn)
134         for record in records:
135             if record.get_type() in ['slice']:
136                 slice_record = record.as_dict()
137         if not slice_record:
138             raise RecordNotFound(hrn)   
139
140         # Make sure slice exists at plc, if it doesnt add it
141         slicename = hrn_to_pl_slicename(hrn)
142         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['node_ids'])
143         if not slices:
144             parts = slicename.split("_")
145             login_base = parts[0]
146             # if site doesnt exist add it
147             sites = self.api.plshell.GetSites(self.api.plauth, [login_base])
148             if not sites:
149                 authority = get_authority(hrn)
150                 site_records = registry.resolve(credential, authority)
151                 site_record = {}
152                 if not site_records:
153                     raise RecordNotFound(authority)
154                 site_record = site_records[0]
155                 site = site_record.as_dict()
156                 
157                  # add the site
158                 remote_site_id = site.pop('site_id')
159                 site_id = self.api.plshell.AddSite(self.api.plauth, site)
160                 # this belongs to a peer 
161                 if peer:
162                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
163             else:
164                 site = sites[0]
165             
166             # create slice object
167             slice_fields = {}
168             slice_keys = ['name', 'url', 'description']
169             for key in slice_keys:
170                 if key in slice and slice[key]:
171                     slice_fields[key] = slice[key]
172
173             # add the slice  
174             slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
175             slice = slice_fields
176             #this belongs to a peer
177             if peer:
178                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
179             slice['node_ids'] = 0
180         else:
181             slice = slices[0]    
182         # get the list of valid slice users from the registry and make 
183         # they are added to the slice 
184         researchers = record.get('researcher', [])
185         for researcher in researchers:
186             person_record = {}
187             person_records = registry.resolve(credential, researcher)
188             for record in person_records:
189                 if record.get_type() in ['user']:
190                     person_record = record
191             if not person_record:
192                 pass
193             person_dict = person_record.as_dict()
194             persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
195
196             # Create the person record 
197             if not persons:
198                 person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
199
200                 # The line below enables the user account on the remote 
201                 # aggregate soon after it is created. without this the 
202                 # user key is not transfered to the slice (as GetSlivers 
203                 # returns key of only enabled users), which prevents the 
204                 # user from login to the slice. We may do additional checks 
205                 # before enabling the user.
206
207                 self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
208                 if peer:
209                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_record['pointer'])
210                 key_ids = []
211             else:
212                 key_ids = persons[0]['key_ids']
213
214             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)        
215             # Get this users local keys
216             keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
217             keys = [key['key'] for key in keylist]
218
219             # add keys that arent already there 
220             for personkey in person_dict['keys']:
221                 if personkey not in keys:
222                     key = {'key_type': 'ssh', 'key': personkey}
223                     if peer:
224                         # XX Need to get the key_id from remote registry somehow 
225                         #self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', None, peer, key_id)   
226                         pass
227                     else:
228                         self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
229
230         # find out where this slice is currently running
231         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
232         hostnames = [node['hostname'] for node in nodelist]
233
234         # get netspec details
235         nodespecs = spec.getDictsByTagName('NodeSpec')
236         nodes = []
237         for nodespec in nodespecs:
238             if isinstance(nodespec['name'], list):
239                 nodes.extend(nodespec['name'])
240             elif isinstance(nodespec['name'], StringTypes):
241                 nodes.append(nodespec['name'])
242
243         # remove nodes not in rspec
244         deleted_nodes = list(set(hostnames).difference(nodes))
245         # add nodes from rspec
246         added_nodes = list(set(nodes).difference(hostnames))
247
248         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
249         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
250
251         return 1
252
253     def create_slice_smgr(self, hrn, rspec):
254         spec = Rspec()
255         tempspec = Rspec()
256         spec.parseString(rspec)
257         slicename = hrn_to_pl_slicename(hrn)
258         specDict = spec.toDict()
259         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
260         if specDict.has_key('start_time'): start_time = specDict['start_time']
261         else: start_time = 0
262         if specDict.has_key('end_time'): end_time = specDict['end_time']
263         else: end_time = 0
264
265         rspecs = {}
266         aggregates = Aggregates(self.api)
267         credential = self.api.getCredential()
268         # only attempt to extract information about the aggregates we know about
269         for aggregate in aggregates:
270             netspec = spec.getDictByTagNameValue('NetSpec', aggregate)
271             if netspec:
272                 # creat a plc dict 
273                 resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
274                 resourceDict = {'Rspec': resources}
275                 tempspec.parseDict(resourceDict)
276                 rspecs[aggregate] = tempspec.toxml()
277
278         # notify the aggregates
279         for aggregate in rspecs.keys():
280             # we are already federated with this aggregate using plc federation, 
281             # we must pass our peer name to that aggregate so they can call BindObjectToPeer
282             local_peer_name = None
283             peers = self.api.plshell.GetPeers(self.api.plauth, {}, ['peername', 'shortname', 'hrn_root'])
284             for peer in peers:
285                 names = peer.values()
286                 if aggregate in names:
287                     local_peer_name = self.api.hrn
288                           
289             try:
290                 # send the whloe rspec to the local aggregate
291                 if aggregate in [self.api.hrn]:
292                     aggregates[aggregate].create_slice(credential, hrn, rspec, local_peer_name)
293                 else:
294                     aggregates[aggregate].create_slice(credential, hrn, rspecs[aggregate], local_peer_name)
295             except:
296                 print >> log, "Error creating slice %(hrn)s at aggregate %(aggregate)s" % locals()
297         return 1
298
299
300     def start_slice(self, hrn):
301         if self.api.interface in ['aggregate']:
302             self.start_slice_aggregate(hrn)
303         elif self.api.interface in ['slicemgr']:
304             self.start_slice_smgr(hrn)
305
306     def start_slice_aggregate(self, hrn):
307         slicename = hrn_to_pl_slicename(hrn)
308         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
309         if not slices:
310             raise RecordNotFound(hrn)
311         slice_id = slices[0]
312         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
313         attribute_id = attreibutes[0]['slice_attribute_id']
314         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
315         return 1
316
317     def start_slice_smgr(self, hrn):
318         credential = self.api.getCredential()
319         aggregates = Aggregates(self.api)
320         for aggregate in aggregates:
321             aggregates[aggregate].start_slice(credential, hrn)
322         return 1
323
324
325     def stop_slice(self, hrn):
326         if self.api.interface in ['aggregate']:
327             self.stop_slice_aggregate(hrn)
328         elif self.api.interface in ['slicemgr']:
329             self.stop_slice_smgr(hrn)
330
331     def stop_slice_aggregate(self, hrn):
332         slicename = hrn_to_pl_slicename(hrn)
333         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
334         if not slices:
335             raise RecordNotFound(hrn)
336         slice_id = slices[0]['slice_id']
337         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
338         attribute_id = attributes[0]['slice_attribute_id']
339         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
340         return 1
341
342     def stop_slice_smgr(self, hrn):
343         credential = self.api.getCredential()
344         aggregates = Aggregates(self.api)
345         for aggregate in aggregates:
346             aggregates[aggregate].stop_slice(credential, hrn)  
347