added getCredentialFromReistry() method
[sfa.git] / geni / slicemgr.py
1 import os
2 import sys
3 import datetime
4 import time
5
6 from geni.util.geniserver import *
7 from geni.util.geniclient import *
8 from geni.util.cert import *
9 from geni.util.credential import Credential
10 from geni.util.trustedroot import *
11 from geni.util.excep import *
12 from geni.util.misc import *
13 from geni.util.config import Config
14 from geni.util.rspec import Rspec
15 from geni.util.specdict import *
16 from geni.util.storage import SimpleStorage
17
18 class SliceMgr(GeniServer):
19
20     hrn = None
21     nodes_ttl = None
22     nodes = None
23     slices = None
24     policy = None
25     aggregates = None
26     timestamp = None
27     threshold = None    
28     shell = None
29     registry = None
30     key_file = None
31     cert_file = None
32     credential = None 
33   
34     ##
35     # Create a new slice manager object.
36     #
37     # @param ip the ip address to listen on
38     # @param port the port to listen on
39     # @param key_file private key filename of registry
40     # @param cert_file certificate filename containing public key (could be a GID file)     
41
42     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
43         GeniServer.__init__(self, ip, port, key_file, cert_file)
44         self.key_file = key_file
45         self.cert_file = cert_file
46         self.config = Config(config)
47         self.basedir = self.config.GENI_BASE_DIR + os.sep
48         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
49         self.hrn = self.config.GENI_INTERFACE_HRN    
50         self.time_format = "%Y-%m-%d %H:%M:%S"
51         
52         # Get list of aggregates this sm talks to
53         # XX do we use simplestorage to maintain this file manually?
54         aggregates_file = self.server_basedir + os.sep + 'aggregates'
55         self.aggregates = SimpleStorage(aggregates_file)
56         
57         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
58         self.nodes = SimpleStorage(nodes_file)
59         self.nodes.load()
60         
61         slices_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.slices'])
62         self.slices = SimpleStorage(slices_file)
63         self.slices.load()
64
65         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
66         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
67         self.policy.load()
68
69         # How long before we refresh nodes cache  
70         self.nodes_ttl = 1
71
72         self.connectRegistry()
73         self.loadCredential()
74         self.connectAggregates(aggregates_file)
75
76
77     def loadCredential(self):
78         """
79         Attempt to load credential from file if it exists. If it doesnt get
80         credential from registry.
81         """
82         
83         # see if this file exists
84         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
85         try:
86             self.credential = Credential(filename = ma_cred_filename)
87         except IOError:
88             self.credential = self.getCrednetialFromRegistry()
89             
90         
91     def getCredentialFromRegistry(self):
92         """
93         Get our current credential from the registry.
94         """
95         # get self credential
96         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
97         self_cred = self.registry.get_credential(None, 'ma', self.hrn)
98         self_cred.save_to_file(self_cred_filename, save_parents=True)
99
100         # get ma credential
101         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
102         ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
103         ma_cred.save_to_file(ma_cred_filename, save_parents=True)
104         return ma_cred        
105
106     def connectRegistry(self):
107         """
108         Connect to the registry
109         """
110         # connect to registry using GeniClient
111         address = self.config.GENI_REGISTRY_HOSTNAME
112         port = self.config.GENI_REGISTRY_PORT
113         url = 'http://%(address)s:%(port)s' % locals()
114         self.registry = GeniClient(url, self.key_file, self.cert_file)
115
116     def connectAggregates(self, aggregates_file):
117         """
118         Get info about the aggregates available to us from file and create 
119         an xmlrpc connection to each. If any info is invalid, skip it. 
120         """
121         lines = []
122         try:
123             f = open(aggregates_file, 'r')
124             lines = f.readlines()
125             f.close()
126         except: raise 
127         
128         for line in lines:
129             # Skip comments
130             if line.strip().startswith("#"):
131                 continue
132             line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
133             agg_info = line.split(" ")
134         
135             # skip invalid info
136             if len(agg_info) != 3:
137                 continue
138
139             # create xmlrpc connection using GeniClient
140             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
141             url = 'http://%(address)s:%(port)s' % locals()
142             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
143
144     def item_hrns(self, items):
145         """
146         Take a list of items (components or slices) and return a dictionary where
147         the key is the authoritative hrn and the value is a list of items at that 
148         hrn.
149         """
150         item_hrns = {}
151         agg_hrns = self.aggregates.keys()
152         for agg_hrn in agg_hrns:
153             item_hrns[agg_hrn] = []
154         for item in items:
155             for agg_hrn in agg_hrns:
156                 if item.startswith(agg_hrn):
157                     item_hrns[agg_hrn] = item
158
159         return item_hrns    
160              
161
162     def hostname_to_hrn(self, login_base, hostname):
163         """
164         Convert hrn to plantelab name.
165         """
166         genihostname = "_".join(hostname.split("."))
167         return ".".join([self.hrn, login_base, genihostname])
168
169     def slicename_to_hrn(self, slicename):
170         """
171         Convert hrn to planetlab name.
172         """
173         slicename = slicename.replace("_", ".")
174         return ".".join([self.hrn, slicename])
175
176     def refresh_components(self):
177         """
178         Update the cached list of nodes.
179         """
180     
181         # convert and threshold to ints
182         if self.nodes.has_key('timestamp') and self.nodes['timestamp']:
183             hr_timestamp = self.nodes['timestamp']
184             timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.time_format)))
185             hr_threshold = self.nodes['threshold']
186             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.time_format)))
187         else:
188             timestamp = datetime.datetime.now()
189             hr_timestamp = timestamp.strftime(self.time_format)
190             delta = datetime.timedelta(hours=self.nodes_ttl)
191             threshold = timestamp + delta
192             hr_threshold = threshold.strftime(self.time_format)
193
194         start_time = int(timestamp.strftime("%s"))
195         end_time = int(threshold.strftime("%s"))
196         duration = end_time - start_time
197
198         aggregates = self.aggregates.keys()
199         rspecs = {}
200         networks = []
201         rspec = Rspec()
202         for aggregate in aggregates:
203             try:
204                 # get the rspec from the aggregate
205                 agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
206                 # extract the netspec from each aggregates rspec
207                 rspec.parseString(agg_rspec)
208                 networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
209             except:
210                 # XX print out to some error log
211                 print "Error calling list nodes at aggregate %s" % aggregate
212                 raise    
213   
214         # create the rspec dict
215         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
216         resourceDict = {'Rspec': resources} 
217         # convert rspec dict to xml
218         rspec.parseDict(resourceDict)
219        
220         # filter according to policy
221         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
222
223         # update timestamp and threshold
224         timestamp = datetime.datetime.now()
225         hr_timestamp = timestamp.strftime(self.time_format)
226         delta = datetime.timedelta(hours=self.nodes_ttl)
227         threshold = timestamp + delta
228         hr_threshold = threshold.strftime(self.time_format)
229         
230         nodedict = {'rspec': rspec.toxml(),
231                     'timestamp': hr_timestamp,
232                     'threshold':  hr_threshold}
233
234         self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
235         self.nodes.write()
236
237     def load_policy(self):
238         """
239         Read the list of blacklisted and whitelisted nodes.
240         """
241         self.policy.load()
242  
243     def load_slices(self):
244         """
245         Read current slice instantiation states.
246         """
247         self.slices.load()
248
249
250     def getNodes(self, format = 'rspec'):
251         """
252         Return a list of components managed by this slice manager.
253         """
254         # Reload components list
255         if not self.nodes.has_key('threshold') or not self.nodes['threshold'] or not self.nodes.has_key('timestamp') or not self.nodes['timestamp']:
256             self.refresh_components()
257         else:
258             now = datetime.datetime.now()
259             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self.nodes['threshold'], self.time_format)))
260             if  now > threshold:
261                 self.refresh_components()
262         return self.nodes[format]
263    
264      
265     def getSlices(self):
266         """
267         Return a list of instnatiated managed by this slice manager.
268         """
269         slice_hrns = []
270         for aggregate in self.aggregates:
271             try:
272                 slices = self.aggregates[aggregate].list_slices(self.credential)
273                 slice_hrns.extend(slices)
274             except:
275                 raise
276                 # print to some error log
277                 pass
278
279         return slice_hrns
280
281     def getResources(self, slice_hrn):
282         """
283         Return the current rspec for the specified slice.
284         """
285
286         # request this slices state from all known aggregates
287         rspec = Rspec()
288         rspecdicts = []
289         networks = []
290         for hrn in self.aggregates.keys():
291             # check if the slice has resources at this hrn
292             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
293             rspec.parseString(slice_resources)
294             networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
295             
296         # merge all these rspecs into one
297         start_time = int(datetime.datetime.now().strftime("%s"))
298         end_time = start_time
299         duration = end_time - start_time
300     
301         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
302         resourceDict = {'Rspec': resources}
303         # convert rspec dict to xml
304         rspec.parseDict(resourceDict)
305         # save this slices resources
306         #self.slices[slice_hrn] = rspec.toxml()
307         #self.slices.write()
308          
309         return rspec.toxml()
310  
311     def createSlice(self, cred, slice_hrn, rspec):
312         """
313         Instantiate the specified slice according to whats defined in the rspec.
314         """
315
316         # save slice state locally
317         # we can assume that spec object has been validated so its safer to
318         # save this instead of the unvalidated rspec the user gave us
319         rspec = Rspec()
320         tempspec = Rspec()
321         rspec.parseString(rspec)
322
323         self.slices[slice_hrn] = rspec.toxml()
324         self.slices.write()
325
326         # extract network list from the rspec and create a separate
327         # rspec for each network
328         slicename = self.hrn_to_plcslicename(slice_hrn)
329         specDict = rspec.toDict()
330         start_time = specDict['start_time']
331         end_time = specDict['end_time']
332
333         rspecs = {}
334         # only attempt to extract information about the aggregates we know about
335         for hrn in self.aggregates.keys():
336             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
337             if netspec:
338                 # creat a plc dict 
339                 resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
340                 resourceDict = {'Rspec': resources}
341                 tempspec.parseDict(resourceDict)
342                 rspecs[hrn] = tempspec.toxml()
343
344         # notify the aggregates
345         for hrn in self.rspecs.keys():
346             self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
347             
348         return 1
349
350     def updateSlice(self, slice_hrn, rspec, attributes = []):
351         """
352         Update the specifed slice
353         """
354         self.create_slice(slice_hrn, rspec, attributes)
355     
356     def deleteSlice_(self, slice_hrn):
357         """
358         Remove this slice from all components it was previouly associated with and 
359         free up the resources it was using.
360         """
361         # XX need to get the correct credential
362         cred = self.credential
363         
364         if self.slices.has_key(slice_hrn):
365             self.slices.pop(slice_hrn)
366             self.slices.write()
367
368         for hrn in self.aggregates.keys():
369             self.aggregates[hrn].deleteSlice(cred, slice_hrn)
370
371         return 1
372
373     def startSlice(self, slice_hrn):
374         """
375         Stop the slice at plc.
376         """
377         cred = self.credential
378
379         for hrn in self.aggregates.keys():
380             self.aggregates[hrn].startSlice(cred, slice_hrn)
381         return 1
382
383     def stopSlice(self, slice_hrn):
384         """
385         Stop the slice at plc
386         """
387         cred = self.credential
388         for hrn in self.aggregates.keys():
389             self.aggregates[hrn].startSlice(cred, slice_hrn)
390         return 1
391
392     def resetSlice(self, slice_hrn):
393         """
394         Reset the slice
395         """
396         # XX not yet implemented
397         return 1
398
399     def getPolicy(self):
400         """
401         Return the policy of this slice manager.
402         """
403     
404         return self.policy
405         
406     
407
408 ##############################
409 ## Server methods here for now
410 ##############################
411
412     def list_nodes(self, cred):
413         self.decode_authentication(cred, 'listnodes')
414         return self.getNodes()
415
416     def list_slices(self, cred):
417         self.decode_authentication(cred, 'listslices')
418         return self.getSlices()
419
420     def get_resources(self, cred, hrn):
421         self.decode_authentication(cred, 'listnodes')
422         return self.getResources(hrn)
423
424     def get_ticket(self, cred, hrn, rspec):
425         self.decode_authentication(cred, 'getticket')
426         return self.getTicket(hrn, rspec)
427
428     def get_policy(self, cred):
429         self.decode_authentication(cred, 'getpolicy')
430         return self.getPolicy()
431
432     def create_slice(self, cred, hrn, rspec):
433         self.decode_authentication(cred, 'creatslice')
434         return self.createSlice(cred, hrn, rspec)
435
436     def delete_slice(self, cred, hrn):
437         self.decode_authentication(cred, 'deleteslice')
438         return self.deleteSlice(hrn)
439
440     def start_slice(self, cred, hrn):
441         self.decode_authentication(cred, 'startslice')
442         return self.startSlice(hrn)
443
444     def stop_slice(self, cred, hrn):
445         self.decode_authentication(cred, 'stopslice')
446         return self.stopSlice(hrn)
447
448     def reset_slice(self, cred, hrn):
449         self.decode_authentication(cred, 'resetslice')
450         return self.resetSlice(hrn)
451
452     def register_functions(self):
453         GeniServer.register_functions(self)
454
455         # Aggregate interface methods
456         self.server.register_function(self.list_nodes)
457         self.server.register_function(self.list_slices)
458         self.server.register_function(self.get_resources)
459         self.server.register_function(self.get_policy)
460         self.server.register_function(self.create_slice)
461         self.server.register_function(self.delete_slice)
462         self.server.register_function(self.start_slice)
463         self.server.register_function(self.stop_slice)
464         self.server.register_function(self.reset_slice)
465