cleanup for config - again
[sfa.git] / sfa / plc / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6
7 from sfa.util.misc import *
8 from sfa.util.rspec import *
9 from sfa.util.specdict import *
10 from sfa.util.faults import *
11 from sfa.util.storage import *
12 from sfa.util.policy import Policy
13 from sfa.util.debug import log
14 from sfa.server.aggregate import Aggregates
15 from sfa.server.registry import Registries
16
17 class Slices(SimpleStorage):
18
19     def __init__(self, api, ttl = .5):
20         self.api = api
21         self.ttl = ttl
22         self.threshold = None
23         path = self.api.config.SFA_BASE_DIR
24         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
25         filepath = path + os.sep + filename
26         self.slices_file = filepath
27         SimpleStorage.__init__(self, self.slices_file)
28         self.policy = Policy(self.api)    
29         self.load()
30
31
32     def refresh(self):
33         """
34         Update the cached list of slices
35         """
36         # Reload components list
37         now = datetime.datetime.now()
38         if not self.has_key('threshold') or not self.has_key('timestamp') or \
39            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
40             if self.api.interface in ['aggregate']:
41                 self.refresh_slices_aggregate()
42             elif self.api.interface in ['slicemgr']:
43                 self.refresh_slices_smgr()
44
45     def refresh_slices_aggregate(self):
46         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
47         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
48
49          # update timestamp and threshold
50         timestamp = datetime.datetime.now()
51         hr_timestamp = timestamp.strftime(self.api.time_format)
52         delta = datetime.timedelta(hours=self.ttl)
53         threshold = timestamp + delta
54         hr_threshold = threshold.strftime(self.api.time_format)
55         
56         slice_details = {'hrn': slice_hrns,
57                          'timestamp': hr_timestamp,
58                          'threshold': hr_threshold
59                         }
60         self.update(slice_details)
61         self.write()     
62         
63
64     def refresh_slices_smgr(self):
65         slice_hrns = []
66         aggregates = Aggregates(self.api)
67         credential = self.api.getCredential()
68         for aggregate in aggregates:
69             try:
70                 slices = aggregates[aggregate].get_slices(credential)
71                 slice_hrns.extend(slices)
72             except:
73                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
74          # update timestamp and threshold
75         timestamp = datetime.datetime.now()
76         hr_timestamp = timestamp.strftime(self.api.time_format)
77         delta = datetime.timedelta(hours=self.ttl)
78         threshold = timestamp + delta
79         hr_threshold = threshold.strftime(self.api.time_format)
80
81         slice_details = {'hrn': slice_hrns,
82                          'timestamp': hr_timestamp,
83                          'threshold': hr_threshold
84                         }
85         self.update(slice_details)
86         self.write()
87
88
89     def delete_slice(self, hrn):
90         if self.api.interface in ['aggregate']:
91             self.delete_slice_aggregate(hrn)
92         elif self.api.interface in ['slicemgr']:
93             self.delete_slice_smgr(hrn)
94         
95     def delete_slice_aggregate(self, hrn):
96         slicename = hrn_to_pl_slicename(hrn)
97         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None, 'name': slicename})
98         if not slices:
99             return 1        
100         slice = slices[0]
101
102         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
103         return 1
104
105     def delete_slice_smgr(self, hrn):
106         credential = self.api.getCredential()
107         aggregates = Aggregates(self.api)
108         for aggregate in aggregates:
109             aggregates[aggregate].delete_slice(credential, hrn)
110
111     def create_slice(self, hrn, rspec):
112         # check our slice policy before we procede
113         whitelist = self.policy['slice_whitelist']     
114         blacklist = self.policy['slice_blacklist']
115         
116         if whitelist and hrn not in whitelist or \
117            blacklist and hrn in blacklist:
118             policy_file = self.policy.policy_file
119             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
120             return 1
121         if self.api.interface in ['aggregate']:     
122             self.create_slice_aggregate(hrn, rspec)
123         elif self.api.interface in ['slicemgr']:
124             self.create_slice_smgr(hrn, rspec)
125  
126     def create_slice_aggregate(self, hrn, rspec):    
127         spec = Rspec(rspec)
128         # Get the slice record from geni
129         slice = {}
130         registries = Registries(self.api)
131         registry = registries[self.api.hrn]
132         credential = self.api.getCredential()
133         records = registry.resolve(credential, hrn)
134         for record in records:
135             if record.get_type() in ['slice']:
136                 slice = record.as_dict()
137         if not slice:
138             raise RecordNotFound(hrn)   
139
140         # Make sure slice exists at plc, if it doesnt add it
141         slicename = hrn_to_pl_slicename(hrn)
142         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['node_ids'])
143         if not slices:
144             parts = slicename.split("_")
145             login_base = parts[0]
146             # if site doesnt exist add it
147             sites = self.api.plshell.GetSites(self.api.plauth, [login_base])
148             if not sites:
149                 authority = get_authority(hrn)
150                 site_records = registry.resolve(credential, authority)
151                 site_record = {}
152                 if not site_records:
153                     raise RecordNotFound(authority)
154                 site_record = site_records[0]
155                 site = site_record.as_dict()
156                 
157                  # add the site
158                 site.pop('site_id')
159                 site_id = self.api.plshell.AddSite(self.api.plauth, site)
160             else:
161                 site = sites[0]
162             
163             slice_fields = {}
164             slice_keys = ['name', 'url', 'description']
165             for key in slice_keys:
166                 if key in slice and slice[key]:
167                     slice_fields[key] = slice[key]  
168             self.api.plshell.AddSlice(self.api.plauth, slice_fields)
169             slice = slice_fields
170             slice['node_ids'] = 0
171         else:
172             slice = slices[0]    
173         # get the list of valid slice users from the registry and make 
174         # they are added to the slice 
175         researchers = record.get('researcher', [])
176         for researcher in researchers:
177             person_record = {}
178             person_records = registry.resolve(credential, researcher)
179             for record in person_records:
180                 if record.get_type() in ['user']:
181                     person_record = record
182             if not person_record:
183                 pass
184             person_dict = person_record.as_dict()
185             persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
186
187             # Create the person record 
188             if not persons:
189                 person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
190
191                 # The line below enables the user account on the remote aggregate soon after it is created.
192                 # without this the user key is not transfered to the slice (as GetSlivers returns key of only enabled users),
193                 # which prevents the user from login to the slice. We may do additional checks before enabling the user.
194
195                 self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
196                 key_ids = []
197             else:
198                 key_ids = persons[0]['key_ids']
199
200             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)        
201
202             # Get this users local keys
203             keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
204             keys = [key['key'] for key in keylist]
205
206             # add keys that arent already there 
207             for personkey in person_dict['keys']:
208                 if personkey not in keys:
209                     key = {'key_type': 'ssh', 'key': personkey}
210                     self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
211
212         # find out where this slice is currently running
213         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
214         hostnames = [node['hostname'] for node in nodelist]
215
216         # get netspec details
217         nodespecs = spec.getDictsByTagName('NodeSpec')
218         nodes = []
219         for nodespec in nodespecs:
220             if isinstance(nodespec['name'], list):
221                 nodes.extend(nodespec['name'])
222             elif isinstance(nodespec['name'], StringTypes):
223                 nodes.append(nodespec['name'])
224
225         # remove nodes not in rspec
226         deleted_nodes = list(set(hostnames).difference(nodes))
227         # add nodes from rspec
228         added_nodes = list(set(nodes).difference(hostnames))
229
230         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
231         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
232
233         return 1
234
235     def create_slice_smgr(self, hrn, rspec):
236         spec = Rspec()
237         tempspec = Rspec()
238         spec.parseString(rspec)
239         slicename = hrn_to_pl_slicename(hrn)
240         specDict = spec.toDict()
241         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
242         if specDict.has_key('start_time'): start_time = specDict['start_time']
243         else: start_time = 0
244         if specDict.has_key('end_time'): end_time = specDict['end_time']
245         else: end_time = 0
246
247         rspecs = {}
248         aggregates = Aggregates(self.api)
249         credential = self.api.getCredential()
250         # only attempt to extract information about the aggregates we know about
251         for aggregate in aggregates:
252             netspec = spec.getDictByTagNameValue('NetSpec', aggregate)
253             if netspec:
254                 # creat a plc dict 
255                 resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
256                 resourceDict = {'Rspec': resources}
257                 tempspec.parseDict(resourceDict)
258                 rspecs[aggregate] = tempspec.toxml()
259
260         # notify the aggregates
261         for aggregate in rspecs.keys():
262             try:
263                 aggregates[aggregate].create_slice(credential, hrn, rspecs[aggregate])
264             except:
265                 print >> log, "Error creating slice %(hrn)s at aggregate %(aggregate)s" % locals()
266         return 1
267
268
269     def start_slice(self, hrn):
270         if self.api.interface in ['aggregate']:
271             self.start_slice_aggregate(hrn)
272         elif self.api.interface in ['slicemgr']:
273             self.start_slice_smgr(hrn)
274
275     def start_slice_aggregate(self, hrn):
276         slicename = hrn_to_pl_slicename(hrn)
277         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
278         if not slices:
279             raise RecordNotFound(hrn)
280         slice_id = slices[0]
281         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
282         attribute_id = attreibutes[0]['slice_attribute_id']
283         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
284         return 1
285
286     def start_slice_smgr(self, hrn):
287         credential = self.api.getCredential()
288         aggregates = Aggregates(self.api)
289         for aggregate in aggregates:
290             aggregates[aggregate].start_slice(credential, hrn)
291         return 1
292
293
294     def stop_slice(self, hrn):
295         if self.api.interface in ['aggregate']:
296             self.stop_slice_aggregate(hrn)
297         elif self.api.interface in ['slicemgr']:
298             self.stop_slice_smgr(hrn)
299
300     def stop_slice_aggregate(self, hrn):
301         slicename = hrn_to_pl_slicename(hrn)
302         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
303         if not slices:
304             raise RecordNotFound(hrn)
305         slice_id = slices[0]['slice_id']
306         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
307         attribute_id = attributes[0]['slice_attribute_id']
308         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
309         return 1
310
311     def stop_slice_smgr(self, hrn):
312         credential = self.api.getCredential()
313         aggregates = Aggregates(self.api)
314         for aggregate in aggregates:
315             aggregates[aggregate].stop_slice(credential, hrn)  
316