add merge_node methods
[sfa.git] / sfa / rspecs / sfa_rspec.py
index 3f85618..56e302b 100755 (executable)
@@ -1,4 +1,5 @@
-#!/usr/bin/python 
+#!/usr/bin/python
+from copy import deepcopy
 from lxml import etree
 from StringIO import StringIO
 from sfa.rspecs.rspec import RSpec 
@@ -38,7 +39,7 @@ class SfaRSpec(RSpec):
         else:
             names = self.xml.xpath('//node/hostname')
         for name in names:
-            if name.text == hostname:
+            if str(name.text).strip() == hostname:
                 return name.getparent()
         return None
  
@@ -53,13 +54,18 @@ class SfaRSpec(RSpec):
             nodes = self.xml.xpath('//node/hostname/text()')
         else:
             nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network)
+
+        nodes = [node.strip() for node in nodes]
         return nodes
 
     def get_nodes_with_slivers(self, network = None):
         if network:
-            return self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
+            nodes =  self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
         else:
-            return self.xml.xpath('//node[sliver]/hostname/text()')
+            nodes = self.xml.xpath('//node[sliver]/hostname/text()')
+
+        nodes = [node.strip() for node in nodes]
+        return nodes     
 
     def get_nodes_without_slivers(self, network=None): 
         xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()'
@@ -76,24 +82,39 @@ class SfaRSpec(RSpec):
         opts = []
         if elem is not None:
             for e in elem:
-                opts.append((e.tag, e.text))
+                opts.append((e.tag, str(e.text).strip()))
         return opts
 
     def get_default_sliver_attributes(self, network=None):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)        
         else:
-            defaults = self.xml.xpath("//network/sliver_defaults" % network)
+            defaults = self.xml.xpath("//sliver_defaults")
+        if isinstance(defaults, list) and defaults:
+            defaults = defaults[0]
         return self.attributes_list(defaults)
 
     def get_sliver_attributes(self, hostname, network=None):
+        attributes = [] 
         node = self.get_node_element(hostname, network)
-        sliver = node.find("sliver")
-        return self.attributes_list(sliver)
+        #sliver = node.find("sliver")
+        slivers = node.xpath('./sliver')
+        if isinstance(slivers, list) and slivers:
+            attributes = self.attributes_list(slivers[0])
+        return attributes
 
     def get_slice_attributes(self, network=None):
-        # TODO: FINISH
-        return []
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        for default_attribute in self.get_default_sliver_attributes(network):
+            attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None}
+            slice_attributes.append(attribute)
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node}
+                slice_attributes.append(attribute)    
+        return slice_attributes
 
     def get_site_nodes(self, siteid, network=None):
         if network:
@@ -173,7 +194,12 @@ class SfaRSpec(RSpec):
     ##################
 
     def add_network(self, network):
-        network_tag = etree.SubElement(self.xml, 'network', id=network)     
+        network_tags = self.xml.xpath('//network[@name="%s"]' % network)
+        if not network_tags:
+            network_tag = etree.SubElement(self.xml, 'network', name=network)
+        else:
+            network_tag = network_tags[0]
+        return network_tag     
 
     def add_nodes(self, nodes, network = None, no_dupes=False):
         if not isinstance(nodes, list):
@@ -187,21 +213,19 @@ class SfaRSpec(RSpec):
             network_tag = self.xml
             if 'network' in node:
                 network = node['network']
-                network_tags = self.xml.xpath('//network[@name="%s"]' % network)
-                if not network_tags:
-                    network_tag = etree.SubElement(self.xml, 'network', name=network)
-                else:
-                    network_tag = network_tags[0]
-                     
+                network_tag = self.add_network(network)
+     
             node_tag = etree.SubElement(network_tag, 'node')
             if 'network' in node:
-                node_tag.set('component_manager_id', network)
+                node_tag.set('component_manager_id', hrn_to_urn(network, 'authority+sa'))
             if 'urn' in node:
                 node_tag.set('component_id', node['urn']) 
             if 'site_urn' in node:
                 node_tag.set('site_id', node['site_urn'])
             if 'node_id' in node: 
                 node_tag.set('node_id', 'n'+str(node['node_id']))
+            if 'boot_state' in node:
+                node_tag.set('boot_state', node['boot_state']) 
             if 'hostname' in node:
                 hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
             if 'interfaces' in node:
@@ -212,21 +236,34 @@ class SfaRSpec(RSpec):
                 for tag in node['tags']:
                    # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category 
                    if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
-                        tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])
+                        tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value']
 
             if 'site' in node:
                 longitude = str(node['site']['longitude'])
                 latitude = str(node['site']['latitude'])
                 location = etree.SubElement(node_tag, 'location', country='unknown', \
-                                            longitude=longitude, latitude=latitude)                
+                                            longitude=longitude, latitude=latitude)
+
+    def merge_node(self, source_node_tag, network, no_dupes=False):
+        if no_dupes and self.get_node_element(node['hostname']):
+            # node already exists
+            return
+
+        network_tag = self.add_network(network)
+        network_tag.append(deepcopy(source_node_tag))
 
     def add_interfaces(self, interfaces):
-        pass     
+        pass
 
     def add_links(self, links):
         pass
-    
+
     def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False):
+        # add slice name to network tag
+        network_tags = self.xml.xpath('//network')
+        if network_tags:
+            network_tag = network_tags[0]
+            network_tag.set('slice', urn_to_hrn(sliver_urn)[0])
         slivers = self._process_slivers(slivers)
         nodes_with_slivers = self.get_nodes_with_slivers(network)
         for sliver in slivers:
@@ -236,7 +273,7 @@ class SfaRSpec(RSpec):
             sliver_elem = etree.SubElement(node_elem, 'sliver')
             if 'tags' in sliver:
                 for tag in sliver['tags']:
-                    etree.SubElement(sliver_elem, tag['tagname'], value=tag['value'])
+                    etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value']
 
     def remove_slivers(self, slivers, network=None, no_dupes=False):
         slivers = self._process_slivers(slivers)
@@ -251,9 +288,13 @@ class SfaRSpec(RSpec):
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
             defaults = self.xml.xpath("//sliver_defaults" % network)
-        if defaults is None:
+        if not defaults :
             network_tag = self.xml.xpath("//network[@name='%s']" % network)
+            if isinstance(network_tag, list):
+                network_tag = network_tag[0]
             defaults = self.add_element('sliver_defaults', attrs={}, parent=network_tag)
+        elif isinstance(defaults, list):
+            defaults = defaults[0]
         self.add_attribute(defaults, name, value)
 
     def add_sliver_attribute(self, hostname, name, value, network=None):
@@ -297,15 +338,22 @@ class SfaRSpec(RSpec):
         Merge contents for specified rspec with current rspec 
         """
 
+        from sfa.rspecs.rspec_parser import parse_rspec
+        rspec = parse_rspec(in_rspec)
+        if rspec.type.lower() == 'protogeni':
+            from sfa.rspecs.rspec_converter import RSpecConverter
+            in_rspec = RSpecConverter.to_sfa_rspec(in_rspec)
+            
         # just copy over all networks
         current_networks = self.get_networks()
         rspec = SfaRSpec(rspec=in_rspec)
         networks = rspec.get_network_elements()
         for network in networks:
             current_network = network.get('name')
-            if not current_network in current_networks:
+            if current_network and current_network not in current_networks:
                 self.xml.append(network)
                 current_networks.append(current_network)
+