Moved helper functions to vini_utils.py
[sfa.git] / sfa / rspecs / aggregates / rspec_manager_vini.py
1 from sfa.util.faults import *
2 from sfa.util.misc import *
3 from sfa.util.rspec import Rspec
4 from sfa.server.registry import Registries
5 from sfa.plc.nodes import *
6 from sfa.rspecs.aggregates.vini_utils import *
7 import sys
8
9 SFA_VINI_DEFAULT_RSPEC = '/etc/sfa/vini.rspec'
10 SFA_VINI_WHITELIST = '/etc/sfa/vini.whitelist'
11
12 """
13 Copied from create_slice_aggregate() in sfa.plc.slices
14 """
15 def create_slice_vini_aggregate(api, hrn, nodes):
16     # Get the slice record from geni
17     slice = {}
18     registries = Registries(api)
19     registry = registries[api.hrn]
20     credential = api.getCredential()
21     records = registry.resolve(credential, hrn)
22     for record in records:
23         if record.get_type() in ['slice']:
24             slice = record.as_dict()
25     if not slice:
26         raise RecordNotFound(hrn)   
27
28     # Make sure slice exists at plc, if it doesnt add it
29     slicename = hrn_to_pl_slicename(hrn)
30     slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids'])
31     if not slices:
32         parts = slicename.split("_")
33         login_base = parts[0]
34         # if site doesnt exist add it
35         sites = api.plshell.GetSites(api.plauth, [login_base])
36         if not sites:
37             authority = get_authority(hrn)
38             site_records = registry.resolve(credential, authority)
39             site_record = {}
40             if not site_records:
41                 raise RecordNotFound(authority)
42             site_record = site_records[0]
43             site = site_record.as_dict()
44                 
45             # add the site
46             site.pop('site_id')
47             site_id = api.plshell.AddSite(api.plauth, site)
48         else:
49             site = sites[0]
50             
51         slice_fields = {}
52         slice_keys = ['name', 'url', 'description']
53         for key in slice_keys:
54             if key in slice and slice[key]:
55                 slice_fields[key] = slice[key]  
56         api.plshell.AddSlice(api.plauth, slice_fields)
57         slice = slice_fields
58         slice['node_ids'] = 0
59     else:
60         slice = slices[0]    
61
62     # get the list of valid slice users from the registry and make 
63     # they are added to the slice 
64     researchers = record.get('researcher', [])
65     for researcher in researchers:
66         person_record = {}
67         person_records = registry.resolve(credential, researcher)
68         for record in person_records:
69             if record.get_type() in ['user']:
70                 person_record = record
71         if not person_record:
72             pass
73         person_dict = person_record.as_dict()
74         persons = api.plshell.GetPersons(api.plauth, [person_dict['email']],
75                                          ['person_id', 'key_ids'])
76
77         # Create the person record 
78         if not persons:
79             person_id=api.plshell.AddPerson(api.plauth, person_dict)
80
81             # The line below enables the user account on the remote aggregate
82             # soon after it is created.
83             # without this the user key is not transfered to the slice
84             # (as GetSlivers returns key of only enabled users),
85             # which prevents the user from login to the slice.
86             # We may do additional checks before enabling the user.
87
88             api.plshell.UpdatePerson(api.plauth, person_id, {'enabled' : True})
89             key_ids = []
90         else:
91             key_ids = persons[0]['key_ids']
92
93         api.plshell.AddPersonToSlice(api.plauth, person_dict['email'],
94                                      slicename)        
95
96         # Get this users local keys
97         keylist = api.plshell.GetKeys(api.plauth, key_ids, ['key'])
98         keys = [key['key'] for key in keylist]
99
100         # add keys that arent already there 
101         for personkey in person_dict['keys']:
102             if personkey not in keys:
103                 key = {'key_type': 'ssh', 'key': personkey}
104                 api.plshell.AddPersonKey(api.plauth, person_dict['email'], key)
105
106     # find out where this slice is currently running
107     nodelist = api.plshell.GetNodes(api.plauth, slice['node_ids'],
108                                     ['hostname'])
109     hostnames = [node['hostname'] for node in nodelist]
110
111     # remove nodes not in rspec
112     deleted_nodes = list(set(hostnames).difference(nodes))
113     # add nodes from rspec
114     added_nodes = list(set(nodes).difference(hostnames))
115
116     """
117     print >> sys.stderr, "Slice on nodes:"
118     for n in hostnames:
119         print >> sys.stderr, n
120     print >> sys.stderr, "Wants nodes:"
121     for n in nodes:
122         print >> sys.stderr, n
123     print >> sys.stderr, "Deleting nodes:"
124     for n in deleted_nodes:
125         print >> sys.stderr, n
126     print >> sys.stderr, "Adding nodes:"
127     for n in added_nodes:
128         print >> sys.stderr, n
129     """
130
131     api.plshell.AddSliceToNodes(api.plauth, slicename, added_nodes) 
132     api.plshell.DeleteSliceFromNodes(api.plauth, slicename, deleted_nodes)
133
134     return 1
135
136 def get_rspec(api, hrn):
137     # Get default rspec
138     default = Rspec()
139     default.parseFile(SFA_VINI_DEFAULT_RSPEC)
140     
141     if (hrn):
142         slicename = hrn_to_pl_slicename(hrn)
143         defaultrspec = default.toDict()
144         nodedict = get_nodedict(defaultrspec)
145
146         # call the default sfa.plc.nodes.get_rspec() method
147         nodes = Nodes(api)      
148         rspec = nodes.get_rspec(hrn)     
149
150         # Grab all the PLC info we'll need at once
151         slice = get_slice(api, slicename)
152         if slice:
153             nodes = get_nodes(api)
154             tags = get_slice_tags(api)
155
156             # Add the node tags from the Capacity statement to Node objects
157             for (k, v) in nodedict.iteritems():
158                 for id in nodes:
159                     if v == nodes[id].hostname:
160                         nodes[id].tag = k
161
162             endpoints = []
163             for node in slice.get_nodes(nodes):
164                 linktag = slice.get_tag('topo_rspec', tags, node)
165                 if linktag:
166                     l = eval(linktag.value)
167                     for (id, realip, bw, lvip, rvip, vnet) in l:
168                         endpoints.append((node.id, id, bw))
169             
170             if endpoints:
171                 linkspecs = []
172                 for (l, r, bw) in endpoints:
173                     if (r, l, bw) in endpoints:
174                         if l < r:
175                             edict = {}
176                             edict['endpoint'] = [nodes[l].tag, nodes[r].tag]
177                             edict['bw'] = [bw]
178                             linkspecs.append(edict)
179
180                 d = default.toDict()
181                 d['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec'] = linkspecs
182                 d['Rspec']['Request'][0]['NetSpec'][0]['name'] = hrn
183                 new = Rspec()
184                 new.parseDict(d)
185                 rspec = new.toxml()
186     else:
187         # Return canned response for now...
188         rspec = default.toxml()
189
190     return rspec
191
192
193 def create_slice(api, hrn, xml):
194     r = Rspec(xml)
195     rspec = r.toDict()
196
197     ### Check the whitelist
198     ### It consists of lines of the form: <slice hrn> <bw>
199     whitelist = {}
200     f = open(SFA_VINI_WHITELIST)
201     for line in f.readlines():
202         (slice, maxbw) = line.split()
203         whitelist[slice] = maxbw
204         
205     if hrn in whitelist:
206         maxbps = get_tc_rate(whitelist[hrn])
207     else:
208         raise PermissionError("%s not in VINI whitelist" % hrn)
209         
210     ### Check to make sure that the slice isn't requesting more
211     ### than its maximum bandwidth.
212     linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
213     if linkspecs:
214         for l in linkspecs:
215             bw = l['bw'][0]
216             bps = get_tc_rate(bw)
217             if bps <= 0:
218                 raise GeniInvalidArgument(bw, "BW")
219             if bps > maxbps:
220                 raise PermissionError(" %s requested %s but max BW is %s" % (hrn, bw, whitelist[hrn]))
221
222     # Check request against current allocations
223     # Request OK
224
225     nodes = rspec_to_nodeset(rspec)
226     create_slice_vini_aggregate(api, hrn, nodes)
227
228     # Add VINI-specific topology attributes to slice here
229     try:
230         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
231         if linkspecs:
232             slicename = hrn_to_pl_slicename(hrn)
233
234             # Grab all the PLC info we'll need at once
235             slice = get_slice(api, slicename)
236             if slice:
237                 nodes = get_nodes(api)
238                 tags = get_slice_tags(api)
239
240                 slice.update_tag('vini_topo', 'manual', tags)
241                 slice.assign_egre_key(tags)
242                 slice.turn_on_netns(tags)
243                 slice.add_cap_net_admin(tags)
244
245                 nodedict = {}
246                 for (k, v) in get_nodedict(rspec).iteritems():
247                     for id in nodes:
248                         if v == nodes[id].hostname:
249                             nodedict[k] = nodes[id]
250
251                 for l in linkspecs:
252                     n1 = nodedict[l['endpoint'][0]]
253                     n2 = nodedict[l['endpoint'][1]]
254                     bw = l['bw'][0]
255                     n1.add_link(n2, bw)
256                     n2.add_link(n1, bw)
257
258                 for node in slice.get_nodes(nodes):
259                     if node.links:
260                         topo_str = "%s" % node.links
261                         slice.update_tag('topo_rspec', topo_str, tags, node)
262
263                 # Update slice tags in database
264                 for i in tags:
265                     tag = tags[i]
266                     if tag.slice_id == slice.id:
267                         if tag.tagname == 'topo_rspec' and not tag.updated:
268                             tag.delete()
269                         tag.write(api)
270     except KeyError:
271         # Bad Rspec
272         pass
273     
274
275     return True
276
277 def get_nodedict(rspec):
278     nodedict = {}
279     try:    
280         sitespecs = rspec['Rspec']['Capacity'][0]['NetSpec'][0]['SiteSpec']
281         for s in sitespecs:
282             for node in s['NodeSpec']:
283                 nodedict[node['name']] = node['hostname'][0]
284     except KeyError:
285         pass
286
287     return nodedict
288
289         
290 def rspec_to_nodeset(rspec):
291     nodes = set()
292     try:
293         nodedict = get_nodedict(rspec)
294         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
295         for l in linkspecs:
296             for e in l['endpoint']:
297                 nodes.add(nodedict[e])
298     except KeyError:
299         # Bad Rspec
300         pass
301     
302     return nodes
303
304 def main():
305     r = Rspec()
306     r.parseFile(sys.argv[1])
307     rspec = r.toDict()
308     create_slice(None,'plc',rspec)
309     
310 if __name__ == "__main__":
311     main()