fix bug in GetSlices call
[sfa.git] / geni / util / slices.py
1 import datetime
2 import time
3 from geni.util.misc import *
4 from geni.util.rspec import *
5 from geni.util.specdict import *
6 from geni.util.excep import *
7 from geni.util.storage import *
8 from geni.util.policy import Policy
9 from geni.util.debug import log
10 from geni.aggregate import Aggregates
11 from geni.registry import Registries
12
13 class Slices(SimpleStorage):
14
15     def __init__(self, api, ttl = .5):
16         self.api = api
17         self.ttl = ttl
18         self.threshold = None
19         path = self.api.config.basepath
20         filename = ".".join([self.api.interface, self.api.hrn, "slices"])
21         filepath = path + os.sep + filename
22         self.slices_file = filepath
23         SimpleStorage.__init__(self, self.slices_file)
24         self.policy = Policy(self.api)    
25         self.load()
26
27
28     def refresh(self):
29         """
30         Update the cached list of slices
31         """
32         # Reload components list
33         now = datetime.datetime.now()
34         if not self.has_key('threshold') or not self.has_key('timestamp') or \
35            now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
36             if self.api.interface in ['aggregate']:
37                 self.refresh_slices_aggregate()
38             elif self.api.interface in ['slicemgr']:
39                 self.refresh_slices_smgr()
40
41     def refresh_slices_aggregate(self):
42         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
43         slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
44
45          # update timestamp and threshold
46         timestamp = datetime.datetime.now()
47         hr_timestamp = timestamp.strftime(self.api.time_format)
48         delta = datetime.timedelta(hours=self.ttl)
49         threshold = timestamp + delta
50         hr_threshold = threshold.strftime(self.api.time_format)
51         
52         slice_details = {'hrn': slice_hrns,
53                          'timestamp': hr_timestamp,
54                          'threshold': hr_threshold
55                         }
56         self.update(slice_details)
57         self.write()     
58         
59
60     def refresh_slices_smgr(self):
61         slice_hrns = []
62         aggregates = Aggregates(self.api)
63         credential = self.api.getCredential()
64         for aggregate in aggregates:
65             try:
66                 slices = aggregates[aggregate].get_slices(credential)
67                 slice_hrns.extend(slices)
68             except:
69                 print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
70          # update timestamp and threshold
71         timestamp = datetime.datetime.now()
72         hr_timestamp = timestamp.strftime(self.api.time_format)
73         delta = datetime.timedelta(hours=self.ttl)
74         threshold = timestamp + delta
75         hr_threshold = threshold.strftime(self.api.time_format)
76
77         slice_details = {'hrn': slice_hrns,
78                          'timestamp': hr_timestamp,
79                          'threshold': hr_threshold
80                         }
81         self.update(slice_details)
82         self.write()
83
84
85     def delete_slice(self, hrn):
86         if self.api.interface in ['aggregate']:
87             self.delete_slice_aggregate(hrn)
88         elif self.api.interface in ['slicemgr']:
89             self.delete_slice_smgr(hrn)
90         
91     def delete_slice_aggregate(self, hrn):
92         slicename = hrn_to_pl_slicename(hrn)
93         slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None, 'name': slicename})
94         if not slices:
95             return 1        
96         slice = slices[0]
97
98         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, slice['node_ids'])
99         return 1
100
101     def delete_slice_smgr(self, hrn):
102         credential = self.api.getCredential()
103         aggregates = Aggregates(self.api)
104         for aggregate in aggregates:
105             print aggregate
106             aggregates[aggregate].delete_slice(credential, hrn)
107
108     def create_slice(self, hrn, rspec):
109         # check our slice policy before we procede
110         whitelist = self.policy['slice_whitelist']     
111         blacklist = self.policy['slice_blacklist']
112         
113         if whitelist and hrn not in whitelist or \
114            blacklist and hrn in blacklist:
115             policy_file = self.policy.policy_file
116             print >> log, "Slice %(hrn)s not allowed by policy %(policy_file)s" % locals()
117             return 1
118         if self.api.interface in ['aggregate']:     
119             self.create_slice_aggregate(hrn, rspec)
120         elif self.api.interface in ['slicemgr']:
121             self.create_slice_smgr(hrn, rspec)
122  
123     def create_slice_aggregate(self, hrn, rspec):    
124         spec = Rspec(rspec)
125         # Get the slice record from geni
126         slice = {}
127         registries = Registries(self.api)
128         registry = registries[self.api.hrn]
129         credential = self.api.getCredential()
130         records = registry.resolve(credential, hrn)
131         for record in records:
132             if record.get_type() in ['slice']:
133                 slice = record.as_dict()
134         if not slice:
135             raise RecordNotFound(slice_hrn)   
136
137         # Make sure slice exists at plc, if it doesnt add it
138         slicename = hrn_to_pl_slicename(hrn)
139         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename], ['node_ids'])
140         if not slices:
141             parts = slicename.split("_")
142             login_base = parts[0]
143             # if site doesnt exist add it
144             sites = self.api.plshell.GetSites(self.api.plauth, [login_base])
145             if not sites:
146                 authority = get_authority(hrn)
147                 site_records = registry.resolve(credential, authority)
148                 site_record = {}
149                 if not site_records:
150                     raise RecordNotFound(authority)
151                 site_record = site_records[0]
152                 site = site_record.as_dict()
153                 
154                  # add the site
155                 site.pop('site_id')
156                 site_id = self.api.plshell.AddSite(self.api.plauth, site)
157             else:
158                 site = sites[0]
159
160             self.api.plshell.AddSlice(self.api.plauth, slice)
161
162         # get the list of valid slice users from the registry and make 
163         # they are added to the slice 
164         researchers = slice.get('researcher', [])
165         for researcher in researchers:
166             person_record = {}
167             person_records = registry.resolve(credential, researcher)
168             for record in person_records:
169                 if record.get_type() in ['user']:
170                     person_record = record
171             if not person_record:
172                 pass
173             person_dict = person_record.as_dict()
174             persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
175
176             # Create the person record 
177             if not persons:
178                 self.api.plshell.AddPerson(self.api.plauth, person_dict)
179                 key_ids = []
180             else:
181                 key_ids = persons[0]['key_ids']
182
183             self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)        
184
185             # Get this users local keys
186             keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
187             keys = [key['key'] for key in keylist]
188
189             # add keys that arent already there 
190             for personkey in person_dict['keys']:
191                 if personkey not in keys:
192                     key = {'key_type': 'ssh', 'key': personkey}
193                     self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
194
195         # find out where this slice is currently running
196         nodelist = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
197         hostnames = [node['hostname'] for node in nodelist]
198
199         # get netspec details
200         nodespecs = spec.getDictsByTagName('NodeSpec')
201         nodes = []
202         for nodespec in nodespecs:
203             if isinstance(nodespec['name'], list):
204                 nodes.extend(nodespec['name'])
205             elif isinstance(nodespec['name'], StringTypes):
206                 nodes.append(nodespec['name'])
207
208         # remove nodes not in rspec
209         deleted_nodes = list(set(hostnames).difference(nodes))
210         # add nodes from rspec
211         added_nodes = list(set(nodes).difference(hostnames))
212
213         self.api.plshell.AddSliceToNodes(self.api.plauth, slicename, added_nodes) 
214         self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slicename, deleted_nodes)
215
216         return 1
217
218     def create_slice_smgr(self, hrn, rspec):
219         spec = Rspec()
220         tempspec = Rspec()
221         spec.parseString(rspec)
222         slicename = hrn_to_pl_slicename(hrn)
223         specDict = spec.toDict()
224         if specDict.has_key('Rspec'): specDict = specDict['Rspec']
225         if specDict.has_key('start_time'): start_time = specDict['start_time']
226         else: start_time = 0
227         if specDict.has_key('end_time'): end_time = specDict['end_time']
228         else: end_time = 0
229
230         rspecs = {}
231         aggregates = Aggregates(self.api)
232         credential = self.api.getCredential()
233         # only attempt to extract information about the aggregates we know about
234         for aggregate in aggregates:
235             netspec = spec.getDictByTagNameValue('NetSpec', aggregate)
236             if netspec:
237                 # creat a plc dict 
238                 resources = {'start_time': start_time, 'end_time': end_time, 'networks': netspec}
239                 resourceDict = {'Rspec': resources}
240                 tempspec.parseDict(resourceDict)
241                 rspecs[aggregate] = tempspec.toxml()
242
243         # notify the aggregates
244         for aggregate in rspecs.keys():
245             try:
246                 aggregates[aggregate].create_slice(credential, hrn, rspecs[aggregate])
247             except:
248                 print >> log, "Error creating slice %(hrn)% at aggregate %(aggregate)%" % locals()
249     
250         return 1
251
252
253     def start_slice(self, hrn):
254         if self.api.interface in ['aggregate']:
255             self.start_slice_aggregate(hrn)
256         elif self.api.interface in ['slicemgr']:
257             self.start_slice_smgr(hrn)
258
259     def start_slice_aggregate(self, hrn):
260         slicename = hrn_to_pl_slicename(hrn)
261         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
262         if not slices:
263             raise RecordNotFound(hrn)
264         slice_id = slices[0]
265         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
266         attribute_id = attreibutes[0]['slice_attribute_id']
267         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "1" )
268         return 1
269
270     def start_slice_smgr(self, hrn):
271         credential = self.api.getCredential()
272         aggregates = Aggregates(self.api)
273         for aggregate in aggregates:
274             aggregates[aggregate].start_slice(credential, hrn)
275         return 1
276
277
278     def stop_slice(self, hrn):
279         if self.api.interface in ['aggregate']:
280             self.stop_slice_aggregate(hrn)
281         elif self.api.interface in ['slicemgr']:
282             self.stop_slice_smgr(hrn)
283
284     def stop_slice_aggregate(self, hrn):
285         slicename = hrn_to_pl_slicename(hrn)
286         slices = self.api.plshell.GetSlices(self.api.plauth, {'name': slicename}, ['slice_id'])
287         if not slices:
288             raise RecordNotFound(hrn)
289         slice_id = slices[0]['slice_id']
290         attributes = self.api.plshell.GetSliceAttributes(self.api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
291         attribute_id = attributes[0]['slice_attribute_id']
292         self.api.plshell.UpdateSliceAttribute(self.api.plauth, attribute_id, "0")
293         return 1
294
295     def stop_slice_smgr(self, hrn):
296         credential = self.api.getCredential()
297         aggregates = Aggregates(self.api)
298         for aggregate in aggregates:
299             aggregates[aggregate].stop_slice(credential, hrn)  
300