check for myplc peers in create_slice_aggregate, not create_slice_smgr
[sfa.git] / sfa / plc / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6
7 from sfa.util.misc import *
8 from sfa.util.rspec import *
9 from sfa.util.specdict import *
10 from sfa.util.faults import *
11 from sfa.util.storage import *
12 from sfa.util.policy import Policy
13 from sfa.util.debug import log
14 from sfa.server.aggregate import Aggregates
15 from sfa.server.registry import Registries
16
17 class Slices(SimpleStorage):
18
19     def __init__(self, api, ttl = .5):
20         self.api = api
21         self.ttl = ttl
22         self.threshold = None
23         path = self.api.config.SFA_BASE_DIR
24         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
25         filepath = path + os.sep + filename
26         self.slices_file = filepath
27         SimpleStorage.__init__(self, self.slices_file)
28         self.policy = Policy(self.api)    
29         self.load()
30
31
32     def refresh(self):
33         """
34         Update the cached list of slices
35         """
36         # Reload components list
37         now = datetime.datetime.now()
38         if not self.has_key('threshold') or not self.has_key('timestamp') or \
39            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
40             if self.api.interface in ['aggregate']:
41                 self.refresh_slices_aggregate()
42             elif self.api.interface in ['slicemgr']:
43                 self.refresh_slices_smgr()
44
45     def refresh_slices_aggregate(self):
46         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
47         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
48
49          # update timestamp and threshold
50         timestamp = datetime.datetime.now()
51         hr_timestamp = timestamp.strftime(self.api.time_format)
52         delta = datetime.timedelta(hours=self.ttl)
53         threshold = timestamp + delta
54         hr_threshold = threshold.strftime(self.api.time_format)
55         
56         slice_details = {'hrn': slice_hrns,
57                          'timestamp': hr_timestamp,
58                          'threshold': hr_threshold
59                         }
60         self.update(slice_details)
61         self.write()     
62         
63
64     def refresh_slices_smgr(self):
65         slice_hrns = []
66         aggregates = Aggregates(self.api)
67         credential = self.api.getCredential()
68         for aggregate in aggregates:
69             try:
70                 slices = aggregates[aggregate].get_slices(credential)
71                 slice_hrns.extend(slices)
72             except:
73                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
74          # update timestamp and threshold
75         timestamp = datetime.datetime.now()
76         hr_timestamp = timestamp.strftime(self.api.time_format)
77         delta = datetime.timedelta(hours=self.ttl)
78         threshold = timestamp + delta
79         hr_threshold = threshold.strftime(self.api.time_format)
80
81         slice_details = {'hrn': slice_hrns,
82                          'timestamp': hr_timestamp,
83                          'threshold': hr_threshold
84                         }
85         self.update(slice_details)
86         self.write()
87
88
89     def delete_slice(self, hrn):
90         if self.api.interface in ['aggregate']:
91             self.delete_slice_aggregate(hrn)
92         elif self.api.interface in ['slicemgr']:
93             self.delete_slice_smgr(hrn)
94         
95     def delete_slice_aggregate(self, hrn):
96         slicename = hrn_to_pl_slicename(hrn)
97         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None, 'name': slicename})
98         if not slices:
99             return 1        
100         slice = slices[0]
101
102         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
103         return 1
104
105     def delete_slice_smgr(self, hrn):
106         credential = self.api.getCredential()
107         aggregates = Aggregates(self.api)
108         for aggregate in aggregates:
109             aggregates[aggregate].delete_slice(credential, hrn)
110
111     def create_slice(self, hrn, rspec):
112         
113         # check our slice policy before we procede
114         whitelist = self.policy['slice_whitelist']     
115         blacklist = self.policy['slice_blacklist']
116        
117         if whitelist and hrn not in whitelist or \
118            blacklist and hrn in blacklist:
119             policy_file = self.policy.policy_file
120             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
121             return 1
122
123         
124         if self.api.interface in ['aggregate']:     
125             self.create_slice_aggregate(hrn, rspec)
126         elif self.api.interface in ['slicemgr']:
127             self.create_slice_smgr(hrn, rspec)
128
129     def create_slice_aggregate(self, hrn, rspec):
130         # Becaues of myplc federation,  we first need to determine if this
131         # slice belongs to out local plc or a myplc peer. We will assume it 
132         # is a local site, unless we find out otherwise  
133         peer = None
134         # get this slice's authority (site)
135         slice_authority = get_authority(hrn)
136         # get this site's authority (sfa root authority or sub authority)
137         site_authority = get_authority(slice_authority)
138         # check if we are already peered with this site_authority at ple, if so
139         peers = self.api.plshell.GetPeers(self.api.plauth, {}, ['peer_id', 'peername', 'shortname', 'hrn_root'])
140         for peer_record in peers:
141             if site_authority in peer_record.values():
142                 peer = peer_record['shortname']                                     
143         spec = Rspec(rspec)
144         # Get the slice record from geni
145         slice = {}
146         registries = Registries(self.api)
147         registry = registries[self.api.hrn]
148         credential = self.api.getCredential()
149         records = registry.resolve(credential, hrn)
150         for record in records:
151             if record.get_type() in ['slice']:
152                 slice_record = record.as_dict()
153         if not slice_record:
154             raise RecordNotFound(hrn)   
155
156         # Make sure slice exists at plc, if it doesnt add it
157         slicename = hrn_to_pl_slicename(hrn)
158         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['node_ids'])
159         if not slices:
160             parts = slicename.split("_")
161             login_base = parts[0]
162             # if site doesnt exist add it
163             sites = self.api.plshell.GetSites(self.api.plauth, [login_base])
164             if not sites:
165                 authority = get_authority(hrn)
166                 site_records = registry.resolve(credential, authority)
167                 site_record = {}
168                 if not site_records:
169                     raise RecordNotFound(authority)
170                 site_record = site_records[0]
171                 site = site_record.as_dict()
172                 
173                  # add the site
174                 remote_site_id = site.pop('site_id')
175                 site_id = self.api.plshell.AddSite(self.api.plauth, site)
176                 # this belongs to a peer 
177                 if peer:
178                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
179             else:
180                 site = sites[0]
181             
182             # create slice object
183             slice_fields = {}
184             slice_keys = ['name', 'url', 'description']
185             for key in slice_keys:
186                 if key in slice_record and slice_record[key]:
187                     slice_fields[key] = slice_record[key]
188
189             # add the slice  
190             slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
191             slice = slice_fields
192             #this belongs to a peer
193         
194             if peer:
195                 self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
196             slice['node_ids'] = 0
197         else:
198             slice = slices[0]    
199         # get the list of valid slice users from the registry and make 
200         # they are added to the slice 
201         researchers = record.get('researcher', [])
202         for researcher in researchers:
203             person_record = {}
204             person_records = registry.resolve(credential, researcher)
205             for record in person_records:
206                 if record.get_type() in ['user']:
207                     person_record = record
208             if not person_record:
209                 pass
210             person_dict = person_record.as_dict()
211             persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
212
213             # Create the person record 
214             if not persons:
215                 person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
216
217                 # The line below enables the user account on the remote 
218                 # aggregate soon after it is created. without this the 
219                 # user key is not transfered to the slice (as GetSlivers 
220                 # returns key of only enabled users), which prevents the 
221                 # user from login to the slice. We may do additional checks 
222                 # before enabling the user.
223
224                 self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
225                 if peer:
226                     self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_record['pointer'])
227                 key_ids = []
228             else:
229                 key_ids = persons[0]['key_ids']
230
231             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)        
232             # Get this users local keys
233             keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
234             keys = [key['key'] for key in keylist]
235
236             # add keys that arent already there 
237             for personkey in person_dict['keys']:
238                 if personkey not in keys:
239                     key = {'key_type': 'ssh', 'key': personkey}
240                     if peer:
241                         # XX Need to get the key_id from remote registry somehow 
242                         #self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', None, peer, key_id)   
243                         pass
244                     else:
245                         self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
246
247         # find out where this slice is currently running
248         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
249         hostnames = [node['hostname'] for node in nodelist]
250
251         # get netspec details
252         nodespecs = spec.getDictsByTagName('NodeSpec')
253         nodes = []
254         for nodespec in nodespecs:
255             if isinstance(nodespec['name'], list):
256                 nodes.extend(nodespec['name'])
257             elif isinstance(nodespec['name'], StringTypes):
258                 nodes.append(nodespec['name'])
259
260         # remove nodes not in rspec
261         deleted_nodes = list(set(hostnames).difference(nodes))
262         # add nodes from rspec
263         added_nodes = list(set(nodes).difference(hostnames))
264
265         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
266         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
267
268         return 1
269
270     def create_slice_smgr(self, hrn, rspec):
271         spec = Rspec()
272         tempspec = Rspec()
273         spec.parseString(rspec)
274         slicename = hrn_to_pl_slicename(hrn)
275         specDict = spec.toDict()
276         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
277         if specDict.has_key('start_time'): start_time = specDict['start_time']
278         else: start_time = 0
279         if specDict.has_key('end_time'): end_time = specDict['end_time']
280         else: end_time = 0
281
282         rspecs = {}
283         aggregates = Aggregates(self.api)
284         credential = self.api.getCredential()
285         # only attempt to extract information about the aggregates we know about
286         for aggregate in aggregates:
287             netspec = spec.getDictByTagNameValue('NetSpec', aggregate)
288             if netspec:
289                 # creat a plc dict 
290                 resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
291                 resourceDict = {'Rspec': resources}
292                 tempspec.parseDict(resourceDict)
293                 rspecs[aggregate] = tempspec.toxml()
294
295         # notify the aggregates
296         for aggregate in rspecs.keys():
297             try:
298                 # send the whloe rspec to the local aggregate
299                 if aggregate in [self.api.hrn]:
300                     aggregates[aggregate].create_slice(credential, hrn, rspec)
301                 else:
302                     aggregates[aggregate].create_slice(credential, hrn, rspecs[aggregate])
303             except:
304                 print >> log, "Error creating slice %(hrn)s at aggregate %(aggregate)s" % locals()
305         return 1
306
307
308     def start_slice(self, hrn):
309         if self.api.interface in ['aggregate']:
310             self.start_slice_aggregate(hrn)
311         elif self.api.interface in ['slicemgr']:
312             self.start_slice_smgr(hrn)
313
314     def start_slice_aggregate(self, hrn):
315         slicename = hrn_to_pl_slicename(hrn)
316         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
317         if not slices:
318             raise RecordNotFound(hrn)
319         slice_id = slices[0]
320         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
321         attribute_id = attreibutes[0]['slice_attribute_id']
322         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
323         return 1
324
325     def start_slice_smgr(self, hrn):
326         credential = self.api.getCredential()
327         aggregates = Aggregates(self.api)
328         for aggregate in aggregates:
329             aggregates[aggregate].start_slice(credential, hrn)
330         return 1
331
332
333     def stop_slice(self, hrn):
334         if self.api.interface in ['aggregate']:
335             self.stop_slice_aggregate(hrn)
336         elif self.api.interface in ['slicemgr']:
337             self.stop_slice_smgr(hrn)
338
339     def stop_slice_aggregate(self, hrn):
340         slicename = hrn_to_pl_slicename(hrn)
341         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
342         if not slices:
343             raise RecordNotFound(hrn)
344         slice_id = slices[0]['slice_id']
345         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
346         attribute_id = attributes[0]['slice_attribute_id']
347         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
348         return 1
349
350     def stop_slice_smgr(self, hrn):
351         credential = self.api.getCredential()
352         aggregates = Aggregates(self.api)
353         for aggregate in aggregates:
354             aggregates[aggregate].stop_slice(credential, hrn)  
355