- support new schema
[plcapi.git] / PLC / NodeGroups.py
1 #
2 # Functions for interacting with the nodegroups table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
8 #
9
10 from types import StringTypes
11
12 from PLC.Faults import *
13 from PLC.Parameter import Parameter
14 from PLC.Debug import profile
15 from PLC.Table import Row, Table
16 from PLC.Nodes import Node, Nodes
17
18 class NodeGroup(Row):
19     """
20     Representation of a row in the nodegroups table. To use, optionally
21     instantiate with a dict of values. Update as you would a
22     dict. Commit to the database with sync().
23     """
24
25     fields = {
26         'nodegroup_id': Parameter(int, "Node group identifier"),
27         'name': Parameter(str, "Node group name", max = 50),
28         'description': Parameter(str, "Node group description", max = 200),
29         'node_ids': Parameter([int], "List of nodes in this node group"),
30         }
31
32     def __init__(self, api, fields):
33         Row.__init__(self, fields)
34         self.api = api
35
36     def validate_name(self, name):
37         # Remove leading and trailing spaces
38         name = name.strip()
39
40         # Make sure name is not blank after we removed the spaces
41         if not len(name) > 0:
42                 raise PLCInvalidArgument, "Invalid node group name"
43         
44         # Make sure node group does not alredy exist
45         conflicts = NodeGroups(self.api, [name])
46         for nodegroup_id in conflicts:
47             if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
48                raise PLCInvalidArgument, "Node group name already in use"
49
50         return name
51
52     def add_node(self, node, commit = True):
53         """
54         Add node to existing nodegroup.
55         """
56
57         assert 'nodegroup_id' in self
58         assert isinstance(node, Node)
59         assert 'node_id' in node
60
61         node_id = node['node_id']
62         nodegroup_id = self['nodegroup_id']
63         self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
64                        " VALUES(%(nodegroup_id)d, %(node_id)d)",
65                        locals())
66
67         if commit:
68             self.api.db.commit()
69
70         if 'node_ids' in self and node_id not in self['node_ids']:
71             self['node_ids'].append(node_id)
72
73         if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
74             node['nodegroup_ids'].append(nodegroup_id)
75
76     def remove_node(self, node, commit = True):
77         """
78         Remove node from existing nodegroup.
79         """
80
81         assert 'nodegroup_id' in self
82         assert isinstance(node, Node)
83         assert 'node_id' in node
84
85         node_id = node['node_id']
86         nodegroup_id = self['nodegroup_id']
87         self.api.db.do("DELETE FROM nodegroup_node" \
88                        " WHERE nodegroup_id = %(nodegroup_id)d" \
89                        " AND node_id = %(node_id)d",
90                        locals())
91
92         if commit:
93             self.api.db.commit()
94
95         if 'node_ids' in self and node_id in self['node_ids']:
96             self['node_ids'].remove(node_id)
97
98         if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
99             node['nodegroup_ids'].remove(nodegroup_id)
100
101     def sync(self, commit = True):
102         """
103         Flush changes back to the database.
104         """
105
106         self.validate()
107
108         # Fetch a new nodegroup_id if necessary
109         if 'nodegroup_id' not in self:
110             rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
111             if not rows:
112                 raise PLCDBError, "Unable to fetch new nodegroup_id"
113             self['nodegroup_id'] = rows[0]['nodegroup_id']
114             insert = True
115         else:
116             insert = False
117
118         # Filter out fields that cannot be set or updated directly
119         nodegroups_fields = self.api.db.fields('nodegroups')
120         fields = dict(filter(lambda (key, value): key in nodegroups_fields,
121                              self.items()))
122
123         # Parameterize for safety
124         keys = fields.keys()
125         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
126
127         if insert:
128             # Insert new row in nodegroups table
129             sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
130                   (", ".join(keys), ", ".join(values))
131         else:
132             # Update existing row in nodegroups table
133             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
134             sql = "UPDATE nodegroups SET " + \
135                   ", ".join(columns) + \
136                   " WHERE nodegroup_id = %(nodegroup_id)d"
137
138         self.api.db.do(sql, fields)
139
140         if commit:
141             self.api.db.commit()
142             
143
144     def delete(self, commit = True):
145         """
146         Delete existing nodegroup from the database.
147         """
148
149         assert 'nodegroup_id' in self
150
151         # Clean up miscellaneous join tables
152         for table in ['nodegroup_node', 'nodegroups']:
153             self.api.db.do("DELETE FROM %s" \
154                            " WHERE nodegroup_id = %d" % \
155                            (table, self['nodegroup_id']), self)
156
157         if commit:
158             self.api.db.commit()
159
160 class NodeGroups(Table):
161     """
162     Representation of row(s) from the nodegroups table in the
163     database.
164     """
165
166     def __init__(self, api, nodegroup_id_or_name_list = None):
167         self.api = api
168
169         sql = "SELECT * FROM view_nodegroups"
170
171         if nodegroup_id_or_name_list:
172             # Separate the list into integers and strings
173             nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
174                                    nodegroup_id_or_name_list)
175             names = filter(lambda name: isinstance(name, StringTypes),
176                            nodegroup_id_or_name_list)
177             sql += " WHERE (False"
178             if nodegroup_ids:
179                 sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
180             if names:
181                 sql += " OR name IN (%s)" % ", ".join(api.db.quote(names))
182             sql += ")"
183
184         rows = self.api.db.selectall(sql)
185
186         for row in rows:
187             self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row)
188             for aggregate in ['node_ids']:
189                 if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None:
190                     nodegroup[aggregate] = []
191                 else:
192                     nodegroup[aggregate] = map(int, nodegroup[aggregate].split(','))