undo previous merge because
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55         self.connection = None
56
57     def cursor(self):
58         if self.connection is None:
59             # (Re)initialize database connection
60             if psycopg2:
61                 try:
62                     # Try UNIX socket first
63                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                        password = self.api.config.PLC_DB_PASSWORD,
65                                                        database = self.api.config.PLC_DB_NAME)
66                 except psycopg2.OperationalError:
67                     # Fall back on TCP
68                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                        password = self.api.config.PLC_DB_PASSWORD,
70                                                        database = self.api.config.PLC_DB_NAME,
71                                                        host = self.api.config.PLC_DB_HOST,
72                                                        port = self.api.config.PLC_DB_PORT)
73                 self.connection.set_client_encoding("UNICODE")
74             else:
75                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
76                                                password = self.api.config.PLC_DB_PASSWORD,
77                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
78                                                database = self.api.config.PLC_DB_NAME)
79
80         (self.rowcount, self.description, self.lastrowid) = \
81                         (None, None, None)
82
83         return self.connection.cursor()
84
85     def close(self):
86         if self.connection is not None:
87             self.connection.close()
88             self.connection = None
89
90     def quote(self, value):
91         """
92         Returns quoted version of the specified value.
93         """
94
95         # The pgdb._quote function is good enough for general SQL
96         # quoting, except for array types.
97         if isinstance(value, (list, tuple, set)):
98             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
99         else:
100             return pgdb._quote(value)
101
102     quote = classmethod(quote)
103
104     def param(self, name, value):
105         # None is converted to the unquoted string NULL
106         if isinstance(value, NoneType):
107             conversion = "s"
108         # True and False are also converted to unquoted strings
109         elif isinstance(value, bool):
110             conversion = "s"
111         elif isinstance(value, float):
112             conversion = "f"
113         elif not isinstance(value, StringTypes):
114             conversion = "d"
115         else:
116             conversion = "s"
117
118         return '%(' + name + ')' + conversion
119
120     param = classmethod(param)
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def last_insert_id(self, table_name, primary_key):
138         if isinstance(self.lastrowid, int):
139             sql = "SELECT %s FROM %s WHERE oid = %d" % \
140                   (primary_key, table_name, self.lastrowid)
141             rows = self.selectall(sql, hashref = False)
142             if rows:
143                 return rows[0][0]
144
145         return None
146
147     # modified for psycopg2-2.0.7 
148     # executemany is undefined for SELECT's
149     # see http://www.python.org/dev/peps/pep-0249/
150     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
151     # or a tuple of several dicts, in which case it executemany's
152     def execute(self, query, params = None):
153
154         cursor = self.cursor()
155         try:
156
157             # psycopg2 requires %()s format for all parameters,
158             # regardless of type.
159             if psycopg2:
160                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
161
162             if not params:
163                 if self.debug:
164                     print >> log,'execute0',query
165                 cursor.execute(query)
166             elif isinstance(params,dict):
167                 if self.debug:
168                     print >> log,'execute-dict: params',params,'query',query%params
169                 cursor.execute(query,params)
170             elif isinstance(params,tuple) and len(params)==1:
171                 if self.debug:
172                     print >> log,'execute-tuple',query%params[0]
173                 cursor.execute(query,params[0])
174             else:
175                 param_seq=(params,)
176                 if self.debug:
177                     for params in param_seq:
178                         print >> log,'executemany',query%params
179                 cursor.executemany(query, param_seq)
180             (self.rowcount, self.description, self.lastrowid) = \
181                             (cursor.rowcount, cursor.description, cursor.lastrowid)
182         except Exception, e:
183             try:
184                 self.rollback()
185             except:
186                 pass
187             uuid = commands.getoutput("uuidgen")
188             print >> log, "Database error %s:" % uuid
189             print >> log, e
190             print >> log, "Query:"
191             print >> log, query
192             print >> log, "Params:"
193             print >> log, pformat(params)
194             raise PLCDBError("Please contact " + \
195                              self.api.config.PLC_NAME + " Support " + \
196                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
197                              " and reference " + uuid)
198
199         return cursor
200
201     def selectall(self, query, params = None, hashref = True, key_field = None):
202         """
203         Return each row as a dictionary keyed on field name (like DBI
204         selectrow_hashref()). If key_field is specified, return rows
205         as a dictionary keyed on the specified field (like DBI
206         selectall_hashref()).
207
208         If params is specified, the specified parameters will be bound
209         to the query.
210         """
211
212         cursor = self.execute(query, params)
213         rows = cursor.fetchall()
214         cursor.close()
215
216         if hashref or key_field is not None:
217             # Return each row as a dictionary keyed on field name
218             # (like DBI selectrow_hashref()).
219             labels = [column[0] for column in self.description]
220             rows = [dict(zip(labels, row)) for row in rows]
221
222         if key_field is not None and key_field in labels:
223             # Return rows as a dictionary keyed on the specified field
224             # (like DBI selectall_hashref()).
225             return dict([(row[key_field], row) for row in rows])
226         else:
227             return rows
228
229     def fields(self, table, notnull = None, hasdef = None):
230         """
231         Return the names of the fields of the specified table.
232         """
233
234         if hasattr(self, 'fields_cache'):
235             if self.fields_cache.has_key((table, notnull, hasdef)):
236                 return self.fields_cache[(table, notnull, hasdef)]
237         else:
238             self.fields_cache = {}
239
240         sql = "SELECT attname FROM pg_attribute, pg_class" \
241               " WHERE pg_class.oid = attrelid" \
242               " AND attnum > 0 AND relname = %(table)s"
243
244         if notnull is not None:
245             sql += " AND attnotnull is %(notnull)s"
246
247         if hasdef is not None:
248             sql += " AND atthasdef is %(hasdef)s"
249
250         rows = self.selectall(sql, locals(), hashref = False)
251
252         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
253
254         return self.fields_cache[(table, notnull, hasdef)]