- do not reuse cursors!
authorMark Huang <mlhuang@cs.princeton.edu>
Sun, 11 Feb 2007 18:34:06 +0000 (18:34 +0000)
committerMark Huang <mlhuang@cs.princeton.edu>
Sun, 11 Feb 2007 18:34:06 +0000 (18:34 +0000)
- add .close() method to shut down DB connection

PLC/PostgreSQL.py

index 57321f5..f15ebf9 100644 (file)
@@ -5,7 +5,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: PostgreSQL.py,v 1.12 2007/02/08 15:15:21 mlhuang Exp $
+# $Id: PostgreSQL.py,v 1.13 2007/02/11 04:53:40 mlhuang Exp $
 #
 
 import psycopg2
@@ -52,33 +52,41 @@ class PostgreSQL:
     def __init__(self, api):
         self.api = api
         self.debug = False
+        self.connection = None
 
-        # Initialize database connection
-        if psycopg2:
-            try:
-                # Try UNIX socket first
-                self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
-                                           password = api.config.PLC_DB_PASSWORD,
-                                           database = api.config.PLC_DB_NAME)
-            except psycopg2.OperationalError:
-                # Fall back on TCP
-                self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
-                                           password = api.config.PLC_DB_PASSWORD,
-                                           database = api.config.PLC_DB_NAME,
-                                           host = api.config.PLC_DB_HOST,
-                                           port = api.config.PLC_DB_PORT)
-            self.db.set_client_encoding("UNICODE")
-        else:
-            self.db = pgdb.connect(user = api.config.PLC_DB_USER,
-                                   password = api.config.PLC_DB_PASSWORD,
-                                   host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
-                                   database = api.config.PLC_DB_NAME)
-
-        self.cursor = self.db.cursor()
+    def cursor(self):
+        if self.connection is None:
+            # (Re)initialize database connection
+            if psycopg2:
+                try:
+                    # Try UNIX socket first
+                    self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+                                                       password = self.api.config.PLC_DB_PASSWORD,
+                                                       database = self.api.config.PLC_DB_NAME)
+                except psycopg2.OperationalError:
+                    # Fall back on TCP
+                    self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+                                                       password = self.api.config.PLC_DB_PASSWORD,
+                                                       database = self.api.config.PLC_DB_NAME,
+                                                       host = self.api.config.PLC_DB_HOST,
+                                                       port = self.api.config.PLC_DB_PORT)
+                self.connection.set_client_encoding("UNICODE")
+            else:
+                self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
+                                               password = self.api.config.PLC_DB_PASSWORD,
+                                               host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
+                                               database = self.api.config.PLC_DB_NAME)
 
         (self.rowcount, self.description, self.lastrowid) = \
                         (None, None, None)
 
+        return self.connection.cursor()
+
+    def close(self):
+        if self.connection is not None:
+            self.connection.close()
+            self.connection = None
+
     def quote(self, value):
         """
         Returns quoted version of the specified value.
@@ -116,10 +124,10 @@ class PostgreSQL:
         pass
 
     def commit(self):
-        self.db.commit()
+        self.connection.commit()
 
     def rollback(self):
-        self.db.rollback()
+        self.connection.rollback()
 
     def do(self, query, params = None):
         self.execute(query, params)
@@ -136,10 +144,10 @@ class PostgreSQL:
         return None
 
     def execute(self, query, params = None):
-        self.execute_array(query, (params,))
+        return self.execute_array(query, (params,))
 
     def execute_array(self, query, param_seq):
-        cursor = self.cursor
+        cursor = self.cursor()
         try:
             if self.debug:
                 for params in param_seq:
@@ -153,13 +161,7 @@ class PostgreSQL:
             if psycopg2:
                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
 
-            try:
-                cursor.executemany(query, param_seq)
-            except InterfaceError:
-                # Try one more time with another cursor
-                cursor = self.cursor = self.db.cursor()
-                cursor.executemany(query, param_seq)
-
+            cursor.executemany(query, param_seq)
             (self.rowcount, self.description, self.lastrowid) = \
                             (cursor.rowcount, cursor.description, cursor.lastrowid)
         except Exception, e:
@@ -179,6 +181,8 @@ class PostgreSQL:
                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
                              " and reference " + uuid)
 
+        return cursor
+
     def selectall(self, query, params = None, hashref = True, key_field = None):
         """
         Return each row as a dictionary keyed on field name (like DBI
@@ -187,12 +191,10 @@ class PostgreSQL:
         selectall_hashref()).
 
         If params is specified, the specified parameters will be bound
-        to the query (see PLC.DB.parameterize() and
-        pgdb.cursor.execute()).
+        to the query.
         """
 
-        self.execute(query, params)
-        rows = self.cursor.fetchall()
+        rows = self.execute(query, params).fetchall()
 
         if hashref or key_field is not None:
             # Return each row as a dictionary keyed on field name