- move common sync() functionality to Table.Row
[plcapi.git] / PLC / NodeGroups.py
index 571e6fa..7e20350 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
+# $Id: NodeGroups.py,v 1.10 2006/09/25 14:52:01 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -22,6 +22,8 @@ class NodeGroup(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'nodegroups'
+    primary_key = 'nodegroup_id'
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
         'name': Parameter(str, "Node group name", max = 50),
@@ -29,7 +31,7 @@ class NodeGroup(Row):
         'node_ids': Parameter([int], "List of nodes in this node group"),
         }
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         Row.__init__(self, fields)
         self.api = api
 
@@ -98,49 +100,6 @@ class NodeGroup(Row):
         if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
             node['nodegroup_ids'].remove(nodegroup_id)
 
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
-
-        self.validate()
-
-        # Fetch a new nodegroup_id if necessary
-        if 'nodegroup_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new nodegroup_id"
-            self['nodegroup_id'] = rows[0]['nodegroup_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        nodegroups_fields = self.api.db.fields('nodegroups')
-        fields = dict(filter(lambda (key, value): key in nodegroups_fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodegroups table
-            sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in nodegroups table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodegroups SET " + \
-                  ", ".join(columns) + \
-                  " WHERE nodegroup_id = %(nodegroup_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-           
-
     def delete(self, commit = True):
         """
         Delete existing nodegroup from the database.
@@ -166,7 +125,8 @@ class NodeGroups(Table):
     def __init__(self, api, nodegroup_id_or_name_list = None):
        self.api = api
 
-        sql = "SELECT * FROM view_nodegroups"
+        sql = "SELECT %s FROM view_nodegroups" % \
+              ", ".join(NodeGroup.fields)
 
         if nodegroup_id_or_name_list:
             # Separate the list into integers and strings