Setting tag plcapi-5.4-2
[plcapi.git] / PLC / PostgreSQL.py
index fa5f1cf..e04b02b 100644 (file)
 #
 #
-# PostgreSQL database interface. Sort of like DBI(3) (Database
-# independent interface for Perl).
+# PostgreSQL database interface. 
+# Sort of like DBI(3) (Database independent interface for Perl).
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: PostgreSQL.py,v 1.8 2006/10/30 16:37:49 mlhuang Exp $
-#
 
 import psycopg2
 import psycopg2.extensions
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
 import psycopg2
 import psycopg2.extensions
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+# UNICODEARRAY not exported yet
+psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
 
 
-import pgdb
+import types
 from types import StringTypes, NoneType
 import traceback
 import commands
 import re
 from pprint import pformat
 
 from types import StringTypes, NoneType
 import traceback
 import commands
 import re
 from pprint import pformat
 
-from PLC.Debug import profile, log
+from PLC.Logger import logger
+from PLC.Debug import profile
 from PLC.Faults import *
 from PLC.Faults import *
-
-if not psycopg2:
-    is8bit = re.compile("[\x80-\xff]").search
-
-    def unicast(typecast):
-        """
-        pgdb returns raw UTF-8 strings. This function casts strings that
-        appear to contain non-ASCII characters to unicode objects.
-        """
-    
-        def wrapper(*args, **kwds):
-            value = typecast(*args, **kwds)
-
-            # pgdb always encodes unicode objects as UTF-8 regardless of
-            # the DB encoding (and gives you no option for overriding
-            # the encoding), so always decode 8-bit objects as UTF-8.
-            if isinstance(value, str) and is8bit(value):
-                value = unicode(value, "utf-8")
-
-            return value
-
-        return wrapper
-
-    pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
+from datetime import datetime as DateTimeType
 
 class PostgreSQL:
     def __init__(self, api):
         self.api = api
 
 class PostgreSQL:
     def __init__(self, api):
         self.api = api
+        self.debug = False
+#        self.debug = True
+        self.connection = None
 
 
-        # Initialize database connection
-        if psycopg2:
+    def cursor(self):
+        if self.connection is None:
+            # (Re)initialize database connection
             try:
                 # Try UNIX socket first
             try:
                 # Try UNIX socket first
-                self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
-                                           password = api.config.PLC_DB_PASSWORD,
-                                           database = api.config.PLC_DB_NAME)
+                self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+                                                   password = self.api.config.PLC_DB_PASSWORD,
+                                                   database = self.api.config.PLC_DB_NAME)
             except psycopg2.OperationalError:
                 # Fall back on TCP
             except psycopg2.OperationalError:
                 # Fall back on TCP
-                self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
-                                           password = api.config.PLC_DB_PASSWORD,
-                                           database = api.config.PLC_DB_NAME,
-                                           host = api.config.PLC_DB_HOST,
-                                           port = api.config.PLC_DB_PORT)
-            self.db.set_client_encoding("UNICODE")
-        else:
-            self.db = pgdb.connect(user = api.config.PLC_DB_USER,
-                                   password = api.config.PLC_DB_PASSWORD,
-                                   host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
-                                   database = api.config.PLC_DB_NAME)
-
-        self.cursor = self.db.cursor()
+                self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
+                                                   password = self.api.config.PLC_DB_PASSWORD,
+                                                   database = self.api.config.PLC_DB_NAME,
+                                                   host = self.api.config.PLC_DB_HOST,
+                                                   port = self.api.config.PLC_DB_PORT)
+            self.connection.set_client_encoding("UNICODE")
 
         (self.rowcount, self.description, self.lastrowid) = \
                         (None, None, None)
 
 
         (self.rowcount, self.description, self.lastrowid) = \
                         (None, None, None)
 
-    def quote(self, value):
-        """
-        Returns quoted version of the specified value.
-        """
+        return self.connection.cursor()
+
+    def close(self):
+        if self.connection is not None:
+            self.connection.close()
+            self.connection = None
 
 
-        # The pgdb._quote function is good enough for general SQL
-        # quoting, except for array types.
-        if isinstance(value, (list, tuple, set)):
-            return "ARRAY[%s]" % ", ".join(map, self.quote, value)
+    @staticmethod
+    # From pgdb, and simplify code
+    def _quote(x):
+        if isinstance(x, DateTimeType):
+            x = str(x)
+        elif isinstance(x, unicode):
+            x = x.encode( 'utf-8' )
+    
+        if isinstance(x, types.StringType):
+            x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
+        elif isinstance(x, (types.IntType, types.LongType, types.FloatType)):
+            pass
+        elif x is None:
+            x = 'NULL'
+        elif isinstance(x, (types.ListType, types.TupleType, set)):
+            x = 'ARRAY[%s]' % ', '.join(map(lambda x: str(_quote(x)), x))
+        elif hasattr(x, '__pg_repr__'):
+            x = x.__pg_repr__()
         else:
         else:
-            return pgdb._quote(value)
+            raise PLCDBError, 'Cannot quote type %s' % type(x)
+        return x
 
 
-    quote = classmethod(quote)
 
 
+    def quote(self, value):
+        """
+        Returns quoted version of the specified value.
+        """
+        return PostgreSQL._quote (value)
+
+# following is an unsuccessful attempt to re-use lib code as much as possible
+#    def quote(self, value):
+#        # The pgdb._quote function is good enough for general SQL
+#        # quoting, except for array types.
+#        if isinstance (value, (types.ListType, types.TupleType, set)):
+#            'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
+#        else:
+#            try:
+#                # up to PyGreSQL-3.x, function was pgdb._quote
+#                import pgdb
+#                return pgdb._quote(value)
+#            except:
+#                # with PyGreSQL-4.x, use psycopg2's adapt
+#                from psycopg2.extensions import adapt
+#                return adapt (value)
+
+    @classmethod
     def param(self, name, value):
         # None is converted to the unquoted string NULL
         if isinstance(value, NoneType):
     def param(self, name, value):
         # None is converted to the unquoted string NULL
         if isinstance(value, NoneType):
@@ -106,22 +120,30 @@ class PostgreSQL:
 
         return '%(' + name + ')' + conversion
 
 
         return '%(' + name + ')' + conversion
 
-    param = classmethod(param)
-
     def begin_work(self):
         # Implicit in pgdb.connect()
         pass
 
     def commit(self):
     def begin_work(self):
         # Implicit in pgdb.connect()
         pass
 
     def commit(self):
-        self.db.commit()
+        self.connection.commit()
 
     def rollback(self):
 
     def rollback(self):
-        self.db.rollback()
+        self.connection.rollback()
 
     def do(self, query, params = None):
 
     def do(self, query, params = None):
-        self.execute(query, params)
+        cursor = self.execute(query, params)
+        cursor.close()
         return self.rowcount
 
         return self.rowcount
 
+    def next_id(self, table_name, primary_key):
+        sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
+        sql = "SELECT nextval('%(sequence)s')" % locals()
+        rows = self.selectall(sql, hashref = False)
+        if rows:
+            return rows[0][0]
+
+        return None
+
     def last_insert_id(self, table_name, primary_key):
         if isinstance(self.lastrowid, int):
             sql = "SELECT %s FROM %s WHERE oid = %d" % \
     def last_insert_id(self, table_name, primary_key):
         if isinstance(self.lastrowid, int):
             sql = "SELECT %s FROM %s WHERE oid = %d" % \
@@ -132,18 +154,45 @@ class PostgreSQL:
 
         return None
 
 
         return None
 
+    # modified for psycopg2-2.0.7
+    # executemany is undefined for SELECT's
+    # see http://www.python.org/dev/peps/pep-0249/
+    # accepts either None, a single dict, a tuple of single dict - in which case it execute's
+    # or a tuple of several dicts, in which case it executemany's
     def execute(self, query, params = None):
     def execute(self, query, params = None):
-        self.execute_array(query, (params,))
 
 
-    def execute_array(self, query, param_seq):
-        cursor = self.cursor
+        cursor = self.cursor()
         try:
         try:
+
             # psycopg2 requires %()s format for all parameters,
             # regardless of type.
             # psycopg2 requires %()s format for all parameters,
             # regardless of type.
+            # this needs to be done carefully though as with pattern-based filters
+            # we might have percents embedded in the query
+            # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
             if psycopg2:
                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
             if psycopg2:
                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
-
-            cursor.executemany(query, param_seq)
+            # rewrite wildcards set by Filter.py as '***' into '%'
+            query = query.replace ('***','%')
+
+            if not params:
+                if self.debug:
+                    logger.debug('execute0: {}'.format(query))
+                cursor.execute(query)
+            elif isinstance(params, dict):
+                if self.debug:
+                    logger.debug('execute-dict: params {} query {}'
+                                 .format(params, query%params))
+                cursor.execute(query, params)
+            elif isinstance(params,tuple) and len(params)==1:
+                if self.debug:
+                    logger.debug('execute-tuple {}'.format(query%params[0]))
+                cursor.execute(query,params[0])
+            else:
+                param_seq=(params,)
+                if self.debug:
+                    for params in param_seq:
+                        logger.debug('executemany {}'.format(query%params))
+                cursor.executemany(query, param_seq)
             (self.rowcount, self.description, self.lastrowid) = \
                             (cursor.rowcount, cursor.description, cursor.lastrowid)
         except Exception, e:
             (self.rowcount, self.description, self.lastrowid) = \
                             (cursor.rowcount, cursor.description, cursor.lastrowid)
         except Exception, e:
@@ -152,17 +201,15 @@ class PostgreSQL:
             except:
                 pass
             uuid = commands.getoutput("uuidgen")
             except:
                 pass
             uuid = commands.getoutput("uuidgen")
-            print >> log, "Database error %s:" % uuid
-            print >> log, e
-            print >> log, "Query:"
-            print >> log, query
-            print >> log, "Params:"
-            print >> log, pformat(param_seq[0])
+            message = "Database error {}: - Query {} - Params {}".format(uuid, query, pformat(params))
+            logger.exception(message)
             raise PLCDBError("Please contact " + \
                              self.api.config.PLC_NAME + " Support " + \
                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
                              " and reference " + uuid)
 
             raise PLCDBError("Please contact " + \
                              self.api.config.PLC_NAME + " Support " + \
                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
                              " and reference " + uuid)
 
+        return cursor
+
     def selectall(self, query, params = None, hashref = True, key_field = None):
         """
         Return each row as a dictionary keyed on field name (like DBI
     def selectall(self, query, params = None, hashref = True, key_field = None):
         """
         Return each row as a dictionary keyed on field name (like DBI
@@ -171,14 +218,14 @@ class PostgreSQL:
         selectall_hashref()).
 
         If params is specified, the specified parameters will be bound
         selectall_hashref()).
 
         If params is specified, the specified parameters will be bound
-        to the query (see PLC.DB.parameterize() and
-        pgdb.cursor.execute()).
+        to the query.
         """
 
         """
 
-        self.execute(query, params)
-        rows = self.cursor.fetchall()
-
-        if hashref:
+        cursor = self.execute(query, params)
+        rows = cursor.fetchall()
+        cursor.close()
+        self.commit()
+        if hashref or key_field is not None:
             # Return each row as a dictionary keyed on field name
             # (like DBI selectrow_hashref()).
             labels = [column[0] for column in self.description]
             # Return each row as a dictionary keyed on field name
             # (like DBI selectrow_hashref()).
             labels = [column[0] for column in self.description]
@@ -196,6 +243,12 @@ class PostgreSQL:
         Return the names of the fields of the specified table.
         """
 
         Return the names of the fields of the specified table.
         """
 
+        if hasattr(self, 'fields_cache'):
+            if self.fields_cache.has_key((table, notnull, hasdef)):
+                return self.fields_cache[(table, notnull, hasdef)]
+        else:
+            self.fields_cache = {}
+
         sql = "SELECT attname FROM pg_attribute, pg_class" \
               " WHERE pg_class.oid = attrelid" \
               " AND attnum > 0 AND relname = %(table)s"
         sql = "SELECT attname FROM pg_attribute, pg_class" \
               " WHERE pg_class.oid = attrelid" \
               " AND attnum > 0 AND relname = %(table)s"
@@ -208,4 +261,6 @@ class PostgreSQL:
 
         rows = self.selectall(sql, locals(), hashref = False)
 
 
         rows = self.selectall(sql, locals(), hashref = False)
 
-        return [row[0] for row in rows]
+        self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
+
+        return self.fields_cache[(table, notnull, hasdef)]