return only the rspec for the slice
[sfa.git] / sfa / plc / network.py
index 1054725..b3e0d05 100644 (file)
@@ -115,6 +115,7 @@ class Slice:
         self.network = network
         self.id = slice['slice_id']
         self.name = slice['name']
+        self.peer_id = slice['peer_id']
         self.node_ids = set(slice['node_ids'])
         self.slice_tag_ids = slice['slice_tag_ids']
     
@@ -221,7 +222,10 @@ class Slicetag:
         self.slice_id = network.slice.id
         self.tagname = tagname
         self.value = value
-        self.node_id = node.id
+        if node:
+            self.node_id = node.id
+        else:
+            self.node_id = None
         self.category = tt.category
         self.min_role_id = tt.min_role_id
         self.status = "new"
@@ -292,7 +296,7 @@ A Network is a compound object consisting of:
 * a dictionary mapping interface IDs to Iface objects
 """
 class Network:
-    def __init__(self, api, type = "PlanetLab"):
+    def __init__(self, api, type = "SFA"):
         self.api = api
         self.type = type
         self.sites = self.get_sites(api)
@@ -417,6 +421,15 @@ class Network:
             message = str(sys.exc_info()[1])
             raise InvalidRSpec(message)
 
+        # Filter out stuff that's not for us
+        rspec = tree.getroot()
+        for network in rspec.iterfind("./network"):
+            if network.get("name") != self.api.hrn:
+                rspec.remove(network)
+        for request in rspec.iterfind("./request"):
+            if request.get("name") != self.api.hrn:
+                rspec.remove(request)
+
         if schema:
             # Validate the incoming request against the RelaxNG schema
             relaxng_doc = etree.parse(schema)
@@ -427,10 +440,9 @@ class Network:
                 message = "%s (line %s)" % (error.message, error.line)
                 raise InvalidRSpec(message)
 
-        rspec = tree.getroot()
         self.rspec = rspec
 
-        defaults = rspec.find("./network/sliver_defaults")
+        defaults = rspec.find(".//sliver_defaults")
         self.__process_attributes(defaults)
 
         # Find slivers under node elements
@@ -489,17 +501,17 @@ class Network:
     def toxml(self):
         xml = XMLBuilder(format = True, tab_step = "  ")
         with xml.RSpec(type=self.type):
-            name = "Public_" + self.type
             if self.slice:
-                element = xml.network(name=name, slice=self.slice.hrn)
+                element = xml.network(name=self.api.hrn, slice=self.slice.hrn)
             else:
-                element = xml.network(name=name)
+                element = xml.network(name=self.api.hrn)
                 
             with element:
                 if self.slice:
                     self.slice.toxml(xml)
-                for site in self.getSites():
-                    site.toxml(xml)
+                else:
+                    for site in self.getSites():
+                        site.toxml(xml)
 
         header = '<?xml version="1.0"?>\n'
         return header + str(xml)
@@ -561,7 +573,7 @@ class Network:
     def get_slice(self, api, hrn):
         slicename = hrn_to_pl_slicename(hrn)
         slice = api.plshell.GetSlices(api.plauth, [slicename])
-        if slice:
+        if len(slice):
             self.slice = Slice(self, slicename, slice[0])
             return self.slice
         else: