- move common sync() functionality to Table.Row
authorMark Huang <mlhuang@cs.princeton.edu>
Tue, 3 Oct 2006 19:27:16 +0000 (19:27 +0000)
committerMark Huang <mlhuang@cs.princeton.edu>
Tue, 3 Oct 2006 19:27:16 +0000 (19:27 +0000)
PLC/Addresses.py
PLC/NodeGroups.py
PLC/NodeNetworks.py
PLC/Nodes.py
PLC/Persons.py
PLC/Sites.py

index 22860c4..776cb7d 100644 (file)
@@ -8,6 +8,8 @@ class Address(Row):
     with a dict of values.
     """
 
+    table_name = 'addresses'
+    primary_key = 'address_id'
     fields = {
         'address_id': Parameter(int, "Address identifier"),
         'line1': Parameter(str, "Address line 1"),
@@ -20,13 +22,9 @@ class Address(Row):
         'address_type': Parameter(str, "Address type"),
         }
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         self.api = api
-        Row.__init__(fields)
-
-    def sync(self, commit = True):
-        # XXX
-        pass
+        Row.__init__(self, fields)
 
     def delete(self, commit = True):
         # XXX
index 571e6fa..7e20350 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
+# $Id: NodeGroups.py,v 1.10 2006/09/25 14:52:01 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -22,6 +22,8 @@ class NodeGroup(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'nodegroups'
+    primary_key = 'nodegroup_id'
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
         'name': Parameter(str, "Node group name", max = 50),
@@ -29,7 +31,7 @@ class NodeGroup(Row):
         'node_ids': Parameter([int], "List of nodes in this node group"),
         }
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         Row.__init__(self, fields)
         self.api = api
 
@@ -98,49 +100,6 @@ class NodeGroup(Row):
         if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
             node['nodegroup_ids'].remove(nodegroup_id)
 
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
-
-        self.validate()
-
-        # Fetch a new nodegroup_id if necessary
-        if 'nodegroup_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new nodegroup_id"
-            self['nodegroup_id'] = rows[0]['nodegroup_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        nodegroups_fields = self.api.db.fields('nodegroups')
-        fields = dict(filter(lambda (key, value): key in nodegroups_fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodegroups table
-            sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in nodegroups table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodegroups SET " + \
-                  ", ".join(columns) + \
-                  " WHERE nodegroup_id = %(nodegroup_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-           
-
     def delete(self, commit = True):
         """
         Delete existing nodegroup from the database.
@@ -166,7 +125,8 @@ class NodeGroups(Table):
     def __init__(self, api, nodegroup_id_or_name_list = None):
        self.api = api
 
-        sql = "SELECT * FROM view_nodegroups"
+        sql = "SELECT %s FROM view_nodegroups" % \
+              ", ".join(NodeGroup.fields)
 
         if nodegroup_id_or_name_list:
             # Separate the list into integers and strings
index 7f83448..4c36e0d 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeNetworks.py,v 1.3 2006/09/19 19:35:05 mlhuang Exp $
+# $Id: NodeNetworks.py,v 1.4 2006/09/25 14:55:43 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -36,6 +36,8 @@ class NodeNetwork(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'nodenetworks'
+    primary_key = 'nodenetwork_id'
     fields = {
         'nodenetwork_id': Parameter(int, "Node interface identifier"),
         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
@@ -65,7 +67,7 @@ class NodeNetwork(Row):
                 '10mbit', '20mbit', '50mbit',
                 '100mbit']
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         Row.__init__(self, fields)
         self.api = api
 
@@ -124,19 +126,16 @@ class NodeNetwork(Row):
         # Validate hostname, and check for conflicts with a node hostname
         return PLC.Nodes.Node.validate_hostname(self, hostname)
 
-    def sync(self, commit = True):
+    def validate(self):
         """
         Flush changes back to the database.
         """
 
-        # Validate all specified fields
-        self.validate()
+        # Basic validation
+        Row.validate(self)
 
-        try:
-            method = self['method']
-            self['type']
-        except KeyError:
-            raise PLCInvalidArgument, "method and type must both be specified"
+        assert 'method' in self
+        method = self['method']
 
         if method == "proxy" or method == "tap":
             if 'mac' in self and self['mac']:
@@ -169,41 +168,6 @@ class NodeNetwork(Row):
             if 'ip' not in self or not self['ip']:
                 raise PLCInvalidArgument, "For ipmi method, ip is required"
 
-        # Fetch a new nodenetwork_id if necessary
-        if 'nodenetwork_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodenetworks_nodenetwork_id_seq') AS nodenetwork_id")
-            if not rows:
-                raise PLCDBError("Unable to fetch new nodenetwork_id")
-            self['nodenetwork_id'] = rows[0]['nodenetwork_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        nodenetworks_fields = self.api.db.fields('nodenetworks')
-        fields = dict(filter(lambda (key, value): key in nodenetworks_fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodenetworks table
-            sql = "INSERT INTO nodenetworks (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in sites table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodenetworks SET " + \
-                  ", ".join(columns) + \
-                  " WHERE nodenetwork_id = %(nodenetwork_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-
     def delete(self, commit = True):
         """
         Delete existing nodenetwork.
@@ -228,8 +192,8 @@ class NodeNetworks(Table):
     def __init__(self, api, nodenetwork_id_or_hostname_list = None):
         self.api = api
 
-        # N.B.: Node IDs returned may be deleted.
-        sql = "SELECT * FROM nodenetworks"
+        sql = "SELECT %s FROM nodenetworks" % \
+              ", ".join(NodeNetwork.fields)
 
         if nodenetwork_id_or_hostname_list:
             # Separate the list into integers and strings
index 7930ab1..325e872 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: Nodes.py,v 1.7 2006/10/02 16:04:42 mlhuang Exp $
+# $Id: Nodes.py,v 1.8 2006/10/02 18:32:31 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -24,6 +24,8 @@ class Node(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'nodes'
+    primary_key = 'node_id'
     fields = {
         'node_id': Parameter(int, "Node identifier"),
         'hostname': Parameter(str, "Fully qualified hostname", max = 255),
@@ -80,50 +82,6 @@ class Node(Row):
 
         return boot_state
 
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
-
-        self.validate()
-
-        # Fetch a new node_id if necessary
-        if 'node_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodes_node_id_seq') AS node_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new node_id"
-            self['node_id'] = rows[0]['node_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        nodes_fields = self.api.db.fields('nodes')
-        fields = dict(filter(lambda (key, value): \
-                             key in nodes_fields and \
-                             (key not in self.fields or not self.fields[key].ro),
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodes table
-            sql = "INSERT INTO nodes (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in nodes table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodes SET " + \
-                  ", ".join(columns) + \
-                  " WHERE node_id = %(node_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-
     def delete(self, commit = True):
         """
         Delete existing node.
@@ -152,11 +110,11 @@ class Nodes(Table):
     database.
     """
 
-    def __init__(self, api, node_id_or_hostname_list = None, fields = Node.fields.keys()):
+    def __init__(self, api, node_id_or_hostname_list = None):
         self.api = api
 
         sql = "SELECT %s FROM view_nodes WHERE deleted IS False" % \
-              ", ".join(fields)
+              ", ".join(Node.fields)
 
         if node_id_or_hostname_list:
             # Separate the list into integers and strings
index ac97192..a3c38b4 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: Persons.py,v 1.6 2006/10/02 16:04:22 mlhuang Exp $
+# $Id: Persons.py,v 1.7 2006/10/02 18:32:31 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -31,6 +31,8 @@ class Person(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'persons'
+    primary_key = 'person_id'
     fields = {
         'person_id': Parameter(int, "Account identifier"),
         'first_name': Parameter(str, "Given name", max = 128),
@@ -51,7 +53,7 @@ class Person(Row):
         'slice_ids': Parameter([int], "List of slice identifiers", ro = True),
         }
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         Row.__init__(self, fields)
         self.api = api
 
@@ -223,50 +225,6 @@ class Person(Row):
         self['site_ids'].remove(site_id)
         self['site_ids'].insert(0, site_id)
 
-    def sync(self, commit = True):
-        """
-        Commit changes back to the database.
-        """
-
-        self.validate()
-
-        # Fetch a new person_id if necessary
-        if 'person_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('persons_person_id_seq') AS person_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new person_id"
-            self['person_id'] = rows[0]['person_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        persons_fields = self.api.db.fields('persons')
-        fields = dict(filter(lambda (key, value): \
-                             key in persons_fields and \
-                             (key not in self.fields or not self.fields[key].ro),
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in persons table
-            sql = "INSERT INTO persons (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in persons table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE persons SET " + \
-                  ", ".join(columns) + \
-                  " WHERE person_id = %(person_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-
     def delete(self, commit = True):
         """
         Delete existing account.
@@ -295,11 +253,11 @@ class Persons(Table):
     non-deleted accounts.
     """
 
-    def __init__(self, api, person_id_or_email_list = None, fields = Person.fields, enabled = None):
+    def __init__(self, api, person_id_or_email_list = None, enabled = None):
         self.api = api
 
         sql = "SELECT %s FROM view_persons WHERE deleted IS False" % \
-              ", ".join(fields)
+              ", ".join(Person.fields)
 
         if enabled is not None:
             sql += " AND enabled IS %(enabled)s"
index 2fbcd1b..cb8d05e 100644 (file)
@@ -19,6 +19,8 @@ class Site(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'sites'
+    primary_key = 'site_id'
     fields = {
         'site_id': Parameter(int, "Site identifier"),
         'name': Parameter(str, "Full site name", max = 254),
@@ -39,15 +41,28 @@ class Site(Row):
         'node_ids': Parameter([int], "List of site node identifiers", ro = True),
         }
 
-    def __init__(self, api, fields):
+    def __init__(self, api, fields = {}):
         Row.__init__(self, fields)
         self.api = api
 
+    def validate_name(self, name):
+        name = name.strip()
+        if not name:
+            raise PLCInvalidArgument, "Name must be specified"
+
+        return name
+
+    validate_abbreviated_name = validate_name
+
     def validate_login_base(self, login_base):
+        login_base = login_base.strip().lower()
+
+        if not login_base:
+            raise PLCInvalidArgument, "Login base must be specified"
+
         if not set(login_base).issubset(string.ascii_letters):
             raise PLCInvalidArgument, "Login base must consist only of ASCII letters"
 
-        login_base = login_base.lower()
         conflicts = Sites(self.api, [login_base])
         for site_id, site in conflicts.iteritems():
             if 'site_id' not in self or self['site_id'] != site_id:
@@ -118,58 +133,6 @@ class Site(Row):
         if 'site_ids' in person and site_id in person['site_ids']:
             person['site_ids'].remove(site_id)
 
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
-
-        self.validate()
-
-        try:
-            if not self['name'] or \
-               not self['abbreviated_name'] or \
-               not self['login_base']:
-                raise KeyError
-        except KeyError:
-            raise PLCInvalidArgument, "name, abbreviated_name, and login_base must all be specified"
-
-        # Fetch a new site_id if necessary
-        if 'site_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('sites_site_id_seq') AS site_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new site_id"
-            self['site_id'] = rows[0]['site_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        sites_fields = self.api.db.fields('sites')
-        fields = dict(filter(lambda (key, value): \
-                             key in sites_fields and \
-                             (key not in self.fields or not self.fields[key].ro),
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in sites table
-            sql = "INSERT INTO sites (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in sites table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE sites SET " + \
-                  ", ".join(columns) + \
-                  " WHERE site_id = %(site_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-
     def delete(self, commit = True):
         """
         Delete existing site.
@@ -229,11 +192,11 @@ class Sites(Table):
     fields.
     """
 
-    def __init__(self, api, site_id_or_login_base_list = None, fields = Site.fields):
+    def __init__(self, api, site_id_or_login_base_list = None):
         self.api = api
 
         sql = "SELECT %s FROM view_sites WHERE deleted IS False" % \
-              ", ".join(fields)
+              ", ".join(Site.fields)
 
         if site_id_or_login_base_list:
             # Separate the list into integers and strings