use self.get_auth_info
[sfa.git] / sfa / util / slices.py
1 ### $Id$
2 ### $URL$
3
4 import datetime
5 import time
6
7 from sfa.util.misc import *
8 from sfa.util.rspec import *
9 from sfa.util.specdict import *
10 from sfa.util.faults import *
11 from sfa.util.storage import *
12 from sfa.util.policy import Policy
13 from sfa.util.debug import log
14 from sfa.server.aggregate import Aggregates
15 from sfa.server.registry import Registries
16
17 class Slices(SimpleStorage):
18
19     def __init__(self, api, ttl = .5):
20         self.api = api
21         self.ttl = ttl
22         self.threshold = None
23         path = self.api.config.basepath
24         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
25         filepath = path + os.sep + filename
26         self.slices_file = filepath
27         SimpleStorage.__init__(self, self.slices_file)
28         self.policy = Policy(self.api)    
29         self.load()
30
31
32     def refresh(self):
33         """
34         Update the cached list of slices
35         """
36         # Reload components list
37         now = datetime.datetime.now()
38         if not self.has_key('threshold') or not self.has_key('timestamp') or \
39            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
40             if self.api.interface in ['aggregate']:
41                 self.refresh_slices_aggregate()
42             elif self.api.interface in ['slicemgr']:
43                 self.refresh_slices_smgr()
44
45     def refresh_slices_aggregate(self):
46         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
47         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
48
49          # update timestamp and threshold
50         timestamp = datetime.datetime.now()
51         hr_timestamp = timestamp.strftime(self.api.time_format)
52         delta = datetime.timedelta(hours=self.ttl)
53         threshold = timestamp + delta
54         hr_threshold = threshold.strftime(self.api.time_format)
55         
56         slice_details = {'hrn': slice_hrns,
57                          'timestamp': hr_timestamp,
58                          'threshold': hr_threshold
59                         }
60         self.update(slice_details)
61         self.write()     
62         
63
64     def refresh_slices_smgr(self):
65         slice_hrns = []
66         aggregates = Aggregates(self.api)
67         credential = self.api.getCredential()
68         for aggregate in aggregates:
69             try:
70                 slices = aggregates[aggregate].get_slices(credential)
71                 slice_hrns.extend(slices)
72             except:
73                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
74          # update timestamp and threshold
75         timestamp = datetime.datetime.now()
76         hr_timestamp = timestamp.strftime(self.api.time_format)
77         delta = datetime.timedelta(hours=self.ttl)
78         threshold = timestamp + delta
79         hr_threshold = threshold.strftime(self.api.time_format)
80
81         slice_details = {'hrn': slice_hrns,
82                          'timestamp': hr_timestamp,
83                          'threshold': hr_threshold
84                         }
85         self.update(slice_details)
86         self.write()
87
88
89     def delete_slice(self, hrn):
90         if self.api.interface in ['aggregate']:
91             self.delete_slice_aggregate(hrn)
92         elif self.api.interface in ['slicemgr']:
93             self.delete_slice_smgr(hrn)
94         
95     def delete_slice_aggregate(self, hrn):
96         slicename = hrn_to_pl_slicename(hrn)
97         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None, 'name': slicename})
98         if not slices:
99             return 1        
100         slice = slices[0]
101
102         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
103         return 1
104
105     def delete_slice_smgr(self, hrn):
106         credential = self.api.getCredential()
107         aggregates = Aggregates(self.api)
108         for aggregate in aggregates:
109             aggregates[aggregate].delete_slice(credential, hrn)
110
111     def create_slice(self, hrn, rspec):
112         # check our slice policy before we procede
113         whitelist = self.policy['slice_whitelist']     
114         blacklist = self.policy['slice_blacklist']
115         
116         if whitelist and hrn not in whitelist or \
117            blacklist and hrn in blacklist:
118             policy_file = self.policy.policy_file
119             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
120             return 1
121         if self.api.interface in ['aggregate']:     
122             self.create_slice_aggregate(hrn, rspec)
123         elif self.api.interface in ['slicemgr']:
124             self.create_slice_smgr(hrn, rspec)
125  
126     def create_slice_aggregate(self, hrn, rspec):    
127         spec = Rspec(rspec)
128         # Get the slice record from geni
129         slice = {}
130         registries = Registries(self.api)
131         registry = registries[self.api.hrn]
132         credential = self.api.getCredential()
133         records = registry.resolve(credential, hrn)
134         for record in records:
135             if record.get_type() in ['slice']:
136                 slice = record.as_dict()
137         if not slice:
138             raise RecordNotFound(hrn)   
139
140         # Make sure slice exists at plc, if it doesnt add it
141         slicename = hrn_to_pl_slicename(hrn)
142         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['node_ids'])
143         if not slices:
144             parts = slicename.split("_")
145             login_base = parts[0]
146             # if site doesnt exist add it
147             sites = self.api.plshell.GetSites(self.api.plauth, [login_base])
148             if not sites:
149                 authority = get_authority(hrn)
150                 site_records = registry.resolve(credential, authority)
151                 site_record = {}
152                 if not site_records:
153                     raise RecordNotFound(authority)
154                 site_record = site_records[0]
155                 site = site_record.as_dict()
156                 
157                  # add the site
158                 site.pop('site_id')
159                 site_id = self.api.plshell.AddSite(self.api.plauth, site)
160             else:
161                 site = sites[0]
162             
163             slice_fields = {}
164             slice_keys = ['name', 'url', 'description']
165             for key in slice_keys:
166                 if key in slice and slice[key]:
167                     slice_fields[key] = slice[key]  
168             self.api.plshell.AddSlice(self.api.plauth, slice_fields)
169             slice = slice_fields
170             slice['node_ids'] = 0
171         else:
172             slice = slices[0]    
173         # get the list of valid slice users from the registry and make 
174         # they are added to the slice 
175         researchers = record.get('researcher', [])
176         for researcher in researchers:
177             person_record = {}
178             person_records = registry.resolve(credential, researcher)
179             for record in person_records:
180                 if record.get_type() in ['user']:
181                     person_record = record
182             if not person_record:
183                 pass
184             person_dict = person_record.as_dict()
185             persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
186
187             # Create the person record 
188             if not persons:
189                 self.api.plshell.AddPerson(self.api.plauth, person_dict)
190                 key_ids = []
191             else:
192                 key_ids = persons[0]['key_ids']
193
194             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)        
195
196             # Get this users local keys
197             keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
198             keys = [key['key'] for key in keylist]
199
200             # add keys that arent already there 
201             for personkey in person_dict['keys']:
202                 if personkey not in keys:
203                     key = {'key_type': 'ssh', 'key': personkey}
204                     self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
205
206         # find out where this slice is currently running
207         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
208         hostnames = [node['hostname'] for node in nodelist]
209
210         # get netspec details
211         nodespecs = spec.getDictsByTagName('NodeSpec')
212         nodes = []
213         for nodespec in nodespecs:
214             if isinstance(nodespec['name'], list):
215                 nodes.extend(nodespec['name'])
216             elif isinstance(nodespec['name'], StringTypes):
217                 nodes.append(nodespec['name'])
218
219         # remove nodes not in rspec
220         deleted_nodes = list(set(hostnames).difference(nodes))
221         # add nodes from rspec
222         added_nodes = list(set(nodes).difference(hostnames))
223
224         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
225         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
226
227         return 1
228
229     def create_slice_smgr(self, hrn, rspec):
230         spec = Rspec()
231         tempspec = Rspec()
232         spec.parseString(rspec)
233         slicename = hrn_to_pl_slicename(hrn)
234         specDict = spec.toDict()
235         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
236         if specDict.has_key('start_time'): start_time = specDict['start_time']
237         else: start_time = 0
238         if specDict.has_key('end_time'): end_time = specDict['end_time']
239         else: end_time = 0
240
241         rspecs = {}
242         aggregates = Aggregates(self.api)
243         credential = self.api.getCredential()
244         # only attempt to extract information about the aggregates we know about
245         for aggregate in aggregates:
246             netspec = spec.getDictByTagNameValue('NetSpec', aggregate)
247             if netspec:
248                 # creat a plc dict 
249                 resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
250                 resourceDict = {'Rspec': resources}
251                 tempspec.parseDict(resourceDict)
252                 rspecs[aggregate] = tempspec.toxml()
253
254         # notify the aggregates
255         for aggregate in rspecs.keys():
256             try:
257                 aggregates[aggregate].create_slice(credential, hrn, rspecs[aggregate])
258             except:
259                 print >> log, "Error creating slice %(hrn)s at aggregate %(aggregate)s" % locals()
260         return 1
261
262
263     def start_slice(self, hrn):
264         if self.api.interface in ['aggregate']:
265             self.start_slice_aggregate(hrn)
266         elif self.api.interface in ['slicemgr']:
267             self.start_slice_smgr(hrn)
268
269     def start_slice_aggregate(self, hrn):
270         slicename = hrn_to_pl_slicename(hrn)
271         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
272         if not slices:
273             raise RecordNotFound(hrn)
274         slice_id = slices[0]
275         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
276         attribute_id = attreibutes[0]['slice_attribute_id']
277         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
278         return 1
279
280     def start_slice_smgr(self, hrn):
281         credential = self.api.getCredential()
282         aggregates = Aggregates(self.api)
283         for aggregate in aggregates:
284             aggregates[aggregate].start_slice(credential, hrn)
285         return 1
286
287
288     def stop_slice(self, hrn):
289         if self.api.interface in ['aggregate']:
290             self.stop_slice_aggregate(hrn)
291         elif self.api.interface in ['slicemgr']:
292             self.stop_slice_smgr(hrn)
293
294     def stop_slice_aggregate(self, hrn):
295         slicename = hrn_to_pl_slicename(hrn)
296         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
297         if not slices:
298             raise RecordNotFound(hrn)
299         slice_id = slices[0]['slice_id']
300         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
301         attribute_id = attributes[0]['slice_attribute_id']
302         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
303         return 1
304
305     def stop_slice_smgr(self, hrn):
306         credential = self.api.getCredential()
307         aggregates = Aggregates(self.api)
308         for aggregate in aggregates:
309             aggregates[aggregate].stop_slice(credential, hrn)  
310