svn keywords
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 # $URL$
10 #
11
12 import psycopg2
13 import psycopg2.extensions
14 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
15 # UNICODEARRAY not exported yet
16 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
17
18 import pgdb
19 from types import StringTypes, NoneType
20 import traceback
21 import commands
22 import re
23 from pprint import pformat
24
25 from PLC.Debug import profile, log
26 from PLC.Faults import *
27
28 if not psycopg2:
29     is8bit = re.compile("[\x80-\xff]").search
30
31     def unicast(typecast):
32         """
33         pgdb returns raw UTF-8 strings. This function casts strings that
34         appear to contain non-ASCII characters to unicode objects.
35         """
36     
37         def wrapper(*args, **kwds):
38             value = typecast(*args, **kwds)
39
40             # pgdb always encodes unicode objects as UTF-8 regardless of
41             # the DB encoding (and gives you no option for overriding
42             # the encoding), so always decode 8-bit objects as UTF-8.
43             if isinstance(value, str) and is8bit(value):
44                 value = unicode(value, "utf-8")
45
46             return value
47
48         return wrapper
49
50     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
51
52 class PostgreSQL:
53     def __init__(self, api):
54         self.api = api
55         self.debug = False
56 #        self.debug = True
57         self.connection = None
58
59     def cursor(self):
60         if self.connection is None:
61             # (Re)initialize database connection
62             if psycopg2:
63                 try:
64                     # Try UNIX socket first
65                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
66                                                        password = self.api.config.PLC_DB_PASSWORD,
67                                                        database = self.api.config.PLC_DB_NAME)
68                 except psycopg2.OperationalError:
69                     # Fall back on TCP
70                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
71                                                        password = self.api.config.PLC_DB_PASSWORD,
72                                                        database = self.api.config.PLC_DB_NAME,
73                                                        host = self.api.config.PLC_DB_HOST,
74                                                        port = self.api.config.PLC_DB_PORT)
75                 self.connection.set_client_encoding("UNICODE")
76             else:
77                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
78                                                password = self.api.config.PLC_DB_PASSWORD,
79                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
80                                                database = self.api.config.PLC_DB_NAME)
81
82         (self.rowcount, self.description, self.lastrowid) = \
83                         (None, None, None)
84
85         return self.connection.cursor()
86
87     def close(self):
88         if self.connection is not None:
89             self.connection.close()
90             self.connection = None
91
92     def quote(self, value):
93         """
94         Returns quoted version of the specified value.
95         """
96
97         # The pgdb._quote function is good enough for general SQL
98         # quoting, except for array types.
99         if isinstance(value, (list, tuple, set)):
100             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
101         else:
102             return pgdb._quote(value)
103
104     quote = classmethod(quote)
105
106     def param(self, name, value):
107         # None is converted to the unquoted string NULL
108         if isinstance(value, NoneType):
109             conversion = "s"
110         # True and False are also converted to unquoted strings
111         elif isinstance(value, bool):
112             conversion = "s"
113         elif isinstance(value, float):
114             conversion = "f"
115         elif not isinstance(value, StringTypes):
116             conversion = "d"
117         else:
118             conversion = "s"
119
120         return '%(' + name + ')' + conversion
121
122     param = classmethod(param)
123
124     def begin_work(self):
125         # Implicit in pgdb.connect()
126         pass
127
128     def commit(self):
129         self.connection.commit()
130
131     def rollback(self):
132         self.connection.rollback()
133
134     def do(self, query, params = None):
135         cursor = self.execute(query, params)
136         cursor.close()
137         return self.rowcount
138
139     def next_id(self, table_name, primary_key):
140         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
141         sql = "SELECT nextval('%(sequence)s')" % locals()
142         rows = self.selectall(sql, hashref = False)
143         if rows: 
144             return rows[0][0]
145                 
146         return None 
147
148     def last_insert_id(self, table_name, primary_key):
149         if isinstance(self.lastrowid, int):
150             sql = "SELECT %s FROM %s WHERE oid = %d" % \
151                   (primary_key, table_name, self.lastrowid)
152             rows = self.selectall(sql, hashref = False)
153             if rows:
154                 return rows[0][0]
155
156         return None
157
158     # modified for psycopg2-2.0.7 
159     # executemany is undefined for SELECT's
160     # see http://www.python.org/dev/peps/pep-0249/
161     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
162     # or a tuple of several dicts, in which case it executemany's
163     def execute(self, query, params = None):
164
165         cursor = self.cursor()
166         try:
167
168             # psycopg2 requires %()s format for all parameters,
169             # regardless of type.
170             # this needs to be done carefully though as with pattern-based filters
171             # we might have percents embedded in the query
172             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
173             if psycopg2:
174                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
175             # rewrite wildcards set by Filter.py as '***' into '%'
176             query = query.replace ('***','%')
177
178             if not params:
179                 if self.debug:
180                     print >> log,'execute0',query
181                 cursor.execute(query)
182             elif isinstance(params,dict):
183                 if self.debug:
184                     print >> log,'execute-dict: params',params,'query',query%params
185                 cursor.execute(query,params)
186             elif isinstance(params,tuple) and len(params)==1:
187                 if self.debug:
188                     print >> log,'execute-tuple',query%params[0]
189                 cursor.execute(query,params[0])
190             else:
191                 param_seq=(params,)
192                 if self.debug:
193                     for params in param_seq:
194                         print >> log,'executemany',query%params
195                 cursor.executemany(query, param_seq)
196             (self.rowcount, self.description, self.lastrowid) = \
197                             (cursor.rowcount, cursor.description, cursor.lastrowid)
198         except Exception, e:
199             try:
200                 self.rollback()
201             except:
202                 pass
203             uuid = commands.getoutput("uuidgen")
204             print >> log, "Database error %s:" % uuid
205             print >> log, e
206             print >> log, "Query:"
207             print >> log, query
208             print >> log, "Params:"
209             print >> log, pformat(params)
210             raise PLCDBError("Please contact " + \
211                              self.api.config.PLC_NAME + " Support " + \
212                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
213                              " and reference " + uuid)
214
215         return cursor
216
217     def selectall(self, query, params = None, hashref = True, key_field = None):
218         """
219         Return each row as a dictionary keyed on field name (like DBI
220         selectrow_hashref()). If key_field is specified, return rows
221         as a dictionary keyed on the specified field (like DBI
222         selectall_hashref()).
223
224         If params is specified, the specified parameters will be bound
225         to the query.
226         """
227
228         cursor = self.execute(query, params)
229         rows = cursor.fetchall()
230         cursor.close()
231         self.commit()
232         if hashref or key_field is not None:
233             # Return each row as a dictionary keyed on field name
234             # (like DBI selectrow_hashref()).
235             labels = [column[0] for column in self.description]
236             rows = [dict(zip(labels, row)) for row in rows]
237
238         if key_field is not None and key_field in labels:
239             # Return rows as a dictionary keyed on the specified field
240             # (like DBI selectall_hashref()).
241             return dict([(row[key_field], row) for row in rows])
242         else:
243             return rows
244
245     def fields(self, table, notnull = None, hasdef = None):
246         """
247         Return the names of the fields of the specified table.
248         """
249
250         if hasattr(self, 'fields_cache'):
251             if self.fields_cache.has_key((table, notnull, hasdef)):
252                 return self.fields_cache[(table, notnull, hasdef)]
253         else:
254             self.fields_cache = {}
255
256         sql = "SELECT attname FROM pg_attribute, pg_class" \
257               " WHERE pg_class.oid = attrelid" \
258               " AND attnum > 0 AND relname = %(table)s"
259
260         if notnull is not None:
261             sql += " AND attnotnull is %(notnull)s"
262
263         if hasdef is not None:
264             sql += " AND atthasdef is %(hasdef)s"
265
266         rows = self.selectall(sql, locals(), hashref = False)
267
268         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
269
270         return self.fields_cache[(table, notnull, hasdef)]