Convert Rspecs and SliceTags to data structure with Node objects connected by Link...
[sfa.git] / sfa / rspecs / aggregates / vini / utils.py
1 import re
2 import socket
3 from sfa.rspecs.aggregates.vini.topology import *
4
5 # Taken from bwlimit.py
6 #
7 # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
8 # warned that older versions of tc interpret "kbps", "mbps", "mbit",
9 # and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
10 # "kibit" and that if an older version is installed, all rates will
11 # be off by a small fraction.
12 suffixes = {
13     "":         1,
14     "bit":      1,
15     "kibit":    1024,
16     "kbit":     1000,
17     "mibit":    1024*1024,
18     "mbit":     1000000,
19     "gibit":    1024*1024*1024,
20     "gbit":     1000000000,
21     "tibit":    1024*1024*1024*1024,
22     "tbit":     1000000000000,
23     "bps":      8,
24     "kibps":    8*1024,
25     "kbps":     8000,
26     "mibps":    8*1024*1024,
27     "mbps":     8000000,
28     "gibps":    8*1024*1024*1024,
29     "gbps":     8000000000,
30     "tibps":    8*1024*1024*1024*1024,
31     "tbps":     8000000000000
32 }
33
34
35 def get_tc_rate(s):
36     """
37     Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
38     """
39
40     if type(s) == int:
41         return s
42     m = re.match(r"([0-9.]+)(\D*)", s)
43     if m is None:
44         return -1
45     suffix = m.group(2).lower()
46     if suffixes.has_key(suffix):
47         return int(float(m.group(1)) * suffixes[suffix])
48     else:
49         return -1
50
51 def format_tc_rate(rate):
52     """
53     Formats a bits/second rate into a tc rate string
54     """
55
56     if rate >= 1000000000 and (rate % 1000000000) == 0:
57         return "%.0fgbit" % (rate / 1000000000.)
58     elif rate >= 1000000 and (rate % 1000000) == 0:
59         return "%.0fmbit" % (rate / 1000000.)
60     elif rate >= 1000:
61         return "%.0fkbit" % (rate / 1000.)
62     else:
63         return "%.0fbit" % rate
64
65
66 class Node:
67     def __init__(self, node, bps = 1000 * 1000000):
68         self.id = node['node_id']
69         self.hostname = node['hostname']
70         self.shortname = self.hostname.replace('.vini-veritas.net', '')
71         self.site_id = node['site_id']
72         self.ipaddr = socket.gethostbyname(self.hostname)
73         self.bps = bps
74         self.links = set()
75
76     def get_link_id(self, remote):
77         if self.id < remote.id:
78             link = (self.id<<7) + remote.id
79         else:
80             link = (remote.id<<7) + self.id
81         return link
82         
83     def get_iface_id(self, remote):
84         if self.id < remote.id:
85             iface = 1
86         else:
87             iface = 2
88         return iface
89     
90     def get_virt_ip(self, remote):
91         link = self.get_link_id(remote)
92         iface = self.get_iface_id(remote)
93         first = link >> 6
94         second = ((link & 0x3f)<<2) + iface
95         return "192.168.%d.%d" % (first, second)
96
97     def get_virt_net(self, remote):
98         link = self.get_link_id(remote)
99         first = link >> 6
100         second = (link & 0x3f)<<2
101         return "192.168.%d.%d/30" % (first, second)
102         
103     def get_site(self, sites):
104         return sites[self.site_id]
105     
106     def get_topo_rspec(self, link):
107         if link.end1 == self:
108             remote = link.end2
109         elif link.end2 == self:
110             remote = link.end1
111         else:
112             raise Error("Link does not connect to Node")
113             
114         my_ip = self.get_virt_ip(remote)
115         remote_ip = remote.get_virt_ip(self)
116         net = self.get_virt_net(remote)
117         bw = format_tc_rate(link.bps)
118         return (remote.id, remote.ipaddr, bw, my_ip, remote_ip, net)
119         
120     def add_link(self, link):
121         self.links.add(link)
122         
123     def add_tag(self, sites):
124         s = self.get_site(sites)
125         words = self.hostname.split(".")
126         index = words[0].replace("node", "")
127         if index.isdigit():
128             self.tag = s.tag + index
129         else:
130             self.tag = None
131
132     # Assumes there is at most one Link between two sites
133     def get_sitelink(self, node, sites):
134         site1 = sites[self.site_id]
135         site2 = sites[node.site_id]
136         sl = site1.links.intersection(site2.links)
137         if len(sl):
138             return sl.pop()
139         return None
140     
141
142 class Link:
143     def __init__(self, end1, end2, bps = 1000 * 1000000):
144         self.end1 = end1
145         self.end2 = end2
146         self.bps = bps
147         
148         end1.add_link(self)
149         end2.add_link(self)
150         
151         
152 class Site:
153     def __init__(self, site):
154         self.id = site['site_id']
155         self.node_ids = site['node_ids']
156         self.name = site['abbreviated_name']
157         self.tag = site['login_base']
158         self.public = site['is_public']
159         self.links = set()
160
161     def get_sitenodes(self, nodes):
162         n = []
163         for i in self.node_ids:
164             n.append(nodes[i])
165         return n
166     
167     def add_link(self, link):
168         self.links.add(link)
169     
170     
171 class Slice:
172     def __init__(self, slice):
173         self.id = slice['slice_id']
174         self.name = slice['name']
175         self.node_ids = set(slice['node_ids'])
176         self.slice_tag_ids = slice['slice_tag_ids']
177     
178     def get_tag(self, tagname, slicetags, node = None):
179         for i in self.slice_tag_ids:
180             tag = slicetags[i]
181             if tag.tagname == tagname:
182                 if (not node) or (node.id == tag.node_id):
183                     return tag
184         else:
185             return None
186         
187     def get_nodes(self, nodes):
188         n = []
189         for id in self.node_ids:
190             n.append(nodes[id])
191         return n
192              
193     
194     # Add a new slice tag   
195     def add_tag(self, tagname, value, slicetags, node = None):
196         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
197         if node:
198             record['node_id'] = node.id
199         else:
200             record['node_id'] = None
201         tag = Slicetag(record)
202         slicetags[tag.id] = tag
203         self.slice_tag_ids.append(tag.id)
204         tag.changed = True       
205         tag.updated = True
206         return tag
207     
208     # Update a slice tag if it exists, else add it             
209     def update_tag(self, tagname, value, slicetags, node = None):
210         tag = self.get_tag(tagname, slicetags, node)
211         if tag and tag.value == value:
212             value = "no change"
213         elif tag:
214             tag.value = value
215             tag.changed = True
216         else:
217             tag = self.add_tag(tagname, value, slicetags, node)
218         tag.updated = True
219             
220     def assign_egre_key(self, slicetags):
221         if not self.get_tag('egre_key', slicetags):
222             try:
223                 key = free_egre_key(slicetags)
224                 self.update_tag('egre_key', key, slicetags)
225             except:
226                 # Should handle this case...
227                 pass
228         return
229             
230     def turn_on_netns(self, slicetags):
231         tag = self.get_tag('netns', slicetags)
232         if (not tag) or (tag.value != '1'):
233             self.update_tag('netns', '1', slicetags)
234         return
235    
236     def turn_off_netns(self, slicetags):
237         tag = self.get_tag('netns', slicetags)
238         if tag and (tag.value != '0'):
239             tag.delete()
240         return
241     
242     def add_cap_net_admin(self, slicetags):
243         tag = self.get_tag('capabilities', slicetags)
244         if tag:
245             caps = tag.value.split(',')
246             for cap in caps:
247                 if cap == "CAP_NET_ADMIN":
248                     return
249             else:
250                 newcaps = "CAP_NET_ADMIN," + tag.value
251                 self.update_tag('capabilities', newcaps, slicetags)
252         else:
253             self.add_tag('capabilities', 'CAP_NET_ADMIN', slicetags)
254         return
255     
256     def remove_cap_net_admin(self, slicetags):
257         tag = self.get_tag('capabilities', slicetags)
258         if tag:
259             caps = tag.value.split(',')
260             newcaps = []
261             for cap in caps:
262                 if cap != "CAP_NET_ADMIN":
263                     newcaps.append(cap)
264             if newcaps:
265                 value = ','.join(newcaps)
266                 self.update_tag('capabilities', value, slicetags)
267             else:
268                 tag.delete()
269         return
270
271     # Update the vsys/setup-link and vsys/setup-nat slice tags.
272     def add_vsys_tags(self, slicetags):
273         link = nat = False
274         for i in self.slice_tag_ids:
275             tag = slicetags[i]
276             if tag.tagname == 'vsys':
277                 if tag.value == 'setup-link':
278                     link = True
279                 elif tag.value == 'setup-nat':
280                     nat = True
281         if not link:
282             self.add_tag('vsys', 'setup-link', slicetags)
283         if not nat:
284             self.add_tag('vsys', 'setup-nat', slicetags)
285         return
286
287
288 class Slicetag:
289     newid = -1 
290     def __init__(self, tag):
291         self.id = tag['slice_tag_id']
292         if not self.id:
293             # Make one up for the time being...
294             self.id = Slicetag.newid
295             Slicetag.newid -= 1
296         self.slice_id = tag['slice_id']
297         self.tagname = tag['tagname']
298         self.value = tag['value']
299         self.node_id = tag['node_id']
300         self.updated = False
301         self.changed = False
302         self.deleted = False
303     
304     # Mark a tag as deleted
305     def delete(self):
306         self.deleted = True
307         self.updated = True
308     
309     def write(self, api):
310         if self.changed:
311             if int(self.id) > 0:
312                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
313             else:
314                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
315                                         self.tagname, self.value, self.node_id)
316         elif self.deleted and int(self.id) > 0:
317             api.plshell.DeleteSliceTag(api.plauth, self.id)
318
319
320 """
321 A topology is a compound object consisting of:
322 * a dictionary mapping site IDs to Site objects
323 * a dictionary mapping node IDs to Node objects
324 * the Site objects are connected via SiteLink objects representing
325   the physical topology and available bandwidth
326 * the Node objects are connected via Link objects representing
327   the requested or assigned virtual topology of a slice
328 """
329 class Topology:
330     def __init__(self, api):
331         self.api = api
332         self.sites = get_sites(api)
333         self.nodes = get_nodes(api)
334         self.tags = get_slice_tags(api)
335         self.sitelinks = []
336         self.nodelinks = []
337     
338         for (s1, s2) in PhysicalLinks:
339             self.sitelinks.append(Link(self.sites[s1], self.sites[s2]))
340         
341         for id in self.nodes:
342             self.nodes[id].add_tag(self.sites)
343         
344         for t in self.tags:
345             tag = self.tags[t]
346             if tag.tagname == 'topo_rspec':
347                 node1 = self.nodes[tag.node_id]
348                 l = eval(tag.value)
349                 for (id, realip, bw, lvip, rvip, vnet) in l:
350                     allocbps = get_tc_rate(bw)
351                     node1.bps -= allocbps
352                     try:
353                         node2 = self.nodes[id]
354                         if node1.id < node2.id:
355                             sl = node1.get_sitelink(node2, self.sites)
356                             sl.bps -= allocbps
357                     except:
358                         pass
359
360     
361     def lookupSite(self, id):
362         val = None
363         try:
364             val = self.sites[id]
365         except:
366             raise KeyError("site ID %s not found" % id)
367         return val
368     
369     def getSites(self):
370         sites = []
371         for s in self.sites:
372             sites.append(self.sites[s])
373         return sites
374         
375     def lookupNode(self, id):
376         val = None
377         try:
378             val = self.nodes[id]
379         except:
380             raise KeyError("node ID %s not found" % id)
381         return val
382     
383     def getNodes(self):
384         nodes = []
385         for n in self.nodes:
386             nodes.append(self.nodes[n])
387         return nodes
388     
389     def nodesInTopo(self):
390         nodes = []
391         for n in self.nodes:
392             if self.nodes[n].links:
393                 nodes.append(self.nodes[n])
394         return nodes
395             
396     def lookupSliceTag(self, id):
397         val = None
398         try:
399             val = self.tags[id]
400         except:
401             raise KeyError("slicetag ID %s not found" % id)
402         return val
403     
404     def getSliceTags(self):
405         tags = []
406         for t in self.tags:
407             tags.append(self.tags[t])
408         return tags
409     
410     def nodeTopoFromRspec(self, rspec):
411         if self.nodelinks:
412             raise Error("virtual topology already present")
413             
414         rspecdict = rspec.toDict()
415         nodedict = {}
416         for node in self.getNodes():
417             nodedict[node.tag] = node
418             
419         linkspecs = rspecdict['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']    
420         for l in linkspecs:
421             n1 = nodedict[l['endpoint'][0]]
422             n2 = nodedict[l['endpoint'][1]]
423             bps = get_tc_rate(l['bw'][0])
424             self.nodelinks.append(Link(n1, n2, bps))
425  
426     def nodeTopoFromSliceTags(self, slice):
427         if self.nodelinks:
428             raise Error("virtual topology already present")
429             
430         for node in slice.get_nodes(self.nodes):
431             linktag = slice.get_tag('topo_rspec', self.tags, node)
432             if linktag:
433                 l = eval(linktag.value)
434                 for (id, realip, bw, lvip, rvip, vnet) in l:
435                     if node.id < id:
436                         bps = get_tc_rate(bw)
437                         remote = self.lookupNode(id)
438                         self.nodelinks.append(Link(node, remote, bps))
439
440     def updateSliceTags(self, slice):
441         if not self.nodelinks:
442             return
443  
444         slice.update_tag('vini_topo', 'manual', self.tags)
445         slice.assign_egre_key(self.tags)
446         slice.turn_on_netns(self.tags)
447         slice.add_cap_net_admin(self.tags)
448
449         for node in slice.get_nodes(self.nodes):
450             linkdesc = []
451             for link in node.links:
452                 linkdesc.append(node.get_topo_rspec(link))
453             if linkdesc:
454                 topo_str = "%s" % linkdesc
455                 slice.update_tag('topo_rspec', topo_str, self.tags, node)
456
457         # Update slice tags in database
458         for tag in self.getSliceTags():
459             if tag.slice_id == slice.id:
460                 if tag.tagname == 'topo_rspec' and not tag.updated:
461                     tag.delete()
462                 tag.write(self.api)
463
464
465 """
466 Create a dictionary of site objects keyed by site ID
467 """
468 def get_sites(api):
469     tmp = []
470     for site in api.plshell.GetSites(api.plauth):
471         t = site['site_id'], Site(site)
472         tmp.append(t)
473     return dict(tmp)
474
475
476 """
477 Create a dictionary of node objects keyed by node ID
478 """
479 def get_nodes(api):
480     tmp = []
481     for node in api.plshell.GetNodes(api.plauth):
482         t = node['node_id'], Node(node)
483         tmp.append(t)
484     return dict(tmp)
485
486 """
487 Create a dictionary of slice objects keyed by slice ID
488 """
489 def get_slice(api, slicename):
490     slice = api.plshell.GetSlices(api.plauth, [slicename])
491     if slice:
492         return Slice(slice[0])
493     else:
494         return None
495
496 """
497 Create a dictionary of slicetag objects keyed by slice tag ID
498 """
499 def get_slice_tags(api):
500     tmp = []
501     for tag in api.plshell.GetSliceTags(api.plauth):
502         t = tag['slice_tag_id'], Slicetag(tag)
503         tmp.append(t)
504     return dict(tmp)
505     
506 """
507 Find a free EGRE key
508 """
509 def free_egre_key(slicetags):
510     used = set()
511     for i in slicetags:
512         tag = slicetags[i]
513         if tag.tagname == 'egre_key':
514             used.add(int(tag.value))
515                 
516     for i in range(1, 256):
517         if i not in used:
518             key = i
519             break
520     else:
521         raise KeyError("No more EGRE keys available")
522         
523     return "%s" % key
524