dac3856fc404f7a3355713752721b6346daef2bd
[sfa.git] / geni / slicemgr.py
1 import os
2 import sys
3 import datetime
4 import time
5
6 from geni.util.geniserver import *
7 from geni.util.geniclient import *
8 from geni.util.cert import *
9 from geni.util.credential import Credential
10 from geni.util.trustedroot import *
11 from geni.util.excep import *
12 from geni.util.misc import *
13 from geni.util.config import Config
14 from geni.util.rspec import Rspec
15 from geni.util.specdict import *
16 from geni.util.storage import SimpleStorage
17
18 class SliceMgr(GeniServer):
19
20     hrn = None
21     nodes_ttl = None
22     nodes = None
23     slices = None
24     policy = None
25     aggregates = None
26     timestamp = None
27     threshold = None    
28     shell = None
29     registry = None
30     key_file = None
31     cert_file = None
32     credential = None 
33   
34     ##
35     # Create a new slice manager object.
36     #
37     # @param ip the ip address to listen on
38     # @param port the port to listen on
39     # @param key_file private key filename of registry
40     # @param cert_file certificate filename containing public key (could be a GID file)     
41
42     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
43         GeniServer.__init__(self, ip, port, key_file, cert_file)
44         self.key_file = key_file
45         self.cert_file = cert_file
46         self.config = Config(config)
47         self.basedir = self.config.GENI_BASE_DIR + os.sep
48         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
49         self.hrn = self.config.GENI_INTERFACE_HRN    
50         self.time_format = "%Y-%m-%d %H:%M:%S"
51         
52         # Get list of aggregates this sm talks to
53         # XX do we use simplestorage to maintain this file manually?
54         aggregates_file = self.server_basedir + os.sep + 'aggregates'
55         self.aggregates = SimpleStorage(aggregates_file)
56         
57         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
58         self.nodes = SimpleStorage(nodes_file)
59         self.nodes.load()
60         
61         slices_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.slices'])
62         self.slices = SimpleStorage(slices_file)
63         self.slices.load()
64
65         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
66         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
67         self.policy.load()
68
69         # How long before we refresh nodes cache  
70         self.nodes_ttl = 1
71
72         self.connectRegistry()
73         self.loadCredential()
74         self.connectAggregates(aggregates_file)
75
76
77     def loadCredential(self):
78         """
79         Attempt to load credential from file if it exists. If it doesnt get
80         credential from registry.
81         """
82
83         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
84         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
85         
86         # see if this file exists
87         try:
88             self.credential = Credential(filename = ma_cred_filename)
89         except IOError:
90             # get self credential
91             self_cred = self.registry.get_credential(None, 'ma', self.hrn)
92             self_cred.save_to_file(self_cred_filename, save_parents=True)
93
94             # get ma credential
95             ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
96             ma_cred.save_to_file(ma_cred_filename, save_parents=True)
97             self.credential = ma_cred
98
99     def connectRegistry(self):
100         """
101         Connect to the registry
102         """
103         # connect to registry using GeniClient
104         address = self.config.GENI_REGISTRY_HOSTNAME
105         port = self.config.GENI_REGISTRY_PORT
106         url = 'http://%(address)s:%(port)s' % locals()
107         self.registry = GeniClient(url, self.key_file, self.cert_file)
108
109     def connectAggregates(self, aggregates_file):
110         """
111         Get info about the aggregates available to us from file and create 
112         an xmlrpc connection to each. If any info is invalid, skip it. 
113         """
114         lines = []
115         try:
116             f = open(aggregates_file, 'r')
117             lines = f.readlines()
118             f.close()
119         except: raise 
120         
121         for line in lines:
122             # Skip comments
123             if line.strip().startswith("#"):
124                 continue
125             line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
126             agg_info = line.split(" ")
127         
128             # skip invalid info
129             if len(agg_info) != 3:
130                 continue
131
132             # create xmlrpc connection using GeniClient
133             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
134             url = 'http://%(address)s:%(port)s' % locals()
135             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
136
137     def item_hrns(self, items):
138         """
139         Take a list of items (components or slices) and return a dictionary where
140         the key is the authoritative hrn and the value is a list of items at that 
141         hrn.
142         """
143         item_hrns = {}
144         agg_hrns = self.aggregates.keys()
145         for agg_hrn in agg_hrns:
146             item_hrns[agg_hrn] = []
147         for item in items:
148             for agg_hrn in agg_hrns:
149                 if item.startswith(agg_hrn):
150                     item_hrns[agg_hrn] = item
151
152         return item_hrns    
153              
154
155     def hostname_to_hrn(self, login_base, hostname):
156         """
157         Convert hrn to plantelab name.
158         """
159         genihostname = "_".join(hostname.split("."))
160         return ".".join([self.hrn, login_base, genihostname])
161
162     def slicename_to_hrn(self, slicename):
163         """
164         Convert hrn to planetlab name.
165         """
166         slicename = slicename.replace("_", ".")
167         return ".".join([self.hrn, slicename])
168
169     def refresh_components(self):
170         """
171         Update the cached list of nodes.
172         """
173     
174         # convert and threshold to ints
175         if self.nodes.has_key('timestamp') and self.nodes['timestamp']:
176             hr_timestamp = self.nodes['timestamp']
177             timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.time_format)))
178             hr_threshold = self.nodes['threshold']
179             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.time_format)))
180         else:
181             timestamp = datetime.datetime.now()
182             hr_timestamp = timestamp.strftime(self.time_format)
183             delta = datetime.timedelta(hours=self.nodes_ttl)
184             threshold = timestamp + delta
185             hr_threshold = threshold.strftime(self.time_format)
186
187         start_time = int(timestamp.strftime("%s"))
188         end_time = int(threshold.strftime("%s"))
189         duration = end_time - start_time
190
191         aggregates = self.aggregates.keys()
192         rspecs = {}
193         networks = []
194         rspec = Rspec()
195         for aggregate in aggregates:
196             try:
197                 # get the rspec from the aggregate
198                 agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
199                 # extract the netspec from each aggregates rspec
200                 rspec.parseString(agg_rspec)
201                 networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
202             except:
203                 # XX print out to some error log
204                 print "Error calling list nodes at aggregate %s" % aggregate
205                 raise    
206   
207         # create the rspec dict
208         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
209         resourceDict = {'Rspec': resources} 
210         # convert rspec dict to xml
211         rspec.parseDict(resourceDict)
212        
213         # filter according to policy
214         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
215
216         # update timestamp and threshold
217         timestamp = datetime.datetime.now()
218         hr_timestamp = timestamp.strftime(self.time_format)
219         delta = datetime.timedelta(hours=self.nodes_ttl)
220         threshold = timestamp + delta
221         hr_threshold = threshold.strftime(self.time_format)
222         
223         nodedict = {'rspec': rspec.toxml(),
224                     'timestamp': hr_timestamp,
225                     'threshold':  hr_threshold}
226
227         self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
228         self.nodes.write()
229
230     def load_policy(self):
231         """
232         Read the list of blacklisted and whitelisted nodes.
233         """
234         self.policy.load()
235  
236     def load_slices(self):
237         """
238         Read current slice instantiation states.
239         """
240         self.slices.load()
241
242
243     def getNodes(self, format = 'rspec'):
244         """
245         Return a list of components managed by this slice manager.
246         """
247         # Reload components list
248         if not self.nodes.has_key('threshold') or not self.nodes['threshold'] or not self.nodes.has_key('timestamp') or not self.nodes['timestamp']:
249             self.refresh_components()
250         else:
251             now = datetime.datetime.now()
252             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self.nodes['threshold'], self.time_format)))
253             if  now > threshold:
254                 self.refresh_components()
255         return self.nodes[format]
256    
257      
258     def getSlices(self):
259         """
260         Return a list of instnatiated managed by this slice manager.
261         """
262         slice_hrns = []
263         for aggregate in self.aggregates:
264             try:
265                 slices = self.aggregates[aggregate].list_slices(self.credential)
266                 slice_hrns.extend(slices)
267             except:
268                 raise
269                 # print to some error log
270                 pass
271
272         return slice_hrns
273
274     def getResources(self, slice_hrn):
275         """
276         Return the current rspec for the specified slice.
277         """
278
279         # request this slices state from all known aggregates
280         rspec = Rspec()
281         rspecdicts = []
282         networks = []
283         for hrn in self.aggregates.keys():
284             # check if the slice has resources at this hrn
285             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
286             rspec.parseString(slice_resources)
287             networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
288             
289         # merge all these rspecs into one
290         start_time = int(datetime.datetime.now().strftime("%s"))
291         end_time = start_time
292         duration = end_time - start_time
293     
294         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
295         resourceDict = {'Rspec': resources}
296         # convert rspec dict to xml
297         rspec.parseDict(resourceDict)
298         # save this slices resources
299         #self.slices[slice_hrn] = rspec.toxml()
300         #self.slices.write()
301          
302         return rspec.toxml()
303  
304     def createSlice(self, cred, slice_hrn, rspec):
305         """
306         Instantiate the specified slice according to whats defined in the rspec.
307         """
308
309         # save slice state locally
310         # we can assume that spec object has been validated so its safer to
311         # save this instead of the unvalidated rspec the user gave us
312         rspec = Rspec()
313         tempspec = Rspec()
314         rspec.parseString(rspec)
315
316         self.slices[slice_hrn] = rspec.toxml()
317         self.slices.write()
318
319         # extract network list from the rspec and create a separate
320         # rspec for each network
321         slicename = self.hrn_to_plcslicename(slice_hrn)
322         specDict = rspec.toDict()
323         start_time = specDict['start_time']
324         end_time = specDict['end_time']
325
326         rspecs = {}
327         # only attempt to extract information about the aggregates we know about
328         for hrn in self.aggregates.keys():
329             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
330             if netspec:
331                 # creat a plc dict 
332                 resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
333                 resourceDict = {'Rspec': resources}
334                 tempspec.parseDict(resourceDict)
335                 rspecs[hrn] = tempspec.toxml()
336
337         # notify the aggregates
338         for hrn in self.rspecs.keys():
339             self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
340             
341         return 1
342
343     def updateSlice(self, slice_hrn, rspec, attributes = []):
344         """
345         Update the specifed slice
346         """
347         self.create_slice(slice_hrn, rspec, attributes)
348     
349     def deleteSlice_(self, slice_hrn):
350         """
351         Remove this slice from all components it was previouly associated with and 
352         free up the resources it was using.
353         """
354         # XX need to get the correct credential
355         cred = self.credential
356         
357         if self.slices.has_key(slice_hrn):
358             self.slices.pop(slice_hrn)
359             self.slices.write()
360
361         for hrn in self.aggregates.keys():
362             self.aggregates[hrn].deleteSlice(cred, slice_hrn)
363
364         return 1
365
366     def startSlice(self, slice_hrn):
367         """
368         Stop the slice at plc.
369         """
370         cred = self.credential
371
372         for hrn in self.aggregates.keys():
373             self.aggregates[hrn].startSlice(cred, slice_hrn)
374         return 1
375
376     def stopSlice(self, slice_hrn):
377         """
378         Stop the slice at plc
379         """
380         cred = self.credential
381         for hrn in self.aggregates.keys():
382             self.aggregates[hrn].startSlice(cred, slice_hrn)
383         return 1
384
385     def resetSlice(self, slice_hrn):
386         """
387         Reset the slice
388         """
389         # XX not yet implemented
390         return 1
391
392     def getPolicy(self):
393         """
394         Return the policy of this slice manager.
395         """
396     
397         return self.policy
398         
399     
400
401 ##############################
402 ## Server methods here for now
403 ##############################
404
405     def list_nodes(self, cred):
406         self.decode_authentication(cred, 'listnodes')
407         return self.getNodes()
408
409     def list_slices(self, cred):
410         self.decode_authentication(cred, 'listslices')
411         return self.getSlices()
412
413     def get_resources(self, cred, hrn):
414         self.decode_authentication(cred, 'listnodes')
415         return self.getResources(hrn)
416
417     def get_ticket(self, cred, hrn, rspec):
418         self.decode_authentication(cred, 'getticket')
419         return self.getTicket(hrn, rspec)
420
421     def get_policy(self, cred):
422         self.decode_authentication(cred, 'getpolicy')
423         return self.getPolicy()
424
425     def create_slice(self, cred, hrn, rspec):
426         self.decode_authentication(cred, 'creatslice')
427         return self.createSlice(cred, hrn, rspec)
428
429     def delete_slice(self, cred, hrn):
430         self.decode_authentication(cred, 'deleteslice')
431         return self.deleteSlice(hrn)
432
433     def start_slice(self, cred, hrn):
434         self.decode_authentication(cred, 'startslice')
435         return self.startSlice(hrn)
436
437     def stop_slice(self, cred, hrn):
438         self.decode_authentication(cred, 'stopslice')
439         return self.stopSlice(hrn)
440
441     def reset_slice(self, cred, hrn):
442         self.decode_authentication(cred, 'resetslice')
443         return self.resetSlice(hrn)
444
445     def register_functions(self):
446         GeniServer.register_functions(self)
447
448         # Aggregate interface methods
449         self.server.register_function(self.list_nodes)
450         self.server.register_function(self.list_slices)
451         self.server.register_function(self.get_resources)
452         self.server.register_function(self.get_policy)
453         self.server.register_function(self.create_slice)
454         self.server.register_function(self.delete_slice)
455         self.server.register_function(self.start_slice)
456         self.server.register_function(self.stop_slice)
457         self.server.register_function(self.reset_slice)
458