Port _quote() from old version of pgdb. Simpliy code.
[plcapi.git] / PLC / PostgreSQL.py
index 09132e2..c08dd71 100644 (file)
@@ -16,6 +16,7 @@ psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
 
 import pgdb
+import types
 from types import StringTypes, NoneType
 import traceback
 import commands
@@ -24,6 +25,28 @@ from pprint import pformat
 
 from PLC.Debug import profile, log
 from PLC.Faults import *
+from datetime import datetime as DateTimeType
+
+# From pgdb
+def _quote(x):
+    if isinstance(x, DateTimeType):
+        x = str(x)
+    elif isinstance(x, unicode):
+        x = x.encode( 'utf-8' )
+
+    if isinstance(x, types.StringType):
+        x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
+    elif isinstance(x, (types.IntType, types.LongType, types.FloatType)):
+        pass
+    elif x is None:
+        x = 'NULL'
+    elif isinstance(x, (types.ListType, types.TupleType, set)):
+        x = 'ARRAY[%s]' % ', '.join(map(lambda x: str(_quote(x)), x))
+    elif hasattr(x, '__pg_repr__'):
+        x = x.__pg_repr__()
+    else:
+        raise pgdb.InterfaceError, 'do not know how to handle type %s' % type(x)
+    return x
 
 class PostgreSQL:
     def __init__(self, api):
@@ -59,23 +82,12 @@ class PostgreSQL:
             self.connection.close()
             self.connection = None
 
-    # join insists on getting strings
-    @classmethod
-    def quote_string(self, value):
-        return str(PostgreSQL.quote(value))
-
     @classmethod
     def quote(self, value):
         """
         Returns quoted version of the specified value.
         """
-
-        # The pgdb._quote function is good enough for general SQL
-        # quoting, except for array types.
-        if isinstance(value, (list, tuple, set)):
-            return "ARRAY[%s]" % ", ".join(map (PostgreSQL.quote_string, value))
-        else:
-            return pgdb._quote(value)
+        return _quote(value)
 
     @classmethod
     def param(self, name, value):
@@ -110,13 +122,13 @@ class PostgreSQL:
         return self.rowcount
 
     def next_id(self, table_name, primary_key):
-       sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
-       sql = "SELECT nextval('%(sequence)s')" % locals()
-       rows = self.selectall(sql, hashref = False)
-       if rows: 
-           return rows[0][0]
-               
-       return None 
+        sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
+        sql = "SELECT nextval('%(sequence)s')" % locals()
+        rows = self.selectall(sql, hashref = False)
+        if rows:
+            return rows[0][0]
+
+        return None
 
     def last_insert_id(self, table_name, primary_key):
         if isinstance(self.lastrowid, int):
@@ -128,7 +140,7 @@ class PostgreSQL:
 
         return None
 
-    # modified for psycopg2-2.0.7 
+    # modified for psycopg2-2.0.7
     # executemany is undefined for SELECT's
     # see http://www.python.org/dev/peps/pep-0249/
     # accepts either None, a single dict, a tuple of single dict - in which case it execute's