X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=PLC%2FTable.py;h=74da3e9c8ac27a1b3bb8dc02e1167d40d4c4dfa6;hb=b74c5345bb5c74c64beeeacad2e9aaba3220455a;hp=4e9ddabb90adf5a466b0e6c07afd8a6468d4023a;hpb=1600712a2312d824ade8c8a688e1dee9a426a68e;p=plcapi.git diff --git a/PLC/Table.py b/PLC/Table.py index 4e9ddab..74da3e9 100644 --- a/PLC/Table.py +++ b/PLC/Table.py @@ -1,4 +1,4 @@ -from types import StringTypes +from types import StringTypes, IntType, LongType import time import calendar @@ -50,6 +50,35 @@ class Row(dict): if value is not None and hasattr(self, 'validate_' + key): validate = getattr(self, 'validate_' + key) self[key] = validate(value) + + def separate_types(self, items): + """ + Separate a list of different typed objects. + Return a list for each type (ints, strs and dicts) + """ + + if isinstance(items, (list, tuple, set)): + ints = filter(lambda x: isinstance(x, (int, long)), items) + strs = filter(lambda x: isinstance(x, StringTypes), items) + dicts = filter(lambda x: isinstance(x, dict), items) + return (ints, strs, dicts) + else: + raise PLCInvalidArgument, "Can only separate list types" + + + def associate(self, *args): + """ + Provides a means for high lvl api calls to associate objects + using low lvl calls. + """ + + if len(args) < 3: + raise PLCInvalidArgumentCount, "auth, field, value must be specified" + elif hasattr(self, 'associate_' + args[1]): + associate = getattr(self, 'associate_'+args[1]) + associate(*args) + else: + raise PLCInvalidArguemnt, "No such associate function associate_%s" % args[1] def validate_timestamp(self, timestamp, check_future = False): """ @@ -201,6 +230,14 @@ class Row(dict): if not self.has_key(self.primary_key) or \ keys == [self.primary_key] or \ insert is True: + + # If primary key id is a serial int, get next id + if self.fields[self.primary_key].type in (IntType, LongType): + pk_id = self.api.db.next_id(self.table_name, self.primary_key) + self[self.primary_key] = pk_id + db_fields[self.primary_key] = pk_id + keys = db_fields.keys() + values = [self.api.db.param(key, value) for (key, value) in db_fields.items()] # Insert new row sql = "INSERT INTO %s (%s) VALUES (%s)" % \ (self.table_name, ", ".join(keys), ", ".join(values)) @@ -215,9 +252,6 @@ class Row(dict): self.api.db.do(sql, db_fields) - if not self.has_key(self.primary_key): - self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key) - if commit: self.api.db.commit()