+from PLC.Faults import *
+from PLC.Parameter import Parameter
+
class Row(dict):
"""
Representation of a row in a database table. To use, optionally
instantiate with a dict of values. Update as you would a
- dict. Commit to the database with flush().
+ dict. Commit to the database with sync().
"""
- # Set this to a dict of the valid columns in this table. If a column
- # name ends in 's' and the column value is set by referring to the
- # column without the 's', it is assumed that the column values
- # should be aggregated into lists. For example, if fields contains
- # the column 'role_ids' and row['role_id'] is set repeatedly to
- # different values, row['role_ids'] will contain a list of the set
- # values.
- fields = {}
-
- # These fields are derived from join tables and are not actually
- # in the sites table.
- join_fields = {}
+ # Set this to the name of the table that stores the row.
+ table_name = None
- # These fields are derived from join tables and are not returned
- # by default unless specified.
- extra_fields = {}
+ # Set this to the name of the primary key of the table. It is
+ # assumed that the this key is a sequence if it is not set when
+ # sync() is called.
+ primary_key = None
- def __init__(self, fields):
- self.update(fields)
-
- def update(self, fields):
- for key, value in fields.iteritems():
- self.__setitem__(key, value)
+ # Set this to the names of tables that reference this table's
+ # primary key.
+ join_tables = []
- def __setitem__(self, key, value):
- """
- Magically takes care of aggregating certain variables into
- lists.
- """
+ # Set this to a dict of the valid fields of this object and their
+ # types. Not all fields (e.g., joined fields) may be updated via
+ # sync().
+ fields = {}
- # All known keys
- all_fields = self.fields.keys() + \
- self.join_fields.keys() + \
- self.extra_fields.keys()
-
- # Aggregate into lists
- if (key + 's') in all_fields:
- key += 's'
- try:
- if value not in self[key] and value is not None:
- self[key].append(value)
- except KeyError:
- if value is None:
- self[key] = []
- else:
- self[key] = [value]
- return
-
- elif key in all_fields:
- dict.__setitem__(self, key, value)
+ def __init__(self, api, fields = {}):
+ dict.__init__(self, fields)
+ self.api = api
def validate(self):
"""
if a function named 'validate_[key]' exists.
"""
+ # Warn about mandatory fields
+ mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False)
+ for field in mandatory_fields:
+ if not self.has_key(field) or self[field] is None:
+ raise PLCInvalidArgument, field + " must be specified and cannot be unset"
+
# Validate values before committing
- # XXX Also truncate strings that are too long
for key, value in self.iteritems():
if value is not None and hasattr(self, 'validate_' + key):
validate = getattr(self, 'validate_' + key)
self[key] = validate(value)
- def flush(self, commit = True):
+ def sync(self, commit = True, insert = None):
"""
Flush changes back to the database.
"""
- pass
+ # Validate all specified fields
+ self.validate()
+
+ # Filter out fields that cannot be set or updated directly
+ all_fields = self.api.db.fields(self.table_name)
+ fields = dict(filter(lambda (key, value): \
+ key in all_fields and \
+ (key not in self.fields or \
+ not isinstance(self.fields[key], Parameter) or \
+ not self.fields[key].ro),
+ self.items()))
+
+ # Parameterize for safety
+ keys = fields.keys()
+ values = [self.api.db.param(key, value) for (key, value) in fields.items()]
+
+ # If the primary key (usually an auto-incrementing serial
+ # identifier) has not been specified, or the primary key is the
+ # only field in the table, or insert has been forced.
+ if not self.has_key(self.primary_key) or \
+ all_fields == [self.primary_key] or \
+ insert is True:
+ # Insert new row
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % \
+ (self.table_name, ", ".join(keys), ", ".join(values))
+ else:
+ # Update existing row
+ columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
+ sql = "UPDATE %s SET " % self.table_name + \
+ ", ".join(columns) + \
+ " WHERE %s = %s" % \
+ (self.primary_key,
+ self.api.db.param(self.primary_key, self[self.primary_key]))
+
+ self.api.db.do(sql, fields)
+
+ if not self.has_key(self.primary_key):
+ self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
+
+ if commit:
+ self.api.db.commit()
+
+ def delete(self, commit = True):
+ """
+ Delete row from its primary table, and from any tables that
+ reference it.
+ """
+
+ assert self.primary_key in self
+
+ for table in self.join_tables + [self.table_name]:
+ if isinstance(table, tuple):
+ key = table[1]
+ table = table[0]
+ else:
+ key = self.primary_key
+
+ sql = "DELETE FROM %s WHERE %s = %s" % \
+ (table, key,
+ self.api.db.param(self.primary_key, self[self.primary_key]))
+
+ self.api.db.do(sql, self)
+
+ if commit:
+ self.api.db.commit()
class Table(dict):
"""
Representation of row(s) in a database table.
"""
- def flush(self, commit = True):
+ def __init__(self, api, row):
+ self.api = api
+ self.row = row
+
+ def sync(self, commit = True):
"""
Flush changes back to the database.
"""
for row in self.values():
- row.flush(commit)
+ row.sync(commit)
+
+ def selectall(self, sql, params = None):
+ """
+ Given a list of rows from the database, fill ourselves with
+ Row objects keyed on the primary key defined by the Row class
+ we were initialized with.
+ """
+
+ for row in self.api.db.selectall(sql, params):
+ self[row[self.row.primary_key]] = self.row(self.api, row)