Manually generate XML rather than using Rspec() class, so that XSugar works
[sfa.git] / sfa / rspecs / aggregates / vini / utils.py
1 import re
2 import socket
3 from sfa.rspecs.aggregates.vini.topology import *
4
5 default_topo_xml = """
6             <LinkSpec>
7                 <endpoint>i2atla1</endpoint>
8                 <endpoint>i2chic1</endpoint>
9                 <bw>1Mbit</bw>
10             </LinkSpec>
11             <LinkSpec>
12                 <endpoint>i2atla1</endpoint>
13                 <endpoint>i2hous1</endpoint>
14                 <bw>1Mbit</bw>
15             </LinkSpec>
16             <LinkSpec>
17                 <endpoint>i2atla1</endpoint>
18                 <endpoint>i2wash1</endpoint>
19                 <bw>1Mbit</bw>
20             </LinkSpec>
21             <LinkSpec>
22                 <endpoint>i2chic1</endpoint>
23                 <endpoint>i2kans1</endpoint>
24                 <bw>1Mbit</bw>
25             </LinkSpec>
26             <LinkSpec>
27                 <endpoint>i2chic1</endpoint>
28                 <endpoint>i2wash1</endpoint>
29                 <bw>1Mbit</bw>
30             </LinkSpec>
31             <LinkSpec>
32                 <endpoint>i2hous1</endpoint>
33                 <endpoint>i2kans1</endpoint>
34                 <bw>1Mbit</bw>
35             </LinkSpec>
36             <LinkSpec>
37                 <endpoint>i2hous1</endpoint>
38                 <endpoint>i2losa1</endpoint>
39                 <bw>1Mbit</bw>
40             </LinkSpec>
41             <LinkSpec>
42                 <endpoint>i2kans1</endpoint>
43                 <endpoint>i2salt1</endpoint>
44                 <bw>1Mbit</bw>
45             </LinkSpec>
46             <LinkSpec>
47                 <endpoint>i2losa1</endpoint>
48                 <endpoint>i2salt1</endpoint>
49                 <bw>1Mbit</bw>
50             </LinkSpec>
51             <LinkSpec>
52                 <endpoint>i2losa1</endpoint>
53                 <endpoint>i2seat1</endpoint>
54                 <bw>1Mbit</bw>
55             </LinkSpec>
56             <LinkSpec>
57                 <endpoint>i2newy1</endpoint>
58                 <endpoint>i2wash1</endpoint>
59                 <bw>1Mbit</bw>
60             </LinkSpec>
61             <LinkSpec>
62                 <endpoint>i2salt1</endpoint>
63                 <endpoint>i2seat1</endpoint>
64                 <bw>1Mbit</bw>
65             </LinkSpec>"""
66       
67 # Taken from bwlimit.py
68 #
69 # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
70 # warned that older versions of tc interpret "kbps", "mbps", "mbit",
71 # and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
72 # "kibit" and that if an older version is installed, all rates will
73 # be off by a small fraction.
74 suffixes = {
75     "":         1,
76     "bit":      1,
77     "kibit":    1024,
78     "kbit":     1000,
79     "mibit":    1024*1024,
80     "mbit":     1000000,
81     "gibit":    1024*1024*1024,
82     "gbit":     1000000000,
83     "tibit":    1024*1024*1024*1024,
84     "tbit":     1000000000000,
85     "bps":      8,
86     "kibps":    8*1024,
87     "kbps":     8000,
88     "mibps":    8*1024*1024,
89     "mbps":     8000000,
90     "gibps":    8*1024*1024*1024,
91     "gbps":     8000000000,
92     "tibps":    8*1024*1024*1024*1024,
93     "tbps":     8000000000000
94 }
95
96
97 def get_tc_rate(s):
98     """
99     Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
100     """
101
102     if type(s) == int:
103         return s
104     m = re.match(r"([0-9.]+)(\D*)", s)
105     if m is None:
106         return -1
107     suffix = m.group(2).lower()
108     if suffixes.has_key(suffix):
109         return int(float(m.group(1)) * suffixes[suffix])
110     else:
111         return -1
112
113 def format_tc_rate(rate):
114     """
115     Formats a bits/second rate into a tc rate string
116     """
117
118     if rate >= 1000000000 and (rate % 1000000000) == 0:
119         return "%.0fgbit" % (rate / 1000000000.)
120     elif rate >= 1000000 and (rate % 1000000) == 0:
121         return "%.0fmbit" % (rate / 1000000.)
122     elif rate >= 1000:
123         return "%.0fkbit" % (rate / 1000.)
124     else:
125         return "%.0fbit" % rate
126
127
128 class Node:
129     def __init__(self, node, bps = 1000 * 1000000):
130         self.id = node['node_id']
131         self.hostname = node['hostname']
132         self.shortname = self.hostname.replace('.vini-veritas.net', '')
133         self.site_id = node['site_id']
134         self.ipaddr = socket.gethostbyname(self.hostname)
135         self.bps = bps
136         self.links = set()
137
138     def get_link_id(self, remote):
139         if self.id < remote.id:
140             link = (self.id<<7) + remote.id
141         else:
142             link = (remote.id<<7) + self.id
143         return link
144         
145     def get_iface_id(self, remote):
146         if self.id < remote.id:
147             iface = 1
148         else:
149             iface = 2
150         return iface
151     
152     def get_virt_ip(self, remote):
153         link = self.get_link_id(remote)
154         iface = self.get_iface_id(remote)
155         first = link >> 6
156         second = ((link & 0x3f)<<2) + iface
157         return "192.168.%d.%d" % (first, second)
158
159     def get_virt_net(self, remote):
160         link = self.get_link_id(remote)
161         first = link >> 6
162         second = (link & 0x3f)<<2
163         return "192.168.%d.%d/30" % (first, second)
164         
165     def get_site(self, sites):
166         return sites[self.site_id]
167     
168     def get_topo_rspec(self, link):
169         if link.end1 == self:
170             remote = link.end2
171         elif link.end2 == self:
172             remote = link.end1
173         else:
174             raise Error("Link does not connect to Node")
175             
176         my_ip = self.get_virt_ip(remote)
177         remote_ip = remote.get_virt_ip(self)
178         net = self.get_virt_net(remote)
179         bw = format_tc_rate(link.bps)
180         return (remote.id, remote.ipaddr, bw, my_ip, remote_ip, net)
181         
182     def add_link(self, link):
183         self.links.add(link)
184         
185     def add_tag(self, sites):
186         s = self.get_site(sites)
187         words = self.hostname.split(".")
188         index = words[0].replace("node", "")
189         if index.isdigit():
190             self.tag = s.tag + index
191         else:
192             self.tag = None
193
194     # Assumes there is at most one Link between two sites
195     def get_sitelink(self, node, sites):
196         site1 = sites[self.site_id]
197         site2 = sites[node.site_id]
198         sl = site1.links.intersection(site2.links)
199         if len(sl):
200             return sl.pop()
201         return None
202     
203
204 class Link:
205     def __init__(self, end1, end2, bps = 1000 * 1000000):
206         self.end1 = end1
207         self.end2 = end2
208         self.bps = bps
209         
210         end1.add_link(self)
211         end2.add_link(self)
212         
213         
214 class Site:
215     def __init__(self, site):
216         self.id = site['site_id']
217         self.node_ids = site['node_ids']
218         self.name = site['abbreviated_name'].replace(" ", "_")
219         self.tag = site['login_base']
220         self.public = site['is_public']
221         self.links = set()
222
223     def get_sitenodes(self, nodes):
224         n = []
225         for i in self.node_ids:
226             n.append(nodes[i])
227         return n
228     
229     def add_link(self, link):
230         self.links.add(link)
231     
232     
233 class Slice:
234     def __init__(self, slice):
235         self.id = slice['slice_id']
236         self.name = slice['name']
237         self.node_ids = set(slice['node_ids'])
238         self.slice_tag_ids = slice['slice_tag_ids']
239     
240     def get_tag(self, tagname, slicetags, node = None):
241         for i in self.slice_tag_ids:
242             tag = slicetags[i]
243             if tag.tagname == tagname:
244                 if (not node) or (node.id == tag.node_id):
245                     return tag
246         else:
247             return None
248         
249     def get_nodes(self, nodes):
250         n = []
251         for id in self.node_ids:
252             n.append(nodes[id])
253         return n
254              
255     
256     # Add a new slice tag   
257     def add_tag(self, tagname, value, slicetags, node = None):
258         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
259         if node:
260             record['node_id'] = node.id
261         else:
262             record['node_id'] = None
263         tag = Slicetag(record)
264         slicetags[tag.id] = tag
265         self.slice_tag_ids.append(tag.id)
266         tag.changed = True       
267         tag.updated = True
268         return tag
269     
270     # Update a slice tag if it exists, else add it             
271     def update_tag(self, tagname, value, slicetags, node = None):
272         tag = self.get_tag(tagname, slicetags, node)
273         if tag and tag.value == value:
274             value = "no change"
275         elif tag:
276             tag.value = value
277             tag.changed = True
278         else:
279             tag = self.add_tag(tagname, value, slicetags, node)
280         tag.updated = True
281             
282     def assign_egre_key(self, slicetags):
283         if not self.get_tag('egre_key', slicetags):
284             try:
285                 key = free_egre_key(slicetags)
286                 self.update_tag('egre_key', key, slicetags)
287             except:
288                 # Should handle this case...
289                 pass
290         return
291             
292     def turn_on_netns(self, slicetags):
293         tag = self.get_tag('netns', slicetags)
294         if (not tag) or (tag.value != '1'):
295             self.update_tag('netns', '1', slicetags)
296         return
297    
298     def turn_off_netns(self, slicetags):
299         tag = self.get_tag('netns', slicetags)
300         if tag and (tag.value != '0'):
301             tag.delete()
302         return
303     
304     def add_cap_net_admin(self, slicetags):
305         tag = self.get_tag('capabilities', slicetags)
306         if tag:
307             caps = tag.value.split(',')
308             for cap in caps:
309                 if cap == "CAP_NET_ADMIN":
310                     return
311             else:
312                 newcaps = "CAP_NET_ADMIN," + tag.value
313                 self.update_tag('capabilities', newcaps, slicetags)
314         else:
315             self.add_tag('capabilities', 'CAP_NET_ADMIN', slicetags)
316         return
317     
318     def remove_cap_net_admin(self, slicetags):
319         tag = self.get_tag('capabilities', slicetags)
320         if tag:
321             caps = tag.value.split(',')
322             newcaps = []
323             for cap in caps:
324                 if cap != "CAP_NET_ADMIN":
325                     newcaps.append(cap)
326             if newcaps:
327                 value = ','.join(newcaps)
328                 self.update_tag('capabilities', value, slicetags)
329             else:
330                 tag.delete()
331         return
332
333     # Update the vsys/setup-link and vsys/setup-nat slice tags.
334     def add_vsys_tags(self, slicetags):
335         link = nat = False
336         for i in self.slice_tag_ids:
337             tag = slicetags[i]
338             if tag.tagname == 'vsys':
339                 if tag.value == 'setup-link':
340                     link = True
341                 elif tag.value == 'setup-nat':
342                     nat = True
343         if not link:
344             self.add_tag('vsys', 'setup-link', slicetags)
345         if not nat:
346             self.add_tag('vsys', 'setup-nat', slicetags)
347         return
348
349
350 class Slicetag:
351     newid = -1 
352     def __init__(self, tag):
353         self.id = tag['slice_tag_id']
354         if not self.id:
355             # Make one up for the time being...
356             self.id = Slicetag.newid
357             Slicetag.newid -= 1
358         self.slice_id = tag['slice_id']
359         self.tagname = tag['tagname']
360         self.value = tag['value']
361         self.node_id = tag['node_id']
362         self.updated = False
363         self.changed = False
364         self.deleted = False
365     
366     # Mark a tag as deleted
367     def delete(self):
368         self.deleted = True
369         self.updated = True
370     
371     def write(self, api):
372         if self.changed:
373             if int(self.id) > 0:
374                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
375             else:
376                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
377                                         self.tagname, self.value, self.node_id)
378         elif self.deleted and int(self.id) > 0:
379             api.plshell.DeleteSliceTag(api.plauth, self.id)
380
381
382 """
383 A topology is a compound object consisting of:
384 * a dictionary mapping site IDs to Site objects
385 * a dictionary mapping node IDs to Node objects
386 * the Site objects are connected via SiteLink objects representing
387   the physical topology and available bandwidth
388 * the Node objects are connected via Link objects representing
389   the requested or assigned virtual topology of a slice
390 """
391 class Topology:
392     def __init__(self, api):
393         self.api = api
394         self.sites = get_sites(api)
395         self.nodes = get_nodes(api)
396         self.tags = get_slice_tags(api)
397         self.sitelinks = []
398         self.nodelinks = []
399     
400         for (s1, s2) in PhysicalLinks:
401             self.sitelinks.append(Link(self.sites[s1], self.sites[s2]))
402         
403         for id in self.nodes:
404             self.nodes[id].add_tag(self.sites)
405         
406         for t in self.tags:
407             tag = self.tags[t]
408             if tag.tagname == 'topo_rspec':
409                 node1 = self.nodes[tag.node_id]
410                 l = eval(tag.value)
411                 for (id, realip, bw, lvip, rvip, vnet) in l:
412                     allocbps = get_tc_rate(bw)
413                     node1.bps -= allocbps
414                     try:
415                         node2 = self.nodes[id]
416                         if node1.id < node2.id:
417                             sl = node1.get_sitelink(node2, self.sites)
418                             sl.bps -= allocbps
419                     except:
420                         pass
421
422     
423     def lookupSite(self, id):
424         val = None
425         try:
426             val = self.sites[id]
427         except:
428             raise KeyError("site ID %s not found" % id)
429         return val
430     
431     def getSites(self):
432         sites = []
433         for s in self.sites:
434             sites.append(self.sites[s])
435         return sites
436         
437     def lookupNode(self, id):
438         val = None
439         try:
440             val = self.nodes[id]
441         except:
442             raise KeyError("node ID %s not found" % id)
443         return val
444     
445     def getNodes(self):
446         nodes = []
447         for n in self.nodes:
448             nodes.append(self.nodes[n])
449         return nodes
450     
451     def nodesInTopo(self):
452         nodes = []
453         for n in self.nodes:
454             if self.nodes[n].links:
455                 nodes.append(self.nodes[n])
456         return nodes
457             
458     def lookupSliceTag(self, id):
459         val = None
460         try:
461             val = self.tags[id]
462         except:
463             raise KeyError("slicetag ID %s not found" % id)
464         return val
465     
466     def getSliceTags(self):
467         tags = []
468         for t in self.tags:
469             tags.append(self.tags[t])
470         return tags
471     
472     def nodeTopoFromRspec(self, rspec):
473         if self.nodelinks:
474             raise Error("virtual topology already present")
475             
476         rspecdict = rspec.toDict()
477         nodedict = {}
478         for node in self.getNodes():
479             nodedict[node.tag] = node
480             
481         linkspecs = rspecdict['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec']    
482         for l in linkspecs:
483             n1 = nodedict[l['endpoint'][0]]
484             n2 = nodedict[l['endpoint'][1]]
485             bps = get_tc_rate(l['bw'][0])
486             self.nodelinks.append(Link(n1, n2, bps))
487  
488     def nodeTopoFromSliceTags(self, slice):
489         if self.nodelinks:
490             raise Error("virtual topology already present")
491             
492         for node in slice.get_nodes(self.nodes):
493             linktag = slice.get_tag('topo_rspec', self.tags, node)
494             if linktag:
495                 l = eval(linktag.value)
496                 for (id, realip, bw, lvip, rvip, vnet) in l:
497                     if node.id < id:
498                         bps = get_tc_rate(bw)
499                         remote = self.lookupNode(id)
500                         self.nodelinks.append(Link(node, remote, bps))
501
502     def updateSliceTags(self, slice):
503         if not self.nodelinks:
504             return
505  
506         slice.update_tag('vini_topo', 'manual', self.tags)
507         slice.assign_egre_key(self.tags)
508         slice.turn_on_netns(self.tags)
509         slice.add_cap_net_admin(self.tags)
510
511         for node in slice.get_nodes(self.nodes):
512             linkdesc = []
513             for link in node.links:
514                 linkdesc.append(node.get_topo_rspec(link))
515             if linkdesc:
516                 topo_str = "%s" % linkdesc
517                 slice.update_tag('topo_rspec', topo_str, self.tags, node)
518
519         # Update slice tags in database
520         for tag in self.getSliceTags():
521             if tag.slice_id == slice.id:
522                 if tag.tagname == 'topo_rspec' and not tag.updated:
523                     tag.delete()
524                 tag.write(self.api)
525                 
526     def toxml(self, hrn = None):
527         xml = """<?xml version="1.0"?>
528 <Rspec xmlns="http://www.planet-lab.org/sfa/rspec/" name="vini">
529     <Capacity>
530         <NetSpec name="physical_topology">"""
531
532         for site in self.getSites():
533             if not site.public:
534                 continue
535             
536             xml += """
537             <SiteSpec name="%s"> """ % site.name
538
539             for node in site.get_sitenodes(self.nodes):
540                 if not node.tag:
541                     continue
542                 
543                 xml += """
544                 <NodeSpec name="%s">
545                     <hostname>%s</hostname>
546                     <bw>%s</bw>
547                 </NodeSpec>""" % (node.tag, node.hostname, format_tc_rate(node.bps))
548             xml += """
549             </SiteSpec>"""
550             
551         for link in self.sitelinks:
552             xml += """
553             <SiteLinkSpec>
554                 <endpoint>%s</endpoint>
555                 <endpoint>%s</endpoint> 
556                 <bw>%s</bw>
557             </SiteLinkSpec>""" % (link.end1.name, link.end2.name, format_tc_rate(link.bps))
558             
559         
560         if hrn:
561             name = hrn
562         else:
563             name = 'default_topology'
564         xml += """
565         </NetSpec>
566     </Capacity>
567     <Request>
568         <NetSpec name="%s">""" % name
569         
570         if hrn:
571             for link in self.nodelinks:
572                 xml += """
573             <LinkSpec>
574                 <endpoint>%s</endpoint>
575                 <endpoint>%s</endpoint> 
576                 <bw>%s</bw>
577             </LinkSpec>""" % (link.end1.tag, link.end2.tag, format_tc_rate(link.bps))
578         else:
579             xml += default_topo_xml
580             
581         xml += """
582         </NetSpec>
583     </Request>
584 </Rspec>"""
585
586         # Remove all leading whitespace and newlines
587         lines = xml.split("\n")
588         noblanks = ""
589         for line in lines:
590             noblanks += line.strip()
591         return noblanks
592
593
594 """
595 Create a dictionary of site objects keyed by site ID
596 """
597 def get_sites(api):
598     tmp = []
599     for site in api.plshell.GetSites(api.plauth):
600         t = site['site_id'], Site(site)
601         tmp.append(t)
602     return dict(tmp)
603
604
605 """
606 Create a dictionary of node objects keyed by node ID
607 """
608 def get_nodes(api):
609     tmp = []
610     for node in api.plshell.GetNodes(api.plauth):
611         t = node['node_id'], Node(node)
612         tmp.append(t)
613     return dict(tmp)
614
615 """
616 Create a dictionary of slice objects keyed by slice ID
617 """
618 def get_slice(api, slicename):
619     slice = api.plshell.GetSlices(api.plauth, [slicename])
620     if slice:
621         return Slice(slice[0])
622     else:
623         return None
624
625 """
626 Create a dictionary of slicetag objects keyed by slice tag ID
627 """
628 def get_slice_tags(api):
629     tmp = []
630     for tag in api.plshell.GetSliceTags(api.plauth):
631         t = tag['slice_tag_id'], Slicetag(tag)
632         tmp.append(t)
633     return dict(tmp)
634     
635 """
636 Find a free EGRE key
637 """
638 def free_egre_key(slicetags):
639     used = set()
640     for i in slicetags:
641         tag = slicetags[i]
642         if tag.tagname == 'egre_key':
643             used.add(int(tag.value))
644                 
645     for i in range(1, 256):
646         if i not in used:
647             key = i
648             break
649     else:
650         raise KeyError("No more EGRE keys available")
651         
652     return "%s" % key
653