da490391c724312e121cf00933b5a70820b69cc5
[sfa.git] / geni / slicemgr.py
1 import os
2 import sys
3 import datetime
4 import time
5
6 from geni.util.geniserver import *
7 from geni.util.geniclient import *
8 from geni.util.cert import *
9 from geni.util.credential import Credential
10 from geni.util.trustedroot import *
11 from geni.util.excep import *
12 from geni.util.misc import *
13 from geni.util.config import Config
14 from geni.util.rspec import Rspec
15 from geni.util.specdict import *
16 from geni.util.storage import SimpleStorage
17
18 class SliceMgr(GeniServer):
19
20     hrn = None
21     nodes_ttl = None
22     nodes = None
23     slices = None
24     policy = None
25     aggregates = None
26     timestamp = None
27     threshold = None    
28     shell = None
29     registry = None
30     key_file = None
31     cert_file = None
32     credential = None 
33   
34     ##
35     # Create a new slice manager object.
36     #
37     # @param ip the ip address to listen on
38     # @param port the port to listen on
39     # @param key_file private key filename of registry
40     # @param cert_file certificate filename containing public key (could be a GID file)     
41
42     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
43         GeniServer.__init__(self, ip, port, key_file, cert_file)
44         self.key_file = key_file
45         self.cert_file = cert_file
46         self.config = Config(config)
47         self.basedir = self.config.GENI_BASE_DIR + os.sep
48         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
49         self.hrn = self.config.GENI_INTERFACE_HRN    
50         self.time_format = "%Y-%m-%d %H:%M:%S"
51         
52         # Get list of aggregates this sm talks to
53         # XX do we use simplestorage to maintain this file manually?
54         aggregates_file = self.server_basedir + os.sep + 'aggregates'
55         self.aggregates = SimpleStorage(aggregates_file)
56         
57         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
58         self.nodes = SimpleStorage(nodes_file)
59         self.nodes.load()
60         
61         slices_file = os.sep.join([self.server_basedir, 'smgr' + self.hrn + '.slices'])
62         self.slices = SimpleStorage(slices_file)
63         self.slices.load()
64
65         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
66         self.policy = SimpleStorage(policy_file, {'whitelist': [], 'blacklist': []})
67         self.policy.load()
68
69         # How long before we refresh nodes cache  
70         self.nodes_ttl = 1
71
72         self.connectRegistry()
73         self.loadCredential()
74         self.connectAggregates(aggregates_file)
75
76
77     def loadCredential(self):
78         """
79         Attempt to load credential from file if it exists. If it doesnt get
80         credential from registry.
81         """
82
83         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
84         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
85         
86         # see if this file exists
87         try:
88             self.credential = Credential(filename = ma_cred_filename)
89         except IOError:
90             # get self credential
91             self_cred = self.registry.get_credential(None, 'ma', self.hrn)
92             self_cred.save_to_file(self_cred_filename, save_parents=True)
93
94             # get ma credential
95             ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
96             ma_cred.save_to_file(ma_cred_filename, save_parents=True)
97             self.credential = ma_cred
98
99     def connectRegistry(self):
100         """
101         Connect to the registry
102         """
103         # connect to registry using GeniClient
104         address = self.config.GENI_REGISTRY_HOSTNAME
105         port = self.config.GENI_REGISTRY_PORT
106         url = 'http://%(address)s:%(port)s' % locals()
107         self.registry = GeniClient(url, self.key_file, self.cert_file)
108
109     def connectAggregates(self, aggregates_file):
110         """
111         Get info about the aggregates available to us from file and create 
112         an xmlrpc connection to each. If any info is invalid, skip it. 
113         """
114         lines = []
115         try:
116             f = open(aggregates_file, 'r')
117             lines = f.readlines()
118             f.close()
119         except: raise 
120         
121         for line in lines:
122             # Skip comments
123             if line.strip().startswith("#"):
124                 continue
125             line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
126             agg_info = line.split(" ")
127         
128             # skip invalid info
129             if len(agg_info) != 3:
130                 continue
131
132             # create xmlrpc connection using GeniClient
133             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
134             url = 'http://%(address)s:%(port)s' % locals()
135             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
136
137     def item_hrns(self, items):
138         """
139         Take a list of items (components or slices) and return a dictionary where
140         the key is the authoritative hrn and the value is a list of items at that 
141         hrn.
142         """
143         item_hrns = {}
144         agg_hrns = self.aggregates.keys()
145         for agg_hrn in agg_hrns:
146             item_hrns[agg_hrn] = []
147         for item in items:
148             for agg_hrn in agg_hrns:
149                 if item.startswith(agg_hrn):
150                     item_hrns[agg_hrn] = item
151
152         return item_hrns    
153              
154
155     def hostname_to_hrn(self, login_base, hostname):
156         """
157         Convert hrn to plantelab name.
158         """
159         genihostname = "_".join(hostname.split("."))
160         return ".".join([self.hrn, login_base, genihostname])
161
162     def slicename_to_hrn(self, slicename):
163         """
164         Convert hrn to planetlab name.
165         """
166         slicename = slicename.replace("_", ".")
167         return ".".join([self.hrn, slicename])
168
169     def refresh_components(self):
170         """
171         Update the cached list of nodes.
172         """
173     
174         # convert and threshold to ints
175         if self.nodes.has_key('timestamp') and self.nodes['timestamp']:
176             hr_timestamp = self.nodes['timestamp']
177             timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.time_format)))
178             hr_threshold = self.nodes['threshold']
179             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.time_format)))
180         else:
181             timestamp = datetime.datetime.now()
182             hr_timestamp = timestamp.strftime(self.time_format)
183             delta = datetime.timedelta(hours=self.nodes_ttl)
184             threshold = timestamp + delta
185             hr_threshold = threshold.strftime(self.time_format)
186
187         start_time = int(timestamp.strftime("%s"))
188         end_time = int(threshold.strftime("%s"))
189         duration = end_time - start_time
190
191         aggregates = self.aggregates.keys()
192         rspecs = {}
193         networks = []
194         rspec = Rspec()
195         for aggregate in aggregates:
196             try:
197                 # get the rspec from the aggregate
198                 agg_server = self.aggregates[aggregate]
199                 agg_rspec = self.aggregates[aggregate].list_nodes(self.credential)
200                 
201                 # extract the netspec from each aggregates rspec
202                 rspec.parseString(agg_rspec)
203                 networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
204             except:
205                 # XX print out to some error log
206                 print "Error calling list nodes at aggregate %s" % aggregate
207                 raise    
208    
209         # create the rspec dict
210         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
211         resourceDict = {'Rspec': resources} 
212         # convert rspec dict to xml
213         rspec.parseDict(resourceDict)
214        
215         # filter according to policy
216         rspec.filter('NodeSpec', 'name', blacklist=self.policy['blacklist'], whitelist=self.policy['whitelist'])
217
218         # update timestamp and threshold
219         timestamp = datetime.datetime.now()
220         hr_timestamp = timestamp.strftime(self.time_format)
221         delta = datetime.timedelta(hours=self.nodes_ttl)
222         threshold = timestamp + delta
223         hr_threshold = threshold.strftime(self.time_format)
224         
225         nodedict = {'rspec': rspec.toxml(),
226                     'timestamp': hr_timestamp,
227                     'threshold':  hr_threshold}
228
229         self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
230         self.nodes.write()
231
232     def load_policy(self):
233         """
234         Read the list of blacklisted and whitelisted nodes.
235         """
236         self.policy.load()
237  
238     def load_slices(self):
239         """
240         Read current slice instantiation states.
241         """
242         self.slices.load()
243
244
245     def getNodes(self, format = 'rspec'):
246         """
247         Return a list of components managed by this slice manager.
248         """
249         # Reload components list
250         if not self.nodes.has_key('threshold') or not self.nodes['threshold'] or not self.nodes.has_key('timestamp') or not self.nodes['timestamp']:
251             self.refresh_components()
252         else:
253             now = datetime.datetime.now()
254             threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self.nodes['threshold'], self.time_format)))
255             if  now > threshold:
256                 self.refresh_components()
257         return self.nodes[format]
258    
259      
260     def getSlices(self):
261         """
262         Return a list of instnatiated managed by this slice manager.
263         """
264         slice_hrns = []
265         for aggregate in self.aggregates:
266             try:
267                 slices = self.aggregates[aggregate].list_slices(self.credential)
268                 slice_hrns.extend(slices)
269             except:
270                 raise
271                 # print to some error log
272                 pass
273
274         return slice_hrns
275
276     def getResources(self, slice_hrn):
277         """
278         Return the current rspec for the specified slice.
279         """
280
281         # request this slices state from all known aggregates
282         rspec = Rspec()
283         rspecdicts = []
284         networks = []
285         for hrn in self.aggregates.keys():
286             # check if the slice has resources at this hrn
287             slice_resources = self.aggregates[hrn].get_resources(self.credential, slice_hrn)
288             rspec.parseString(slice_resources)
289             networks.extend({'NetSpec': rspec.getDictsByTagName('NetSpec')})
290             
291         # merge all these rspecs into one
292         start_time = int(datetime.datetime.now().strftime("%s"))
293         end_time = start_time
294         duration = end_time - start_time
295     
296         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
297         resourceDict = {'Rspec': resources}
298         # convert rspec dict to xml
299         rspec.parseDict(resourceDict)
300         # save this slices resources
301         #self.slices[slice_hrn] = rspec.toxml()
302         #self.slices.write()
303          
304         return rspec.toxml()
305  
306     def createSlice(self, cred, slice_hrn, rspec):
307         """
308         Instantiate the specified slice according to whats defined in the rspec.
309         """
310
311         # save slice state locally
312         # we can assume that spec object has been validated so its safer to
313         # save this instead of the unvalidated rspec the user gave us
314         rspec = Rspec()
315         tempspec = Rspec()
316         rspec.parseString(rspec)
317
318         self.slices[slice_hrn] = rspec.toxml()
319         self.slices.write()
320
321         # extract network list from the rspec and create a separate
322         # rspec for each network
323         slicename = self.hrn_to_plcslicename(slice_hrn)
324         specDict = rspec.toDict()
325         start_time = specDict['start_time']
326         end_time = specDict['end_time']
327
328         rspecs = {}
329         # only attempt to extract information about the aggregates we know about
330         for hrn in self.aggregates.keys():
331             netspec = spec.getDictByTagNameValue('NetSpec', hrn)
332             if netspec:
333                 # creat a plc dict 
334                 resources = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
335                 resourceDict = {'Rspec': resources}
336                 tempspec.parseDict(resourceDict)
337                 rspecs[hrn] = tempspec.toxml()
338
339         # notify the aggregates
340         for hrn in self.rspecs.keys():
341             self.aggregates[hrn].createSlice(self.credential, rspecs[hrn])
342             
343         return 1
344
345     def updateSlice(self, slice_hrn, rspec, attributes = []):
346         """
347         Update the specifed slice
348         """
349         self.create_slice(slice_hrn, rspec, attributes)
350     
351     def deleteSlice_(self, slice_hrn):
352         """
353         Remove this slice from all components it was previouly associated with and 
354         free up the resources it was using.
355         """
356         # XX need to get the correct credential
357         cred = self.credential
358         
359         if self.slices.has_key(slice_hrn):
360             self.slices.pop(slice_hrn)
361             self.slices.write()
362
363         for hrn in self.aggregates.keys():
364             self.aggregates[hrn].deleteSlice(cred, slice_hrn)
365
366         return 1
367
368     def startSlice(self, slice_hrn):
369         """
370         Stop the slice at plc.
371         """
372         cred = self.credential
373
374         for hrn in self.aggregates.keys():
375             self.aggregates[hrn].startSlice(cred, slice_hrn)
376         return 1
377
378     def stopSlice(self, slice_hrn):
379         """
380         Stop the slice at plc
381         """
382         cred = self.credential
383         for hrn in self.aggregates.keys():
384             self.aggregates[hrn].startSlice(cred, slice_hrn)
385         return 1
386
387     def resetSlice(self, slice_hrn):
388         """
389         Reset the slice
390         """
391         # XX not yet implemented
392         return 1
393
394     def getPolicy(self):
395         """
396         Return the policy of this slice manager.
397         """
398     
399         return self.policy
400         
401     
402
403 ##############################
404 ## Server methods here for now
405 ##############################
406
407     def list_nodes(self, cred):
408         self.decode_authentication(cred, 'listnodes')
409         return self.getNodes()
410
411     def list_slices(self, cred):
412         self.decode_authentication(cred, 'listslices')
413         return self.getSlices()
414
415     def get_resources(self, cred, hrn):
416         self.decode_authentication(cred, 'listnodes')
417         return self.getResources(hrn)
418
419     def get_ticket(self, cred, hrn, rspec):
420         self.decode_authentication(cred, 'getticket')
421         return self.getTicket(hrn, rspec)
422
423     def get_policy(self, cred):
424         self.decode_authentication(cred, 'getpolicy')
425         return self.getPolicy()
426
427     def create_slice(self, cred, hrn, rspec):
428         self.decode_authentication(cred, 'creatslice')
429         return self.createSlice(cred, hrn, rspec)
430
431     def delete_slice(self, cred, hrn):
432         self.decode_authentication(cred, 'deleteslice')
433         return self.deleteSlice(hrn)
434
435     def start_slice(self, cred, hrn):
436         self.decode_authentication(cred, 'startslice')
437         return self.startSlice(hrn)
438
439     def stop_slice(self, cred, hrn):
440         self.decode_authentication(cred, 'stopslice')
441         return self.stopSlice(hrn)
442
443     def reset_slice(self, cred, hrn):
444         self.decode_authentication(cred, 'resetslice')
445         return self.resetSlice(hrn)
446
447     def register_functions(self):
448         GeniServer.register_functions(self)
449
450         # Aggregate interface methods
451         self.server.register_function(self.list_nodes)
452         self.server.register_function(self.list_slices)
453         self.server.register_function(self.get_resources)
454         self.server.register_function(self.get_policy)
455         self.server.register_function(self.create_slice)
456         self.server.register_function(self.delete_slice)
457         self.server.register_function(self.start_slice)
458         self.server.register_function(self.stop_slice)
459         self.server.register_function(self.reset_slice)
460