add merge_node methods
[sfa.git] / sfa / rspecs / sfa_rspec.py
index 1f997ee..56e302b 100755 (executable)
@@ -1,17 +1,29 @@
-#!/usr/bin/python 
+#!/usr/bin/python
+from copy import deepcopy
 from lxml import etree
 from StringIO import StringIO
 from sfa.rspecs.rspec import RSpec 
 from sfa.util.xrn import *
 from sfa.util.plxrn import hostname_to_urn
-from sfa.util.config import Config  
+from sfa.util.config import Config
+from sfa.rspecs.rspec_version import RSpecVersion  
 
 
+_version = { 'type': 'SFA', 
+             'version': '1' 
+}
+
+sfa_rspec_version = RSpecVersion(_version)
+
 class SfaRSpec(RSpec):
     xml = None
     header = '<?xml version="1.0"?>\n'
-    namespaces = {}
-    format = 'sfa'
+    version = sfa_rspec_version
+
+    def create(self):
+        RSpec.create(self)
+        self.xml.set('type', 'SFA')
+
     ###################
     # Parser
     ###################
@@ -27,25 +39,33 @@ class SfaRSpec(RSpec):
         else:
             names = self.xml.xpath('//node/hostname')
         for name in names:
-            if name.text == hostname:
+            if str(name.text).strip() == hostname:
                 return name.getparent()
         return None
  
-    def get_node_elements(self):
-        return self.xml.xpath('//node')
+    def get_node_elements(self, network=None):
+        if network:
+            return self.xml.xpath('//network[@name="%s"]//node' % network)
+        else:
+            return self.xml.xpath('//node')
 
     def get_nodes(self, network=None):
         if network == None:
             nodes = self.xml.xpath('//node/hostname/text()')
         else:
             nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network)
+
+        nodes = [node.strip() for node in nodes]
         return nodes
 
     def get_nodes_with_slivers(self, network = None):
         if network:
-            return self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
+            nodes =  self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
         else:
-            return self.xml.xpath('//node[sliver]/hostname/text()')
+            nodes = self.xml.xpath('//node[sliver]/hostname/text()')
+
+        nodes = [node.strip() for node in nodes]
+        return nodes     
 
     def get_nodes_without_slivers(self, network=None): 
         xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()'
@@ -62,20 +82,39 @@ class SfaRSpec(RSpec):
         opts = []
         if elem is not None:
             for e in elem:
-                opts.append((e.tag, e.text))
+                opts.append((e.tag, str(e.text).strip()))
         return opts
 
     def get_default_sliver_attributes(self, network=None):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)        
         else:
-            defaults = self.xml.xpath("//network/sliver_defaults" % network)
+            defaults = self.xml.xpath("//sliver_defaults")
+        if isinstance(defaults, list) and defaults:
+            defaults = defaults[0]
         return self.attributes_list(defaults)
 
     def get_sliver_attributes(self, hostname, network=None):
+        attributes = [] 
         node = self.get_node_element(hostname, network)
-        sliver = node.find("sliver")
-        return self.attributes_list(sliver)
+        #sliver = node.find("sliver")
+        slivers = node.xpath('./sliver')
+        if isinstance(slivers, list) and slivers:
+            attributes = self.attributes_list(slivers[0])
+        return attributes
+
+    def get_slice_attributes(self, network=None):
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        for default_attribute in self.get_default_sliver_attributes(network):
+            attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None}
+            slice_attributes.append(attribute)
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node}
+                slice_attributes.append(attribute)    
+        return slice_attributes
 
     def get_site_nodes(self, siteid, network=None):
         if network:
@@ -155,7 +194,12 @@ class SfaRSpec(RSpec):
     ##################
 
     def add_network(self, network):
-        network_tag = etree.SubElement(self.xml, 'network', id=network)     
+        network_tags = self.xml.xpath('//network[@name="%s"]' % network)
+        if not network_tags:
+            network_tag = etree.SubElement(self.xml, 'network', name=network)
+        else:
+            network_tag = network_tags[0]
+        return network_tag     
 
     def add_nodes(self, nodes, network = None, no_dupes=False):
         if not isinstance(nodes, list):
@@ -169,66 +213,88 @@ class SfaRSpec(RSpec):
             network_tag = self.xml
             if 'network' in node:
                 network = node['network']
-                network_tags = self.xml.xpath('//network[@name="%s"]' % network)
-                if not network_tags:
-                    network_tag = etree.SubElement(self.xml, 'network', name=network)
-                else:
-                    network_tag = network_tags[0]
-                     
+                network_tag = self.add_network(network)
+     
             node_tag = etree.SubElement(network_tag, 'node')
             if 'network' in node:
-                node_tag.set('component_manager_id', network)
+                node_tag.set('component_manager_id', hrn_to_urn(network, 'authority+sa'))
             if 'urn' in node:
                 node_tag.set('component_id', node['urn']) 
             if 'site_urn' in node:
                 node_tag.set('site_id', node['site_urn'])
             if 'node_id' in node: 
                 node_tag.set('node_id', 'n'+str(node['node_id']))
+            if 'boot_state' in node:
+                node_tag.set('boot_state', node['boot_state']) 
             if 'hostname' in node:
                 hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
             if 'interfaces' in node:
                 for interface in node['interfaces']:
-                    if 'bwlimit' in interface:
-                        bwlimit = etree.SubElement(node_tag, 'bwlimit', units='kbps').tet = str(interface['bwlimit']/1000)
+                    if 'bwlimit' in interface and interface['bwlimit']:
+                        bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000)
             if 'tags' in node:
                 for tag in node['tags']:
-                   if tag['tagname'] in ['fcdistro', 'arch']:
-                        tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])           
+                   # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category 
+                   if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
+                        tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value']
+
+            if 'site' in node:
+                longitude = str(node['site']['longitude'])
+                latitude = str(node['site']['latitude'])
+                location = etree.SubElement(node_tag, 'location', country='unknown', \
+                                            longitude=longitude, latitude=latitude)
+
+    def merge_node(self, source_node_tag, network, no_dupes=False):
+        if no_dupes and self.get_node_element(node['hostname']):
+            # node already exists
+            return
+
+        network_tag = self.add_network(network)
+        network_tag.append(deepcopy(source_node_tag))
+
     def add_interfaces(self, interfaces):
-        pass     
+        pass
 
     def add_links(self, links):
         pass
-    
-    def add_slivers(self, hostnames, network=None, no_dupes=False):
-        if not isinstance(hostnames, list):
-            hostnames = [hostnames]
 
+    def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False):
+        # add slice name to network tag
+        network_tags = self.xml.xpath('//network')
+        if network_tags:
+            network_tag = network_tags[0]
+            network_tag.set('slice', urn_to_hrn(sliver_urn)[0])
+        slivers = self._process_slivers(slivers)
         nodes_with_slivers = self.get_nodes_with_slivers(network)
-        for hostname in hostnames:
-            if hostname in nodes_with_slivers:
+        for sliver in slivers:
+            if sliver['hostname'] in nodes_with_slivers:
                 continue
-            node = self.get_node_element(hostname, network)
-            etree.SubElement(node, 'sliver')
-
-    def remove_slivers(self, hostnames, network=None, no_dupes=False):
-        if not isinstance(hostnames, list):
-            hostnames = [hostnames]
-        for hostname in hostnames:
-            node = self.get_node_element(hostname, network)
-            sliver = node.find('sliver')
-            if sliver != None:
-                node.remove(sliver)                 
+            node_elem = self.get_node_element(sliver['hostname'], network)
+            sliver_elem = etree.SubElement(node_elem, 'sliver')
+            if 'tags' in sliver:
+                for tag in sliver['tags']:
+                    etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value']
+
+    def remove_slivers(self, slivers, network=None, no_dupes=False):
+        slivers = self._process_slivers(slivers)
+        for sliver in slivers:
+            node_elem = self.get_node_element(sliver['hostname'], network)
+            sliver_elem = node_elem.find('sliver')
+            if sliver_elem != None:
+                node_elem.remove(sliver_elem)
     
     def add_default_sliver_attribute(self, name, value, network=None):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
             defaults = self.xml.xpath("//sliver_defaults" % network)
-        if defaults is None:
-            defaults = etree.Element("sliver_defaults")
-            network = self.xml.xpath("//network[@name='%s']" % network)
-            network.insert(0, defaults)
+        if not defaults :
+            network_tag = self.xml.xpath("//network[@name='%s']" % network)
+            if isinstance(network_tag, list):
+                network_tag = network_tag[0]
+            defaults = self.add_element('sliver_defaults', attrs={}, parent=network_tag)
+        elif isinstance(defaults, list):
+            defaults = defaults[0]
         self.add_attribute(defaults, name, value)
 
     def add_sliver_attribute(self, hostname, name, value, network=None):
@@ -272,15 +338,22 @@ class SfaRSpec(RSpec):
         Merge contents for specified rspec with current rspec 
         """
 
+        from sfa.rspecs.rspec_parser import parse_rspec
+        rspec = parse_rspec(in_rspec)
+        if rspec.type.lower() == 'protogeni':
+            from sfa.rspecs.rspec_converter import RSpecConverter
+            in_rspec = RSpecConverter.to_sfa_rspec(in_rspec)
+            
         # just copy over all networks
         current_networks = self.get_networks()
         rspec = SfaRSpec(rspec=in_rspec)
         networks = rspec.get_network_elements()
         for network in networks:
             current_network = network.get('name')
-            if not current_network in current_networks:
+            if current_network and current_network not in current_networks:
                 self.xml.append(network)
                 current_networks.append(current_network)
+