from PLC.Faults import * from PLC.Parameter import Parameter class Row(dict): """ Representation of a row in a database table. To use, optionally instantiate with a dict of values. Update as you would a dict. Commit to the database with sync(). """ # Set this to the name of the table that stores the row. table_name = None # Set this to the name of the primary key of the table. It is # assumed that the this key is a sequence if it is not set when # sync() is called. primary_key = None # Set this to a dict of the valid fields of this object. Not all # fields (e.g., joined fields) may be updated via sync(). fields = {} def validate(self): """ Validates values. Will validate a value with a custom function if a function named 'validate_[key]' exists. """ # Warn about mandatory fields mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False) for field in mandatory_fields: if not self.has_key(field) or self[field] is None: raise PLCInvalidArgument, field + " must be specified and cannot be unset" # Validate values before committing for key, value in self.iteritems(): if value is not None and hasattr(self, 'validate_' + key): validate = getattr(self, 'validate_' + key) self[key] = validate(value) def sync(self, commit = True): """ Flush changes back to the database. """ # Validate all specified fields self.validate() # Filter out fields that cannot be set or updated directly all_fields = self.api.db.fields(self.table_name) fields = dict(filter(lambda (key, value): \ key in all_fields and \ (key not in self.fields or \ not isinstance(self.fields[key], Parameter) or \ not self.fields[key].ro), self.items())) # Parameterize for safety keys = fields.keys() values = [self.api.db.param(key, value) for (key, value) in fields.items()] if not self.has_key(self.primary_key): # Insert new row sql = "INSERT INTO %s (%s) VALUES (%s);" % \ (self.table_name, ", ".join(keys), ", ".join(values)) else: # Update existing row columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)] sql = "UPDATE %s SET " % self.table_name + \ ", ".join(columns) + \ " WHERE %s = %s" % \ (self.primary_key, self.api.db.param(self.primary_key, self[self.primary_key])) self.api.db.do(sql, fields) if not self.has_key(self.primary_key): self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key) if commit: self.api.db.commit() def delete(self, commit = True): """ Delete row from its primary table. """ assert self.primary_key in self sql = "DELETE FROM %s" % self.table_name + \ " WHERE %s = %s" % \ (self.primary_key, self.api.db.param(self.primary_key, self[self.primary_key])) self.api.db.do(sql, self) if commit: self.api.db.commit() class Table(dict): """ Representation of row(s) in a database table. """ def sync(self, commit = True): """ Flush changes back to the database. """ for row in self.values(): row.sync(commit)