Merge remote-tracking branch 'origin/pycurl' into planetlab-4_0-branch
[plcapi.git] / PLC / NodeGroups.py
index 260628e..65b4a41 100644 (file)
@@ -4,13 +4,14 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.2 2006/09/06 16:03:24 mlhuang Exp $
+# $Id: NodeGroups.py 5666 2007-11-06 21:52:21Z tmack $
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
-from PLC.Parameter import Parameter
+from PLC.Parameter import Parameter, Mixed
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
@@ -19,31 +20,37 @@ class NodeGroup(Row):
     """
     Representation of a row in the nodegroups table. To use, optionally
     instantiate with a dict of values. Update as you would a
     """
     Representation of a row in the nodegroups table. To use, optionally
     instantiate with a dict of values. Update as you would a
-    dict. Commit to the database with flush().
+    dict. Commit to the database with sync().
     """
 
     """
 
+    table_name = 'nodegroups'
+    primary_key = 'nodegroup_id'
+    join_tables = ['nodegroup_node', 'conf_file_nodegroup']
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
-        'name': Parameter(str, "Node group name"),
-        'description': Parameter(str, "Node group description"),
-        'is_custom': Parameter(bool, "Is a custom node group (i.e., is not a site node group)")
-        }
-
-    # These fields are derived from join tables and are not
-    # actually in the nodegroups table.
-    join_fields = {
+        'name': Parameter(str, "Node group name", max = 50),
+        'description': Parameter(str, "Node group description", max = 200, nullok = True),
         'node_ids': Parameter([int], "List of nodes in this node group"),
         'node_ids': Parameter([int], "List of nodes in this node group"),
+        'conf_file_ids': Parameter([int], "List of configuration files specific to this node group"),
         }
         }
-
-    def __init__(self, api, fields):
-        Row.__init__(self, fields)
-        self.api = api
+    related_fields = {
+       'conf_files': [Parameter(int, "ConfFile identifier")],
+       'nodes': [Mixed(Parameter(int, "Node identifier"),
+                        Parameter(str, "Fully qualified hostname"))]
+       }
 
     def validate_name(self, name):
 
     def validate_name(self, name):
-        conflicts = NodeGroups(self.api, [name])
-        for nodegroup_id in conflicts:
-            if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
-                raise PLCInvalidArgument, "Node group name already in use"
+       # Make sure name is not blank
+        if not len(name):
+                raise PLCInvalidArgument, "Invalid node group name"
+       
+       # Make sure node group does not alredy exist
+       conflicts = NodeGroups(self.api, [name])
+       for nodegroup in conflicts:
+            if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup['nodegroup_id']:
+               raise PLCInvalidArgument, "Node group name already in use"
+
+       return name
 
     def add_node(self, node, commit = True):
         """
 
     def add_node(self, node, commit = True):
         """
@@ -56,17 +63,18 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \
-                       " VALUES(%(nodegroup_id)d, %(node_id)d)",
-                       locals())
 
 
-        if commit:
-            self.api.db.commit()
+        if node_id not in self['node_ids']:
+            assert nodegroup_id not in node['nodegroup_ids']
 
 
-        if 'node_ids' in self and node_id not in self['node_ids']:
-            self['node_ids'].append(node_id)
+            self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
+                           " VALUES(%(nodegroup_id)d, %(node_id)d)",
+                           locals())
+
+            if commit:
+                self.api.db.commit()
 
 
-        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
+            self['node_ids'].append(node_id)
             node['nodegroup_ids'].append(nodegroup_id)
 
     def remove_node(self, node, commit = True):
             node['nodegroup_ids'].append(nodegroup_id)
 
     def remove_node(self, node, commit = True):
@@ -80,84 +88,73 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \
-                       " VALUES(%(nodegroup_id)d, %(node_id)d)",
-                       locals())
 
 
-        if commit:
-            self.api.db.commit()
+        if node_id in self['node_ids']:
+            assert nodegroup_id in node['nodegroup_ids']
 
 
-        if 'node_ids' in self and node_id not in self['node_ids']:
-            self['node_ids'].append(node_id)
+            self.api.db.do("DELETE FROM nodegroup_node" \
+                           " WHERE nodegroup_id = %(nodegroup_id)d" \
+                           " AND node_id = %(node_id)d",
+                           locals())
 
 
-        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
-            node['nodegroup_ids'].append(nodegroup_id)
+            if commit:
+                self.api.db.commit()
+
+            self['node_ids'].remove(node_id)
+            node['nodegroup_ids'].remove(nodegroup_id)
 
 
-    def flush(self, commit = True):
+    def associate_nodes(self, auth, field, value):
         """
         """
-        Flush changes back to the database.
+        Adds nodes found in value list to this nodegroup (using AddNodeToNodeGroup).
+        Deletes nodes not found in value list from this slice (using DeleteNodeFromNodeGroup).
         """
 
         """
 
-        self.validate()
-
-        # Fetch a new nodegroup_id if necessary
-        if 'nodegroup_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new nodegroup_id"
-            self['nodegroup_id'] = rows[0]['nodegroup_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        fields = dict(filter(lambda (key, value): key in self.fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodegroups table
-            sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in sites table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodegroups SET " + \
-                  ", ".join(columns) + \
-                  " WHERE nodegroup_id = %(nodegroup_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-
-    def delete(self, commit = True):
+        assert 'node_ids' in self
+        assert 'nodegroup_id' in self
+        assert isinstance(value, list)
+
+        (node_ids, hostnames) = self.separate_types(value)[0:2]
+
+        # Translate hostnames into node_ids
+        if hostnames:
+            nodes = Nodes(self.api, hostnames, ['node_id']).dict('node_id')
+            node_ids += nodes.keys()
+
+        # Add new ids, remove stale ids
+        if self['node_ids'] != node_ids:
+            from PLC.Methods.AddNodeToNodeGroup import AddNodeToNodeGroup
+            from PLC.Methods.DeleteNodeFromNodeGroup import DeleteNodeFromNodeGroup
+            new_nodes = set(node_ids).difference(self['node_ids'])
+            stale_nodes = set(self['node_ids']).difference(node_ids)
+
+            for new_node in new_nodes:
+                AddNodeToNodeGroup.__call__(AddNodeToNodeGroup(self.api), auth, new_node, self['nodegroup_id'])
+            for stale_node in stale_nodes:
+                DeleteNodeFromNodeGroup.__call__(DeleteNodeFromNodeGroup(self.api), auth, stale_node, self['nodegroup_id'])
+
+    def associate_conf_files(self, auth, field, value):
         """
         """
-        Delete existing nodegroup from the database.
+        Add conf_files found in value list (AddConfFileToNodeGroup)
+        Delets conf_files not found in value list (DeleteConfFileFromNodeGroup)
         """
 
         """
 
+        assert 'conf_file_ids' in self
         assert 'nodegroup_id' in self
         assert 'nodegroup_id' in self
+        assert isinstance(value, list)
 
 
-        # Delete ourself
-        tables = ['nodegroup_nodes', 'override_bootscripts',
-                  'conf_assoc', 'node_root_access']
+        conf_file_ids = self.separate_types(value)[0]
 
 
-        if self['is_custom']:
-            tables.append('nodegroups')
-        else:
-            # XXX Cannot delete site node groups yet
-            pass
+        if self['conf_file_ids'] != conf_file_ids:
+            from PLC.Methods.AddConfFileToNodeGroup import AddConfFileToNodeGroup
+            from PLC.Methods.DeleteConfFileFromNodeGroup import DeleteConfFileFromNodeGroup
+            new_conf_files = set(conf_file_ids).difference(self['conf_file_ids'])
+            stale_conf_files = set(self['conf_file_ids']).difference(conf_file_ids)
 
 
-        for table in tables:
-            self.api.db.do("DELETE FROM %s" \
-                           " WHERE nodegroup_id = %(nodegroup_id)" % \
-                           table, self)
+            for new_conf_file in new_conf_files:
+                AddConfFileToNodeGroup.__call__(AddConfFileToNodeGroup(self.api), auth, new_conf_file, self['nodegroup_id'])
+            for stale_conf_file in stale_conf_files:
+                DeleteConfFileFromNodeGroup.__call__(DeleteConfFileFromNodeGroup(self.api), auth, stale_conf_file, self['nodegroup_id'])
 
 
-        if commit:
-            self.api.db.commit()
 
 class NodeGroups(Table):
     """
 
 class NodeGroups(Table):
     """
@@ -165,31 +162,21 @@ class NodeGroups(Table):
     database.
     """
 
     database.
     """
 
-    def __init__(self, api, nodegroup_id_or_name_list = None):
-        self.api = api
-
-        # N.B.: Node IDs returned may be deleted.
-        sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \
-              " FROM nodegroups" \
-              " LEFT JOIN nodegroup_nodes USING (nodegroup_id)"
-
-        if nodegroup_id_or_name_list:
-            # Separate the list into integers and strings
-            nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
-                                   nodegroup_id_or_name_list)
-            names = filter(lambda name: isinstance(name, StringTypes),
-                           nodegroup_id_or_name_list)
-            sql += " WHERE (False"
-            if nodegroup_ids:
-                sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
-            if names:
-                sql += " OR name IN (%s)" % ", ".join(api.db.quote(names)).lower()
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-        for row in rows:
-            if self.has_key(row['nodegroup_id']):
-                nodegroup = self[row['nodegroup_id']]
-                nodegroup.update(row)
-            else:
-                self[row['nodegroup_id']] = NodeGroup(api, row)
+    def __init__(self, api, nodegroup_filter = None, columns = None):
+        Table.__init__(self, api, NodeGroup, columns)
+
+        sql = "SELECT %s FROM view_nodegroups WHERE True" % \
+              ", ".join(self.columns)
+
+        if nodegroup_filter is not None:
+            if isinstance(nodegroup_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), nodegroup_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), nodegroup_filter)
+                nodegroup_filter = Filter(NodeGroup.fields, {'nodegroup_id': ints, 'name': strs})
+                sql += " AND (%s) %s" % nodegroup_filter.sql(api, "OR")
+            elif isinstance(nodegroup_filter, dict):
+                nodegroup_filter = Filter(NodeGroup.fields, nodegroup_filter)
+                sql += " AND (%s) %s" % nodegroup_filter.sql(api, "AND")
+
+        self.selectall(sql)