Port _quote() from old version of pgdb. Simpliy code.
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 # $URL$
10 #
11
12 import psycopg2
13 import psycopg2.extensions
14 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
15 # UNICODEARRAY not exported yet
16 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
17
18 import pgdb
19 import types
20 from types import StringTypes, NoneType
21 import traceback
22 import commands
23 import re
24 from pprint import pformat
25
26 from PLC.Debug import profile, log
27 from PLC.Faults import *
28 from datetime import datetime as DateTimeType
29
30 # From pgdb
31 def _quote(x):
32     if isinstance(x, DateTimeType):
33         x = str(x)
34     elif isinstance(x, unicode):
35         x = x.encode( 'utf-8' )
36
37     if isinstance(x, types.StringType):
38         x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
39     elif isinstance(x, (types.IntType, types.LongType, types.FloatType)):
40         pass
41     elif x is None:
42         x = 'NULL'
43     elif isinstance(x, (types.ListType, types.TupleType, set)):
44         x = 'ARRAY[%s]' % ', '.join(map(lambda x: str(_quote(x)), x))
45     elif hasattr(x, '__pg_repr__'):
46         x = x.__pg_repr__()
47     else:
48         raise pgdb.InterfaceError, 'do not know how to handle type %s' % type(x)
49     return x
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55 #        self.debug = True
56         self.connection = None
57
58     def cursor(self):
59         if self.connection is None:
60             # (Re)initialize database connection
61             try:
62                 # Try UNIX socket first
63                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                    password = self.api.config.PLC_DB_PASSWORD,
65                                                    database = self.api.config.PLC_DB_NAME)
66             except psycopg2.OperationalError:
67                 # Fall back on TCP
68                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                    password = self.api.config.PLC_DB_PASSWORD,
70                                                    database = self.api.config.PLC_DB_NAME,
71                                                    host = self.api.config.PLC_DB_HOST,
72                                                    port = self.api.config.PLC_DB_PORT)
73             self.connection.set_client_encoding("UNICODE")
74
75         (self.rowcount, self.description, self.lastrowid) = \
76                         (None, None, None)
77
78         return self.connection.cursor()
79
80     def close(self):
81         if self.connection is not None:
82             self.connection.close()
83             self.connection = None
84
85     @classmethod
86     def quote(self, value):
87         """
88         Returns quoted version of the specified value.
89         """
90         return _quote(value)
91
92     @classmethod
93     def param(self, name, value):
94         # None is converted to the unquoted string NULL
95         if isinstance(value, NoneType):
96             conversion = "s"
97         # True and False are also converted to unquoted strings
98         elif isinstance(value, bool):
99             conversion = "s"
100         elif isinstance(value, float):
101             conversion = "f"
102         elif not isinstance(value, StringTypes):
103             conversion = "d"
104         else:
105             conversion = "s"
106
107         return '%(' + name + ')' + conversion
108
109     def begin_work(self):
110         # Implicit in pgdb.connect()
111         pass
112
113     def commit(self):
114         self.connection.commit()
115
116     def rollback(self):
117         self.connection.rollback()
118
119     def do(self, query, params = None):
120         cursor = self.execute(query, params)
121         cursor.close()
122         return self.rowcount
123
124     def next_id(self, table_name, primary_key):
125         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
126         sql = "SELECT nextval('%(sequence)s')" % locals()
127         rows = self.selectall(sql, hashref = False)
128         if rows:
129             return rows[0][0]
130
131         return None
132
133     def last_insert_id(self, table_name, primary_key):
134         if isinstance(self.lastrowid, int):
135             sql = "SELECT %s FROM %s WHERE oid = %d" % \
136                   (primary_key, table_name, self.lastrowid)
137             rows = self.selectall(sql, hashref = False)
138             if rows:
139                 return rows[0][0]
140
141         return None
142
143     # modified for psycopg2-2.0.7
144     # executemany is undefined for SELECT's
145     # see http://www.python.org/dev/peps/pep-0249/
146     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
147     # or a tuple of several dicts, in which case it executemany's
148     def execute(self, query, params = None):
149
150         cursor = self.cursor()
151         try:
152
153             # psycopg2 requires %()s format for all parameters,
154             # regardless of type.
155             # this needs to be done carefully though as with pattern-based filters
156             # we might have percents embedded in the query
157             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
158             if psycopg2:
159                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
160             # rewrite wildcards set by Filter.py as '***' into '%'
161             query = query.replace ('***','%')
162
163             if not params:
164                 if self.debug:
165                     print >> log,'execute0',query
166                 cursor.execute(query)
167             elif isinstance(params,dict):
168                 if self.debug:
169                     print >> log,'execute-dict: params',params,'query',query%params
170                 cursor.execute(query,params)
171             elif isinstance(params,tuple) and len(params)==1:
172                 if self.debug:
173                     print >> log,'execute-tuple',query%params[0]
174                 cursor.execute(query,params[0])
175             else:
176                 param_seq=(params,)
177                 if self.debug:
178                     for params in param_seq:
179                         print >> log,'executemany',query%params
180                 cursor.executemany(query, param_seq)
181             (self.rowcount, self.description, self.lastrowid) = \
182                             (cursor.rowcount, cursor.description, cursor.lastrowid)
183         except Exception, e:
184             try:
185                 self.rollback()
186             except:
187                 pass
188             uuid = commands.getoutput("uuidgen")
189             print >> log, "Database error %s:" % uuid
190             print >> log, e
191             print >> log, "Query:"
192             print >> log, query
193             print >> log, "Params:"
194             print >> log, pformat(params)
195             raise PLCDBError("Please contact " + \
196                              self.api.config.PLC_NAME + " Support " + \
197                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
198                              " and reference " + uuid)
199
200         return cursor
201
202     def selectall(self, query, params = None, hashref = True, key_field = None):
203         """
204         Return each row as a dictionary keyed on field name (like DBI
205         selectrow_hashref()). If key_field is specified, return rows
206         as a dictionary keyed on the specified field (like DBI
207         selectall_hashref()).
208
209         If params is specified, the specified parameters will be bound
210         to the query.
211         """
212
213         cursor = self.execute(query, params)
214         rows = cursor.fetchall()
215         cursor.close()
216         self.commit()
217         if hashref or key_field is not None:
218             # Return each row as a dictionary keyed on field name
219             # (like DBI selectrow_hashref()).
220             labels = [column[0] for column in self.description]
221             rows = [dict(zip(labels, row)) for row in rows]
222
223         if key_field is not None and key_field in labels:
224             # Return rows as a dictionary keyed on the specified field
225             # (like DBI selectall_hashref()).
226             return dict([(row[key_field], row) for row in rows])
227         else:
228             return rows
229
230     def fields(self, table, notnull = None, hasdef = None):
231         """
232         Return the names of the fields of the specified table.
233         """
234
235         if hasattr(self, 'fields_cache'):
236             if self.fields_cache.has_key((table, notnull, hasdef)):
237                 return self.fields_cache[(table, notnull, hasdef)]
238         else:
239             self.fields_cache = {}
240
241         sql = "SELECT attname FROM pg_attribute, pg_class" \
242               " WHERE pg_class.oid = attrelid" \
243               " AND attnum > 0 AND relname = %(table)s"
244
245         if notnull is not None:
246             sql += " AND attnotnull is %(notnull)s"
247
248         if hasdef is not None:
249             sql += " AND atthasdef is %(hasdef)s"
250
251         rows = self.selectall(sql, locals(), hashref = False)
252
253         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
254
255         return self.fields_cache[(table, notnull, hasdef)]