Cleanup, added whitelisting to XML generation (though this should really go in sfatables)
[sfa.git] / sfa / plc / network.py
1 from __future__ import with_statement
2 import re
3 import socket
4 from sfa.util.namespace import *
5 from sfa.util.faults import *
6 from xmlbuilder import XMLBuilder
7 #from lxml import etree
8 import sys
9 from StringIO import StringIO
10
11 # Taken from bwlimit.py
12 #
13 # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
14 # warned that older versions of tc interpret "kbps", "mbps", "mbit",
15 # and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
16 # "kibit" and that if an older version is installed, all rates will
17 # be off by a small fraction.
18 suffixes = {
19     "":         1,
20     "bit":      1,
21     "kibit":    1024,
22     "kbit":     1000,
23     "mibit":    1024*1024,
24     "mbit":     1000000,
25     "gibit":    1024*1024*1024,
26     "gbit":     1000000000,
27     "tibit":    1024*1024*1024*1024,
28     "tbit":     1000000000000,
29     "bps":      8,
30     "kibps":    8*1024,
31     "kbps":     8000,
32     "mibps":    8*1024*1024,
33     "mbps":     8000000,
34     "gibps":    8*1024*1024*1024,
35     "gbps":     8000000000,
36     "tibps":    8*1024*1024*1024*1024,
37     "tbps":     8000000000000
38 }
39
40
41 def get_tc_rate(s):
42     """
43     Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
44     """
45
46     if type(s) == int:
47         return s
48     m = re.match(r"([0-9.]+)(\D*)", s)
49     if m is None:
50         return -1
51     suffix = m.group(2).lower()
52     if suffixes.has_key(suffix):
53         return int(float(m.group(1)) * suffixes[suffix])
54     else:
55         return -1
56
57 def format_tc_rate(rate):
58     """
59     Formats a bits/second rate into a tc rate string
60     """
61
62     if rate >= 1000000000 and (rate % 1000000000) == 0:
63         return "%.0fgbit" % (rate / 1000000000.)
64     elif rate >= 1000000 and (rate % 1000000) == 0:
65         return "%.0fmbit" % (rate / 1000000.)
66     elif rate >= 1000:
67         return "%.0fkbit" % (rate / 1000.)
68     else:
69         return "%.0fbit" % rate 
70
71
72 class Iface:
73     def __init__(self, network, iface):
74         self.network = network
75         self.id = iface['interface_id']
76         self.idtag = "i%s" % self.id
77         self.ipv4 = iface['ip']
78         self.bwlimit = iface['bwlimit']
79         self.hostname = iface['hostname']
80
81     """
82     Just print out bwlimit right now
83     """
84     def toxml(self, xml):
85         if self.bwlimit:
86             with xml.bwlimit:
87                 xml << format_tc_rate(self.bwlimit)
88
89
90 class Node:
91     def __init__(self, network, node, bps = 1000 * 1000000):
92         self.network = network
93         self.id = node['node_id']
94         self.idtag = "n%s" % self.id
95         self.hostname = node['hostname']
96         self.name = self.shortname = self.hostname.replace('.vini-veritas.net', '')
97         self.site_id = node['site_id']
98         #self.ipaddr = socket.gethostbyname(self.hostname)
99         self.bps = bps
100         self.links = set()
101         self.iface_ids = node['interface_ids']
102         self.iface_ids.sort()
103         self.sliver = False
104         self.whitelist = node['slice_ids_whitelist']
105
106     def get_link_id(self, remote):
107         if self.id < remote.id:
108             link = (self.id<<7) + remote.id
109         else:
110             link = (remote.id<<7) + self.id
111         return link
112         
113     def get_iface_id(self, remote):
114         if self.id < remote.id:
115             iface = 1
116         else:
117             iface = 2
118         return iface
119     
120     def get_ifaces(self):
121         i = []
122         for id in self.iface_ids:
123             i.append(self.network.lookupIface(id))
124             # Only return the first interface
125             break
126         return i
127         
128     def get_virt_ip(self, remote):
129         link = self.get_link_id(remote)
130         iface = self.get_iface_id(remote)
131         first = link >> 6
132         second = ((link & 0x3f)<<2) + iface
133         return "192.168.%d.%d" % (first, second)
134
135     def get_virt_net(self, remote):
136         link = self.get_link_id(remote)
137         first = link >> 6
138         second = (link & 0x3f)<<2
139         return "192.168.%d.%d/30" % (first, second)
140         
141     def get_site(self):
142         return self.network.lookupSite(self.site_id)
143     
144     def get_topo_rspec(self, link):
145         if link.end1 == self:
146             remote = link.end2
147         elif link.end2 == self:
148             remote = link.end1
149         else:
150             raise Error("Link does not connect to Node")
151             
152         my_ip = self.get_virt_ip(remote)
153         remote_ip = remote.get_virt_ip(self)
154         net = self.get_virt_net(remote)
155         bw = format_tc_rate(link.bps)
156         return (remote.id, remote.ipaddr, bw, my_ip, remote_ip, net)
157         
158     def add_link(self, link):
159         self.links.add(link)
160         
161     # Assumes there is at most one Link between two sites
162     def get_sitelink(self, node, sites):
163         site1 = sites[self.site_id]
164         site2 = sites[node.site_id]
165         sl = site1.links.intersection(site2.links)
166         if len(sl):
167             return sl.pop()
168         return None
169
170     def add_sliver(self):
171         self.sliver = True
172
173     def toxml(self, xml):
174         slice = self.network.slice
175         if self.whitelist and not self.sliver:
176             if not slice or slice.id not in self.whitelist:
177                 return
178
179         with xml.node(id = self.idtag):
180             with xml.hostname:
181                 xml << self.hostname
182             if self.network.type == "VINI":
183                 with xml.kbps:
184                     xml << str(int(self.bps/1000))
185             for iface in self.get_ifaces():
186                 iface.toxml(xml)
187             if self.sliver:
188                 with xml.sliver:
189                     pass
190     
191
192 class Link:
193     def __init__(self, end1, end2, bps = 1000000000, parent = None):
194         self.end1 = end1
195         self.end2 = end2
196         self.bps = bps
197         self.parent = parent
198         self.children = []
199
200         end1.add_link(self)
201         end2.add_link(self)
202         
203         if self.parent:
204             self.parent.children.append(self)
205             
206     def toxml(self, xml):
207         end_ids = "%s %s" % (self.end1.idtag, self.end2.idtag)
208
209         if self.parent:
210             element = xml.vlink(endpoints=end_ids)
211         else:
212             element = xml.link(endpoints=end_ids)
213
214         with element:
215             with xml.description:
216                 xml << "%s -- %s" % (self.end1.name, self.end2.name)
217             with xml.kbps:
218                 xml << str(int(self.bps/1000))
219             for child in self.children:
220                 child.toxml(xml)
221         
222
223 class Site:
224     def __init__(self, network, site):
225         self.network = network
226         self.id = site['site_id']
227         self.idtag = "s%s" % self.id
228         self.node_ids = site['node_ids']
229         self.node_ids.sort()
230         self.name = site['abbreviated_name']
231         self.tag = site['login_base']
232         self.public = site['is_public']
233         self.enabled = site['enabled']
234         self.links = set()
235         self.whitelist = False
236
237     def get_sitenodes(self):
238         n = []
239         for i in self.node_ids:
240             n.append(self.network.lookupNode(i))
241         return n
242     
243     def add_link(self, link):
244         self.links.add(link)
245
246     def toxml(self, xml):
247         if not (self.public and self.enabled and self.node_ids):
248             return
249         with xml.site(id = self.idtag):
250             with xml.name:
251                 xml << self.name
252                 
253             for node in self.get_sitenodes():
254                 node.toxml(xml)
255    
256     
257 class Slice:
258     def __init__(self, network, hrn, slice):
259         self.hrn = hrn
260         self.network = network
261         self.id = slice['slice_id']
262         self.name = slice['name']
263         self.node_ids = set(slice['node_ids'])
264         self.slice_tag_ids = slice['slice_tag_ids']
265     
266     def get_tag(self, tagname, node = None):
267         for i in self.slice_tag_ids:
268             tag = self.network.lookupSliceTag(i)
269             if tag.tagname == tagname:
270                 if (not node) or (node.id == tag.node_id):
271                     return tag
272         else:
273             return None
274         
275     def get_nodes(self, nodes):
276         n = []
277         for id in self.node_ids:
278             n.append(nodes[id])
279         return n
280   
281     # Add a new slice tag   
282     def add_tag(self, tagname, value, node = None):
283         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
284         if node:
285             record['node_id'] = node.id
286         else:
287             record['node_id'] = None
288         tag = Slicetag(record)
289         self.network.slicetags[tag.id] = tag
290         self.slice_tag_ids.append(tag.id)
291         tag.changed = True       
292         tag.updated = True
293         return tag
294     
295     # Update a slice tag if it exists, else add it             
296     def update_tag(self, tagname, value, node = None):
297         tag = self.get_tag(tagname, node)
298         if tag and tag.value == value:
299             value = "no change"
300         elif tag:
301             tag.value = value
302             tag.changed = True
303         else:
304             tag = self.add_tag(tagname, value, node)
305         tag.updated = True
306             
307     """
308     Find a free EGRE key
309     """
310     def new_egre_key():
311         slicetags = self.network.slicetags
312         used = set()
313         for i in slicetags:
314             tag = slicetags[i]
315             if tag.tagname == 'egre_key':
316                 used.add(int(tag.value))
317                 
318         for i in range(1, 256):
319             if i not in used:
320                 key = i
321                 break
322         else:
323             raise KeyError("No more EGRE keys available")
324         
325         return "%s" % key
326    
327
328     def assign_egre_key(self):
329         if not self.get_tag('egre_key'):
330             try:
331                 key = self.new_egre_key()
332                 self.update_tag('egre_key', key)
333             except:
334                 # Should handle this case...
335                 pass
336         return
337             
338     def turn_on_netns(self):
339         tag = self.get_tag('netns')
340         if (not tag) or (tag.value != '1'):
341             self.update_tag('netns', '1')
342         return
343    
344     def turn_off_netns(self):
345         tag = self.get_tag('netns')
346         if tag and (tag.value != '0'):
347             tag.delete()
348         return
349     
350     def add_cap_net_admin(self):
351         tag = self.get_tag('capabilities')
352         if tag:
353             caps = tag.value.split(',')
354             for cap in caps:
355                 if cap == "CAP_NET_ADMIN":
356                     return
357             else:
358                 newcaps = "CAP_NET_ADMIN," + tag.value
359                 self.update_tag('capabilities', newcaps)
360         else:
361             self.add_tag('capabilities', 'CAP_NET_ADMIN')
362         return
363     
364     def remove_cap_net_admin(self):
365         tag = self.get_tag('capabilities')
366         if tag:
367             caps = tag.value.split(',')
368             newcaps = []
369             for cap in caps:
370                 if cap != "CAP_NET_ADMIN":
371                     newcaps.append(cap)
372             if newcaps:
373                 value = ','.join(newcaps)
374                 self.update_tag('capabilities', value)
375             else:
376                 tag.delete()
377         return
378
379     # Update the vsys/setup-link and vsys/setup-nat slice tags.
380     def add_vsys_tags(self):
381         link = nat = False
382         for i in self.slice_tag_ids:
383             tag = self.network.lookupSliceTag(i)
384             if tag.tagname == 'vsys':
385                 if tag.value == 'setup-link':
386                     link = True
387                 elif tag.value == 'setup-nat':
388                     nat = True
389         if not link:
390             self.add_tag('vsys', 'setup-link')
391         if not nat:
392             self.add_tag('vsys', 'setup-nat')
393         return
394
395
396 class Slicetag:
397     newid = -1 
398     def __init__(self, tag):
399         self.id = tag['slice_tag_id']
400         if not self.id:
401             # Make one up for the time being...
402             self.id = Slicetag.newid
403             Slicetag.newid -= 1
404         self.slice_id = tag['slice_id']
405         self.tagname = tag['tagname']
406         self.value = tag['value']
407         self.node_id = tag['node_id']
408         self.updated = False
409         self.changed = False
410         self.deleted = False
411     
412     # Mark a tag as deleted
413     def delete(self):
414         self.deleted = True
415         self.updated = True
416     
417     def write(self, api):
418         if self.changed:
419             if int(self.id) > 0:
420                 api.plshell.UpdateSliceTag(api.plauth, self.id, self.value)
421             else:
422                 api.plshell.AddSliceTag(api.plauth, self.slice_id, 
423                                         self.tagname, self.value, self.node_id)
424         elif self.deleted and int(self.id) > 0:
425             api.plshell.DeleteSliceTag(api.plauth, self.id)
426
427
428 """
429 A Network is a compound object consisting of:
430 * a dictionary mapping site IDs to Site objects
431 * a dictionary mapping node IDs to Node objects
432 * a dictionary mapping interface IDs to Iface objects
433 * the Site objects are connected via Link objects representing
434   the physical topology and available bandwidth
435 * the Node objects are connected via Link objects representing
436   the requested or assigned virtual topology of a slice
437 """
438 class Network:
439     def __init__(self, api, type = "PlanetLab", physical_links = [], 
440                  schema = None):
441         self.api = api
442         self.type = type
443         self.sites = self.get_sites(api)
444         self.nodes = self.get_nodes(api)
445         self.ifaces = self.get_ifaces(api)
446         self.tags = self.get_slice_tags(api)
447         self.slice = None
448         self.sitelinks = []
449         self.nodelinks = []
450         self.schema = schema
451     
452         for (s1, s2) in physical_links:
453             self.sitelinks.append(Link(self.sites[s1], self.sites[s2]))
454         
455         for t in self.tags:
456             tag = self.tags[t]
457             if tag.tagname == 'topo_rspec':
458                 node1 = self.nodes[tag.node_id]
459                 l = eval(tag.value)
460                 for (id, realip, bw, lvip, rvip, vnet) in l:
461                     allocbps = get_tc_rate(bw)
462                     node1.bps -= allocbps
463                     try:
464                         node2 = self.nodes[id]
465                         if node1.id < node2.id:
466                             sl = node1.get_sitelink(node2, self.sites)
467                             sl.bps -= allocbps
468                     except:
469                         pass
470
471     
472     """ Lookup site based on id or idtag value """
473     def lookupSite(self, id):
474         val = None
475         if isinstance(id, basestring):
476             id = int(id.lstrip('s'))
477         try:
478             val = self.sites[id]
479         except:
480             raise KeyError("site ID %s not found" % id)
481         return val
482     
483     def getSites(self):
484         sites = []
485         for s in self.sites:
486             sites.append(self.sites[s])
487         return sites
488         
489     """ Lookup node based on id or idtag value """
490     def lookupNode(self, id):
491         val = None
492         if isinstance(id, basestring):
493             id = int(id.lstrip('n'))
494         try:
495             val = self.nodes[id]
496         except:
497             raise KeyError("node ID %s not found" % id)
498         return val
499     
500     def getNodes(self):
501         nodes = []
502         for n in self.nodes:
503             nodes.append(self.nodes[n])
504         return nodes
505     
506     """ Lookup iface based on id or idtag value """
507     def lookupIface(self, id):
508         val = None
509         if isinstance(id, basestring):
510             id = int(id.lstrip('i'))
511         try:
512             val = self.ifaces[id]
513         except:
514             raise KeyError("interface ID %s not found" % id)
515         return val
516     
517     def getIfaces(self):
518         ifaces = []
519         for i in self.ifaces:
520             ifaces.append(self.ifaces[i])
521         return ifaces
522     
523     def nodesInNetwork(self):
524         nodes = []
525         for n in self.nodes:
526             node = self.nodes[n]
527             if node.sliver:
528                 nodes.append(node)
529         return nodes
530             
531     def lookupSliceTag(self, id):
532         val = None
533         try:
534             val = self.tags[id]
535         except:
536             raise KeyError("slicetag ID %s not found" % id)
537         return val
538     
539     def getSliceTags(self):
540         tags = []
541         for t in self.tags:
542             tags.append(self.tags[t])
543         return tags
544     
545     def lookupSiteLink(self, node1, node2):
546         site1 = self.sites[node1.site_id]
547         site2 = self.sites[node2.site_id]
548         for link in self.sitelinks:
549             if site1 == link.end1 and site2 == link.end2:
550                 return link
551             if site2 == link.end1 and site1 == link.end2:
552                 return link
553         return None
554     
555
556     def __add_vlink(self, vlink, slicenodes, parent = None):
557         n1 = n2 = None
558         endpoints = vlink.get("endpoints")
559         if endpoints:
560             (end1, end2) = endpoints.split()
561             n1 = self.lookupNode(end1)
562             n2 = self.lookupNode(end2)
563         elif parent:
564             """ Try to infer the endpoints for the virtual link """
565             site_endpoints = parent.get("endpoints")
566             (n1, n2) = self.__infer_endpoints(site_endpoints, slicenodes)
567         else:
568             raise Error("no endpoints given")
569
570         #print "Added virtual link: %s -- %s" % (n1.tag, n2.tag)
571         bps = int(vlink.findtext("kbps")) * 1000
572         sitelink = self.lookupSiteLink(n1, n2)
573         if not sitelink:
574             raise PermissionError("nodes %s and %s not adjacent" % 
575                                   (n1.idtag, n2.idtag))
576         self.nodelinks.append(Link(n1, n2, bps, sitelink))
577         return
578
579     """ 
580     Infer the endpoints of the virtual link.  If the slice exists on 
581     only a single node at each end of the physical link, we'll assume that
582     the user wants the virtual link to terminate at these nodes.
583     """
584     def __infer_endpoints(self, endpoints, slicenodes):
585         n = []
586         ends = endpoints.split()
587         for end in ends:
588             found = 0
589             site = self.lookupSite(end)
590             for id in site.node_ids:
591                 if id in slicenodes:
592                     n.append(slicenodes[id])
593                     found += 1
594             if found != 1:
595                 raise Error("could not infer endpoint for site %s" % site.id)
596         #print "Inferred endpoints: %s %s" % (n[0].idtag, n[1].idtag)
597         return n
598         
599     def annotateFromRSpec(self, xml):
600         if self.nodelinks:
601             raise Error("virtual topology already present")
602             
603         nodedict = {}
604         for node in self.getNodes():
605             nodedict[node.idtag] = node
606             
607         slicenodes = {}
608
609         tree = etree.parse(StringIO(xml))
610
611         if self.schema:
612             # Validate the incoming request against the RelaxNG schema
613             relaxng_doc = etree.parse(self.schema)
614             relaxng = etree.RelaxNG(relaxng_doc)
615         
616             if not relaxng(tree):
617                 error = relaxng.error_log.last_error
618                 message = "%s (line %s)" % (error.message, error.line)
619                 raise InvalidRSpec(message)
620
621         rspec = tree.getroot()
622
623         """
624         Handle requests where the user has annotated a description of the
625         physical resources (nodes and links) with virtual ones (slivers
626         and vlinks).
627         """
628         # Find slivers under node elements
629         for sliver in rspec.iterfind("./network/site/node/sliver"):
630             elem = sliver.getparent()
631             node = nodedict[elem.get("id")]
632             slicenodes[node.id] = node
633             node.add_sliver()
634
635         # Find vlinks under link elements
636         for vlink in rspec.iterfind("./network/link/vlink"):
637             link = vlink.getparent()
638             self.__add_vlink(vlink, slicenodes, link)
639
640         """
641         Handle requests where the user has listed the virtual resources only
642         """
643         # Find slivers that specify nodeid
644         for sliver in rspec.iterfind("./request/sliver[@nodeid]"):
645             node = nodedict[sliver.get("nodeid")]
646             slicenodes[node.id] = node
647             node.add_sliver()
648
649         # Find vlinks that specify endpoints
650         for vlink in rspec.iterfind("./request/vlink[@endpoints]"):
651             self.__add_vlink(vlink, slicenodes)
652
653         return
654
655     def annotateFromSliceTags(self):
656         slice = self.slice
657         if not slice:
658             raise Error("no slice associated with network")
659
660         if self.nodelinks:
661             raise Error("virtual topology already present")
662             
663         for node in slice.get_nodes(self.nodes):
664             node.sliver = True
665             linktag = slice.get_tag('topo_rspec', node)
666             if linktag:
667                 l = eval(linktag.value)
668                 for (id, realip, bw, lvip, rvip, vnet) in l:
669                     if node.id < id:
670                         bps = get_tc_rate(bw)
671                         remote = self.lookupNode(id)
672                         sitelink = self.lookupSiteLink(node, remote)
673                         self.nodelinks.append(Link(node,remote,bps,sitelink))
674
675     def updateSliceTags(self, slice):
676         if not self.nodelinks:
677             return
678  
679         slice.update_tag('vini_topo', 'manual', self.tags)
680         slice.assign_egre_key(self.tags)
681         slice.turn_on_netns(self.tags)
682         slice.add_cap_net_admin(self.tags)
683
684         for node in slice.get_nodes(self.nodes):
685             linkdesc = []
686             for link in node.links:
687                 linkdesc.append(node.get_topo_rspec(link))
688             if linkdesc:
689                 topo_str = "%s" % linkdesc
690                 slice.update_tag('topo_rspec', topo_str, self.tags, node)
691
692         # Update slice tags in database
693         for tag in self.getSliceTags():
694             if tag.slice_id == slice.id:
695                 if tag.tagname == 'topo_rspec' and not tag.updated:
696                     tag.delete()
697                 tag.write(self.api)
698                 
699     """
700     Check the requested topology against the available topology and capacity
701     """
702     def verifyNodeNetwork(self, hrn, topo):
703         for link in self.nodelinks:
704             if link.bps <= 0:
705                 raise GeniInvalidArgument(bw, "BW")
706                 
707             n1 = link.end1
708             n2 = link.end2
709             sitelink = self.lookupSiteLink(n1, n2)
710             if not sitelink:
711                 raise PermissionError("%s: nodes %s and %s not adjacent" % (hrn, n1.tag, n2.tag))
712             if sitelink.bps < link.bps:
713                 raise PermissionError("%s: insufficient capacity between %s and %s" % (hrn, n1.tag, n2.tag))
714                 
715     """
716     Produce XML directly from the topology specification.
717     """
718     def toxml(self):
719         xml = XMLBuilder(format = True, tab_step = "  ")
720         with xml.RSpec(type=self.type):
721             name = "Public_" + self.type
722             if self.slice:
723                 element = xml.network(name=name, slice=self.slice.hrn)
724             else:
725                 element = xml.network(name=name)
726                 
727             with element:
728                 for site in self.getSites():
729                     site.toxml(xml)
730                 for link in self.sitelinks:
731                     link.toxml(xml)
732
733         header = '<?xml version="1.0"?>\n'
734         return header + str(xml)
735
736     """
737     Create a dictionary of site objects keyed by site ID
738     """
739     def get_sites(self, api):
740         tmp = []
741         for site in api.plshell.GetSites(api.plauth):
742             t = site['site_id'], Site(self, site)
743             tmp.append(t)
744         return dict(tmp)
745
746
747     """
748     Create a dictionary of node objects keyed by node ID
749     """
750     def get_nodes(self, api):
751         tmp = []
752         for node in api.plshell.GetNodes(api.plauth):
753             t = node['node_id'], Node(self, node)
754             tmp.append(t)
755         return dict(tmp)
756
757     """
758     Create a dictionary of node objects keyed by node ID
759     """
760     def get_ifaces(self, api):
761         tmp = []
762         for iface in api.plshell.GetInterfaces(api.plauth):
763             t = iface['interface_id'], Iface(self, iface)
764             tmp.append(t)
765         return dict(tmp)
766
767     """
768     Create a dictionary of slicetag objects keyed by slice tag ID
769     """
770     def get_slice_tags(self, api):
771         tmp = []
772         for tag in api.plshell.GetSliceTags(api.plauth):
773             t = tag['slice_tag_id'], Slicetag(tag)
774             tmp.append(t)
775         return dict(tmp)
776     
777     """
778     Return a Slice object for a single slice
779     """
780     def get_slice(self, api, hrn):
781         slicename = hrn_to_pl_slicename(hrn)
782         slice = api.plshell.GetSlices(api.plauth, [slicename])
783         if slice:
784             self.slice = Slice(self, slicename, slice[0])
785             return self.slice
786         else:
787             return None
788     
789