X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=PLC%2FTable.py;h=1300cf146240eec3136a7de66f18c847230696ed;hb=234a5cd8a1a6526cda2647d17d29e378ff1a052e;hp=ddc5f88a10b31e3fa24e228c06a93b304acab6fa;hpb=13c265f9cbbb598afbae59748d0cbec896d729f1;p=plcapi.git diff --git a/PLC/Table.py b/PLC/Table.py index ddc5f88..1300cf1 100644 --- a/PLC/Table.py +++ b/PLC/Table.py @@ -1,3 +1,6 @@ +from PLC.Faults import * +from PLC.Parameter import Parameter + class Row(dict): """ Representation of a row in a database table. To use, optionally @@ -5,6 +8,14 @@ class Row(dict): dict. Commit to the database with sync(). """ + # Set this to the name of the table that stores the row. + table_name = None + + # Set this to the name of the primary key of the table. It is + # assumed that the this key is a sequence if it is not set when + # sync() is called. + primary_key = None + # Set this to a dict of the valid fields of this object. Not all # fields (e.g., joined fields) may be updated via sync(). fields = {} @@ -15,8 +26,13 @@ class Row(dict): if a function named 'validate_[key]' exists. """ + # Warn about mandatory fields + mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False) + for field in mandatory_fields: + if not self.has_key(field) or self[field] is None: + raise PLCInvalidArgument, field + " must be specified and cannot be unset" + # Validate values before committing - # XXX Also truncate strings that are too long for key, value in self.iteritems(): if value is not None and hasattr(self, 'validate_' + key): validate = getattr(self, 'validate_' + key) @@ -27,7 +43,59 @@ class Row(dict): Flush changes back to the database. """ - pass + # Validate all specified fields + self.validate() + + # Filter out fields that cannot be set or updated directly + all_fields = self.api.db.fields(self.table_name) + fields = dict(filter(lambda (key, value): \ + key in all_fields and \ + (key not in self.fields or \ + not isinstance(self.fields[key], Parameter) or \ + not self.fields[key].ro), + self.items())) + + # Parameterize for safety + keys = fields.keys() + values = [self.api.db.param(key, value) for (key, value) in fields.items()] + + if not self.has_key(self.primary_key): + # Insert new row + sql = "INSERT INTO %s (%s) VALUES (%s);" % \ + (self.table_name, ", ".join(keys), ", ".join(values)) + else: + # Update existing row + columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)] + sql = "UPDATE %s SET " % self.table_name + \ + ", ".join(columns) + \ + " WHERE %s = %s" % \ + (self.primary_key, + self.api.db.param(self.primary_key, self[self.primary_key])) + + self.api.db.do(sql, fields) + + if not self.has_key(self.primary_key): + self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key) + + if commit: + self.api.db.commit() + + def delete(self, commit = True): + """ + Delete row from its primary table. + """ + + assert self.primary_key in self + + sql = "DELETE FROM %s" % self.table_name + \ + " WHERE %s = %s" % \ + (self.primary_key, + self.api.db.param(self.primary_key, self[self.primary_key])) + + self.api.db.do(sql, self) + + if commit: + self.api.db.commit() class Table(dict): """