Initial checkin of new API implementation
[plcapi.git] / PLC / NodeGroups.py
diff --git a/PLC/NodeGroups.py b/PLC/NodeGroups.py
new file mode 100644 (file)
index 0000000..aff2c8f
--- /dev/null
@@ -0,0 +1,144 @@
+#
+# Functions for interacting with the nodegroups table in the database
+#
+# Mark Huang <mlhuang@cs.princeton.edu>
+# Copyright (C) 2006 The Trustees of Princeton University
+#
+# $Id$
+#
+
+from PLC.Faults import *
+from PLC.Parameter import Parameter
+from PLC.Debug import profile
+from PLC.Table import Row, Table
+
+class NodeGroup(Row):
+    """
+    Representation of a row in the nodegroups table. To use, optionally
+    instantiate with a dict of values. Update as you would a
+    dict. Commit to the database with flush().
+    """
+
+    fields = {
+        'nodegroup_id': Parameter(int, "Node group identifier"),
+        'name': Parameter(str, "Node group name"),
+        'description': Parameter(str, "Node group description"),
+        'is_custom': Parameter(bool, "Is a custom node group (i.e., is not a site node group)")
+        }
+
+    # These fields are derived from join tables and are not
+    # actually in the nodegroups table.
+    join_fields = {
+        'node_ids': Parameter([int], "List of nodes in this node group"),
+        }
+
+    def __init__(self, api, fields):
+        Row.__init__(self, fields)
+        self.api = api
+
+    def validate_name(self, name):
+        conflicts = NodeGroups(self.api, [name])
+        for nodegroup_id in conflicts:
+            if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
+                raise PLCInvalidArgument, "Node group name already in use"
+
+    def flush(self, commit = True):
+        """
+        Flush changes back to the database.
+        """
+
+        self.validate()
+
+        # Fetch a new nodegroup_id if necessary
+        if 'nodegroup_id' not in self:
+            rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
+            if not rows:
+                raise PLCDBError, "Unable to fetch new nodegroup_id"
+            self['nodegroup_id'] = rows[0]['nodegroup_id']
+            insert = True
+        else:
+            insert = False
+
+        # Filter out fields that cannot be set or updated directly
+        fields = dict(filter(lambda (key, value): key in self.fields,
+                             self.items()))
+
+        # Parameterize for safety
+        keys = fields.keys()
+        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
+
+        if insert:
+            # Insert new row in nodegroups table
+            sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
+                  (", ".join(keys), ", ".join(values))
+        else:
+            # Update existing row in sites table
+            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
+            sql = "UPDATE nodegroups SET " + \
+                  ", ".join(columns) + \
+                  " WHERE nodegroup_id = %(nodegroup_id)d"
+
+        self.api.db.do(sql, fields)
+
+        if commit:
+            self.api.db.commit()
+
+    def delete(self, commit = True):
+        """
+        Delete existing nodegroup from the database.
+        """
+
+        assert 'nodegroup_id' in self
+
+        # Delete ourself
+        tables = ['nodegroup_nodes', 'override_bootscripts',
+                  'conf_assoc', 'node_root_access']
+
+        if self['is_custom']:
+            tables.append('nodegroups')
+        else:
+            # XXX Cannot delete site node groups yet
+            pass
+
+        for table in tables:
+            self.api.db.do("DELETE FROM %s" \
+                           " WHERE nodegroup_id = %(nodegroup_id)" % \
+                           table, self)
+
+        if commit:
+            self.api.db.commit()
+
+class NodeGroups(Table):
+    """
+    Representation of row(s) from the nodegroups table in the
+    database.
+    """
+
+    def __init__(self, api, nodegroup_id_or_name_list = None):
+        self.api = api
+
+        # N.B.: Node IDs returned may be deleted.
+        sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \
+              " FROM nodegroups" \
+              " LEFT JOIN nodegroup_nodes USING (nodegroup_id)"
+
+        if nodegroup_id_or_name_list:
+            # Separate the list into integers and strings
+            nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
+                                   nodegroup_id_or_name_list)
+            names = filter(lambda name: isinstance(name, StringTypes),
+                           nodegroup_id_or_name_list)
+            sql += " AND (False"
+            if nodegroup_ids:
+                sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
+            if names:
+                sql += " OR name IN (%s)" % ", ".join(api.db.quote(names)).lower()
+            sql += ")"
+
+        rows = self.api.db.selectall(sql)
+        for row in rows:
+            if self.has_key(row['nodegroup_id']):
+                nodegroup = self[row['nodegroup_id']]
+                nodegroup.update(row)
+            else:
+                self[row['nodegroup_id']] = NodeGroup(api, row)