911833bca8e508c82ac0c05a443a67ea0f8d3947
[sfa.git] / geni / slicemgr.py
1 import os
2 import sys
3 import datetime
4 import time
5
6 from geni.util.geniserver import *
7 from geni.util.geniclient import *
8 from geni.util.cert import *
9 from geni.util.credential import Credential
10 from geni.util.trustedroot import *
11 from geni.util.excep import *
12 from geni.util.misc import *
13 from geni.util.config import Config
14 from geni.util.rspec import Rspec
15 from geni.util.specdict import *
16 from geni.util.storage import SimpleStorage, XmlStorage
17
18 class SliceMgr(GeniServer):
19
20     hrn = None
21     nodes_ttl = None
22     nodes = None
23     slices = None
24     policy = None
25     aggregates = None
26     timestamp = None
27     threshold = None    
28     shell = None
29     registry = None
30     key_file = None
31     cert_file = None
32     credential = None 
33   
34     ##
35     # Create a new slice manager object.
36     #
37     # @param ip the ip address to listen on
38     # @param port the port to listen on
39     # @param key_file private key filename of registry
40     # @param cert_file certificate filename containing public key (could be a GID file)     
41
42     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
43         GeniServer.__init__(self, ip, port, key_file, cert_file)
44         self.key_file = key_file
45         self.cert_file = cert_file
46         self.config = Config(config)
47         self.basedir = self.config.GENI_BASE_DIR + os.sep
48         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
49         self.hrn = self.config.GENI_INTERFACE_HRN    
50         self.time_format = "%Y-%m-%d %H:%M:%S"
51         
52         # Get list of aggregates this sm talks to
53         aggregates_file = self.server_basedir + os.sep + 'aggregates.xml'
54         self.aggregate_info = XmlStorage(aggregates_file, {'aggregates': {'aggregate': []}} )
55         self.aggregate_info.load()
56         
57         # Get cached list of nodes (rspec) 
58         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
59         self.nodes = SimpleStorage(nodes_file)
60         self.nodes.load()
61         
62         # Get cacheds slice states
63         slices_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.slices'])
64         self.slices = SimpleStorage(slices_file)
65         self.slices.load()
66
67         # Get the policy
68         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
69         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
70         self.policy.load()
71
72         # How long before we refresh nodes cache  
73         self.nodes_ttl = 1
74
75         self.connectRegistry()
76         self.loadCredential()
77         self.connectAggregates()
78
79
80     def loadCredential(self):
81         """
82         Attempt to load credential from file if it exists. If it doesnt get
83         credential from registry.
84         """
85         
86         # see if this file exists
87         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
88         try:
89             self.credential = Credential(filename = ma_cred_filename)
90         except IOError:
91             self.credential = self.getCrednetialFromRegistry()
92             
93         
94     def getCredentialFromRegistry(self):
95         """
96         Get our current credential from the registry.
97         """
98         # get self credential
99         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
100         self_cred = self.registry.get_credential(None, 'ma', self.hrn)
101         self_cred.save_to_file(self_cred_filename, save_parents=True)
102
103         # get ma credential
104         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
105         ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
106         ma_cred.save_to_file(ma_cred_filename, save_parents=True)
107         return ma_cred        
108
109     def connectRegistry(self):
110         """
111         Connect to the registry
112         """
113         # connect to registry using GeniClient
114         address = self.config.GENI_REGISTRY_HOSTNAME
115         port = self.config.GENI_REGISTRY_PORT
116         url = 'http://%(address)s:%(port)s' % locals()
117         self.registry = GeniClient(url, self.key_file, self.cert_file)
118
119     def connectAggregates(self):
120         """
121         Get info about the aggregates available to us from file and create 
122         an xmlrpc connection to each. If any info is invalid, skip it. 
123         """
124         self.aggregates = {} 
125         aggregates = self.aggregate_info['aggregates']['aggregate']
126         if isinstance(aggregates, dict):
127             aggregates = [aggregates]
128         if isinstance(aggregates, list):
129             for aggregate in aggregates:         
130                 # create xmlrpc connection using GeniClient
131                 hrn, address, port = aggregate['hrn'], aggregate['addr'], aggregate['port']
132                 url = 'http://%(address)s:%(port)s' % locals()
133                 self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
134
135     def item_hrns(self, items):
136         """
137         Take a list of items (components or slices) and return a dictionary where
138         the key is the authoritative hrn and the value is a list of items at that 
139         hrn.
140         """
141         item_hrns = {}
142         agg_hrns = self.aggregates.keys()
143         for agg_hrn in agg_hrns:
144             item_hrns[agg_hrn] = []
145         for item in items:
146             for agg_hrn in agg_hrns:
147                 if item.startswith(agg_hrn):
148                     item_hrns[agg_hrn] = item
149
150         return item_hrns    
151              
152
153     def hostname_to_hrn(self, login_base, hostname):
154         """
155         Convert hrn to plantelab name.
156         """
157         genihostname = "_".join(hostname.split("."))
158         return ".".join([self.hrn, login_base, genihostname])
159
160     def slicename_to_hrn(self, slicename):
161         """
162         Convert hrn to planetlab name.
163         """
164         slicename = slicename.replace("_", ".")
165         return ".".join([self.hrn, slicename])
166
167     def refresh_components(self):
168         """
169         Update the cached list of nodes.
170         """
171     
172         # convert and threshold to ints
173         if self.nodes.has_key('timestamp') and self.nodes['timestamp']:
174             hr_timestamp = self.nodes['timestamp']
175             timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.time_format)))
176             hr_threshold = self.nodes['threshold']
177             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.time_format)))
178         else:
179             timestamp = datetime.datetime.now()
180             hr_timestamp = timestamp.strftime(self.time_format)
181             delta = datetime.timedelta(hours=self.nodes_ttl)
182             threshold = timestamp + delta
183             hr_threshold = threshold.strftime(self.time_format)
184
185         start_time = int(timestamp.strftime("%s"))
186         end_time = int(threshold.strftime("%s"))
187         duration = end_time - start_time
188
189         aggregates = self.aggregates.keys()
190         rspecs = {}
191         networks = []
192         rspec = Rspec()
193         for aggregate in aggregates:
194             try:
195                 # get the rspec from the aggregate
196                 agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
197                 # extract the netspec from each aggregates rspec
198                 rspec.parseString(agg_rspec)
199                 networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
200             except:
201                 # XX print out to some error log
202                 print "Error calling list nodes at aggregate %s" % aggregate
203                 raise    
204   
205         # create the rspec dict
206         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
207         resourceDict = {'Rspec': resources} 
208         # convert rspec dict to xml
209         rspec.parseDict(resourceDict)
210        
211         # filter according to policy
212         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
213
214         # update timestamp and threshold
215         timestamp = datetime.datetime.now()
216         hr_timestamp = timestamp.strftime(self.time_format)
217         delta = datetime.timedelta(hours=self.nodes_ttl)
218         threshold = timestamp + delta
219         hr_threshold = threshold.strftime(self.time_format)
220         
221         nodedict = {'rspec': rspec.toxml(),
222                     'timestamp': hr_timestamp,
223                     'threshold':  hr_threshold}
224
225         self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
226         self.nodes.write()
227
228     def load_policy(self):
229         """
230         Read the list of blacklisted and whitelisted nodes.
231         """
232         self.policy.load()
233  
234     def load_slices(self):
235         """
236         Read current slice instantiation states.
237         """
238         self.slices.load()
239
240
241     def getNodes(self, format = 'rspec'):
242         """
243         Return a list of components managed by this slice manager.
244         """
245         # Reload components list
246         if not self.nodes.has_key('threshold') or not self.nodes['threshold'] or not self.nodes.has_key('timestamp') or not self.nodes['timestamp']:
247             self.refresh_components()
248         else:
249             now = datetime.datetime.now()
250             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self.nodes['threshold'], self.time_format)))
251             if  now > threshold:
252                 self.refresh_components()
253         return self.nodes[format]
254    
255      
256     def getSlices(self):
257         """
258         Return a list of instnatiated managed by this slice manager.
259         """
260         slice_hrns = []
261         for aggregate in self.aggregates:
262             try:
263                 slices = self.aggregates[aggregate].list_slices(self.credential)
264                 slice_hrns.extend(slices)
265             except:
266                 raise
267                 # print to some error log
268                 pass
269
270         return slice_hrns
271
272     def getResources(self, slice_hrn):
273         """
274         Return the current rspec for the specified slice.
275         """
276
277         # request this slices state from all known aggregates
278         rspec = Rspec()
279         rspecdicts = []
280         networks = []
281         for hrn in self.aggregates.keys():
282             # check if the slice has resources at this hrn
283             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
284             rspec.parseString(slice_resources)
285             networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
286             
287         # merge all these rspecs into one
288         start_time = int(datetime.datetime.now().strftime("%s"))
289         end_time = start_time
290         duration = end_time - start_time
291     
292         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
293         resourceDict = {'Rspec': resources}
294         # convert rspec dict to xml
295         rspec.parseDict(resourceDict)
296         # save this slices resources
297         #self.slices[slice_hrn] = rspec.toxml()
298         #self.slices.write()
299          
300         return rspec.toxml()
301  
302     def createSlice(self, cred, slice_hrn, rspec):
303         """
304         Instantiate the specified slice according to whats defined in the rspec.
305         """
306
307         # save slice state locally
308         # we can assume that spec object has been validated so its safer to
309         # save this instead of the unvalidated rspec the user gave us
310         rspec = Rspec()
311         tempspec = Rspec()
312         rspec.parseString(rspec)
313
314         self.slices[slice_hrn] = rspec.toxml()
315         self.slices.write()
316
317         # extract network list from the rspec and create a separate
318         # rspec for each network
319         slicename = self.hrn_to_plcslicename(slice_hrn)
320         specDict = rspec.toDict()
321         start_time = specDict['start_time']
322         end_time = specDict['end_time']
323
324         rspecs = {}
325         # only attempt to extract information about the aggregates we know about
326         for hrn in self.aggregates.keys():
327             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
328             if netspec:
329                 # creat a plc dict 
330                 resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
331                 resourceDict = {'Rspec': resources}
332                 tempspec.parseDict(resourceDict)
333                 rspecs[hrn] = tempspec.toxml()
334
335         # notify the aggregates
336         for hrn in self.rspecs.keys():
337             self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
338             
339         return 1
340
341     def updateSlice(self, slice_hrn, rspec, attributes = []):
342         """
343         Update the specifed slice
344         """
345         self.create_slice(slice_hrn, rspec, attributes)
346     
347     def deleteSlice_(self, slice_hrn):
348         """
349         Remove this slice from all components it was previouly associated with and 
350         free up the resources it was using.
351         """
352         # XX need to get the correct credential
353         cred = self.credential
354         
355         if self.slices.has_key(slice_hrn):
356             self.slices.pop(slice_hrn)
357             self.slices.write()
358
359         for hrn in self.aggregates.keys():
360             self.aggregates[hrn].deleteSlice(cred, slice_hrn)
361
362         return 1
363
364     def startSlice(self, slice_hrn):
365         """
366         Stop the slice at plc.
367         """
368         cred = self.credential
369
370         for hrn in self.aggregates.keys():
371             self.aggregates[hrn].startSlice(cred, slice_hrn)
372         return 1
373
374     def stopSlice(self, slice_hrn):
375         """
376         Stop the slice at plc
377         """
378         cred = self.credential
379         for hrn in self.aggregates.keys():
380             self.aggregates[hrn].startSlice(cred, slice_hrn)
381         return 1
382
383     def resetSlice(self, slice_hrn):
384         """
385         Reset the slice
386         """
387         # XX not yet implemented
388         return 1
389
390     def getPolicy(self):
391         """
392         Return the policy of this slice manager.
393         """
394     
395         return self.policy
396         
397     
398
399 ##############################
400 ## Server methods here for now
401 ##############################
402
403     def list_nodes(self, cred):
404         self.decode_authentication(cred, 'listnodes')
405         return self.getNodes()
406
407     def list_slices(self, cred):
408         self.decode_authentication(cred, 'listslices')
409         return self.getSlices()
410
411     def get_resources(self, cred, hrn):
412         self.decode_authentication(cred, 'listnodes')
413         return self.getResources(hrn)
414
415     def get_ticket(self, cred, hrn, rspec):
416         self.decode_authentication(cred, 'getticket')
417         return self.getTicket(hrn, rspec)
418
419     def get_policy(self, cred):
420         self.decode_authentication(cred, 'getpolicy')
421         return self.getPolicy()
422
423     def create_slice(self, cred, hrn, rspec):
424         self.decode_authentication(cred, 'creatslice')
425         return self.createSlice(cred, hrn, rspec)
426
427     def delete_slice(self, cred, hrn):
428         self.decode_authentication(cred, 'deleteslice')
429         return self.deleteSlice(hrn)
430
431     def start_slice(self, cred, hrn):
432         self.decode_authentication(cred, 'startslice')
433         return self.startSlice(hrn)
434
435     def stop_slice(self, cred, hrn):
436         self.decode_authentication(cred, 'stopslice')
437         return self.stopSlice(hrn)
438
439     def reset_slice(self, cred, hrn):
440         self.decode_authentication(cred, 'resetslice')
441         return self.resetSlice(hrn)
442
443     def register_functions(self):
444         GeniServer.register_functions(self)
445
446         # Aggregate interface methods
447         self.server.register_function(self.list_nodes)
448         self.server.register_function(self.list_slices)
449         self.server.register_function(self.get_resources)
450         self.server.register_function(self.get_policy)
451         self.server.register_function(self.create_slice)
452         self.server.register_function(self.delete_slice)
453         self.server.register_function(self.start_slice)
454         self.server.register_function(self.stop_slice)
455         self.server.register_function(self.reset_slice)
456