- make last_insert_id() return the primary key value of the last
authorMark Huang <mlhuang@cs.princeton.edu>
Tue, 3 Oct 2006 19:27:07 +0000 (19:27 +0000)
committerMark Huang <mlhuang@cs.princeton.edu>
Tue, 3 Oct 2006 19:27:07 +0000 (19:27 +0000)
  inserted row, like DBI
- add notnull and hasdef keywords to fields() so that we know which
  columns must be set

PLC/PostgreSQL.py

index 9eff584..7fe5ad2 100644 (file)
@@ -5,7 +5,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: PostgreSQL.py,v 1.1 2006/09/06 15:36:07 mlhuang Exp $
+# $Id: PostgreSQL.py,v 1.2 2006/09/25 15:12:20 mlhuang Exp $
 #
 
 import pgdb
@@ -80,8 +80,15 @@ class PostgreSQL:
         self.execute(query, params)
         return self.rowcount
 
-    def last_insert_id(self):
-        return self.lastrowid
+    def last_insert_id(self, table_name, primary_key):
+        if isinstance(self.lastrowid, int):
+            sql = "SELECT %s FROM %s WHERE oid = %d" % \
+                  (primary_key, table_name, self.lastrowid)
+            rows = self.selectall(sql, hashref = False)
+            if rows:
+                return rows[0][0]
+
+        return None
 
     def execute(self, query, params = None):
         self.execute_array(query, (params,))
@@ -134,15 +141,21 @@ class PostgreSQL:
         else:
             return rows
 
-    def fields(self, table):
+    def fields(self, table, notnull = None, hasdef = None):
         """
         Return the names of the fields of the specified table.
         """
 
-        rows = self.selectall("SELECT attname FROM pg_attribute, pg_class" \
-                              " WHERE pg_class.oid = attrelid" \
-                              " AND attnum > 0 AND relname = %(table)s",
-                              locals(),
-                              hashref = False)
+        sql = "SELECT attname FROM pg_attribute, pg_class" \
+              " WHERE pg_class.oid = attrelid" \
+              " AND attnum > 0 AND relname = %(table)s"
+
+        if notnull is not None:
+            sql += " AND attnotnull is %(notnull)s"
+
+        if hasdef is not None:
+            sql += " AND atthasdef is %(hasdef)s"
+
+        rows = self.selectall(sql, locals(), hashref = False)
 
         return [row[0] for row in rows]