fix quoting - was totally wrong on f16
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. 
3 # Sort of like DBI(3) (Database independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8
9 import psycopg2
10 import psycopg2.extensions
11 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
12 # UNICODEARRAY not exported yet
13 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
14
15 import types
16 from types import StringTypes, NoneType
17 import traceback
18 import commands
19 import re
20 from pprint import pformat
21
22 from PLC.Debug import profile, log
23 from PLC.Faults import *
24 from datetime import datetime as DateTimeType
25
26 class PostgreSQL:
27     def __init__(self, api):
28         self.api = api
29         self.debug = False
30 #        self.debug = True
31         self.connection = None
32
33     def cursor(self):
34         if self.connection is None:
35             # (Re)initialize database connection
36             try:
37                 # Try UNIX socket first
38                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
39                                                    password = self.api.config.PLC_DB_PASSWORD,
40                                                    database = self.api.config.PLC_DB_NAME)
41             except psycopg2.OperationalError:
42                 # Fall back on TCP
43                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
44                                                    password = self.api.config.PLC_DB_PASSWORD,
45                                                    database = self.api.config.PLC_DB_NAME,
46                                                    host = self.api.config.PLC_DB_HOST,
47                                                    port = self.api.config.PLC_DB_PORT)
48             self.connection.set_client_encoding("UNICODE")
49
50         (self.rowcount, self.description, self.lastrowid) = \
51                         (None, None, None)
52
53         return self.connection.cursor()
54
55     def close(self):
56         if self.connection is not None:
57             self.connection.close()
58             self.connection = None
59
60     def quote(self, value):
61         """
62         Returns quoted version of the specified value.
63         """
64         # The pgdb._quote function is good enough for general SQL
65         # quoting, except for array types.
66         if isinstance (value, (types.ListType, types.TupleType, set)):
67             'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
68         else:
69             try:
70                 # up to PyGreSQL-3.x, function was pgdb._quote
71                 import pgdb
72                 return pgdb._quote(value)
73             except:
74                 # with PyGreSQL-4.x, use psycopg2's adapt
75                 from psycopg2.extensions import adapt
76                 return adapt (value)
77
78     @classmethod
79     def param(self, name, value):
80         # None is converted to the unquoted string NULL
81         if isinstance(value, NoneType):
82             conversion = "s"
83         # True and False are also converted to unquoted strings
84         elif isinstance(value, bool):
85             conversion = "s"
86         elif isinstance(value, float):
87             conversion = "f"
88         elif not isinstance(value, StringTypes):
89             conversion = "d"
90         else:
91             conversion = "s"
92
93         return '%(' + name + ')' + conversion
94
95     def begin_work(self):
96         # Implicit in pgdb.connect()
97         pass
98
99     def commit(self):
100         self.connection.commit()
101
102     def rollback(self):
103         self.connection.rollback()
104
105     def do(self, query, params = None):
106         cursor = self.execute(query, params)
107         cursor.close()
108         return self.rowcount
109
110     def next_id(self, table_name, primary_key):
111         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
112         sql = "SELECT nextval('%(sequence)s')" % locals()
113         rows = self.selectall(sql, hashref = False)
114         if rows:
115             return rows[0][0]
116
117         return None
118
119     def last_insert_id(self, table_name, primary_key):
120         if isinstance(self.lastrowid, int):
121             sql = "SELECT %s FROM %s WHERE oid = %d" % \
122                   (primary_key, table_name, self.lastrowid)
123             rows = self.selectall(sql, hashref = False)
124             if rows:
125                 return rows[0][0]
126
127         return None
128
129     # modified for psycopg2-2.0.7
130     # executemany is undefined for SELECT's
131     # see http://www.python.org/dev/peps/pep-0249/
132     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
133     # or a tuple of several dicts, in which case it executemany's
134     def execute(self, query, params = None):
135
136         cursor = self.cursor()
137         try:
138
139             # psycopg2 requires %()s format for all parameters,
140             # regardless of type.
141             # this needs to be done carefully though as with pattern-based filters
142             # we might have percents embedded in the query
143             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
144             if psycopg2:
145                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
146             # rewrite wildcards set by Filter.py as '***' into '%'
147             query = query.replace ('***','%')
148
149             if not params:
150                 if self.debug:
151                     print >> log,'execute0',query
152                 cursor.execute(query)
153             elif isinstance(params,dict):
154                 if self.debug:
155                     print >> log,'execute-dict: params',params,'query',query%params
156                 cursor.execute(query,params)
157             elif isinstance(params,tuple) and len(params)==1:
158                 if self.debug:
159                     print >> log,'execute-tuple',query%params[0]
160                 cursor.execute(query,params[0])
161             else:
162                 param_seq=(params,)
163                 if self.debug:
164                     for params in param_seq:
165                         print >> log,'executemany',query%params
166                 cursor.executemany(query, param_seq)
167             (self.rowcount, self.description, self.lastrowid) = \
168                             (cursor.rowcount, cursor.description, cursor.lastrowid)
169         except Exception, e:
170             try:
171                 self.rollback()
172             except:
173                 pass
174             uuid = commands.getoutput("uuidgen")
175             print >> log, "Database error %s:" % uuid
176             print >> log, e
177             print >> log, "Query:"
178             print >> log, query
179             print >> log, "Params:"
180             print >> log, pformat(params)
181             raise PLCDBError("Please contact " + \
182                              self.api.config.PLC_NAME + " Support " + \
183                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
184                              " and reference " + uuid)
185
186         return cursor
187
188     def selectall(self, query, params = None, hashref = True, key_field = None):
189         """
190         Return each row as a dictionary keyed on field name (like DBI
191         selectrow_hashref()). If key_field is specified, return rows
192         as a dictionary keyed on the specified field (like DBI
193         selectall_hashref()).
194
195         If params is specified, the specified parameters will be bound
196         to the query.
197         """
198
199         cursor = self.execute(query, params)
200         rows = cursor.fetchall()
201         cursor.close()
202         self.commit()
203         if hashref or key_field is not None:
204             # Return each row as a dictionary keyed on field name
205             # (like DBI selectrow_hashref()).
206             labels = [column[0] for column in self.description]
207             rows = [dict(zip(labels, row)) for row in rows]
208
209         if key_field is not None and key_field in labels:
210             # Return rows as a dictionary keyed on the specified field
211             # (like DBI selectall_hashref()).
212             return dict([(row[key_field], row) for row in rows])
213         else:
214             return rows
215
216     def fields(self, table, notnull = None, hasdef = None):
217         """
218         Return the names of the fields of the specified table.
219         """
220
221         if hasattr(self, 'fields_cache'):
222             if self.fields_cache.has_key((table, notnull, hasdef)):
223                 return self.fields_cache[(table, notnull, hasdef)]
224         else:
225             self.fields_cache = {}
226
227         sql = "SELECT attname FROM pg_attribute, pg_class" \
228               " WHERE pg_class.oid = attrelid" \
229               " AND attnum > 0 AND relname = %(table)s"
230
231         if notnull is not None:
232             sql += " AND attnotnull is %(notnull)s"
233
234         if hasdef is not None:
235             sql += " AND atthasdef is %(hasdef)s"
236
237         rows = self.selectall(sql, locals(), hashref = False)
238
239         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
240
241         return self.fields_cache[(table, notnull, hasdef)]