- fields(): cache table field names
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.11 2006/12/04 19:10:47 mlhuang Exp $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55
56         # Initialize database connection
57         if psycopg2:
58             try:
59                 # Try UNIX socket first
60                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
61                                            password = api.config.PLC_DB_PASSWORD,
62                                            database = api.config.PLC_DB_NAME)
63             except psycopg2.OperationalError:
64                 # Fall back on TCP
65                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
66                                            password = api.config.PLC_DB_PASSWORD,
67                                            database = api.config.PLC_DB_NAME,
68                                            host = api.config.PLC_DB_HOST,
69                                            port = api.config.PLC_DB_PORT)
70             self.db.set_client_encoding("UNICODE")
71         else:
72             self.db = pgdb.connect(user = api.config.PLC_DB_USER,
73                                    password = api.config.PLC_DB_PASSWORD,
74                                    host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
75                                    database = api.config.PLC_DB_NAME)
76
77         self.cursor = self.db.cursor()
78
79         (self.rowcount, self.description, self.lastrowid) = \
80                         (None, None, None)
81
82     def quote(self, value):
83         """
84         Returns quoted version of the specified value.
85         """
86
87         # The pgdb._quote function is good enough for general SQL
88         # quoting, except for array types.
89         if isinstance(value, (list, tuple, set)):
90             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
91         else:
92             return pgdb._quote(value)
93
94     quote = classmethod(quote)
95
96     def param(self, name, value):
97         # None is converted to the unquoted string NULL
98         if isinstance(value, NoneType):
99             conversion = "s"
100         # True and False are also converted to unquoted strings
101         elif isinstance(value, bool):
102             conversion = "s"
103         elif isinstance(value, float):
104             conversion = "f"
105         elif not isinstance(value, StringTypes):
106             conversion = "d"
107         else:
108             conversion = "s"
109
110         return '%(' + name + ')' + conversion
111
112     param = classmethod(param)
113
114     def begin_work(self):
115         # Implicit in pgdb.connect()
116         pass
117
118     def commit(self):
119         self.db.commit()
120
121     def rollback(self):
122         self.db.rollback()
123
124     def do(self, query, params = None):
125         self.execute(query, params)
126         return self.rowcount
127
128     def last_insert_id(self, table_name, primary_key):
129         if isinstance(self.lastrowid, int):
130             sql = "SELECT %s FROM %s WHERE oid = %d" % \
131                   (primary_key, table_name, self.lastrowid)
132             rows = self.selectall(sql, hashref = False)
133             if rows:
134                 return rows[0][0]
135
136         return None
137
138     def execute(self, query, params = None):
139         self.execute_array(query, (params,))
140
141     def execute_array(self, query, param_seq):
142         cursor = self.cursor
143         try:
144             if self.debug:
145                 for params in param_seq:
146                     if params:
147                         print >> log, query % params
148                     else:
149                         print >> log, query
150
151             # psycopg2 requires %()s format for all parameters,
152             # regardless of type.
153             if psycopg2:
154                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
155
156             cursor.executemany(query, param_seq)
157             (self.rowcount, self.description, self.lastrowid) = \
158                             (cursor.rowcount, cursor.description, cursor.lastrowid)
159         except Exception, e:
160             try:
161                 self.rollback()
162             except:
163                 pass
164             uuid = commands.getoutput("uuidgen")
165             print >> log, "Database error %s:" % uuid
166             print >> log, e
167             print >> log, "Query:"
168             print >> log, query
169             print >> log, "Params:"
170             print >> log, pformat(param_seq[0])
171             raise PLCDBError("Please contact " + \
172                              self.api.config.PLC_NAME + " Support " + \
173                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
174                              " and reference " + uuid)
175
176     def selectall(self, query, params = None, hashref = True, key_field = None):
177         """
178         Return each row as a dictionary keyed on field name (like DBI
179         selectrow_hashref()). If key_field is specified, return rows
180         as a dictionary keyed on the specified field (like DBI
181         selectall_hashref()).
182
183         If params is specified, the specified parameters will be bound
184         to the query (see PLC.DB.parameterize() and
185         pgdb.cursor.execute()).
186         """
187
188         self.execute(query, params)
189         rows = self.cursor.fetchall()
190
191         if hashref or key_field is not None:
192             # Return each row as a dictionary keyed on field name
193             # (like DBI selectrow_hashref()).
194             labels = [column[0] for column in self.description]
195             rows = [dict(zip(labels, row)) for row in rows]
196
197         if key_field is not None and key_field in labels:
198             # Return rows as a dictionary keyed on the specified field
199             # (like DBI selectall_hashref()).
200             return dict([(row[key_field], row) for row in rows])
201         else:
202             return rows
203
204     def fields(self, table, notnull = None, hasdef = None):
205         """
206         Return the names of the fields of the specified table.
207         """
208
209         if hasattr(self, 'fields_cache'):
210             if self.fields_cache.has_key((table, notnull, hasdef)):
211                 return self.fields_cache[(table, notnull, hasdef)]
212         else:
213             self.fields_cache = {}
214
215         sql = "SELECT attname FROM pg_attribute, pg_class" \
216               " WHERE pg_class.oid = attrelid" \
217               " AND attnum > 0 AND relname = %(table)s"
218
219         if notnull is not None:
220             sql += " AND attnotnull is %(notnull)s"
221
222         if hasdef is not None:
223             sql += " AND atthasdef is %(hasdef)s"
224
225         rows = self.selectall(sql, locals(), hashref = False)
226
227         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
228
229         return self.fields_cache[(table, notnull, hasdef)]