Added simple whitelist
[sfa.git] / sfa / rspecs / aggregates / rspec_manager_vini.py
1 from sfa.util.faults import *
2 from sfa.util.misc import *
3 from sfa.util.rspec import Rspec
4 from sfa.server.registry import Registries
5 from sfa.plc.nodes import *
6 import sys
7 import socket
8
9 SFA_VINI_DEFAULT_RSPEC = '/etc/sfa/vini.rspec'
10 SFA_VINI_WHITELIST = '/etc/sfa/vini.whitelist'
11
12 class Node:
13     def __init__(self, node):
14         self.id = node['node_id']
15         self.hostname = node['hostname']
16         self.shortname = self.hostname.replace('.vini-veritas.net', '')
17         self.site_id = node['site_id']
18         self.ipaddr = socket.gethostbyname(self.hostname)
19         self.links = []
20
21     def get_link_id(self, remote):
22         if self.id < remote.id:
23             link = (self.id<<7) + remote.id
24         else:
25             link = (remote.id<<7) + self.id
26         return link
27         
28     def get_iface_id(self, remote):
29         if self.id < remote.id:
30             iface = 1
31         else:
32             iface = 2
33         return iface
34     
35     def get_virt_ip(self, remote):
36         link = self.get_link_id(remote)
37         iface = self.get_iface_id(remote)
38         first = link >> 6
39         second = ((link & 0x3f)<<2) + iface
40         return "192.168.%d.%d" % (first, second)
41
42     def get_virt_net(self, remote):
43         link = self.get_link_id(remote)
44         first = link >> 6
45         second = (link & 0x3f)<<2
46         return "192.168.%d.%d/30" % (first, second)
47         
48     def get_site(self, sites):
49         return sites[self.site_id]
50             
51     def adjacent_nodes(self, sites, nodes, node_ids):
52         mysite = self.get_site(sites)
53         adj_ids = mysite.adj_node_ids.intersection(node_ids)
54         adj_nodes = []
55         for id in adj_ids:
56             adj_nodes.append(nodes[id])
57         return adj_nodes
58     
59     def init_links(self):
60         self.links = []
61         
62     def add_link(self, remote, bw):
63         my_ip = self.get_virt_ip(remote)
64         remote_ip = remote.get_virt_ip(self)
65         net = self.get_virt_net(remote)
66         link = remote.id, remote.ipaddr, bw, my_ip, remote_ip, net
67         self.links.append(link)
68
69         
70 class Site:
71     def __init__(self, site):
72         self.id = site['site_id']
73         self.node_ids = site['node_ids']
74         self.adj_site_ids = set()
75         self.adj_node_ids = set()
76
77     def get_sitenodes(self, nodes):
78         n = []
79         for i in self.node_ids:
80             n.append(nodes[i])
81         return n
82     
83     def add_adjacency(self, site):
84         self.adj_site_ids.add(site.id)
85         for n in site.node_ids:
86             self.adj_node_ids.add(n)
87         
88     
89 class Slice:
90     def __init__(self, slice):
91         self.id = slice['slice_id']
92         self.name = slice['name']
93         self.node_ids = set(slice['node_ids'])
94         self.slice_tag_ids = slice['slice_tag_ids']
95     
96     def get_tag(self, tagname, slicetags, node = None):
97         for i in self.slice_tag_ids:
98             tag = slicetags[i]
99             if tag.tagname == tagname:
100                 if (not node) or (node.id == tag.node_id):
101                     return tag
102         else:
103             return None
104         
105     def get_nodes(self, nodes):
106         n = []
107         for id in self.node_ids:
108             n.append(nodes[id])
109         return n
110              
111     
112     # Add a new slice tag   
113     def add_tag(self, tagname, value, slicetags, node = None):
114         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
115         if node:
116             record['node_id'] = node.id
117         else:
118             record['node_id'] = None
119         tag = Slicetag(record)
120         slicetags[tag.id] = tag
121         self.slice_tag_ids.append(tag.id)
122         tag.changed = True       
123         tag.updated = True
124         return tag
125     
126     # Update a slice tag if it exists, else add it             
127     def update_tag(self, tagname, value, slicetags, node = None):
128         tag = self.get_tag(tagname, slicetags, node)
129         if tag and tag.value == value:
130             value = "no change"
131         elif tag:
132             tag.value = value
133             tag.changed = True
134         else:
135             tag = self.add_tag(tagname, value, slicetags, node)
136         tag.updated = True
137             
138     def assign_egre_key(self, slicetags):
139         if not self.get_tag('egre_key', slicetags):
140             try:
141                 key = free_egre_key(slicetags)
142                 self.update_tag('egre_key', key, slicetags)
143             except:
144                 # Should handle this case...
145                 pass
146         return
147             
148     def turn_on_netns(self, slicetags):
149         tag = self.get_tag('netns', slicetags)
150         if (not tag) or (tag.value != '1'):
151             self.update_tag('netns', '1', slicetags)
152         return
153    
154     def turn_off_netns(self, slicetags):
155         tag = self.get_tag('netns', slicetags)
156         if tag and (tag.value != '0'):
157             tag.delete()
158         return
159     
160     def add_cap_net_admin(self, slicetags):
161         tag = self.get_tag('capabilities', slicetags)
162         if tag:
163             caps = tag.value.split(',')
164             for cap in caps:
165                 if cap == "CAP_NET_ADMIN":
166                     return
167             else:
168                 newcaps = "CAP_NET_ADMIN," + tag.value
169                 self.update_tag('capabilities', newcaps, slicetags)
170         else:
171             self.add_tag('capabilities', 'CAP_NET_ADMIN', slicetags)
172         return
173     
174     def remove_cap_net_admin(self, slicetags):
175         tag = self.get_tag('capabilities', slicetags)
176         if tag:
177             caps = tag.value.split(',')
178             newcaps = []
179             for cap in caps:
180                 if cap != "CAP_NET_ADMIN":
181                     newcaps.append(cap)
182             if newcaps:
183                 value = ','.join(newcaps)
184                 self.update_tag('capabilities', value, slicetags)
185             else:
186                 tag.delete()
187         return
188
189     # Update the vsys/setup-link and vsys/setup-nat slice tags.
190     def add_vsys_tags(self, slicetags):
191         link = nat = False
192         for i in self.slice_tag_ids:
193             tag = slicetags[i]
194             if tag.tagname == 'vsys':
195                 if tag.value == 'setup-link':
196                     link = True
197                 elif tag.value == 'setup-nat':
198                     nat = True
199         if not link:
200             self.add_tag('vsys', 'setup-link', slicetags)
201         if not nat:
202             self.add_tag('vsys', 'setup-nat', slicetags)
203         return
204
205
206 class Slicetag:
207     newid = -1 
208     def __init__(self, tag):
209         self.id = tag['slice_tag_id']
210         if not self.id:
211             # Make one up for the time being...
212             self.id = Slicetag.newid
213             Slicetag.newid -= 1
214         self.slice_id = tag['slice_id']
215         self.tagname = tag['tagname']
216         self.value = tag['value']
217         self.node_id = tag['node_id']
218         self.updated = False
219         self.changed = False
220         self.deleted = False
221     
222     # Mark a tag as deleted
223     def delete(self):
224         self.deleted = True
225         self.updated = True
226     
227     def write(self, api):
228         if self.changed:
229             if int(self.id) > 0:
230                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
231             else:
232                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
233                                         self.tagname, self.value, self.node_id)
234         elif self.deleted and int(self.id) > 0:
235             api.plshell.DeleteSliceTag(api.plauth, self.id)
236
237
238 """
239 Create a dictionary of site objects keyed by site ID
240 """
241 def get_sites():
242     tmp = []
243     for site in GetSites():
244         t = site['site_id'], Site(site)
245         tmp.append(t)
246     return dict(tmp)
247
248
249 """
250 Create a dictionary of node objects keyed by node ID
251 """
252 def get_nodes(api):
253     tmp = []
254     for node in api.plshell.GetNodes(api.plauth):
255         t = node['node_id'], Node(node)
256         tmp.append(t)
257     return dict(tmp)
258
259 """
260 Create a dictionary of slice objects keyed by slice ID
261 """
262 def get_slice(api, slicename):
263     slice = api.plshell.GetSlices(api.plauth, [slicename])
264     if slice:
265         return Slice(slice[0])
266     else:
267         return None
268
269 """
270 Create a dictionary of slicetag objects keyed by slice tag ID
271 """
272 def get_slice_tags(api):
273     tmp = []
274     for tag in api.plshell.GetSliceTags(api.plauth):
275         t = tag['slice_tag_id'], Slicetag(tag)
276         tmp.append(t)
277     return dict(tmp)
278     
279 """
280 Find a free EGRE key
281 """
282 def free_egre_key(slicetags):
283     used = set()
284     for i in slicetags:
285         tag = slicetags[i]
286         if tag.tagname == 'egre_key':
287             used.add(int(tag.value))
288                 
289     for i in range(1, 256):
290         if i not in used:
291             key = i
292             break
293     else:
294         raise KeyError("No more EGRE keys available")
295         
296     return "%s" % key
297    
298
299 """
300 Copied from create_slice_aggregate() in sfa.plc.slices
301 """
302 def create_slice_vini_aggregate(api, hrn, nodes):
303     # Get the slice record from geni
304     slice = {}
305     registries = Registries(api)
306     registry = registries[api.hrn]
307     credential = api.getCredential()
308     records = registry.resolve(credential, hrn)
309     for record in records:
310         if record.get_type() in ['slice']:
311             slice = record.as_dict()
312     if not slice:
313         raise RecordNotFound(hrn)   
314
315     # Make sure slice exists at plc, if it doesnt add it
316     slicename = hrn_to_pl_slicename(hrn)
317     slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids'])
318     if not slices:
319         parts = slicename.split("_")
320         login_base = parts[0]
321         # if site doesnt exist add it
322         sites = api.plshell.GetSites(api.plauth, [login_base])
323         if not sites:
324             authority = get_authority(hrn)
325             site_records = registry.resolve(credential, authority)
326             site_record = {}
327             if not site_records:
328                 raise RecordNotFound(authority)
329             site_record = site_records[0]
330             site = site_record.as_dict()
331                 
332             # add the site
333             site.pop('site_id')
334             site_id = api.plshell.AddSite(api.plauth, site)
335         else:
336             site = sites[0]
337             
338         slice_fields = {}
339         slice_keys = ['name', 'url', 'description']
340         for key in slice_keys:
341             if key in slice and slice[key]:
342                 slice_fields[key] = slice[key]  
343         api.plshell.AddSlice(api.plauth, slice_fields)
344         slice = slice_fields
345         slice['node_ids'] = 0
346     else:
347         slice = slices[0]    
348
349     # get the list of valid slice users from the registry and make 
350     # they are added to the slice 
351     researchers = record.get('researcher', [])
352     for researcher in researchers:
353         person_record = {}
354         person_records = registry.resolve(credential, researcher)
355         for record in person_records:
356             if record.get_type() in ['user']:
357                 person_record = record
358         if not person_record:
359             pass
360         person_dict = person_record.as_dict()
361         persons = api.plshell.GetPersons(api.plauth, [person_dict['email']],
362                                          ['person_id', 'key_ids'])
363
364         # Create the person record 
365         if not persons:
366             person_id=api.plshell.AddPerson(api.plauth, person_dict)
367
368             # The line below enables the user account on the remote aggregate
369             # soon after it is created.
370             # without this the user key is not transfered to the slice
371             # (as GetSlivers returns key of only enabled users),
372             # which prevents the user from login to the slice.
373             # We may do additional checks before enabling the user.
374
375             api.plshell.UpdatePerson(api.plauth, person_id, {'enabled' : True})
376             key_ids = []
377         else:
378             key_ids = persons[0]['key_ids']
379
380         api.plshell.AddPersonToSlice(api.plauth, person_dict['email'],
381                                      slicename)        
382
383         # Get this users local keys
384         keylist = api.plshell.GetKeys(api.plauth, key_ids, ['key'])
385         keys = [key['key'] for key in keylist]
386
387         # add keys that arent already there 
388         for personkey in person_dict['keys']:
389             if personkey not in keys:
390                 key = {'key_type': 'ssh', 'key': personkey}
391                 api.plshell.AddPersonKey(api.plauth, person_dict['email'], key)
392
393     # find out where this slice is currently running
394     nodelist = api.plshell.GetNodes(api.plauth, slice['node_ids'],
395                                     ['hostname'])
396     hostnames = [node['hostname'] for node in nodelist]
397
398     # remove nodes not in rspec
399     deleted_nodes = list(set(hostnames).difference(nodes))
400     # add nodes from rspec
401     added_nodes = list(set(nodes).difference(hostnames))
402
403     """
404     print >> sys.stderr, "Slice on nodes:"
405     for n in hostnames:
406         print >> sys.stderr, n
407     print >> sys.stderr, "Wants nodes:"
408     for n in nodes:
409         print >> sys.stderr, n
410     print >> sys.stderr, "Deleting nodes:"
411     for n in deleted_nodes:
412         print >> sys.stderr, n
413     print >> sys.stderr, "Adding nodes:"
414     for n in added_nodes:
415         print >> sys.stderr, n
416     """
417
418     api.plshell.AddSliceToNodes(api.plauth, slicename, added_nodes) 
419     api.plshell.DeleteSliceFromNodes(api.plauth, slicename, deleted_nodes)
420
421     return 1
422
423 def get_rspec(api, hrn):
424     # Get default rspec
425     default = Rspec()
426     default.parseFile(SFA_VINI_DEFAULT_RSPEC)
427     
428     if (hrn):
429         slicename = hrn_to_pl_slicename(hrn)
430         defaultrspec = default.toDict()
431         nodedict = get_nodedict(defaultrspec)
432
433         # call the default sfa.plc.nodes.get_rspec() method
434         nodes = Nodes(api)      
435         rspec = nodes.get_rspec(hrn)     
436
437         # Grab all the PLC info we'll need at once
438         slice = get_slice(api, slicename)
439         if slice:
440             nodes = get_nodes(api)
441             tags = get_slice_tags(api)
442
443             # Add the node tags from the Capacity statement to Node objects
444             for (k, v) in nodedict.iteritems():
445                 for id in nodes:
446                     if v == nodes[id].hostname:
447                         nodes[id].tag = k
448
449             endpoints = []
450             for node in slice.get_nodes(nodes):
451                 linktag = slice.get_tag('topo_rspec', tags, node)
452                 if linktag:
453                     l = eval(linktag.value)
454                     for (id, realip, bw, lvip, rvip, vnet) in l:
455                         endpoints.append((node.id, id, bw))
456             
457             if endpoints:
458                 linkspecs = []
459                 for (l, r, bw) in endpoints:
460                     if (r, l, bw) in endpoints:
461                         if l < r:
462                             edict = {}
463                             edict['endpoint'] = [nodes[l].tag, nodes[r].tag]
464                             edict['bw'] = [bw]
465                             linkspecs.append(edict)
466
467                 d = default.toDict()
468                 d['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec'] = linkspecs
469                 d['Rspec']['Request'][0]['NetSpec'][0]['name'] = hrn
470                 new = Rspec()
471                 new.parseDict(d)
472                 rspec = new.toxml()
473     else:
474         # Return canned response for now...
475         rspec = default.toxml()
476
477     return rspec
478
479
480 def create_slice(api, hrn, xml):
481     r = Rspec(xml)
482     rspec = r.toDict()
483
484     ### Check the whitelist
485     ### It consists of lines of the form: <slice hrn> <bw>
486     whitelist = {}
487     f = open(SFA_VINI_WHITELIST)
488     for line in f.readlines():
489         (slice, bw) = line.split()
490         whitelist[slice] = bw
491     if not hrn in whitelist:
492         raise PermissionError("%s not in VINI whitelist" % hrn)
493
494     # Check request against current allocations
495     # Request OK
496
497     nodes = rspec_to_nodeset(rspec)
498     create_slice_vini_aggregate(api, hrn, nodes)
499
500     # Add VINI-specific topology attributes to slice here
501     try:
502         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
503         if linkspecs:
504             slicename = hrn_to_pl_slicename(hrn)
505
506             # Grab all the PLC info we'll need at once
507             slice = get_slice(api, slicename)
508             if slice:
509                 nodes = get_nodes(api)
510                 tags = get_slice_tags(api)
511
512                 slice.update_tag('vini_topo', 'manual', tags)
513                 slice.assign_egre_key(tags)
514                 slice.turn_on_netns(tags)
515                 slice.add_cap_net_admin(tags)
516
517                 nodedict = {}
518                 for (k, v) in get_nodedict(rspec).iteritems():
519                     for id in nodes:
520                         if v == nodes[id].hostname:
521                             nodedict[k] = nodes[id]
522
523                 for l in linkspecs:
524                     n1 = nodedict[l['endpoint'][0]]
525                     n2 = nodedict[l['endpoint'][1]]
526                     bw = l['bw'][0]
527                     n1.add_link(n2, bw)
528                     n2.add_link(n1, bw)
529
530                 for node in slice.get_nodes(nodes):
531                     if node.links:
532                         topo_str = "%s" % node.links
533                         slice.update_tag('topo_rspec', topo_str, tags, node)
534
535                 # Update slice tags in database
536                 for i in tags:
537                     tag = tags[i]
538                     if tag.slice_id == slice.id:
539                         if tag.tagname == 'topo_rspec' and not tag.updated:
540                             tag.delete()
541                         tag.write(api)
542     except KeyError:
543         # Bad Rspec
544         pass
545     
546
547     return True
548
549 def get_nodedict(rspec):
550     nodedict = {}
551     try:    
552         sitespecs = rspec['Rspec']['Capacity'][0]['NetSpec'][0]['SiteSpec']
553         for s in sitespecs:
554             for node in s['NodeSpec']:
555                 nodedict[node['name']] = node['hostname'][0]
556     except KeyError:
557         pass
558
559     return nodedict
560
561         
562 def rspec_to_nodeset(rspec):
563     nodes = set()
564     try:
565         nodedict = get_nodedict(rspec)
566         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
567         for l in linkspecs:
568             for e in l['endpoint']:
569                 nodes.add(nodedict[e])
570     except KeyError:
571         # Bad Rspec
572         pass
573     
574     return nodes
575
576 def main():
577     r = Rspec()
578     r.parseFile(sys.argv[1])
579     rspec = r.toDict()
580     create_slice(None,'plc',rspec)
581     
582 if __name__ == "__main__":
583     main()