From 397613bdf2e14391a9aaa378146a0a5fe429c7ba Mon Sep 17 00:00:00 2001 From: Mark Huang Date: Mon, 25 Sep 2006 14:52:01 +0000 Subject: [PATCH] - support new schema - use new table views instead of defining join_fields, extra_fields, and all_fields - rename flush() to sync() to be like shelve - whitespace nits - no need to validate description - fix remove_node() --- PLC/NodeGroups.py | 84 ++++++++++++++++------------------------------- 1 file changed, 29 insertions(+), 55 deletions(-) diff --git a/PLC/NodeGroups.py b/PLC/NodeGroups.py index 00a1b20b..571e6fa8 100644 --- a/PLC/NodeGroups.py +++ b/PLC/NodeGroups.py @@ -4,7 +4,7 @@ # Mark Huang # Copyright (C) 2006 The Trustees of Princeton University # -# $Id: NodeGroups.py,v 1.8 2006/09/19 19:08:24 mlhuang Exp $ +# $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $ # from types import StringTypes @@ -19,53 +19,35 @@ class NodeGroup(Row): """ Representation of a row in the nodegroups table. To use, optionally instantiate with a dict of values. Update as you would a - dict. Commit to the database with flush(). + dict. Commit to the database with sync(). """ fields = { 'nodegroup_id': Parameter(int, "Node group identifier"), - 'name': Parameter(str, "Node group name"), - 'description': Parameter(str, "Node group description"), - 'is_custom': Parameter(bool, "Is a custom node group (i.e., is not a site node group)") - } - - # These fields are derived from join tables and are not - # actually in the nodegroups table. - join_fields = { + 'name': Parameter(str, "Node group name", max = 50), + 'description': Parameter(str, "Node group description", max = 200), 'node_ids': Parameter([int], "List of nodes in this node group"), } - all_fields = dict(fields.items() + join_fields.items()) - def __init__(self, api, fields): Row.__init__(self, fields) self.api = api def validate_name(self, name): - #remove leading and trailing spaces + # Remove leading and trailing spaces name = name.strip() - #make sure name is not blank after we removed the spaces + # Make sure name is not blank after we removed the spaces if not len(name) > 0: - raise PLCInvalidArgument, "Invalid Node Group Name" + raise PLCInvalidArgument, "Invalid node group name" - #make sure name doenst alredy exist + # Make sure node group does not alredy exist conflicts = NodeGroups(self.api, [name]) for nodegroup_id in conflicts: if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id: raise PLCInvalidArgument, "Node group name already in use" return name - - def validate_description(self, description): - #remove trailing and leading spaces - description = description.strip() - - #make sure decription is not blank after we removed the spaces - if not len(description) > 0: - raise PLCInvalidArgument, "Invalid Node Group Description" - - return description def add_node(self, node, commit = True): """ @@ -78,7 +60,7 @@ class NodeGroup(Row): node_id = node['node_id'] nodegroup_id = self['nodegroup_id'] - self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \ + self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \ " VALUES(%(nodegroup_id)d, %(node_id)d)", locals()) @@ -102,20 +84,21 @@ class NodeGroup(Row): node_id = node['node_id'] nodegroup_id = self['nodegroup_id'] - self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \ - " VALUES(%(nodegroup_id)d, %(node_id)d)", + self.api.db.do("DELETE FROM nodegroup_node" \ + " WHERE nodegroup_id = %(nodegroup_id)d" \ + " AND node_id = %(node_id)d", locals()) if commit: self.api.db.commit() - if 'node_ids' in self and node_id not in self['node_ids']: - self['node_ids'].append(node_id) + if 'node_ids' in self and node_id in self['node_ids']: + self['node_ids'].remove(node_id) - if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']: - node['nodegroup_ids'].append(nodegroup_id) + if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']: + node['nodegroup_ids'].remove(nodegroup_id) - def flush(self, commit = True): + def sync(self, commit = True): """ Flush changes back to the database. """ @@ -133,7 +116,8 @@ class NodeGroup(Row): insert = False # Filter out fields that cannot be set or updated directly - fields = dict(filter(lambda (key, value): key in self.fields, + nodegroups_fields = self.api.db.fields('nodegroups') + fields = dict(filter(lambda (key, value): key in nodegroups_fields, self.items())) # Parameterize for safety @@ -163,18 +147,9 @@ class NodeGroup(Row): """ assert 'nodegroup_id' in self - assert self is not {} - # Delete ourself - tables = ['nodegroup_nodes', 'override_bootscripts', - 'conf_assoc', 'node_root_access'] - if self['is_custom']: - tables.append('nodegroups') - else: - # XXX Cannot delete site node groups yet - pass - - for table in tables: + # Clean up miscellaneous join tables + for table in ['nodegroup_node', 'nodegroups']: self.api.db.do("DELETE FROM %s" \ " WHERE nodegroup_id = %d" % \ (table, self['nodegroup_id']), self) @@ -191,10 +166,7 @@ class NodeGroups(Table): def __init__(self, api, nodegroup_id_or_name_list = None): self.api = api - # N.B.: Node IDs returned may be deleted. - sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \ - " FROM nodegroups" \ - " LEFT JOIN nodegroup_nodes USING (nodegroup_id)" + sql = "SELECT * FROM view_nodegroups" if nodegroup_id_or_name_list: # Separate the list into integers and strings @@ -210,9 +182,11 @@ class NodeGroups(Table): sql += ")" rows = self.api.db.selectall(sql) + for row in rows: - if self.has_key(row['nodegroup_id']): - nodegroup = self[row['nodegroup_id']] - nodegroup.update(row) - else: - self[row['nodegroup_id']] = NodeGroup(api, row) + self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row) + for aggregate in ['node_ids']: + if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None: + nodegroup[aggregate] = [] + else: + nodegroup[aggregate] = map(int, nodegroup[aggregate].split(',')) -- 2.47.0