From 105179c4e8cdb5de0d620e149155fa3c816ba024 Mon Sep 17 00:00:00 2001 From: Tony Mack Date: Wed, 25 May 2011 14:07:43 -0400 Subject: [PATCH] add_slivers() accepts a string, list of strings or list of dicts --- sfa/rspecs/pg_rspec.py | 4 +--- sfa/rspecs/rspec.py | 17 ++++++++++++++++- sfa/rspecs/sfa_rspec.py | 4 +--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sfa/rspecs/pg_rspec.py b/sfa/rspecs/pg_rspec.py index 2fd8b766..0feeb671 100755 --- a/sfa/rspecs/pg_rspec.py +++ b/sfa/rspecs/pg_rspec.py @@ -122,9 +122,7 @@ class PGRSpec(RSpec): def add_slivers(self, slivers, sliver_urn=None, no_dupes=False): - if not isinstance(slivers, list): - slivers = [slivers] - + slivers = self.__process_slivers(slivers) nodes_with_slivers = self.get_nodes_with_slivers() for sliver in slivers: hostname = sliver['hostname'] diff --git a/sfa/rspecs/rspec.py b/sfa/rspecs/rspec.py index ad2075b0..af33d411 100755 --- a/sfa/rspecs/rspec.py +++ b/sfa/rspecs/rspec.py @@ -125,13 +125,28 @@ class RSpec: raise InvalidRSpec(message) return True - def cleanup(self): """ Optional method which inheriting classes can choose to implent. """ pass + def __process_slivers(self, slivers): + """ + Creates a dict of sliver details for each sliver host + + @param slivers a single hostname, list of hostanmes or list of dicts keys on hostname, + Returns a list of dicts + """ + if not isinstance(slivers, list): + slivers = [slivers] + dicts = [] + for sliver in slivers: + if isinstance(sliver, dict): + dicts.append(sliver) + elif isinstance(sliver, basestring): + dicts.append({'hostname': sliver}) + return dicts def __str__(self): return self.toxml() diff --git a/sfa/rspecs/sfa_rspec.py b/sfa/rspecs/sfa_rspec.py index 4a9dc492..16d81c12 100755 --- a/sfa/rspecs/sfa_rspec.py +++ b/sfa/rspecs/sfa_rspec.py @@ -223,9 +223,7 @@ class SfaRSpec(RSpec): pass def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False): - if not isinstance(slivers, list): - slivers = [slivers] - + slivers = self.__process_slivers(slivers) nodes_with_slivers = self.get_nodes_with_slivers(network) for sliver in slivers: if sliver['hostname'] in nodes_with_slivers: -- 2.47.0