- try obtaining another cursor once if it appears to be closed
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.12 2007/02/08 15:15:21 mlhuang Exp $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55
56         # Initialize database connection
57         if psycopg2:
58             try:
59                 # Try UNIX socket first
60                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
61                                            password = api.config.PLC_DB_PASSWORD,
62                                            database = api.config.PLC_DB_NAME)
63             except psycopg2.OperationalError:
64                 # Fall back on TCP
65                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
66                                            password = api.config.PLC_DB_PASSWORD,
67                                            database = api.config.PLC_DB_NAME,
68                                            host = api.config.PLC_DB_HOST,
69                                            port = api.config.PLC_DB_PORT)
70             self.db.set_client_encoding("UNICODE")
71         else:
72             self.db = pgdb.connect(user = api.config.PLC_DB_USER,
73                                    password = api.config.PLC_DB_PASSWORD,
74                                    host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
75                                    database = api.config.PLC_DB_NAME)
76
77         self.cursor = self.db.cursor()
78
79         (self.rowcount, self.description, self.lastrowid) = \
80                         (None, None, None)
81
82     def quote(self, value):
83         """
84         Returns quoted version of the specified value.
85         """
86
87         # The pgdb._quote function is good enough for general SQL
88         # quoting, except for array types.
89         if isinstance(value, (list, tuple, set)):
90             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
91         else:
92             return pgdb._quote(value)
93
94     quote = classmethod(quote)
95
96     def param(self, name, value):
97         # None is converted to the unquoted string NULL
98         if isinstance(value, NoneType):
99             conversion = "s"
100         # True and False are also converted to unquoted strings
101         elif isinstance(value, bool):
102             conversion = "s"
103         elif isinstance(value, float):
104             conversion = "f"
105         elif not isinstance(value, StringTypes):
106             conversion = "d"
107         else:
108             conversion = "s"
109
110         return '%(' + name + ')' + conversion
111
112     param = classmethod(param)
113
114     def begin_work(self):
115         # Implicit in pgdb.connect()
116         pass
117
118     def commit(self):
119         self.db.commit()
120
121     def rollback(self):
122         self.db.rollback()
123
124     def do(self, query, params = None):
125         self.execute(query, params)
126         return self.rowcount
127
128     def last_insert_id(self, table_name, primary_key):
129         if isinstance(self.lastrowid, int):
130             sql = "SELECT %s FROM %s WHERE oid = %d" % \
131                   (primary_key, table_name, self.lastrowid)
132             rows = self.selectall(sql, hashref = False)
133             if rows:
134                 return rows[0][0]
135
136         return None
137
138     def execute(self, query, params = None):
139         self.execute_array(query, (params,))
140
141     def execute_array(self, query, param_seq):
142         cursor = self.cursor
143         try:
144             if self.debug:
145                 for params in param_seq:
146                     if params:
147                         print >> log, query % params
148                     else:
149                         print >> log, query
150
151             # psycopg2 requires %()s format for all parameters,
152             # regardless of type.
153             if psycopg2:
154                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
155
156             try:
157                 cursor.executemany(query, param_seq)
158             except InterfaceError:
159                 # Try one more time with another cursor
160                 cursor = self.cursor = self.db.cursor()
161                 cursor.executemany(query, param_seq)
162
163             (self.rowcount, self.description, self.lastrowid) = \
164                             (cursor.rowcount, cursor.description, cursor.lastrowid)
165         except Exception, e:
166             try:
167                 self.rollback()
168             except:
169                 pass
170             uuid = commands.getoutput("uuidgen")
171             print >> log, "Database error %s:" % uuid
172             print >> log, e
173             print >> log, "Query:"
174             print >> log, query
175             print >> log, "Params:"
176             print >> log, pformat(param_seq[0])
177             raise PLCDBError("Please contact " + \
178                              self.api.config.PLC_NAME + " Support " + \
179                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
180                              " and reference " + uuid)
181
182     def selectall(self, query, params = None, hashref = True, key_field = None):
183         """
184         Return each row as a dictionary keyed on field name (like DBI
185         selectrow_hashref()). If key_field is specified, return rows
186         as a dictionary keyed on the specified field (like DBI
187         selectall_hashref()).
188
189         If params is specified, the specified parameters will be bound
190         to the query (see PLC.DB.parameterize() and
191         pgdb.cursor.execute()).
192         """
193
194         self.execute(query, params)
195         rows = self.cursor.fetchall()
196
197         if hashref or key_field is not None:
198             # Return each row as a dictionary keyed on field name
199             # (like DBI selectrow_hashref()).
200             labels = [column[0] for column in self.description]
201             rows = [dict(zip(labels, row)) for row in rows]
202
203         if key_field is not None and key_field in labels:
204             # Return rows as a dictionary keyed on the specified field
205             # (like DBI selectall_hashref()).
206             return dict([(row[key_field], row) for row in rows])
207         else:
208             return rows
209
210     def fields(self, table, notnull = None, hasdef = None):
211         """
212         Return the names of the fields of the specified table.
213         """
214
215         if hasattr(self, 'fields_cache'):
216             if self.fields_cache.has_key((table, notnull, hasdef)):
217                 return self.fields_cache[(table, notnull, hasdef)]
218         else:
219             self.fields_cache = {}
220
221         sql = "SELECT attname FROM pg_attribute, pg_class" \
222               " WHERE pg_class.oid = attrelid" \
223               " AND attnum > 0 AND relname = %(table)s"
224
225         if notnull is not None:
226             sql += " AND attnotnull is %(notnull)s"
227
228         if hasdef is not None:
229             sql += " AND atthasdef is %(hasdef)s"
230
231         rows = self.selectall(sql, locals(), hashref = False)
232
233         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
234
235         return self.fields_cache[(table, notnull, hasdef)]