2 # Functions for interacting with the nodegroups table in the database
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
7 # $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
10 from types import StringTypes
12 from PLC.Faults import *
13 from PLC.Parameter import Parameter
14 from PLC.Debug import profile
15 from PLC.Table import Row, Table
16 from PLC.Nodes import Node, Nodes
20 Representation of a row in the nodegroups table. To use, optionally
21 instantiate with a dict of values. Update as you would a
22 dict. Commit to the database with sync().
26 'nodegroup_id': Parameter(int, "Node group identifier"),
27 'name': Parameter(str, "Node group name", max = 50),
28 'description': Parameter(str, "Node group description", max = 200),
29 'node_ids': Parameter([int], "List of nodes in this node group"),
32 def __init__(self, api, fields):
33 Row.__init__(self, fields)
36 def validate_name(self, name):
37 # Remove leading and trailing spaces
40 # Make sure name is not blank after we removed the spaces
42 raise PLCInvalidArgument, "Invalid node group name"
44 # Make sure node group does not alredy exist
45 conflicts = NodeGroups(self.api, [name])
46 for nodegroup_id in conflicts:
47 if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
48 raise PLCInvalidArgument, "Node group name already in use"
52 def add_node(self, node, commit = True):
54 Add node to existing nodegroup.
57 assert 'nodegroup_id' in self
58 assert isinstance(node, Node)
59 assert 'node_id' in node
61 node_id = node['node_id']
62 nodegroup_id = self['nodegroup_id']
63 self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
64 " VALUES(%(nodegroup_id)d, %(node_id)d)",
70 if 'node_ids' in self and node_id not in self['node_ids']:
71 self['node_ids'].append(node_id)
73 if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
74 node['nodegroup_ids'].append(nodegroup_id)
76 def remove_node(self, node, commit = True):
78 Remove node from existing nodegroup.
81 assert 'nodegroup_id' in self
82 assert isinstance(node, Node)
83 assert 'node_id' in node
85 node_id = node['node_id']
86 nodegroup_id = self['nodegroup_id']
87 self.api.db.do("DELETE FROM nodegroup_node" \
88 " WHERE nodegroup_id = %(nodegroup_id)d" \
89 " AND node_id = %(node_id)d",
95 if 'node_ids' in self and node_id in self['node_ids']:
96 self['node_ids'].remove(node_id)
98 if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
99 node['nodegroup_ids'].remove(nodegroup_id)
101 def sync(self, commit = True):
103 Flush changes back to the database.
108 # Fetch a new nodegroup_id if necessary
109 if 'nodegroup_id' not in self:
110 rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
112 raise PLCDBError, "Unable to fetch new nodegroup_id"
113 self['nodegroup_id'] = rows[0]['nodegroup_id']
118 # Filter out fields that cannot be set or updated directly
119 nodegroups_fields = self.api.db.fields('nodegroups')
120 fields = dict(filter(lambda (key, value): key in nodegroups_fields,
123 # Parameterize for safety
125 values = [self.api.db.param(key, value) for (key, value) in fields.items()]
128 # Insert new row in nodegroups table
129 sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
130 (", ".join(keys), ", ".join(values))
132 # Update existing row in nodegroups table
133 columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
134 sql = "UPDATE nodegroups SET " + \
135 ", ".join(columns) + \
136 " WHERE nodegroup_id = %(nodegroup_id)d"
138 self.api.db.do(sql, fields)
144 def delete(self, commit = True):
146 Delete existing nodegroup from the database.
149 assert 'nodegroup_id' in self
151 # Clean up miscellaneous join tables
152 for table in ['nodegroup_node', 'nodegroups']:
153 self.api.db.do("DELETE FROM %s" \
154 " WHERE nodegroup_id = %d" % \
155 (table, self['nodegroup_id']), self)
160 class NodeGroups(Table):
162 Representation of row(s) from the nodegroups table in the
166 def __init__(self, api, nodegroup_id_or_name_list = None):
169 sql = "SELECT * FROM view_nodegroups"
171 if nodegroup_id_or_name_list:
172 # Separate the list into integers and strings
173 nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
174 nodegroup_id_or_name_list)
175 names = filter(lambda name: isinstance(name, StringTypes),
176 nodegroup_id_or_name_list)
177 sql += " WHERE (False"
179 sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
181 sql += " OR name IN (%s)" % ", ".join(api.db.quote(names))
184 rows = self.api.db.selectall(sql)
187 self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row)
188 for aggregate in ['node_ids']:
189 if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None:
190 nodegroup[aggregate] = []
192 nodegroup[aggregate] = map(int, nodegroup[aggregate].split(','))