obtain integer primary keys by calling nextval() on the sequence instead of quering...
[plcapi.git] / PLC / Table.py
index cb0eb6b..74da3e9 100644 (file)
@@ -1,4 +1,4 @@
-from types import StringTypes
+from types import StringTypes, IntType, LongType
 import time
 import calendar
 
@@ -50,6 +50,35 @@ class Row(dict):
             if value is not None and hasattr(self, 'validate_' + key):
                 validate = getattr(self, 'validate_' + key)
                 self[key] = validate(value)
+       
+    def separate_types(self, items):
+       """
+       Separate a list of different typed objects. 
+       Return a list for each type (ints, strs and dicts)
+       """
+       
+       if isinstance(items, (list, tuple, set)):
+           ints = filter(lambda x: isinstance(x, (int, long)), items)
+           strs = filter(lambda x: isinstance(x, StringTypes), items)
+           dicts = filter(lambda x: isinstance(x, dict), items)
+           return (ints, strs, dicts)          
+       else:
+           raise PLCInvalidArgument, "Can only separate list types" 
+               
+
+    def associate(self, *args):
+       """
+       Provides a means for high lvl api calls to associate objects
+        using low lvl calls.
+       """
+
+       if len(args) < 3:
+           raise PLCInvalidArgumentCount, "auth, field, value must be specified"
+       elif hasattr(self, 'associate_' + args[1]):
+           associate = getattr(self, 'associate_'+args[1])
+           associate(*args)
+       else:
+           raise PLCInvalidArguemnt, "No such associate function associate_%s" % args[1]
 
     def validate_timestamp(self, timestamp, check_future = False):
         """
@@ -92,7 +121,7 @@ class Row(dict):
             assert isinstance(obj, classobj)
             assert isinstance(obj, Row)
             assert obj.primary_key in obj
-            assert join_table in obj.join_tables
+           assert join_table in obj.join_tables
 
             # By default, just insert the primary keys of each object
             # into the join table.
@@ -201,6 +230,14 @@ class Row(dict):
         if not self.has_key(self.primary_key) or \
            keys == [self.primary_key] or \
            insert is True:
+           
+           # If primary key id is a serial int, get next id
+           if self.fields[self.primary_key].type in (IntType, LongType):
+               pk_id = self.api.db.next_id(self.table_name, self.primary_key)
+               self[self.primary_key] = pk_id
+               db_fields[self.primary_key] = pk_id
+               keys = db_fields.keys()
+               values = [self.api.db.param(key, value) for (key, value) in db_fields.items()]
             # Insert new row
             sql = "INSERT INTO %s (%s) VALUES (%s)" % \
                   (self.table_name, ", ".join(keys), ", ".join(values))
@@ -215,9 +252,6 @@ class Row(dict):
 
         self.api.db.do(sql, db_fields)
 
-        if not self.has_key(self.primary_key):
-            self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
-
         if commit:
             self.api.db.commit()