2c2d0b31debf298b43b07104c6b902cb2d3df301
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55 #        self.debug = True
56         self.connection = None
57
58     def cursor(self):
59         if self.connection is None:
60             # (Re)initialize database connection
61             if psycopg2:
62                 try:
63                     # Try UNIX socket first
64                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
65                                                        password = self.api.config.PLC_DB_PASSWORD,
66                                                        database = self.api.config.PLC_DB_NAME)
67                 except psycopg2.OperationalError:
68                     # Fall back on TCP
69                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
70                                                        password = self.api.config.PLC_DB_PASSWORD,
71                                                        database = self.api.config.PLC_DB_NAME,
72                                                        host = self.api.config.PLC_DB_HOST,
73                                                        port = self.api.config.PLC_DB_PORT)
74                 self.connection.set_client_encoding("UNICODE")
75             else:
76                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
77                                                password = self.api.config.PLC_DB_PASSWORD,
78                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
79                                                database = self.api.config.PLC_DB_NAME)
80
81         (self.rowcount, self.description, self.lastrowid) = \
82                         (None, None, None)
83
84         return self.connection.cursor()
85
86     def close(self):
87         if self.connection is not None:
88             self.connection.close()
89             self.connection = None
90
91     def quote(self, value):
92         """
93         Returns quoted version of the specified value.
94         """
95
96         # The pgdb._quote function is good enough for general SQL
97         # quoting, except for array types.
98         if isinstance(value, (list, tuple, set)):
99             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
100         else:
101             return pgdb._quote(value)
102
103     quote = classmethod(quote)
104
105     def param(self, name, value):
106         # None is converted to the unquoted string NULL
107         if isinstance(value, NoneType):
108             conversion = "s"
109         # True and False are also converted to unquoted strings
110         elif isinstance(value, bool):
111             conversion = "s"
112         elif isinstance(value, float):
113             conversion = "f"
114         elif not isinstance(value, StringTypes):
115             conversion = "d"
116         else:
117             conversion = "s"
118
119         return '%(' + name + ')' + conversion
120
121     param = classmethod(param)
122
123     def begin_work(self):
124         # Implicit in pgdb.connect()
125         pass
126
127     def commit(self):
128         self.connection.commit()
129
130     def rollback(self):
131         self.connection.rollback()
132
133     def do(self, query, params = None):
134         cursor = self.execute(query, params)
135         cursor.close()
136         return self.rowcount
137
138     def next_id(self, table_name, primary_key):
139         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
140         sql = "SELECT nextval('%(sequence)s')" % locals()
141         rows = self.selectall(sql, hashref = False)
142         if rows: 
143             return rows[0][0]
144                 
145         return None 
146
147     def last_insert_id(self, table_name, primary_key):
148         if isinstance(self.lastrowid, int):
149             sql = "SELECT %s FROM %s WHERE oid = %d" % \
150                   (primary_key, table_name, self.lastrowid)
151             rows = self.selectall(sql, hashref = False)
152             if rows:
153                 return rows[0][0]
154
155         return None
156
157     # modified for psycopg2-2.0.7 
158     # executemany is undefined for SELECT's
159     # see http://www.python.org/dev/peps/pep-0249/
160     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
161     # or a tuple of several dicts, in which case it executemany's
162     def execute(self, query, params = None):
163
164         cursor = self.cursor()
165         try:
166
167             # psycopg2 requires %()s format for all parameters,
168             # regardless of type.
169             # this needs to be done carefully though as with pattern-based filters
170             # we might have percents embedded in the query
171             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
172             if psycopg2:
173                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
174             # rewrite wildcards set by Filter.py as '***' into '%'
175             query = query.replace ('***','%')
176
177             if not params:
178                 if self.debug:
179                     print >> log,'execute0',query
180                 cursor.execute(query)
181             elif isinstance(params,dict):
182                 if self.debug:
183                     print >> log,'execute-dict: params',params,'query',query%params
184                 cursor.execute(query,params)
185             elif isinstance(params,tuple) and len(params)==1:
186                 if self.debug:
187                     print >> log,'execute-tuple',query%params[0]
188                 cursor.execute(query,params[0])
189             else:
190                 param_seq=(params,)
191                 if self.debug:
192                     for params in param_seq:
193                         print >> log,'executemany',query%params
194                 cursor.executemany(query, param_seq)
195             (self.rowcount, self.description, self.lastrowid) = \
196                             (cursor.rowcount, cursor.description, cursor.lastrowid)
197         except Exception, e:
198             try:
199                 self.rollback()
200             except:
201                 pass
202             uuid = commands.getoutput("uuidgen")
203             print >> log, "Database error %s:" % uuid
204             print >> log, e
205             print >> log, "Query:"
206             print >> log, query
207             print >> log, "Params:"
208             print >> log, pformat(params)
209             raise PLCDBError("Please contact " + \
210                              self.api.config.PLC_NAME + " Support " + \
211                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
212                              " and reference " + uuid)
213
214         return cursor
215
216     def selectall(self, query, params = None, hashref = True, key_field = None):
217         """
218         Return each row as a dictionary keyed on field name (like DBI
219         selectrow_hashref()). If key_field is specified, return rows
220         as a dictionary keyed on the specified field (like DBI
221         selectall_hashref()).
222
223         If params is specified, the specified parameters will be bound
224         to the query.
225         """
226
227         cursor = self.execute(query, params)
228         rows = cursor.fetchall()
229         cursor.close()
230         self.commit()
231         if hashref or key_field is not None:
232             # Return each row as a dictionary keyed on field name
233             # (like DBI selectrow_hashref()).
234             labels = [column[0] for column in self.description]
235             rows = [dict(zip(labels, row)) for row in rows]
236
237         if key_field is not None and key_field in labels:
238             # Return rows as a dictionary keyed on the specified field
239             # (like DBI selectall_hashref()).
240             return dict([(row[key_field], row) for row in rows])
241         else:
242             return rows
243
244     def fields(self, table, notnull = None, hasdef = None):
245         """
246         Return the names of the fields of the specified table.
247         """
248
249         if hasattr(self, 'fields_cache'):
250             if self.fields_cache.has_key((table, notnull, hasdef)):
251                 return self.fields_cache[(table, notnull, hasdef)]
252         else:
253             self.fields_cache = {}
254
255         sql = "SELECT attname FROM pg_attribute, pg_class" \
256               " WHERE pg_class.oid = attrelid" \
257               " AND attnum > 0 AND relname = %(table)s"
258
259         if notnull is not None:
260             sql += " AND attnotnull is %(notnull)s"
261
262         if hasdef is not None:
263             sql += " AND atthasdef is %(hasdef)s"
264
265         rows = self.selectall(sql, locals(), hashref = False)
266
267         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
268
269         return self.fields_cache[(table, notnull, hasdef)]