merged 9820 in trunk
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55         self.connection = None
56
57     def cursor(self):
58         if self.connection is None:
59             # (Re)initialize database connection
60             if psycopg2:
61                 try:
62                     # Try UNIX socket first
63                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                        password = self.api.config.PLC_DB_PASSWORD,
65                                                        database = self.api.config.PLC_DB_NAME)
66                 except psycopg2.OperationalError:
67                     # Fall back on TCP
68                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                        password = self.api.config.PLC_DB_PASSWORD,
70                                                        database = self.api.config.PLC_DB_NAME,
71                                                        host = self.api.config.PLC_DB_HOST,
72                                                        port = self.api.config.PLC_DB_PORT)
73                 self.connection.set_client_encoding("UNICODE")
74             else:
75                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
76                                                password = self.api.config.PLC_DB_PASSWORD,
77                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
78                                                database = self.api.config.PLC_DB_NAME)
79
80         (self.rowcount, self.description, self.lastrowid) = \
81                         (None, None, None)
82
83         return self.connection.cursor()
84
85     def close(self):
86         if self.connection is not None:
87             self.connection.close()
88             self.connection = None
89
90     def quote(self, value):
91         """
92         Returns quoted version of the specified value.
93         """
94
95         # The pgdb._quote function is good enough for general SQL
96         # quoting, except for array types.
97         if isinstance(value, (list, tuple, set)):
98             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
99         else:
100             return pgdb._quote(value)
101
102     quote = classmethod(quote)
103
104     def param(self, name, value):
105         # None is converted to the unquoted string NULL
106         if isinstance(value, NoneType):
107             conversion = "s"
108         # True and False are also converted to unquoted strings
109         elif isinstance(value, bool):
110             conversion = "s"
111         elif isinstance(value, float):
112             conversion = "f"
113         elif not isinstance(value, StringTypes):
114             conversion = "d"
115         else:
116             conversion = "s"
117
118         return '%(' + name + ')' + conversion
119
120     param = classmethod(param)
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def next_id(self, table_name, primary_key):
138         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
139         sql = "SELECT nextval('%(sequence)s')" % locals()
140         rows = self.selectall(sql, hashref = False)
141         if rows: 
142             return rows[0][0]
143                 
144         return None 
145
146     def last_insert_id(self, table_name, primary_key):
147         if isinstance(self.lastrowid, int):
148             sql = "SELECT %s FROM %s WHERE oid = %d" % \
149                   (primary_key, table_name, self.lastrowid)
150             rows = self.selectall(sql, hashref = False)
151             if rows:
152                 return rows[0][0]
153
154         return None
155
156     # modified for psycopg2-2.0.7 
157     # executemany is undefined for SELECT's
158     # see http://www.python.org/dev/peps/pep-0249/
159     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
160     # or a tuple of several dicts, in which case it executemany's
161     def execute(self, query, params = None):
162
163         cursor = self.cursor()
164         try:
165
166             # psycopg2 requires %()s format for all parameters,
167             # regardless of type.
168             if psycopg2:
169                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
170
171             if not params:
172                 if self.debug:
173                     print >> log,'execute0',query
174                 cursor.execute(query)
175             elif isinstance(params,dict):
176                 if self.debug:
177                     print >> log,'execute-dict: params',params,'query',query%params
178                 cursor.execute(query,params)
179             elif isinstance(params,tuple) and len(params)==1:
180                 if self.debug:
181                     print >> log,'execute-tuple',query%params[0]
182                 cursor.execute(query,params[0])
183             else:
184                 param_seq=(params,)
185                 if self.debug:
186                     for params in param_seq:
187                         print >> log,'executemany',query%params
188                 cursor.executemany(query, param_seq)
189             (self.rowcount, self.description, self.lastrowid) = \
190                             (cursor.rowcount, cursor.description, cursor.lastrowid)
191         except Exception, e:
192             try:
193                 self.rollback()
194             except:
195                 pass
196             uuid = commands.getoutput("uuidgen")
197             print >> log, "Database error %s:" % uuid
198             print >> log, e
199             print >> log, "Query:"
200             print >> log, query
201             print >> log, "Params:"
202             print >> log, pformat(params)
203             raise PLCDBError("Please contact " + \
204                              self.api.config.PLC_NAME + " Support " + \
205                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
206                              " and reference " + uuid)
207
208         return cursor
209
210     def selectall(self, query, params = None, hashref = True, key_field = None):
211         """
212         Return each row as a dictionary keyed on field name (like DBI
213         selectrow_hashref()). If key_field is specified, return rows
214         as a dictionary keyed on the specified field (like DBI
215         selectall_hashref()).
216
217         If params is specified, the specified parameters will be bound
218         to the query.
219         """
220
221         cursor = self.execute(query, params)
222         rows = cursor.fetchall()
223         cursor.close()
224
225         if hashref or key_field is not None:
226             # Return each row as a dictionary keyed on field name
227             # (like DBI selectrow_hashref()).
228             labels = [column[0] for column in self.description]
229             rows = [dict(zip(labels, row)) for row in rows]
230
231         if key_field is not None and key_field in labels:
232             # Return rows as a dictionary keyed on the specified field
233             # (like DBI selectall_hashref()).
234             return dict([(row[key_field], row) for row in rows])
235         else:
236             return rows
237
238     def fields(self, table, notnull = None, hasdef = None):
239         """
240         Return the names of the fields of the specified table.
241         """
242
243         if hasattr(self, 'fields_cache'):
244             if self.fields_cache.has_key((table, notnull, hasdef)):
245                 return self.fields_cache[(table, notnull, hasdef)]
246         else:
247             self.fields_cache = {}
248
249         sql = "SELECT attname FROM pg_attribute, pg_class" \
250               " WHERE pg_class.oid = attrelid" \
251               " AND attnum > 0 AND relname = %(table)s"
252
253         if notnull is not None:
254             sql += " AND attnotnull is %(notnull)s"
255
256         if hasdef is not None:
257             sql += " AND atthasdef is %(hasdef)s"
258
259         rows = self.selectall(sql, locals(), hashref = False)
260
261         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
262
263         return self.fields_cache[(table, notnull, hasdef)]