from PLC.Faults import * from PLC.Parameter import Parameter class Row(dict): """ Representation of a row in a database table. To use, optionally instantiate with a dict of values. Update as you would a dict. Commit to the database with sync(). """ # Set this to the name of the table that stores the row. table_name = None # Set this to the name of the primary key of the table. It is # assumed that the this key is a sequence if it is not set when # sync() is called. primary_key = None # Set this to the names of tables that reference this table's # primary key. join_tables = [] # Set this to a dict of the valid fields of this object and their # types. Not all fields (e.g., joined fields) may be updated via # sync(). fields = {} def __init__(self, api, fields = {}): dict.__init__(self, fields) self.api = api def validate(self): """ Validates values. Will validate a value with a custom function if a function named 'validate_[key]' exists. """ # Warn about mandatory fields mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False) for field in mandatory_fields: if not self.has_key(field) or self[field] is None: raise PLCInvalidArgument, field + " must be specified and cannot be unset" # Validate values before committing for key, value in self.iteritems(): if value is not None and hasattr(self, 'validate_' + key): validate = getattr(self, 'validate_' + key) self[key] = validate(value) def sync(self, commit = True, insert = None): """ Flush changes back to the database. """ # Validate all specified fields self.validate() # Filter out fields that cannot be set or updated directly all_fields = self.api.db.fields(self.table_name) fields = dict(filter(lambda (key, value): \ key in all_fields and \ (key not in self.fields or \ not isinstance(self.fields[key], Parameter) or \ not self.fields[key].ro), self.items())) # Parameterize for safety keys = fields.keys() values = [self.api.db.param(key, value) for (key, value) in fields.items()] # If the primary key (usually an auto-incrementing serial # identifier) has not been specified, or the primary key is the # only field in the table, or insert has been forced. if not self.has_key(self.primary_key) or \ all_fields == [self.primary_key] or \ insert is True: # Insert new row sql = "INSERT INTO %s (%s) VALUES (%s)" % \ (self.table_name, ", ".join(keys), ", ".join(values)) else: # Update existing row columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)] sql = "UPDATE %s SET " % self.table_name + \ ", ".join(columns) + \ " WHERE %s = %s" % \ (self.primary_key, self.api.db.param(self.primary_key, self[self.primary_key])) self.api.db.do(sql, fields) if not self.has_key(self.primary_key): self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key) if commit: self.api.db.commit() def delete(self, commit = True): """ Delete row from its primary table, and from any tables that reference it. """ assert self.primary_key in self for table in self.join_tables + [self.table_name]: if isinstance(table, tuple): key = table[1] table = table[0] else: key = self.primary_key sql = "DELETE FROM %s WHERE %s = %s" % \ (table, key, self.api.db.param(self.primary_key, self[self.primary_key])) self.api.db.do(sql, self) if commit: self.api.db.commit() class Table(list): """ Representation of row(s) in a database table. """ def __init__(self, api, row, columns = None): self.api = api self.row = row if columns is None: columns = row.fields else: columns = filter(lambda x: x in row.fields, columns) if not columns: raise PLCInvalidArgument, "No valid return fields specified" self.columns = columns def sync(self, commit = True): """ Flush changes back to the database. """ for row in self: row.sync(commit) def selectall(self, sql, params = None): """ Given a list of rows from the database, fill ourselves with Row objects. """ for row in self.api.db.selectall(sql, params): self.append(self.row(self.api, row))