f15ebf9c9570a0b4590cbf1c707b31467afb2154
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.13 2007/02/11 04:53:40 mlhuang Exp $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55         self.connection = None
56
57     def cursor(self):
58         if self.connection is None:
59             # (Re)initialize database connection
60             if psycopg2:
61                 try:
62                     # Try UNIX socket first
63                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                        password = self.api.config.PLC_DB_PASSWORD,
65                                                        database = self.api.config.PLC_DB_NAME)
66                 except psycopg2.OperationalError:
67                     # Fall back on TCP
68                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                        password = self.api.config.PLC_DB_PASSWORD,
70                                                        database = self.api.config.PLC_DB_NAME,
71                                                        host = self.api.config.PLC_DB_HOST,
72                                                        port = self.api.config.PLC_DB_PORT)
73                 self.connection.set_client_encoding("UNICODE")
74             else:
75                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
76                                                password = self.api.config.PLC_DB_PASSWORD,
77                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
78                                                database = self.api.config.PLC_DB_NAME)
79
80         (self.rowcount, self.description, self.lastrowid) = \
81                         (None, None, None)
82
83         return self.connection.cursor()
84
85     def close(self):
86         if self.connection is not None:
87             self.connection.close()
88             self.connection = None
89
90     def quote(self, value):
91         """
92         Returns quoted version of the specified value.
93         """
94
95         # The pgdb._quote function is good enough for general SQL
96         # quoting, except for array types.
97         if isinstance(value, (list, tuple, set)):
98             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
99         else:
100             return pgdb._quote(value)
101
102     quote = classmethod(quote)
103
104     def param(self, name, value):
105         # None is converted to the unquoted string NULL
106         if isinstance(value, NoneType):
107             conversion = "s"
108         # True and False are also converted to unquoted strings
109         elif isinstance(value, bool):
110             conversion = "s"
111         elif isinstance(value, float):
112             conversion = "f"
113         elif not isinstance(value, StringTypes):
114             conversion = "d"
115         else:
116             conversion = "s"
117
118         return '%(' + name + ')' + conversion
119
120     param = classmethod(param)
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         self.execute(query, params)
134         return self.rowcount
135
136     def last_insert_id(self, table_name, primary_key):
137         if isinstance(self.lastrowid, int):
138             sql = "SELECT %s FROM %s WHERE oid = %d" % \
139                   (primary_key, table_name, self.lastrowid)
140             rows = self.selectall(sql, hashref = False)
141             if rows:
142                 return rows[0][0]
143
144         return None
145
146     def execute(self, query, params = None):
147         return self.execute_array(query, (params,))
148
149     def execute_array(self, query, param_seq):
150         cursor = self.cursor()
151         try:
152             if self.debug:
153                 for params in param_seq:
154                     if params:
155                         print >> log, query % params
156                     else:
157                         print >> log, query
158
159             # psycopg2 requires %()s format for all parameters,
160             # regardless of type.
161             if psycopg2:
162                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
163
164             cursor.executemany(query, param_seq)
165             (self.rowcount, self.description, self.lastrowid) = \
166                             (cursor.rowcount, cursor.description, cursor.lastrowid)
167         except Exception, e:
168             try:
169                 self.rollback()
170             except:
171                 pass
172             uuid = commands.getoutput("uuidgen")
173             print >> log, "Database error %s:" % uuid
174             print >> log, e
175             print >> log, "Query:"
176             print >> log, query
177             print >> log, "Params:"
178             print >> log, pformat(param_seq[0])
179             raise PLCDBError("Please contact " + \
180                              self.api.config.PLC_NAME + " Support " + \
181                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
182                              " and reference " + uuid)
183
184         return cursor
185
186     def selectall(self, query, params = None, hashref = True, key_field = None):
187         """
188         Return each row as a dictionary keyed on field name (like DBI
189         selectrow_hashref()). If key_field is specified, return rows
190         as a dictionary keyed on the specified field (like DBI
191         selectall_hashref()).
192
193         If params is specified, the specified parameters will be bound
194         to the query.
195         """
196
197         rows = self.execute(query, params).fetchall()
198
199         if hashref or key_field is not None:
200             # Return each row as a dictionary keyed on field name
201             # (like DBI selectrow_hashref()).
202             labels = [column[0] for column in self.description]
203             rows = [dict(zip(labels, row)) for row in rows]
204
205         if key_field is not None and key_field in labels:
206             # Return rows as a dictionary keyed on the specified field
207             # (like DBI selectall_hashref()).
208             return dict([(row[key_field], row) for row in rows])
209         else:
210             return rows
211
212     def fields(self, table, notnull = None, hasdef = None):
213         """
214         Return the names of the fields of the specified table.
215         """
216
217         if hasattr(self, 'fields_cache'):
218             if self.fields_cache.has_key((table, notnull, hasdef)):
219                 return self.fields_cache[(table, notnull, hasdef)]
220         else:
221             self.fields_cache = {}
222
223         sql = "SELECT attname FROM pg_attribute, pg_class" \
224               " WHERE pg_class.oid = attrelid" \
225               " AND attnum > 0 AND relname = %(table)s"
226
227         if notnull is not None:
228             sql += " AND attnotnull is %(notnull)s"
229
230         if hasdef is not None:
231             sql += " AND atthasdef is %(hasdef)s"
232
233         rows = self.selectall(sql, locals(), hashref = False)
234
235         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
236
237         return self.fields_cache[(table, notnull, hasdef)]