Initial checkin of new API implementation
[plcapi.git] / PLC / NodeGroups.py
1 #
2 # Functions for interacting with the nodegroups table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id$
8 #
9
10 from PLC.Faults import *
11 from PLC.Parameter import Parameter
12 from PLC.Debug import profile
13 from PLC.Table import Row, Table
14
15 class NodeGroup(Row):
16     """
17     Representation of a row in the nodegroups table. To use, optionally
18     instantiate with a dict of values. Update as you would a
19     dict. Commit to the database with flush().
20     """
21
22     fields = {
23         'nodegroup_id': Parameter(int, "Node group identifier"),
24         'name': Parameter(str, "Node group name"),
25         'description': Parameter(str, "Node group description"),
26         'is_custom': Parameter(bool, "Is a custom node group (i.e., is not a site node group)")
27         }
28
29     # These fields are derived from join tables and are not
30     # actually in the nodegroups table.
31     join_fields = {
32         'node_ids': Parameter([int], "List of nodes in this node group"),
33         }
34
35     def __init__(self, api, fields):
36         Row.__init__(self, fields)
37         self.api = api
38
39     def validate_name(self, name):
40         conflicts = NodeGroups(self.api, [name])
41         for nodegroup_id in conflicts:
42             if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
43                 raise PLCInvalidArgument, "Node group name already in use"
44
45     def flush(self, commit = True):
46         """
47         Flush changes back to the database.
48         """
49
50         self.validate()
51
52         # Fetch a new nodegroup_id if necessary
53         if 'nodegroup_id' not in self:
54             rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
55             if not rows:
56                 raise PLCDBError, "Unable to fetch new nodegroup_id"
57             self['nodegroup_id'] = rows[0]['nodegroup_id']
58             insert = True
59         else:
60             insert = False
61
62         # Filter out fields that cannot be set or updated directly
63         fields = dict(filter(lambda (key, value): key in self.fields,
64                              self.items()))
65
66         # Parameterize for safety
67         keys = fields.keys()
68         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
69
70         if insert:
71             # Insert new row in nodegroups table
72             sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
73                   (", ".join(keys), ", ".join(values))
74         else:
75             # Update existing row in sites table
76             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
77             sql = "UPDATE nodegroups SET " + \
78                   ", ".join(columns) + \
79                   " WHERE nodegroup_id = %(nodegroup_id)d"
80
81         self.api.db.do(sql, fields)
82
83         if commit:
84             self.api.db.commit()
85
86     def delete(self, commit = True):
87         """
88         Delete existing nodegroup from the database.
89         """
90
91         assert 'nodegroup_id' in self
92
93         # Delete ourself
94         tables = ['nodegroup_nodes', 'override_bootscripts',
95                   'conf_assoc', 'node_root_access']
96
97         if self['is_custom']:
98             tables.append('nodegroups')
99         else:
100             # XXX Cannot delete site node groups yet
101             pass
102
103         for table in tables:
104             self.api.db.do("DELETE FROM %s" \
105                            " WHERE nodegroup_id = %(nodegroup_id)" % \
106                            table, self)
107
108         if commit:
109             self.api.db.commit()
110
111 class NodeGroups(Table):
112     """
113     Representation of row(s) from the nodegroups table in the
114     database.
115     """
116
117     def __init__(self, api, nodegroup_id_or_name_list = None):
118         self.api = api
119
120         # N.B.: Node IDs returned may be deleted.
121         sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \
122               " FROM nodegroups" \
123               " LEFT JOIN nodegroup_nodes USING (nodegroup_id)"
124
125         if nodegroup_id_or_name_list:
126             # Separate the list into integers and strings
127             nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
128                                    nodegroup_id_or_name_list)
129             names = filter(lambda name: isinstance(name, StringTypes),
130                            nodegroup_id_or_name_list)
131             sql += " AND (False"
132             if nodegroup_ids:
133                 sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
134             if names:
135                 sql += " OR name IN (%s)" % ", ".join(api.db.quote(names)).lower()
136             sql += ")"
137
138         rows = self.api.db.selectall(sql)
139         for row in rows:
140             if self.has_key(row['nodegroup_id']):
141                 nodegroup = self[row['nodegroup_id']]
142                 nodegroup.update(row)
143             else:
144                 self[row['nodegroup_id']] = NodeGroup(api, row)