Make sure slice doesn't request more than allowed bandwidth
[sfa.git] / sfa / rspecs / aggregates / rspec_manager_vini.py
1 from sfa.util.faults import *
2 from sfa.util.misc import *
3 from sfa.util.rspec import Rspec
4 from sfa.server.registry import Registries
5 from sfa.plc.nodes import *
6 import sys
7 import socket
8 import re
9
10 SFA_VINI_DEFAULT_RSPEC = '/etc/sfa/vini.rspec'
11 SFA_VINI_WHITELIST = '/etc/sfa/vini.whitelist'
12
13 # Taken from bwlimit.py
14 #
15 # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
16 # warned that older versions of tc interpret "kbps", "mbps", "mbit",
17 # and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
18 # "kibit" and that if an older version is installed, all rates will
19 # be off by a small fraction.
20 suffixes = {
21     "":         1,
22     "bit":      1,
23     "kibit":    1024,
24     "kbit":     1000,
25     "mibit":    1024*1024,
26     "mbit":     1000000,
27     "gibit":    1024*1024*1024,
28     "gbit":     1000000000,
29     "tibit":    1024*1024*1024*1024,
30     "tbit":     1000000000000,
31     "bps":      8,
32     "kibps":    8*1024,
33     "kbps":     8000,
34     "mibps":    8*1024*1024,
35     "mbps":     8000000,
36     "gibps":    8*1024*1024*1024,
37     "gbps":     8000000000,
38     "tibps":    8*1024*1024*1024*1024,
39     "tbps":     8000000000000
40 }
41
42
43 def get_tc_rate(s):
44     """
45     Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
46     """
47
48     if type(s) == int:
49         return s
50     m = re.match(r"([0-9.]+)(\D*)", s)
51     if m is None:
52         return -1
53     suffix = m.group(2).lower()
54     if suffixes.has_key(suffix):
55         return int(float(m.group(1)) * suffixes[suffix])
56     else:
57         return -1
58
59
60 class Node:
61     def __init__(self, node):
62         self.id = node['node_id']
63         self.hostname = node['hostname']
64         self.shortname = self.hostname.replace('.vini-veritas.net', '')
65         self.site_id = node['site_id']
66         self.ipaddr = socket.gethostbyname(self.hostname)
67         self.links = []
68
69     def get_link_id(self, remote):
70         if self.id < remote.id:
71             link = (self.id<<7) + remote.id
72         else:
73             link = (remote.id<<7) + self.id
74         return link
75         
76     def get_iface_id(self, remote):
77         if self.id < remote.id:
78             iface = 1
79         else:
80             iface = 2
81         return iface
82     
83     def get_virt_ip(self, remote):
84         link = self.get_link_id(remote)
85         iface = self.get_iface_id(remote)
86         first = link >> 6
87         second = ((link & 0x3f)<<2) + iface
88         return "192.168.%d.%d" % (first, second)
89
90     def get_virt_net(self, remote):
91         link = self.get_link_id(remote)
92         first = link >> 6
93         second = (link & 0x3f)<<2
94         return "192.168.%d.%d/30" % (first, second)
95         
96     def get_site(self, sites):
97         return sites[self.site_id]
98             
99     def adjacent_nodes(self, sites, nodes, node_ids):
100         mysite = self.get_site(sites)
101         adj_ids = mysite.adj_node_ids.intersection(node_ids)
102         adj_nodes = []
103         for id in adj_ids:
104             adj_nodes.append(nodes[id])
105         return adj_nodes
106     
107     def init_links(self):
108         self.links = []
109         
110     def add_link(self, remote, bw):
111         my_ip = self.get_virt_ip(remote)
112         remote_ip = remote.get_virt_ip(self)
113         net = self.get_virt_net(remote)
114         link = remote.id, remote.ipaddr, bw, my_ip, remote_ip, net
115         self.links.append(link)
116
117         
118 class Site:
119     def __init__(self, site):
120         self.id = site['site_id']
121         self.node_ids = site['node_ids']
122         self.adj_site_ids = set()
123         self.adj_node_ids = set()
124
125     def get_sitenodes(self, nodes):
126         n = []
127         for i in self.node_ids:
128             n.append(nodes[i])
129         return n
130     
131     def add_adjacency(self, site):
132         self.adj_site_ids.add(site.id)
133         for n in site.node_ids:
134             self.adj_node_ids.add(n)
135         
136     
137 class Slice:
138     def __init__(self, slice):
139         self.id = slice['slice_id']
140         self.name = slice['name']
141         self.node_ids = set(slice['node_ids'])
142         self.slice_tag_ids = slice['slice_tag_ids']
143     
144     def get_tag(self, tagname, slicetags, node = None):
145         for i in self.slice_tag_ids:
146             tag = slicetags[i]
147             if tag.tagname == tagname:
148                 if (not node) or (node.id == tag.node_id):
149                     return tag
150         else:
151             return None
152         
153     def get_nodes(self, nodes):
154         n = []
155         for id in self.node_ids:
156             n.append(nodes[id])
157         return n
158              
159     
160     # Add a new slice tag   
161     def add_tag(self, tagname, value, slicetags, node = None):
162         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
163         if node:
164             record['node_id'] = node.id
165         else:
166             record['node_id'] = None
167         tag = Slicetag(record)
168         slicetags[tag.id] = tag
169         self.slice_tag_ids.append(tag.id)
170         tag.changed = True       
171         tag.updated = True
172         return tag
173     
174     # Update a slice tag if it exists, else add it             
175     def update_tag(self, tagname, value, slicetags, node = None):
176         tag = self.get_tag(tagname, slicetags, node)
177         if tag and tag.value == value:
178             value = "no change"
179         elif tag:
180             tag.value = value
181             tag.changed = True
182         else:
183             tag = self.add_tag(tagname, value, slicetags, node)
184         tag.updated = True
185             
186     def assign_egre_key(self, slicetags):
187         if not self.get_tag('egre_key', slicetags):
188             try:
189                 key = free_egre_key(slicetags)
190                 self.update_tag('egre_key', key, slicetags)
191             except:
192                 # Should handle this case...
193                 pass
194         return
195             
196     def turn_on_netns(self, slicetags):
197         tag = self.get_tag('netns', slicetags)
198         if (not tag) or (tag.value != '1'):
199             self.update_tag('netns', '1', slicetags)
200         return
201    
202     def turn_off_netns(self, slicetags):
203         tag = self.get_tag('netns', slicetags)
204         if tag and (tag.value != '0'):
205             tag.delete()
206         return
207     
208     def add_cap_net_admin(self, slicetags):
209         tag = self.get_tag('capabilities', slicetags)
210         if tag:
211             caps = tag.value.split(',')
212             for cap in caps:
213                 if cap == "CAP_NET_ADMIN":
214                     return
215             else:
216                 newcaps = "CAP_NET_ADMIN," + tag.value
217                 self.update_tag('capabilities', newcaps, slicetags)
218         else:
219             self.add_tag('capabilities', 'CAP_NET_ADMIN', slicetags)
220         return
221     
222     def remove_cap_net_admin(self, slicetags):
223         tag = self.get_tag('capabilities', slicetags)
224         if tag:
225             caps = tag.value.split(',')
226             newcaps = []
227             for cap in caps:
228                 if cap != "CAP_NET_ADMIN":
229                     newcaps.append(cap)
230             if newcaps:
231                 value = ','.join(newcaps)
232                 self.update_tag('capabilities', value, slicetags)
233             else:
234                 tag.delete()
235         return
236
237     # Update the vsys/setup-link and vsys/setup-nat slice tags.
238     def add_vsys_tags(self, slicetags):
239         link = nat = False
240         for i in self.slice_tag_ids:
241             tag = slicetags[i]
242             if tag.tagname == 'vsys':
243                 if tag.value == 'setup-link':
244                     link = True
245                 elif tag.value == 'setup-nat':
246                     nat = True
247         if not link:
248             self.add_tag('vsys', 'setup-link', slicetags)
249         if not nat:
250             self.add_tag('vsys', 'setup-nat', slicetags)
251         return
252
253
254 class Slicetag:
255     newid = -1 
256     def __init__(self, tag):
257         self.id = tag['slice_tag_id']
258         if not self.id:
259             # Make one up for the time being...
260             self.id = Slicetag.newid
261             Slicetag.newid -= 1
262         self.slice_id = tag['slice_id']
263         self.tagname = tag['tagname']
264         self.value = tag['value']
265         self.node_id = tag['node_id']
266         self.updated = False
267         self.changed = False
268         self.deleted = False
269     
270     # Mark a tag as deleted
271     def delete(self):
272         self.deleted = True
273         self.updated = True
274     
275     def write(self, api):
276         if self.changed:
277             if int(self.id) > 0:
278                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
279             else:
280                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
281                                         self.tagname, self.value, self.node_id)
282         elif self.deleted and int(self.id) > 0:
283             api.plshell.DeleteSliceTag(api.plauth, self.id)
284
285
286 """
287 Create a dictionary of site objects keyed by site ID
288 """
289 def get_sites():
290     tmp = []
291     for site in GetSites():
292         t = site['site_id'], Site(site)
293         tmp.append(t)
294     return dict(tmp)
295
296
297 """
298 Create a dictionary of node objects keyed by node ID
299 """
300 def get_nodes(api):
301     tmp = []
302     for node in api.plshell.GetNodes(api.plauth):
303         t = node['node_id'], Node(node)
304         tmp.append(t)
305     return dict(tmp)
306
307 """
308 Create a dictionary of slice objects keyed by slice ID
309 """
310 def get_slice(api, slicename):
311     slice = api.plshell.GetSlices(api.plauth, [slicename])
312     if slice:
313         return Slice(slice[0])
314     else:
315         return None
316
317 """
318 Create a dictionary of slicetag objects keyed by slice tag ID
319 """
320 def get_slice_tags(api):
321     tmp = []
322     for tag in api.plshell.GetSliceTags(api.plauth):
323         t = tag['slice_tag_id'], Slicetag(tag)
324         tmp.append(t)
325     return dict(tmp)
326     
327 """
328 Find a free EGRE key
329 """
330 def free_egre_key(slicetags):
331     used = set()
332     for i in slicetags:
333         tag = slicetags[i]
334         if tag.tagname == 'egre_key':
335             used.add(int(tag.value))
336                 
337     for i in range(1, 256):
338         if i not in used:
339             key = i
340             break
341     else:
342         raise KeyError("No more EGRE keys available")
343         
344     return "%s" % key
345    
346
347 """
348 Copied from create_slice_aggregate() in sfa.plc.slices
349 """
350 def create_slice_vini_aggregate(api, hrn, nodes):
351     # Get the slice record from geni
352     slice = {}
353     registries = Registries(api)
354     registry = registries[api.hrn]
355     credential = api.getCredential()
356     records = registry.resolve(credential, hrn)
357     for record in records:
358         if record.get_type() in ['slice']:
359             slice = record.as_dict()
360     if not slice:
361         raise RecordNotFound(hrn)   
362
363     # Make sure slice exists at plc, if it doesnt add it
364     slicename = hrn_to_pl_slicename(hrn)
365     slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids'])
366     if not slices:
367         parts = slicename.split("_")
368         login_base = parts[0]
369         # if site doesnt exist add it
370         sites = api.plshell.GetSites(api.plauth, [login_base])
371         if not sites:
372             authority = get_authority(hrn)
373             site_records = registry.resolve(credential, authority)
374             site_record = {}
375             if not site_records:
376                 raise RecordNotFound(authority)
377             site_record = site_records[0]
378             site = site_record.as_dict()
379                 
380             # add the site
381             site.pop('site_id')
382             site_id = api.plshell.AddSite(api.plauth, site)
383         else:
384             site = sites[0]
385             
386         slice_fields = {}
387         slice_keys = ['name', 'url', 'description']
388         for key in slice_keys:
389             if key in slice and slice[key]:
390                 slice_fields[key] = slice[key]  
391         api.plshell.AddSlice(api.plauth, slice_fields)
392         slice = slice_fields
393         slice['node_ids'] = 0
394     else:
395         slice = slices[0]    
396
397     # get the list of valid slice users from the registry and make 
398     # they are added to the slice 
399     researchers = record.get('researcher', [])
400     for researcher in researchers:
401         person_record = {}
402         person_records = registry.resolve(credential, researcher)
403         for record in person_records:
404             if record.get_type() in ['user']:
405                 person_record = record
406         if not person_record:
407             pass
408         person_dict = person_record.as_dict()
409         persons = api.plshell.GetPersons(api.plauth, [person_dict['email']],
410                                          ['person_id', 'key_ids'])
411
412         # Create the person record 
413         if not persons:
414             person_id=api.plshell.AddPerson(api.plauth, person_dict)
415
416             # The line below enables the user account on the remote aggregate
417             # soon after it is created.
418             # without this the user key is not transfered to the slice
419             # (as GetSlivers returns key of only enabled users),
420             # which prevents the user from login to the slice.
421             # We may do additional checks before enabling the user.
422
423             api.plshell.UpdatePerson(api.plauth, person_id, {'enabled' : True})
424             key_ids = []
425         else:
426             key_ids = persons[0]['key_ids']
427
428         api.plshell.AddPersonToSlice(api.plauth, person_dict['email'],
429                                      slicename)        
430
431         # Get this users local keys
432         keylist = api.plshell.GetKeys(api.plauth, key_ids, ['key'])
433         keys = [key['key'] for key in keylist]
434
435         # add keys that arent already there 
436         for personkey in person_dict['keys']:
437             if personkey not in keys:
438                 key = {'key_type': 'ssh', 'key': personkey}
439                 api.plshell.AddPersonKey(api.plauth, person_dict['email'], key)
440
441     # find out where this slice is currently running
442     nodelist = api.plshell.GetNodes(api.plauth, slice['node_ids'],
443                                     ['hostname'])
444     hostnames = [node['hostname'] for node in nodelist]
445
446     # remove nodes not in rspec
447     deleted_nodes = list(set(hostnames).difference(nodes))
448     # add nodes from rspec
449     added_nodes = list(set(nodes).difference(hostnames))
450
451     """
452     print >> sys.stderr, "Slice on nodes:"
453     for n in hostnames:
454         print >> sys.stderr, n
455     print >> sys.stderr, "Wants nodes:"
456     for n in nodes:
457         print >> sys.stderr, n
458     print >> sys.stderr, "Deleting nodes:"
459     for n in deleted_nodes:
460         print >> sys.stderr, n
461     print >> sys.stderr, "Adding nodes:"
462     for n in added_nodes:
463         print >> sys.stderr, n
464     """
465
466     api.plshell.AddSliceToNodes(api.plauth, slicename, added_nodes) 
467     api.plshell.DeleteSliceFromNodes(api.plauth, slicename, deleted_nodes)
468
469     return 1
470
471 def get_rspec(api, hrn):
472     # Get default rspec
473     default = Rspec()
474     default.parseFile(SFA_VINI_DEFAULT_RSPEC)
475     
476     if (hrn):
477         slicename = hrn_to_pl_slicename(hrn)
478         defaultrspec = default.toDict()
479         nodedict = get_nodedict(defaultrspec)
480
481         # call the default sfa.plc.nodes.get_rspec() method
482         nodes = Nodes(api)      
483         rspec = nodes.get_rspec(hrn)     
484
485         # Grab all the PLC info we'll need at once
486         slice = get_slice(api, slicename)
487         if slice:
488             nodes = get_nodes(api)
489             tags = get_slice_tags(api)
490
491             # Add the node tags from the Capacity statement to Node objects
492             for (k, v) in nodedict.iteritems():
493                 for id in nodes:
494                     if v == nodes[id].hostname:
495                         nodes[id].tag = k
496
497             endpoints = []
498             for node in slice.get_nodes(nodes):
499                 linktag = slice.get_tag('topo_rspec', tags, node)
500                 if linktag:
501                     l = eval(linktag.value)
502                     for (id, realip, bw, lvip, rvip, vnet) in l:
503                         endpoints.append((node.id, id, bw))
504             
505             if endpoints:
506                 linkspecs = []
507                 for (l, r, bw) in endpoints:
508                     if (r, l, bw) in endpoints:
509                         if l < r:
510                             edict = {}
511                             edict['endpoint'] = [nodes[l].tag, nodes[r].tag]
512                             edict['bw'] = [bw]
513                             linkspecs.append(edict)
514
515                 d = default.toDict()
516                 d['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec'] = linkspecs
517                 d['Rspec']['Request'][0]['NetSpec'][0]['name'] = hrn
518                 new = Rspec()
519                 new.parseDict(d)
520                 rspec = new.toxml()
521     else:
522         # Return canned response for now...
523         rspec = default.toxml()
524
525     return rspec
526
527
528 def create_slice(api, hrn, xml):
529     r = Rspec(xml)
530     rspec = r.toDict()
531
532     ### Check the whitelist
533     ### It consists of lines of the form: <slice hrn> <bw>
534     whitelist = {}
535     f = open(SFA_VINI_WHITELIST)
536     for line in f.readlines():
537         (slice, maxbw) = line.split()
538         whitelist[slice] = maxbw
539         
540     if hrn in whitelist:
541         maxbps = get_tc_rate(whitelist[hrn])
542     else:
543         raise PermissionError("%s not in VINI whitelist" % hrn)
544         
545     ### Check to make sure that the slice isn't requesting more
546     ### than its maximum bandwidth.
547     linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
548     if linkspecs:
549         for l in linkspecs:
550             bw = l['bw'][0]
551             bps = get_tc_rate(bw)
552             if bps <= 0:
553                 raise GeniInvalidArgument(bw, "BW")
554             if bps > maxbps:
555                 raise PermissionError(" %s requested %s but max BW is %s" % (hrn, bw, whitelist[hrn]))
556
557     # Check request against current allocations
558     # Request OK
559
560     nodes = rspec_to_nodeset(rspec)
561     create_slice_vini_aggregate(api, hrn, nodes)
562
563     # Add VINI-specific topology attributes to slice here
564     try:
565         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
566         if linkspecs:
567             slicename = hrn_to_pl_slicename(hrn)
568
569             # Grab all the PLC info we'll need at once
570             slice = get_slice(api, slicename)
571             if slice:
572                 nodes = get_nodes(api)
573                 tags = get_slice_tags(api)
574
575                 slice.update_tag('vini_topo', 'manual', tags)
576                 slice.assign_egre_key(tags)
577                 slice.turn_on_netns(tags)
578                 slice.add_cap_net_admin(tags)
579
580                 nodedict = {}
581                 for (k, v) in get_nodedict(rspec).iteritems():
582                     for id in nodes:
583                         if v == nodes[id].hostname:
584                             nodedict[k] = nodes[id]
585
586                 for l in linkspecs:
587                     n1 = nodedict[l['endpoint'][0]]
588                     n2 = nodedict[l['endpoint'][1]]
589                     bw = l['bw'][0]
590                     n1.add_link(n2, bw)
591                     n2.add_link(n1, bw)
592
593                 for node in slice.get_nodes(nodes):
594                     if node.links:
595                         topo_str = "%s" % node.links
596                         slice.update_tag('topo_rspec', topo_str, tags, node)
597
598                 # Update slice tags in database
599                 for i in tags:
600                     tag = tags[i]
601                     if tag.slice_id == slice.id:
602                         if tag.tagname == 'topo_rspec' and not tag.updated:
603                             tag.delete()
604                         tag.write(api)
605     except KeyError:
606         # Bad Rspec
607         pass
608     
609
610     return True
611
612 def get_nodedict(rspec):
613     nodedict = {}
614     try:    
615         sitespecs = rspec['Rspec']['Capacity'][0]['NetSpec'][0]['SiteSpec']
616         for s in sitespecs:
617             for node in s['NodeSpec']:
618                 nodedict[node['name']] = node['hostname'][0]
619     except KeyError:
620         pass
621
622     return nodedict
623
624         
625 def rspec_to_nodeset(rspec):
626     nodes = set()
627     try:
628         nodedict = get_nodedict(rspec)
629         linkspecs = rspec['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']
630         for l in linkspecs:
631             for e in l['endpoint']:
632                 nodes.add(nodedict[e])
633     except KeyError:
634         # Bad Rspec
635         pass
636     
637     return nodes
638
639 def main():
640     r = Rspec()
641     r.parseFile(sys.argv[1])
642     rspec = r.toDict()
643     create_slice(None,'plc',rspec)
644     
645 if __name__ == "__main__":
646     main()