merge from trunk
[plcapi.git] / PLC / NodeGroups.py
index b27c3a0..65b4a41 100644 (file)
@@ -4,13 +4,14 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.15 2006/10/25 14:29:13 mlhuang Exp $
+# $Id: NodeGroups.py 5666 2007-11-06 21:52:21Z tmack $
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
-from PLC.Parameter import Parameter
+from PLC.Parameter import Parameter, Mixed
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
@@ -32,6 +33,11 @@ class NodeGroup(Row):
         'node_ids': Parameter([int], "List of nodes in this node group"),
         'conf_file_ids': Parameter([int], "List of configuration files specific to this node group"),
         }
         'node_ids': Parameter([int], "List of nodes in this node group"),
         'conf_file_ids': Parameter([int], "List of configuration files specific to this node group"),
         }
+    related_fields = {
+       'conf_files': [Parameter(int, "ConfFile identifier")],
+       'nodes': [Mixed(Parameter(int, "Node identifier"),
+                        Parameter(str, "Fully qualified hostname"))]
+       }
 
     def validate_name(self, name):
        # Make sure name is not blank
 
     def validate_name(self, name):
        # Make sure name is not blank
@@ -40,8 +46,8 @@ class NodeGroup(Row):
        
        # Make sure node group does not alredy exist
        conflicts = NodeGroups(self.api, [name])
        
        # Make sure node group does not alredy exist
        conflicts = NodeGroups(self.api, [name])
-       for nodegroup_id in conflicts:
-            if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
+       for nodegroup in conflicts:
+            if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup['nodegroup_id']:
                raise PLCInvalidArgument, "Node group name already in use"
 
        return name
                raise PLCInvalidArgument, "Node group name already in use"
 
        return name
@@ -97,37 +103,80 @@ class NodeGroup(Row):
             self['node_ids'].remove(node_id)
             node['nodegroup_ids'].remove(nodegroup_id)
 
             self['node_ids'].remove(node_id)
             node['nodegroup_ids'].remove(nodegroup_id)
 
+    def associate_nodes(self, auth, field, value):
+        """
+        Adds nodes found in value list to this nodegroup (using AddNodeToNodeGroup).
+        Deletes nodes not found in value list from this slice (using DeleteNodeFromNodeGroup).
+        """
+
+        assert 'node_ids' in self
+        assert 'nodegroup_id' in self
+        assert isinstance(value, list)
+
+        (node_ids, hostnames) = self.separate_types(value)[0:2]
+
+        # Translate hostnames into node_ids
+        if hostnames:
+            nodes = Nodes(self.api, hostnames, ['node_id']).dict('node_id')
+            node_ids += nodes.keys()
+
+        # Add new ids, remove stale ids
+        if self['node_ids'] != node_ids:
+            from PLC.Methods.AddNodeToNodeGroup import AddNodeToNodeGroup
+            from PLC.Methods.DeleteNodeFromNodeGroup import DeleteNodeFromNodeGroup
+            new_nodes = set(node_ids).difference(self['node_ids'])
+            stale_nodes = set(self['node_ids']).difference(node_ids)
+
+            for new_node in new_nodes:
+                AddNodeToNodeGroup.__call__(AddNodeToNodeGroup(self.api), auth, new_node, self['nodegroup_id'])
+            for stale_node in stale_nodes:
+                DeleteNodeFromNodeGroup.__call__(DeleteNodeFromNodeGroup(self.api), auth, stale_node, self['nodegroup_id'])
+
+    def associate_conf_files(self, auth, field, value):
+        """
+        Add conf_files found in value list (AddConfFileToNodeGroup)
+        Delets conf_files not found in value list (DeleteConfFileFromNodeGroup)
+        """
+
+        assert 'conf_file_ids' in self
+        assert 'nodegroup_id' in self
+        assert isinstance(value, list)
+
+        conf_file_ids = self.separate_types(value)[0]
+
+        if self['conf_file_ids'] != conf_file_ids:
+            from PLC.Methods.AddConfFileToNodeGroup import AddConfFileToNodeGroup
+            from PLC.Methods.DeleteConfFileFromNodeGroup import DeleteConfFileFromNodeGroup
+            new_conf_files = set(conf_file_ids).difference(self['conf_file_ids'])
+            stale_conf_files = set(self['conf_file_ids']).difference(conf_file_ids)
+
+            for new_conf_file in new_conf_files:
+                AddConfFileToNodeGroup.__call__(AddConfFileToNodeGroup(self.api), auth, new_conf_file, self['nodegroup_id'])
+            for stale_conf_file in stale_conf_files:
+                DeleteConfFileFromNodeGroup.__call__(DeleteConfFileFromNodeGroup(self.api), auth, stale_conf_file, self['nodegroup_id'])
+
+
 class NodeGroups(Table):
     """
     Representation of row(s) from the nodegroups table in the
     database.
     """
 
 class NodeGroups(Table):
     """
     Representation of row(s) from the nodegroups table in the
     database.
     """
 
-    def __init__(self, api, nodegroup_id_or_name_list = None):
-       self.api = api
-
-        sql = "SELECT %s FROM view_nodegroups" % \
-              ", ".join(NodeGroup.fields)
-
-        if nodegroup_id_or_name_list:
-            # Separate the list into integers and strings
-            nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
-                                   nodegroup_id_or_name_list)
-            names = filter(lambda name: isinstance(name, StringTypes),
-                           nodegroup_id_or_name_list)
-            sql += " WHERE (False"
-            if nodegroup_ids:
-                sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
-            if names:
-                sql += " OR name IN (%s)" % ", ".join(api.db.quote(names))
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-
-        for row in rows:
-            self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row)
-            for aggregate in ['node_ids', 'conf_file_ids']:
-                if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None:
-                    nodegroup[aggregate] = []
-                else:
-                    nodegroup[aggregate] = map(int, nodegroup[aggregate].split(','))
+    def __init__(self, api, nodegroup_filter = None, columns = None):
+        Table.__init__(self, api, NodeGroup, columns)
+
+        sql = "SELECT %s FROM view_nodegroups WHERE True" % \
+              ", ".join(self.columns)
+
+        if nodegroup_filter is not None:
+            if isinstance(nodegroup_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), nodegroup_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), nodegroup_filter)
+                nodegroup_filter = Filter(NodeGroup.fields, {'nodegroup_id': ints, 'name': strs})
+                sql += " AND (%s) %s" % nodegroup_filter.sql(api, "OR")
+            elif isinstance(nodegroup_filter, dict):
+                nodegroup_filter = Filter(NodeGroup.fields, nodegroup_filter)
+                sql += " AND (%s) %s" % nodegroup_filter.sql(api, "AND")
+
+        self.selectall(sql)