Generalizing object model used by VINI's Aggregate Manager for PLC
[sfa.git] / sfa / plc / network.py
1 from __future__ import with_statement
2 import re
3 import socket
4 from sfa.util.faults import *
5 from xmlbuilder import XMLBuilder
6 #from lxml import etree
7 import sys
8 from StringIO import StringIO
9
10 # Taken from bwlimit.py
11 #
12 # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
13 # warned that older versions of tc interpret "kbps", "mbps", "mbit",
14 # and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
15 # "kibit" and that if an older version is installed, all rates will
16 # be off by a small fraction.
17 suffixes = {
18     "":         1,
19     "bit":      1,
20     "kibit":    1024,
21     "kbit":     1000,
22     "mibit":    1024*1024,
23     "mbit":     1000000,
24     "gibit":    1024*1024*1024,
25     "gbit":     1000000000,
26     "tibit":    1024*1024*1024*1024,
27     "tbit":     1000000000000,
28     "bps":      8,
29     "kibps":    8*1024,
30     "kbps":     8000,
31     "mibps":    8*1024*1024,
32     "mbps":     8000000,
33     "gibps":    8*1024*1024*1024,
34     "gbps":     8000000000,
35     "tibps":    8*1024*1024*1024*1024,
36     "tbps":     8000000000000
37 }
38
39
40 def get_tc_rate(s):
41     """
42     Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
43     """
44
45     if type(s) == int:
46         return s
47     m = re.match(r"([0-9.]+)(\D*)", s)
48     if m is None:
49         return -1
50     suffix = m.group(2).lower()
51     if suffixes.has_key(suffix):
52         return int(float(m.group(1)) * suffixes[suffix])
53     else:
54         return -1
55
56 def format_tc_rate(rate):
57     """
58     Formats a bits/second rate into a tc rate string
59     """
60
61     if rate >= 1000000000 and (rate % 1000000000) == 0:
62         return "%.0fgbit" % (rate / 1000000000.)
63     elif rate >= 1000000 and (rate % 1000000) == 0:
64         return "%.0fmbit" % (rate / 1000000.)
65     elif rate >= 1000:
66         return "%.0fkbit" % (rate / 1000.)
67     else:
68         return "%.0fbit" % rate 
69
70
71 class Iface:
72     def __init__(self, network, iface):
73         self.network = network
74         self.id = iface['interface_id']
75         self.idtag = "i%s" % self.id
76         self.ipv4 = iface['ip']
77         self.bwlimit = iface['bwlimit']
78         self.hostname = iface['hostname']
79
80     def toxml(self, xml):
81         with xml.interface(id = self.idtag):
82             if self.hostname:
83                 with xml.hostname:
84                     xml << self.hostname
85             with xml.ipv4:
86                 xml << self.ipv4
87             if self.bwlimit:
88                 with xml.bwlimit:
89                     xml << format_tc_rate(self.bwlimit)
90
91 class Node:
92     def __init__(self, network, node, bps = 1000 * 1000000):
93         self.network = network
94         self.id = node['node_id']
95         self.idtag = "n%s" % self.id
96         self.hostname = node['hostname']
97         self.name = self.shortname = self.hostname.replace('.vini-veritas.net', '')
98         self.site_id = node['site_id']
99         #self.ipaddr = socket.gethostbyname(self.hostname)
100         self.bps = bps
101         self.links = set()
102         self.iface_ids = node['interface_ids']
103         self.iface_ids.sort()
104         self.sliver = False
105
106     def get_link_id(self, remote):
107         if self.id < remote.id:
108             link = (self.id<<7) + remote.id
109         else:
110             link = (remote.id<<7) + self.id
111         return link
112         
113     def get_iface_id(self, remote):
114         if self.id < remote.id:
115             iface = 1
116         else:
117             iface = 2
118         return iface
119     
120     def get_ifaces(self):
121         i = []
122         for id in self.iface_ids:
123             i.append(self.network.lookupIface(id))
124         return i
125         
126     def get_virt_ip(self, remote):
127         link = self.get_link_id(remote)
128         iface = self.get_iface_id(remote)
129         first = link >> 6
130         second = ((link & 0x3f)<<2) + iface
131         return "192.168.%d.%d" % (first, second)
132
133     def get_virt_net(self, remote):
134         link = self.get_link_id(remote)
135         first = link >> 6
136         second = (link & 0x3f)<<2
137         return "192.168.%d.%d/30" % (first, second)
138         
139     def get_site(self):
140         return self.network.lookupSite(self.site_id)
141     
142     def get_topo_rspec(self, link):
143         if link.end1 == self:
144             remote = link.end2
145         elif link.end2 == self:
146             remote = link.end1
147         else:
148             raise Error("Link does not connect to Node")
149             
150         my_ip = self.get_virt_ip(remote)
151         remote_ip = remote.get_virt_ip(self)
152         net = self.get_virt_net(remote)
153         bw = format_tc_rate(link.bps)
154         return (remote.id, remote.ipaddr, bw, my_ip, remote_ip, net)
155         
156     def add_link(self, link):
157         self.links.add(link)
158         
159     # Assumes there is at most one Link between two sites
160     def get_sitelink(self, node, sites):
161         site1 = sites[self.site_id]
162         site2 = sites[node.site_id]
163         sl = site1.links.intersection(site2.links)
164         if len(sl):
165             return sl.pop()
166         return None
167
168     def add_sliver(self):
169         self.sliver = True
170
171     def toxml(self, xml, hrn):
172         with xml.node(id = self.idtag):
173             with xml.hostname:
174                 xml << self.hostname
175             if self.network.type == "VINI":
176                 with xml.kbps:
177                     xml << str(int(self.bps/1000))
178             for iface in self.get_ifaces():
179                 iface.toxml(xml)
180             if self.sliver:
181                 with xml.sliver:
182                     pass
183     
184
185 class Link:
186     def __init__(self, end1, end2, bps = 1000000000, parent = None):
187         self.end1 = end1
188         self.end2 = end2
189         self.bps = bps
190         self.parent = parent
191         self.children = []
192
193         end1.add_link(self)
194         end2.add_link(self)
195         
196         if self.parent:
197             self.parent.children.append(self)
198             
199     def toxml(self, xml):
200         end_ids = "%s %s" % (self.end1.idtag, self.end2.idtag)
201
202         if self.parent:
203             element = xml.vlink(endpoints=end_ids)
204         else:
205             element = xml.link(endpoints=end_ids)
206
207         with element:
208             with xml.description:
209                 xml << "%s -- %s" % (self.end1.name, self.end2.name)
210             with xml.kbps:
211                 xml << str(int(self.bps/1000))
212             for child in self.children:
213                 child.toxml(xml)
214         
215
216 class Site:
217     def __init__(self, network, site):
218         self.network = network
219         self.id = site['site_id']
220         self.idtag = "s%s" % self.id
221         self.node_ids = site['node_ids']
222         self.node_ids.sort()
223         self.name = site['abbreviated_name']
224         self.tag = site['login_base']
225         self.public = site['is_public']
226         self.enabled = site['enabled']
227         self.links = set()
228
229     def get_sitenodes(self):
230         n = []
231         for i in self.node_ids:
232             n.append(self.network.lookupNode(i))
233         return n
234     
235     def add_link(self, link):
236         self.links.add(link)
237
238     def toxml(self, xml, hrn, nodes):
239         if not (self.public and self.enabled and self.node_ids):
240             return
241         with xml.site(id = self.idtag):
242             with xml.name:
243                 xml << self.name
244                 
245             for node in self.get_sitenodes():
246                 node.toxml(xml, hrn)
247    
248     
249 class Slice:
250     def __init__(self, network, slice):
251         self.network = network
252         self.id = slice['slice_id']
253         self.name = slice['name']
254         self.node_ids = set(slice['node_ids'])
255         self.slice_tag_ids = slice['slice_tag_ids']
256     
257     def get_tag(self, tagname, node = None):
258         for i in self.slice_tag_ids:
259             tag = self.network.lookupSliceTag(i)
260             if tag.tagname == tagname:
261                 if (not node) or (node.id == tag.node_id):
262                     return tag
263         else:
264             return None
265         
266     def get_nodes(self, nodes):
267         n = []
268         for id in self.node_ids:
269             n.append(nodes[id])
270         return n
271   
272     # Add a new slice tag   
273     def add_tag(self, tagname, value, node = None):
274         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
275         if node:
276             record['node_id'] = node.id
277         else:
278             record['node_id'] = None
279         tag = Slicetag(record)
280         self.network.slicetags[tag.id] = tag
281         self.slice_tag_ids.append(tag.id)
282         tag.changed = True       
283         tag.updated = True
284         return tag
285     
286     # Update a slice tag if it exists, else add it             
287     def update_tag(self, tagname, value, node = None):
288         tag = self.get_tag(tagname, node)
289         if tag and tag.value == value:
290             value = "no change"
291         elif tag:
292             tag.value = value
293             tag.changed = True
294         else:
295             tag = self.add_tag(tagname, value, node)
296         tag.updated = True
297             
298     """
299     Find a free EGRE key
300     """
301     def new_egre_key():
302         slicetags = self.network.slicetags
303         used = set()
304         for i in slicetags:
305             tag = slicetags[i]
306             if tag.tagname == 'egre_key':
307                 used.add(int(tag.value))
308                 
309         for i in range(1, 256):
310             if i not in used:
311                 key = i
312                 break
313         else:
314             raise KeyError("No more EGRE keys available")
315         
316         return "%s" % key
317    
318
319     def assign_egre_key(self):
320         if not self.get_tag('egre_key'):
321             try:
322                 key = self.new_egre_key()
323                 self.update_tag('egre_key', key)
324             except:
325                 # Should handle this case...
326                 pass
327         return
328             
329     def turn_on_netns(self):
330         tag = self.get_tag('netns')
331         if (not tag) or (tag.value != '1'):
332             self.update_tag('netns', '1')
333         return
334    
335     def turn_off_netns(self):
336         tag = self.get_tag('netns')
337         if tag and (tag.value != '0'):
338             tag.delete()
339         return
340     
341     def add_cap_net_admin(self):
342         tag = self.get_tag('capabilities')
343         if tag:
344             caps = tag.value.split(',')
345             for cap in caps:
346                 if cap == "CAP_NET_ADMIN":
347                     return
348             else:
349                 newcaps = "CAP_NET_ADMIN," + tag.value
350                 self.update_tag('capabilities', newcaps)
351         else:
352             self.add_tag('capabilities', 'CAP_NET_ADMIN')
353         return
354     
355     def remove_cap_net_admin(self):
356         tag = self.get_tag('capabilities')
357         if tag:
358             caps = tag.value.split(',')
359             newcaps = []
360             for cap in caps:
361                 if cap != "CAP_NET_ADMIN":
362                     newcaps.append(cap)
363             if newcaps:
364                 value = ','.join(newcaps)
365                 self.update_tag('capabilities', value)
366             else:
367                 tag.delete()
368         return
369
370     # Update the vsys/setup-link and vsys/setup-nat slice tags.
371     def add_vsys_tags(self):
372         link = nat = False
373         for i in self.slice_tag_ids:
374             tag = self.network.lookupSliceTag(i)
375             if tag.tagname == 'vsys':
376                 if tag.value == 'setup-link':
377                     link = True
378                 elif tag.value == 'setup-nat':
379                     nat = True
380         if not link:
381             self.add_tag('vsys', 'setup-link')
382         if not nat:
383             self.add_tag('vsys', 'setup-nat')
384         return
385
386
387 class Slicetag:
388     newid = -1 
389     def __init__(self, tag):
390         self.id = tag['slice_tag_id']
391         if not self.id:
392             # Make one up for the time being...
393             self.id = Slicetag.newid
394             Slicetag.newid -= 1
395         self.slice_id = tag['slice_id']
396         self.tagname = tag['tagname']
397         self.value = tag['value']
398         self.node_id = tag['node_id']
399         self.updated = False
400         self.changed = False
401         self.deleted = False
402     
403     # Mark a tag as deleted
404     def delete(self):
405         self.deleted = True
406         self.updated = True
407     
408     def write(self, api):
409         if self.changed:
410             if int(self.id) > 0:
411                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
412             else:
413                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
414                                         self.tagname, self.value, self.node_id)
415         elif self.deleted and int(self.id) > 0:
416             api.plshell.DeleteSliceTag(api.plauth, self.id)
417
418
419 """
420 A Network is a compound object consisting of:
421 * a dictionary mapping site IDs to Site objects
422 * a dictionary mapping node IDs to Node objects
423 * a dictionary mapping interface IDs to Iface objects
424 * the Site objects are connected via Link objects representing
425   the physical topology and available bandwidth
426 * the Node objects are connected via Link objects representing
427   the requested or assigned virtual topology of a slice
428 """
429 class Network:
430     def __init__(self, api, type = "PlanetLab", physical_links = [], 
431                  schema = None):
432         self.api = api
433         self.type = type
434         self.sites = self.get_sites(api)
435         self.nodes = self.get_nodes(api)
436         self.ifaces = self.get_ifaces(api)
437         self.tags = self.get_slice_tags(api)
438         self.slice = None
439         self.sitelinks = []
440         self.nodelinks = []
441         self.schema = schema
442     
443         for (s1, s2) in physical_links:
444             self.sitelinks.append(Link(self.sites[s1], self.sites[s2]))
445         
446         for t in self.tags:
447             tag = self.tags[t]
448             if tag.tagname == 'topo_rspec':
449                 node1 = self.nodes[tag.node_id]
450                 l = eval(tag.value)
451                 for (id, realip, bw, lvip, rvip, vnet) in l:
452                     allocbps = get_tc_rate(bw)
453                     node1.bps -= allocbps
454                     try:
455                         node2 = self.nodes[id]
456                         if node1.id < node2.id:
457                             sl = node1.get_sitelink(node2, self.sites)
458                             sl.bps -= allocbps
459                     except:
460                         pass
461
462     
463     """ Lookup site based on id or idtag value """
464     def lookupSite(self, id):
465         val = None
466         if isinstance(id, basestring):
467             id = int(id.lstrip('s'))
468         try:
469             val = self.sites[id]
470         except:
471             raise KeyError("site ID %s not found" % id)
472         return val
473     
474     def getSites(self):
475         sites = []
476         for s in self.sites:
477             sites.append(self.sites[s])
478         return sites
479         
480     """ Lookup node based on id or idtag value """
481     def lookupNode(self, id):
482         val = None
483         if isinstance(id, basestring):
484             id = int(id.lstrip('n'))
485         try:
486             val = self.nodes[id]
487         except:
488             raise KeyError("node ID %s not found" % id)
489         return val
490     
491     def getNodes(self):
492         nodes = []
493         for n in self.nodes:
494             nodes.append(self.nodes[n])
495         return nodes
496     
497     """ Lookup iface based on id or idtag value """
498     def lookupIface(self, id):
499         val = None
500         if isinstance(id, basestring):
501             id = int(id.lstrip('i'))
502         try:
503             val = self.ifaces[id]
504         except:
505             raise KeyError("interface ID %s not found" % id)
506         return val
507     
508     def getIfaces(self):
509         ifaces = []
510         for i in self.ifaces:
511             ifaces.append(self.ifaces[i])
512         return ifaces
513     
514     def nodesInNetwork(self):
515         nodes = []
516         for n in self.nodes:
517             node = self.nodes[n]
518             if node.sliver:
519                 nodes.append(node)
520         return nodes
521             
522     def lookupSliceTag(self, id):
523         val = None
524         try:
525             val = self.tags[id]
526         except:
527             raise KeyError("slicetag ID %s not found" % id)
528         return val
529     
530     def getSliceTags(self):
531         tags = []
532         for t in self.tags:
533             tags.append(self.tags[t])
534         return tags
535     
536     def lookupSiteLink(self, node1, node2):
537         site1 = self.sites[node1.site_id]
538         site2 = self.sites[node2.site_id]
539         for link in self.sitelinks:
540             if site1 == link.end1 and site2 == link.end2:
541                 return link
542             if site2 == link.end1 and site1 == link.end2:
543                 return link
544         return None
545     
546
547     def __add_vlink(self, vlink, slicenodes, parent = None):
548         n1 = n2 = None
549         endpoints = vlink.get("endpoints")
550         if endpoints:
551             (end1, end2) = endpoints.split()
552             n1 = self.lookupNode(end1)
553             n2 = self.lookupNode(end2)
554         elif parent:
555             """ Try to infer the endpoints for the virtual link """
556             site_endpoints = parent.get("endpoints")
557             (n1, n2) = self.__infer_endpoints(site_endpoints, slicenodes)
558         else:
559             raise Error("no endpoints given")
560
561         #print "Added virtual link: %s -- %s" % (n1.tag, n2.tag)
562         bps = int(vlink.findtext("kbps")) * 1000
563         sitelink = self.lookupSiteLink(n1, n2)
564         if not sitelink:
565             raise PermissionError("nodes %s and %s not adjacent" % 
566                                   (n1.idtag, n2.idtag))
567         self.nodelinks.append(Link(n1, n2, bps, sitelink))
568         return
569
570     """ 
571     Infer the endpoints of the virtual link.  If the slice exists on 
572     only a single node at each end of the physical link, we'll assume that
573     the user wants the virtual link to terminate at these nodes.
574     """
575     def __infer_endpoints(self, endpoints, slicenodes):
576         n = []
577         ends = endpoints.split()
578         for end in ends:
579             found = 0
580             site = self.lookupSite(end)
581             for id in site.node_ids:
582                 if id in slicenodes:
583                     n.append(slicenodes[id])
584                     found += 1
585             if found != 1:
586                 raise Error("could not infer endpoint for site %s" % site.id)
587         #print "Inferred endpoints: %s %s" % (n[0].idtag, n[1].idtag)
588         return n
589         
590     def annotateFromRSpec(self, xml):
591         if self.nodelinks:
592             raise Error("virtual topology already present")
593             
594         nodedict = {}
595         for node in self.getNodes():
596             nodedict[node.idtag] = node
597             
598         slicenodes = {}
599
600         tree = etree.parse(StringIO(xml))
601
602         if self.schema:
603             # Validate the incoming request against the RelaxNG schema
604             relaxng_doc = etree.parse(self.schema)
605             relaxng = etree.RelaxNG(relaxng_doc)
606         
607             if not relaxng(tree):
608                 error = relaxng.error_log.last_error
609                 message = "%s (line %s)" % (error.message, error.line)
610                 raise InvalidRSpec(message)
611
612         rspec = tree.getroot()
613
614         """
615         Handle requests where the user has annotated a description of the
616         physical resources (nodes and links) with virtual ones (slivers
617         and vlinks).
618         """
619         # Find slivers under node elements
620         for sliver in rspec.iterfind("./network/site/node/sliver"):
621             elem = sliver.getparent()
622             node = nodedict[elem.get("id")]
623             slicenodes[node.id] = node
624             node.add_sliver()
625
626         # Find vlinks under link elements
627         for vlink in rspec.iterfind("./network/link/vlink"):
628             link = vlink.getparent()
629             self.__add_vlink(vlink, slicenodes, link)
630
631         """
632         Handle requests where the user has listed the virtual resources only
633         """
634         # Find slivers that specify nodeid
635         for sliver in rspec.iterfind("./request/sliver[@nodeid]"):
636             node = nodedict[sliver.get("nodeid")]
637             slicenodes[node.id] = node
638             node.add_sliver()
639
640         # Find vlinks that specify endpoints
641         for vlink in rspec.iterfind("./request/vlink[@endpoints]"):
642             self.__add_vlink(vlink, slicenodes)
643
644         return
645
646     def annotateFromSliceTags(self, slice):
647         if self.nodelinks:
648             raise Error("virtual topology already present")
649             
650         for node in slice.get_nodes(self.nodes):
651             node.sliver = True
652             linktag = slice.get_tag('topo_rspec', self.tags, node)
653             if linktag:
654                 l = eval(linktag.value)
655                 for (id, realip, bw, lvip, rvip, vnet) in l:
656                     if node.id < id:
657                         bps = get_tc_rate(bw)
658                         remote = self.lookupNode(id)
659                         sitelink = self.lookupSiteLink(node, remote)
660                         self.nodelinks.append(Link(node,remote,bps,sitelink))
661
662     def updateSliceTags(self, slice):
663         if not self.nodelinks:
664             return
665  
666         slice.update_tag('vini_topo', 'manual', self.tags)
667         slice.assign_egre_key(self.tags)
668         slice.turn_on_netns(self.tags)
669         slice.add_cap_net_admin(self.tags)
670
671         for node in slice.get_nodes(self.nodes):
672             linkdesc = []
673             for link in node.links:
674                 linkdesc.append(node.get_topo_rspec(link))
675             if linkdesc:
676                 topo_str = "%s" % linkdesc
677                 slice.update_tag('topo_rspec', topo_str, self.tags, node)
678
679         # Update slice tags in database
680         for tag in self.getSliceTags():
681             if tag.slice_id == slice.id:
682                 if tag.tagname == 'topo_rspec' and not tag.updated:
683                     tag.delete()
684                 tag.write(self.api)
685                 
686     """
687     Check the requested topology against the available topology and capacity
688     """
689     def verifyNodeNetwork(self, hrn, topo):
690         for link in self.nodelinks:
691             if link.bps <= 0:
692                 raise GeniInvalidArgument(bw, "BW")
693                 
694             n1 = link.end1
695             n2 = link.end2
696             sitelink = self.lookupSiteLink(n1, n2)
697             if not sitelink:
698                 raise PermissionError("%s: nodes %s and %s not adjacent" % (hrn, n1.tag, n2.tag))
699             if sitelink.bps < link.bps:
700                 raise PermissionError("%s: insufficient capacity between %s and %s" % (hrn, n1.tag, n2.tag))
701                 
702     """
703     Produce XML directly from the topology specification.
704     """
705     def toxml(self, hrn = None):
706         xml = XMLBuilder(format = True, tab_step = "  ")
707         with xml.RSpec(type="VINI"):
708             if hrn:
709                 element = xml.network(name="Public_VINI", slice=hrn)
710             else:
711                 element = xml.network(name="Public_VINI")
712                 
713             with element:
714                 for site in self.getSites():
715                     site.toxml(xml, hrn, self.nodes)
716                 for link in self.sitelinks:
717                     link.toxml(xml)
718
719         header = '<?xml version="1.0"?>\n'
720         return header + str(xml)
721
722     """
723     Create a dictionary of site objects keyed by site ID
724     """
725     def get_sites(self, api):
726         tmp = []
727         for site in api.plshell.GetSites(api.plauth):
728             t = site['site_id'], Site(self, site)
729             tmp.append(t)
730         return dict(tmp)
731
732
733     """
734     Create a dictionary of node objects keyed by node ID
735     """
736     def get_nodes(self, api):
737         tmp = []
738         for node in api.plshell.GetNodes(api.plauth):
739             t = node['node_id'], Node(self, node)
740             tmp.append(t)
741         return dict(tmp)
742
743     """
744     Create a dictionary of node objects keyed by node ID
745     """
746     def get_ifaces(self, api):
747         tmp = []
748         for iface in api.plshell.GetInterfaces(api.plauth):
749             t = iface['interface_id'], Iface(self, iface)
750             tmp.append(t)
751         return dict(tmp)
752
753     """
754     Create a dictionary of slicetag objects keyed by slice tag ID
755     """
756     def get_slice_tags(self, api):
757         tmp = []
758         for tag in api.plshell.GetSliceTags(api.plauth):
759             t = tag['slice_tag_id'], Slicetag(tag)
760             tmp.append(t)
761         return dict(tmp)
762     
763     """
764     Return a Slice object for a single slice
765     """
766     def get_slice(self, api, slicename):
767         slice = api.plshell.GetSlices(api.plauth, [slicename])
768         if slice:
769             return Slice(self, slice[0])
770         else:
771             return None
772     
773