076a8643dd6e3d7ea1e8e2360171e13244734b0c
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 # $URL$
10 #
11
12 import psycopg2
13 import psycopg2.extensions
14 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
15 # UNICODEARRAY not exported yet
16 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
17
18 import pgdb
19 from types import StringTypes, NoneType
20 import traceback
21 import commands
22 import re
23 from pprint import pformat
24
25 from PLC.Debug import profile, log
26 from PLC.Faults import *
27
28 if not psycopg2:
29     is8bit = re.compile("[\x80-\xff]").search
30
31     def unicast(typecast):
32         """
33         pgdb returns raw UTF-8 strings. This function casts strings that
34         appear to contain non-ASCII characters to unicode objects.
35         """
36     
37         def wrapper(*args, **kwds):
38             value = typecast(*args, **kwds)
39
40             # pgdb always encodes unicode objects as UTF-8 regardless of
41             # the DB encoding (and gives you no option for overriding
42             # the encoding), so always decode 8-bit objects as UTF-8.
43             if isinstance(value, str) and is8bit(value):
44                 value = unicode(value, "utf-8")
45
46             return value
47
48         return wrapper
49
50     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
51
52 class PostgreSQL:
53     def __init__(self, api):
54         self.api = api
55         self.debug = False
56 #        self.debug = True
57         self.connection = None
58
59     def cursor(self):
60         if self.connection is None:
61             # (Re)initialize database connection
62             if psycopg2:
63                 try:
64                     # Try UNIX socket first
65                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
66                                                        password = self.api.config.PLC_DB_PASSWORD,
67                                                        database = self.api.config.PLC_DB_NAME)
68                 except psycopg2.OperationalError:
69                     # Fall back on TCP
70                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
71                                                        password = self.api.config.PLC_DB_PASSWORD,
72                                                        database = self.api.config.PLC_DB_NAME,
73                                                        host = self.api.config.PLC_DB_HOST,
74                                                        port = self.api.config.PLC_DB_PORT)
75                 self.connection.set_client_encoding("UNICODE")
76             else:
77                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
78                                                password = self.api.config.PLC_DB_PASSWORD,
79                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
80                                                database = self.api.config.PLC_DB_NAME)
81
82         (self.rowcount, self.description, self.lastrowid) = \
83                         (None, None, None)
84
85         return self.connection.cursor()
86
87     def close(self):
88         if self.connection is not None:
89             self.connection.close()
90             self.connection = None
91
92     # join insists on getting strings
93     @classmethod
94     def quote_string(self, value):
95         return str(PostgreSQL.quote(value))
96
97     @classmethod
98     def quote(self, value):
99         """
100         Returns quoted version of the specified value.
101         """
102
103         # The pgdb._quote function is good enough for general SQL
104         # quoting, except for array types.
105         if isinstance(value, (list, tuple, set)):
106             return "ARRAY[%s]" % ", ".join(map (PostgreSQL.quote_string, value))
107         else:
108             return pgdb._quote(value)
109
110     @classmethod
111     def param(self, name, value):
112         # None is converted to the unquoted string NULL
113         if isinstance(value, NoneType):
114             conversion = "s"
115         # True and False are also converted to unquoted strings
116         elif isinstance(value, bool):
117             conversion = "s"
118         elif isinstance(value, float):
119             conversion = "f"
120         elif not isinstance(value, StringTypes):
121             conversion = "d"
122         else:
123             conversion = "s"
124
125         return '%(' + name + ')' + conversion
126
127     def begin_work(self):
128         # Implicit in pgdb.connect()
129         pass
130
131     def commit(self):
132         self.connection.commit()
133
134     def rollback(self):
135         self.connection.rollback()
136
137     def do(self, query, params = None):
138         cursor = self.execute(query, params)
139         cursor.close()
140         return self.rowcount
141
142     def next_id(self, table_name, primary_key):
143         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
144         sql = "SELECT nextval('%(sequence)s')" % locals()
145         rows = self.selectall(sql, hashref = False)
146         if rows: 
147             return rows[0][0]
148                 
149         return None 
150
151     def last_insert_id(self, table_name, primary_key):
152         if isinstance(self.lastrowid, int):
153             sql = "SELECT %s FROM %s WHERE oid = %d" % \
154                   (primary_key, table_name, self.lastrowid)
155             rows = self.selectall(sql, hashref = False)
156             if rows:
157                 return rows[0][0]
158
159         return None
160
161     # modified for psycopg2-2.0.7 
162     # executemany is undefined for SELECT's
163     # see http://www.python.org/dev/peps/pep-0249/
164     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
165     # or a tuple of several dicts, in which case it executemany's
166     def execute(self, query, params = None):
167
168         cursor = self.cursor()
169         try:
170
171             # psycopg2 requires %()s format for all parameters,
172             # regardless of type.
173             # this needs to be done carefully though as with pattern-based filters
174             # we might have percents embedded in the query
175             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
176             if psycopg2:
177                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
178             # rewrite wildcards set by Filter.py as '***' into '%'
179             query = query.replace ('***','%')
180
181             if not params:
182                 if self.debug:
183                     print >> log,'execute0',query
184                 cursor.execute(query)
185             elif isinstance(params,dict):
186                 if self.debug:
187                     print >> log,'execute-dict: params',params,'query',query%params
188                 cursor.execute(query,params)
189             elif isinstance(params,tuple) and len(params)==1:
190                 if self.debug:
191                     print >> log,'execute-tuple',query%params[0]
192                 cursor.execute(query,params[0])
193             else:
194                 param_seq=(params,)
195                 if self.debug:
196                     for params in param_seq:
197                         print >> log,'executemany',query%params
198                 cursor.executemany(query, param_seq)
199             (self.rowcount, self.description, self.lastrowid) = \
200                             (cursor.rowcount, cursor.description, cursor.lastrowid)
201         except Exception, e:
202             try:
203                 self.rollback()
204             except:
205                 pass
206             uuid = commands.getoutput("uuidgen")
207             print >> log, "Database error %s:" % uuid
208             print >> log, e
209             print >> log, "Query:"
210             print >> log, query
211             print >> log, "Params:"
212             print >> log, pformat(params)
213             raise PLCDBError("Please contact " + \
214                              self.api.config.PLC_NAME + " Support " + \
215                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
216                              " and reference " + uuid)
217
218         return cursor
219
220     def selectall(self, query, params = None, hashref = True, key_field = None):
221         """
222         Return each row as a dictionary keyed on field name (like DBI
223         selectrow_hashref()). If key_field is specified, return rows
224         as a dictionary keyed on the specified field (like DBI
225         selectall_hashref()).
226
227         If params is specified, the specified parameters will be bound
228         to the query.
229         """
230
231         cursor = self.execute(query, params)
232         rows = cursor.fetchall()
233         cursor.close()
234         self.commit()
235         if hashref or key_field is not None:
236             # Return each row as a dictionary keyed on field name
237             # (like DBI selectrow_hashref()).
238             labels = [column[0] for column in self.description]
239             rows = [dict(zip(labels, row)) for row in rows]
240
241         if key_field is not None and key_field in labels:
242             # Return rows as a dictionary keyed on the specified field
243             # (like DBI selectall_hashref()).
244             return dict([(row[key_field], row) for row in rows])
245         else:
246             return rows
247
248     def fields(self, table, notnull = None, hasdef = None):
249         """
250         Return the names of the fields of the specified table.
251         """
252
253         if hasattr(self, 'fields_cache'):
254             if self.fields_cache.has_key((table, notnull, hasdef)):
255                 return self.fields_cache[(table, notnull, hasdef)]
256         else:
257             self.fields_cache = {}
258
259         sql = "SELECT attname FROM pg_attribute, pg_class" \
260               " WHERE pg_class.oid = attrelid" \
261               " AND attnum > 0 AND relname = %(table)s"
262
263         if notnull is not None:
264             sql += " AND attnotnull is %(notnull)s"
265
266         if hasdef is not None:
267             sql += " AND atthasdef is %(hasdef)s"
268
269         rows = self.selectall(sql, locals(), hashref = False)
270
271         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
272
273         return self.fields_cache[(table, notnull, hasdef)]