From: Mark Huang Date: Mon, 8 Jan 2007 18:14:55 +0000 (+0000) Subject: - add class methods add_object() and remove_object() which can be used X-Git-Tag: pycurl-7_13_1~158 X-Git-Url: http://git.onelab.eu/?a=commitdiff_plain;h=f5c08abcdf23aaf3d3c133b54758fb99913278b4;p=plcapi.git - add class methods add_object() and remove_object() which can be used as generic join/unjoin functions - add db_fields() function to return only those fields that can be set or updated directly (i.e., intrinsic fields) - add __eq__() operator to compare two objects (just intrinsic fields) --- diff --git a/PLC/Table.py b/PLC/Table.py index 9041c260..12290073 100644 --- a/PLC/Table.py +++ b/PLC/Table.py @@ -1,3 +1,4 @@ +from types import StringTypes import time import calendar @@ -53,7 +54,7 @@ class Row(dict): time_format = "%Y-%m-%d %H:%M:%S" def validate_timestamp (self, timestamp, check_future=False): # in case we try to sync the same object twice - if isinstance(timestamp,str): + if isinstance(timestamp, StringTypes): # calendar.timegm is the inverse of time.gmtime, in that it computes in UTC # surprisingly enough, no other method in the time module behaves this way # this method is documented in the time module's documentation @@ -63,6 +64,110 @@ class Row(dict): raise PLCInvalidArgument, "%s: date must be in the future"%human return human + @classmethod + def add_object(self, classobj, join_table, columns = None): + """ + Returns a function that can be used to associate this object + with another. + """ + + def add(self, obj, columns = None, commit = True): + """ + Associate with the specified object. + """ + + # Various sanity checks + assert isinstance(self, Row) + assert self.primary_key in self + assert join_table in self.join_tables + assert isinstance(obj, classobj) + assert isinstance(obj, Row) + assert obj.primary_key in obj + assert join_table in obj.join_tables + + # By default, just insert the primary keys of each object + # into the join table. + if columns is None: + columns = {self.primary_key: self[self.primary_key], + obj.primary_key: obj[obj.primary_key]} + + params = [] + for name, value in columns.iteritems(): + params.append(self.api.db.param(name, value)) + + self.api.db.do("INSERT INTO %s (%s) VALUES(%s)" % \ + (join_table, ", ".join(columns), ", ".join(params)), + columns) + + if commit: + self.api.db.commit() + + return add + + @classmethod + def remove_object(self, classobj, join_table): + """ + Returns a function that can be used to disassociate this + object with another. + """ + + def remove(self, obj, commit = True): + """ + Disassociate from the specified object. + """ + + assert isinstance(self, Row) + assert self.primary_key in self + assert join_table in self.join_tables + assert isinstance(obj, classobj) + assert isinstance(obj, Row) + assert obj.primary_key in obj + assert join_table in obj.join_tables + + self_id = self[self.primary_key] + obj_id = obj[obj.primary_key] + + self.api.db.do("DELETE FROM %s WHERE %s = %s AND %s = %s" % \ + (join_table, + self.primary_key, self.api.db.param('self_id', self_id), + obj.primary_key, self.api.db.param('obj_id', obj_id)), + locals()) + + if commit: + self.api.db.commit() + + return remove + + def db_fields(self, obj = None): + """ + Return only those fields that can be set or updated directly + (i.e., those fields that are in the primary table (table_name) + for this object, and are not marked as a read-only Parameter. + """ + + if obj is None: + obj = self + + db_fields = self.api.db.fields(self.table_name) + return dict(filter(lambda (key, value): \ + key in db_fields and \ + (key not in self.fields or \ + not isinstance(self.fields[key], Parameter) or \ + not self.fields[key].ro), + obj.items())) + + def __eq__(self, y): + """ + Compare two objects. + """ + + # Filter out fields that cannot be set or updated directly + # (and thus would not affect equality for the purposes of + # deciding if we should sync() or not). + x = self.db_fields() + y = self.db_fields(y) + return dict.__eq__(x, y) + def sync(self, commit = True, insert = None): """ Flush changes back to the database. @@ -72,23 +177,17 @@ class Row(dict): self.validate() # Filter out fields that cannot be set or updated directly - all_fields = self.api.db.fields(self.table_name) - fields = dict(filter(lambda (key, value): \ - key in all_fields and \ - (key not in self.fields or \ - not isinstance(self.fields[key], Parameter) or \ - not self.fields[key].ro), - self.items())) + db_fields = self.db_fields() # Parameterize for safety - keys = fields.keys() - values = [self.api.db.param(key, value) for (key, value) in fields.items()] + keys = db_fields.keys() + values = [self.api.db.param(key, value) for (key, value) in db_fields.items()] # If the primary key (usually an auto-incrementing serial # identifier) has not been specified, or the primary key is the # only field in the table, or insert has been forced. if not self.has_key(self.primary_key) or \ - all_fields == [self.primary_key] or \ + keys == [self.primary_key] or \ insert is True: # Insert new row sql = "INSERT INTO %s (%s) VALUES (%s)" % \ @@ -102,7 +201,7 @@ class Row(dict): (self.primary_key, self.api.db.param(self.primary_key, self[self.primary_key])) - self.api.db.do(sql, fields) + self.api.db.do(sql, db_fields) if not self.has_key(self.primary_key): self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)