# Mark Huang <mlhuang@cs.princeton.edu>
# Copyright (C) 2006 The Trustees of Princeton University
#
-# $Id: PostgreSQL.py,v 1.8 2006/10/30 16:37:49 mlhuang Exp $
+# $Id: PostgreSQL.py,v 1.15 2007/02/12 18:41:27 mlhuang Exp $
#
import psycopg2
import psycopg2.extensions
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+# UNICODEARRAY not exported yet
+psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
import pgdb
from types import StringTypes, NoneType
class PostgreSQL:
def __init__(self, api):
self.api = api
+ self.debug = False
+ self.connection = None
- # Initialize database connection
- if psycopg2:
- try:
- # Try UNIX socket first
- self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
- password = api.config.PLC_DB_PASSWORD,
- database = api.config.PLC_DB_NAME)
- except psycopg2.OperationalError:
- # Fall back on TCP
- self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
- password = api.config.PLC_DB_PASSWORD,
- database = api.config.PLC_DB_NAME,
- host = api.config.PLC_DB_HOST,
- port = api.config.PLC_DB_PORT)
- self.db.set_client_encoding("UNICODE")
- else:
- self.db = pgdb.connect(user = api.config.PLC_DB_USER,
- password = api.config.PLC_DB_PASSWORD,
- host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
- database = api.config.PLC_DB_NAME)
-
- self.cursor = self.db.cursor()
+ def cursor(self):
+ if self.connection is None:
+ # (Re)initialize database connection
+ if psycopg2:
+ try:
+ # Try UNIX socket first
+ self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+ password = self.api.config.PLC_DB_PASSWORD,
+ database = self.api.config.PLC_DB_NAME)
+ except psycopg2.OperationalError:
+ # Fall back on TCP
+ self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+ password = self.api.config.PLC_DB_PASSWORD,
+ database = self.api.config.PLC_DB_NAME,
+ host = self.api.config.PLC_DB_HOST,
+ port = self.api.config.PLC_DB_PORT)
+ self.connection.set_client_encoding("UNICODE")
+ else:
+ self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
+ password = self.api.config.PLC_DB_PASSWORD,
+ host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
+ database = self.api.config.PLC_DB_NAME)
(self.rowcount, self.description, self.lastrowid) = \
(None, None, None)
+ return self.connection.cursor()
+
+ def close(self):
+ if self.connection is not None:
+ self.connection.close()
+ self.connection = None
+
def quote(self, value):
"""
Returns quoted version of the specified value.
pass
def commit(self):
- self.db.commit()
+ self.connection.commit()
def rollback(self):
- self.db.rollback()
+ self.connection.rollback()
def do(self, query, params = None):
- self.execute(query, params)
+ cursor = self.execute(query, params)
+ cursor.close()
return self.rowcount
def last_insert_id(self, table_name, primary_key):
return None
def execute(self, query, params = None):
- self.execute_array(query, (params,))
+ return self.execute_array(query, (params,))
def execute_array(self, query, param_seq):
- cursor = self.cursor
+ cursor = self.cursor()
try:
+ if self.debug:
+ for params in param_seq:
+ if params:
+ print >> log, query % params
+ else:
+ print >> log, query
+
# psycopg2 requires %()s format for all parameters,
# regardless of type.
if psycopg2:
"<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
" and reference " + uuid)
+ return cursor
+
def selectall(self, query, params = None, hashref = True, key_field = None):
"""
Return each row as a dictionary keyed on field name (like DBI
selectall_hashref()).
If params is specified, the specified parameters will be bound
- to the query (see PLC.DB.parameterize() and
- pgdb.cursor.execute()).
+ to the query.
"""
- self.execute(query, params)
- rows = self.cursor.fetchall()
+ cursor = self.execute(query, params)
+ rows = cursor.fetchall()
+ cursor.close()
- if hashref:
+ if hashref or key_field is not None:
# Return each row as a dictionary keyed on field name
# (like DBI selectrow_hashref()).
labels = [column[0] for column in self.description]
Return the names of the fields of the specified table.
"""
+ if hasattr(self, 'fields_cache'):
+ if self.fields_cache.has_key((table, notnull, hasdef)):
+ return self.fields_cache[(table, notnull, hasdef)]
+ else:
+ self.fields_cache = {}
+
sql = "SELECT attname FROM pg_attribute, pg_class" \
" WHERE pg_class.oid = attrelid" \
" AND attnum > 0 AND relname = %(table)s"
rows = self.selectall(sql, locals(), hashref = False)
- return [row[0] for row in rows]
+ self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
+
+ return self.fields_cache[(table, notnull, hasdef)]