- support new schema
authorMark Huang <mlhuang@cs.princeton.edu>
Mon, 25 Sep 2006 14:52:01 +0000 (14:52 +0000)
committerMark Huang <mlhuang@cs.princeton.edu>
Mon, 25 Sep 2006 14:52:01 +0000 (14:52 +0000)
- use new table views instead of defining join_fields, extra_fields, and
  all_fields
- rename flush() to sync() to be like shelve
- whitespace nits
- no need to validate description
- fix remove_node()

PLC/NodeGroups.py

index 00a1b20..571e6fa 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.8 2006/09/19 19:08:24 mlhuang Exp $
+# $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -19,53 +19,35 @@ class NodeGroup(Row):
     """
     Representation of a row in the nodegroups table. To use, optionally
     instantiate with a dict of values. Update as you would a
-    dict. Commit to the database with flush().
+    dict. Commit to the database with sync().
     """
 
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
-        'name': Parameter(str, "Node group name"),
-        'description': Parameter(str, "Node group description"),
-        'is_custom': Parameter(bool, "Is a custom node group (i.e., is not a site node group)")
-        }
-
-    # These fields are derived from join tables and are not
-    # actually in the nodegroups table.
-    join_fields = {
+        'name': Parameter(str, "Node group name", max = 50),
+        'description': Parameter(str, "Node group description", max = 200),
         'node_ids': Parameter([int], "List of nodes in this node group"),
         }
 
-    all_fields = dict(fields.items() + join_fields.items())
-
     def __init__(self, api, fields):
         Row.__init__(self, fields)
         self.api = api
 
     def validate_name(self, name):
-       #remove leading and trailing spaces
+       # Remove leading and trailing spaces
        name = name.strip()
 
-       #make sure name is not blank after we removed the spaces
+       # Make sure name is not blank after we removed the spaces
         if not len(name) > 0:
-                raise PLCInvalidArgument, "Invalid Node Group Name"
+                raise PLCInvalidArgument, "Invalid node group name"
        
-       #make sure name doenst alredy exist
+       # Make sure node group does not alredy exist
        conflicts = NodeGroups(self.api, [name])
        for nodegroup_id in conflicts:
             if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
                raise PLCInvalidArgument, "Node group name already in use"
 
        return name
-       
-    def validate_description(self, description):
-       #remove trailing and leading spaces
-       description = description.strip()
-       
-       #make sure decription is not blank after we removed the spaces  
-       if not len(description) > 0:
-               raise PLCInvalidArgument, "Invalid Node Group Description"
-
-       return description
 
     def add_node(self, node, commit = True):
         """
@@ -78,7 +60,7 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \
+        self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
                        " VALUES(%(nodegroup_id)d, %(node_id)d)",
                        locals())
 
@@ -102,20 +84,21 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \
-                       " VALUES(%(nodegroup_id)d, %(node_id)d)",
+        self.api.db.do("DELETE FROM nodegroup_node" \
+                       " WHERE nodegroup_id = %(nodegroup_id)d" \
+                       " AND node_id = %(node_id)d",
                        locals())
 
         if commit:
             self.api.db.commit()
 
-        if 'node_ids' in self and node_id not in self['node_ids']:
-            self['node_ids'].append(node_id)
+        if 'node_ids' in self and node_id in self['node_ids']:
+            self['node_ids'].remove(node_id)
 
-        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
-            node['nodegroup_ids'].append(nodegroup_id)
+        if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
+            node['nodegroup_ids'].remove(nodegroup_id)
 
-    def flush(self, commit = True):
+    def sync(self, commit = True):
         """
         Flush changes back to the database.
         """
@@ -133,7 +116,8 @@ class NodeGroup(Row):
             insert = False
 
         # Filter out fields that cannot be set or updated directly
-        fields = dict(filter(lambda (key, value): key in self.fields,
+        nodegroups_fields = self.api.db.fields('nodegroups')
+        fields = dict(filter(lambda (key, value): key in nodegroups_fields,
                              self.items()))
 
         # Parameterize for safety
@@ -163,18 +147,9 @@ class NodeGroup(Row):
         """
 
         assert 'nodegroup_id' in self
-       assert self is not {}
-        # Delete ourself
-        tables = ['nodegroup_nodes', 'override_bootscripts',
-                  'conf_assoc', 'node_root_access']
 
-        if self['is_custom']:
-            tables.append('nodegroups')
-        else:
-            # XXX Cannot delete site node groups yet
-            pass
-
-        for table in tables:
+        # Clean up miscellaneous join tables
+        for table in ['nodegroup_node', 'nodegroups']:
             self.api.db.do("DELETE FROM %s" \
                            " WHERE nodegroup_id = %d" % \
                            (table, self['nodegroup_id']), self)
@@ -191,10 +166,7 @@ class NodeGroups(Table):
     def __init__(self, api, nodegroup_id_or_name_list = None):
        self.api = api
 
-        # N.B.: Node IDs returned may be deleted.
-        sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \
-              " FROM nodegroups" \
-              " LEFT JOIN nodegroup_nodes USING (nodegroup_id)"
+        sql = "SELECT * FROM view_nodegroups"
 
         if nodegroup_id_or_name_list:
             # Separate the list into integers and strings
@@ -210,9 +182,11 @@ class NodeGroups(Table):
             sql += ")"
 
         rows = self.api.db.selectall(sql)
+
         for row in rows:
-            if self.has_key(row['nodegroup_id']):
-                nodegroup = self[row['nodegroup_id']]
-                nodegroup.update(row)
-            else:
-                self[row['nodegroup_id']] = NodeGroup(api, row)
+            self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row)
+            for aggregate in ['node_ids']:
+                if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None:
+                    nodegroup[aggregate] = []
+                else:
+                    nodegroup[aggregate] = map(int, nodegroup[aggregate].split(','))