fix bug when removing researchers
[sfa.git] / geni / slicemgr.py
1 import os
2 import sys
3 import datetime
4 import time
5
6 from geni.util.geniserver import *
7 from geni.util.geniclient import *
8 from geni.util.cert import *
9 from geni.util.credential import Credential
10 from geni.util.trustedroot import *
11 from geni.util.excep import *
12 from geni.util.misc import *
13 from geni.util.config import Config
14 from geni.util.rspec import Rspec
15 from geni.util.specdict import *
16 from geni.util.storage import SimpleStorage
17
18 class SliceMgr(GeniServer):
19
20     hrn = None
21     nodes_ttl = None
22     nodes = None
23     slices = None
24     policy = None
25     aggregates = None
26     timestamp = None
27     threshold = None    
28     shell = None
29     registry = None
30     key_file = None
31     cert_file = None
32     credential = None 
33   
34     ##
35     # Create a new slice manager object.
36     #
37     # @param ip the ip address to listen on
38     # @param port the port to listen on
39     # @param key_file private key filename of registry
40     # @param cert_file certificate filename containing public key (could be a GID file)     
41
42     def __init__(self, ip, port, key_file, cert_file, config = "/usr/share/geniwrapper/geni/util/geni_config"):
43         GeniServer.__init__(self, ip, port, key_file, cert_file)
44         self.key_file = key_file
45         self.cert_file = cert_file
46         self.config = Config(config)
47         self.basedir = self.config.GENI_BASE_DIR + os.sep
48         self.server_basedir = self.basedir + os.sep + "geni" + os.sep
49         self.hrn = self.config.GENI_INTERFACE_HRN    
50
51         # Get list of aggregates this sm talks to
52         # XX do we use simplestorage to maintain this file manually?
53         aggregates_file = self.server_basedir + os.sep + 'aggregates'
54         self.aggregates = SimpleStorage(aggregates_file)
55         
56         nodes_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.components'])
57         self.nodes = SimpleStorage(nodes_file)
58         self.nodes.load()
59         
60         slices_file = os.sep.join([self.server_basedir, 'smgr' + self.hrn + '.slices'])
61         self.slices = SimpleStorage(slices_file)
62         self.slices.load()
63
64         policy_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.policy'])
65         self.policy = SimpleStorage(policy_file)
66         self.policy.load()
67
68         timestamp_file = os.sep.join([self.server_basedir, 'smgr.' + self.hrn + '.timestamp'])
69         self.timestamp = SimpleStorage(timestamp_file)
70         
71         # How long before we refresh nodes cache  
72         self.nodes_ttl = 1
73
74         self.connectRegistry()
75         self.loadCredential()
76         self.connectAggregates(aggregates_file)
77
78
79     def loadCredential(self):
80         """
81         Attempt to load credential from file if it exists. If it doesnt get
82         credential from registry.
83         """
84
85         self_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".cred"
86         ma_cred_filename = self.server_basedir + os.sep + "smgr." + self.hrn + ".sa.cred"
87         
88         # see if this file exists
89         try:
90             self.credential = Credential(filename = ma_cred_filename)
91         except IOError:
92             # get self credential
93             self_cred = self.registry.get_credential(None, 'ma', self.hrn)
94             self_cred.save_to_file(self_cred_filename, save_parents=True)
95
96             # get ma credential
97             ma_cred = self.registry.get_credential(self_cred, 'sa', self.hrn)
98             ma_cred.save_to_file(ma_cred_filename, save_parents=True)
99             self.credential = ma_cred
100
101     def connectRegistry(self):
102         """
103         Connect to the registry
104         """
105         # connect to registry using GeniClient
106         address = self.config.GENI_REGISTRY_HOSTNAME
107         port = self.config.GENI_REGISTRY_PORT
108         url = 'http://%(address)s:%(port)s' % locals()
109         self.registry = GeniClient(url, self.key_file, self.cert_file)
110
111     def connectAggregates(self, aggregates_file):
112         """
113         Get info about the aggregates available to us from file and create 
114         an xmlrpc connection to each. If any info is invalid, skip it. 
115         """
116         lines = []
117         try:
118             f = open(aggregates_file, 'r')
119             lines = f.readlines()
120             f.close()
121         except: raise 
122         
123         for line in lines:
124             # Skip comments
125             if line.strip().startswith("#"):
126                 continue
127             line = line.replace("\t", " ").replace("\n", "").replace("\r", "").strip()
128             agg_info = line.split(" ")
129         
130             # skip invalid info
131             if len(agg_info) != 3:
132                 continue
133
134             # create xmlrpc connection using GeniClient
135             hrn, address, port = agg_info[0], agg_info[1], agg_info[2]
136             url = 'http://%(address)s:%(port)s' % locals()
137             self.aggregates[hrn] = GeniClient(url, self.key_file, self.cert_file)
138             self.aggregates[hrn].list_nodes(self.credential)
139
140     def item_hrns(self, items):
141         """
142         Take a list of items (components or slices) and return a dictionary where
143         the key is the authoritative hrn and the value is a list of items at that 
144         hrn.
145         """
146         item_hrns = {}
147         agg_hrns = self.aggregates.keys()
148         for agg_hrn in agg_hrns:
149             item_hrns[agg_hrn] = []
150         for item in items:
151             for agg_hrn in agg_hrns:
152                 if item.startswith(agg_hrn):
153                     item_hrns[agg_hrn] = item
154
155         return item_hrns    
156              
157
158     def hostname_to_hrn(self, login_base, hostname):
159         """
160         Convert hrn to plantelab name.
161         """
162         genihostname = "_".join(hostname.split("."))
163         return ".".join([self.hrn, login_base, genihostname])
164
165     def slicename_to_hrn(self, slicename):
166         """
167         Convert hrn to planetlab name.
168         """
169         slicename = slicename.replace("_", ".")
170         return ".".join([self.hrn, slicename])
171
172     def refresh_components(self):
173         """
174         Update the cached list of nodes.
175         """
176     
177          # convert and threshold to ints
178         if self.timestamp.has_key('timestamp') and self.timestamp['timestamp']:
179             timestamp = self.timestamp['timestamp']
180             threshold = self.threshold
181         else:
182             timestamp = datetime.datetime.now()
183             delta = datetime.timedelta(hours=self.nodes_ttl)
184             threshold = timestamp + delta
185
186         start_time = int(timestamp.strftime("%s"))
187         end_time = int(threshold.strftime("%s"))
188         duration = end_time - start_time
189
190         aggregates = self.aggregates.keys()
191         rspecs = {}
192         for aggregate in aggregates:
193             try:
194                 # get the rspec from the aggregate
195                 agg_server = self.aggregates[aggregate]
196                 nodes = self.aggregates[aggregate].list_nodes(self.credential)
197                 rspecs[aggregate] = nodes
198                 
199                 # XX apply policy whitelist, balcklist here
200             except:
201                 # XX print out to some error log
202                 print "Error calling list nodes at aggregate %s" % aggregate
203                 raise    
204    
205         # extract the netspec from each aggregates rspec
206         networks = []
207         for rs in rspecs:
208             r = Rspec()
209             r.parseString(rspecs[rs])
210             networks.extend(r.getDictsByTagName('NetSpec'))
211         
212         # create the plc dict
213         resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
214         
215         # convert plc dict to rspec dict
216         resourceDict = RspecDict(resources)
217         
218         # convert rspec dict to xml
219         rspec = Rspec()
220         rspec.parseDict(resourceDict)
221         
222         #for node in all_nodes:
223         #    if self.polciy['whitelist'] and node not in self.polciy['whitelist']:
224         #        continue
225         #    if self.polciy['blacklist'] and node in self.policy['blacklist']:
226         #        continue
227         #    nodedict[node] = node
228
229         nodedict = {'rspec': rspec.toxml()}
230         self.nodes = SimpleStorage(self.nodes.db_filename, nodedict)
231         self.nodes.write()
232
233         # update timestamp and threshold
234         self.timestamp['timestamp'] = datetime.datetime.now()
235         delta = datetime.timedelta(hours=self.nodes_ttl)
236         self.threshold = self.timestamp['timestamp'] + delta
237         self.timestamp.write()
238         
239  
240     def load_components(self):
241         """
242         Read cached list of nodes.
243         """
244         # Read component list from cached file 
245         self.nodes.load()
246         self.timestamp.load()
247         time_format = "%Y-%m-%d %H:%M:%S"
248         timestamp = self.timestamp['timestamp']
249         self.timestamp['timestamp'] = datetime.datetime.fromtimestamp(time.mktime(time.strptime(timestamp, time_format)))
250         delta = datetime.timedelta(hours=self.nodes_ttl)
251         self.threshold = self.timestamp['timestamp'] + delta
252
253     def load_policy(self):
254         """
255         Read the list of blacklisted and whitelisted nodes.
256         """
257         self.policy.load()
258  
259     def load_slices(self):
260         """
261         Read current slice instantiation states.
262         """
263         self.slices.load()
264
265
266     def getNodes(self, format = 'rspec'):
267         """
268         Return a list of components managed by this slice manager.
269         """
270         # Reload components list
271         now = datetime.datetime.now()
272         #self.load_components()
273         if not self.threshold or not self.timestamp or now > self.threshold:
274             self.refresh_components()
275         elif now < self.threshold and not self.nodes: 
276             self.load_components()
277         return self.nodes[format]
278    
279      
280     def getSlices(self):
281         """
282         Return a list of instnatiated managed by this slice manager.
283         """
284         # XX return only the slices at the specified hrn
285         return dict(self.slices)
286
287     def getResources(self, slice_hrn):
288         """
289         Return the current rspec for the specified slice.
290         """
291
292         if slice_hrn in self.slices.keys():
293             # check if we alreay have this slices state saved
294             rspec = self.slices[slice_hrn]
295         else:
296             # request this slices state from all  known aggregates
297             rspecdicts = []
298             for hrn in self.aggregates.keys():
299                 # XX need to use the right credentials for this call
300                 # check if the slice has resources at this hrn
301                 tempresources = self.aggregates[hrn].resources(self.credential, slice_hrn)
302                 temprspec = Rspec()
303                 temprspec.parseString(temprspec)
304                 if temprspec.getDictsByTagName('NodeSpec'):
305                     # append this rspec to the list of rspecs
306                     rspecdicts.append(temprspec.toDict())
307                 
308             # merge all these rspecs into one
309             start_time = int(self.timestamp['timestamp'].strftime("%s"))
310             end_time = int(self.duration.strftime("%s"))
311             duration = end_time - start_time
312                 
313             # create a plc dict 
314             networks = [rspecdict['networks'][0] for rspecdict in rspecdicts]
315             resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
316             # convert the plc dict to an rspec dict
317             resourceDict = RspecDict(resources)
318             resourceSpec = Rspec()
319             resourceSpec.parseDict(resourceDict)
320             rspec = resourceSpec.toxml() 
321             # save this slices resources
322             self.slices[slice_hrn] = rspec
323             self.slices.write()
324          
325         return rspec
326  
327     def createSlice(self, slice_hrn, rspec, attributes):
328         """
329         Instantiate the specified slice according to whats defined in the rspec.
330         """
331         # XX need to gget the correct credentials
332         cred = self.credential
333
334         # save slice state locally
335         # we can assume that spec object has been validated so its safer to
336         # save this instead of the unvalidated rspec the user gave us
337         self.slices[slice_hrn] = spec.toxml()
338         self.slices.write()
339
340         # extract network list from the rspec and create a separate
341         # rspec for each network
342         slicename = self.hrn_to_plcslicename(slice_hrn)
343         spec = Rspec()
344         spec.parseString(rspec)
345         specDict = spec.toDict()
346         start_time = specDict['start_time']
347         end_time = specDict['end_time']
348
349         rspecs = {}
350         # only attempt to extract information about the aggregates we know about
351         for hrn in self.aggregates.keys():
352             netspec = spec.getDictByTagNameValue('NetSpec', 'hrn')
353             if netspec:
354                 # creat a plc dict 
355                 tempdict = {'start_time': star_time, 'end_time': end_time, 'networks': netspec}
356                 #convert the plc dict to rpsec dict
357                 resourceDict = RspecDict(tempdict)
358                 # parse rspec dict
359                 tempspec = Rspec()
360                 tempspec.parseDict(resourceDict)
361                 rspecs[hrn] = tempspec.toxml()
362
363         # notify the aggregates
364         for hrn in self.rspecs.keys():
365             self.aggregates[hrn].createSlice(cred, rspecs[hrn])
366             
367         return 1
368
369     def updateSlice(self, slice_hrn, rspec, attributes = []):
370         """
371         Update the specifed slice
372         """
373         self.create_slice(slice_hrn, rspec, attributes)
374     
375     def deleteSlice_(self, slice_hrn):
376         """
377         Remove this slice from all components it was previouly associated with and 
378         free up the resources it was using.
379         """
380         # XX need to get the correct credential
381         cred = self.credential
382         
383         if self.slices.has_key(slice_hrn):
384             self.slices.pop(slice_hrn)
385             self.slices.write()
386
387         for hrn in self.aggregates.keys():
388             self.aggregates[hrn].deleteSlice(cred, slice_hrn)
389
390         return 1
391
392     def startSlice(self, slice_hrn):
393         """
394         Stop the slice at plc.
395         """
396         cred = self.credential
397
398         for hrn in self.aggregates.keys():
399             self.aggregates[hrn].startSlice(cred, slice_hrn)
400         return 1
401
402     def stopSlice(self, slice_hrn):
403         """
404         Stop the slice at plc
405         """
406         cred = self.credential
407         for hrn in self.aggregates.keys():
408             self.aggregates[hrn].startSlice(cred, slice_hrn)
409         return 1
410
411     def resetSlice(self, slice_hrn):
412         """
413         Reset the slice
414         """
415         # XX not yet implemented
416         return 1
417
418     def getPolicy(self):
419         """
420         Return the policy of this slice manager.
421         """
422     
423         return self.policy
424         
425     
426
427 ##############################
428 ## Server methods here for now
429 ##############################
430
431     def list_nodes(self, cred):
432         self.decode_authentication(cred, 'listnodes')
433         return self.getNodes()
434
435     def list_slices(self, cred, hrn):
436         self.decode_authentication(cred, 'listslices')
437         return self.getSlices(hrn)
438
439     def get_resources(self, cred, hrn):
440         self.decode_authentication(cred, 'listnodes')
441         return self.getResources(hrn)
442
443     def get_ticket(self, cred, hrn, rspec):
444         self.decode_authentication(cred, 'getticket')
445         return self.getTicket(hrn, rspec)
446
447     def get_policy(self, cred):
448         self.decode_authentication(cred, 'getpolicy')
449         return self.getPolicy()
450
451     def create_slice(self, cred, hrn, rspec):
452         self.decode_authentication(cred, 'creatslice')
453         return self.createSlice(hrn)
454
455     def delete_slice(self, cred, hrn):
456         self.decode_authentication(cred, 'deleteslice')
457         return self.deleteSlice(hrn)
458
459     def start_slice(self, cred, hrn):
460         self.decode_authentication(cred, 'startslice')
461         return self.startSlice(hrn)
462
463     def stop_slice(self, cred, hrn):
464         self.decode_authentication(cred, 'stopslice')
465         return self.stopSlice(hrn)
466
467     def reset_slice(self, cred, hrn):
468         self.decode_authentication(cred, 'resetslice')
469         return self.resetSlice(hrn)
470
471     def register_functions(self):
472         GeniServer.register_functions(self)
473
474         # Aggregate interface methods
475         self.server.register_function(self.list_nodes)
476         self.server.register_function(self.list_slices)
477         self.server.register_function(self.get_resources)
478         self.server.register_function(self.get_policy)
479         self.server.register_function(self.create_slice)
480         self.server.register_function(self.delete_slice)
481         self.server.register_function(self.start_slice)
482         self.server.register_function(self.stop_slice)
483         self.server.register_function(self.reset_slice)
484