new method next_id(...) returns next primary key id
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55         self.connection = None
56
57     def cursor(self):
58         if self.connection is None:
59             # (Re)initialize database connection
60             if psycopg2:
61                 try:
62                     # Try UNIX socket first
63                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                        password = self.api.config.PLC_DB_PASSWORD,
65                                                        database = self.api.config.PLC_DB_NAME)
66                 except psycopg2.OperationalError:
67                     # Fall back on TCP
68                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                        password = self.api.config.PLC_DB_PASSWORD,
70                                                        database = self.api.config.PLC_DB_NAME,
71                                                        host = self.api.config.PLC_DB_HOST,
72                                                        port = self.api.config.PLC_DB_PORT)
73                 self.connection.set_client_encoding("UNICODE")
74             else:
75                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
76                                                password = self.api.config.PLC_DB_PASSWORD,
77                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
78                                                database = self.api.config.PLC_DB_NAME)
79
80         (self.rowcount, self.description, self.lastrowid) = \
81                         (None, None, None)
82
83         return self.connection.cursor()
84
85     def close(self):
86         if self.connection is not None:
87             self.connection.close()
88             self.connection = None
89
90     def quote(self, value):
91         """
92         Returns quoted version of the specified value.
93         """
94
95         # The pgdb._quote function is good enough for general SQL
96         # quoting, except for array types.
97         if isinstance(value, (list, tuple, set)):
98             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
99         else:
100             return pgdb._quote(value)
101
102     quote = classmethod(quote)
103
104     def param(self, name, value):
105         # None is converted to the unquoted string NULL
106         if isinstance(value, NoneType):
107             conversion = "s"
108         # True and False are also converted to unquoted strings
109         elif isinstance(value, bool):
110             conversion = "s"
111         elif isinstance(value, float):
112             conversion = "f"
113         elif not isinstance(value, StringTypes):
114             conversion = "d"
115         else:
116             conversion = "s"
117
118         return '%(' + name + ')' + conversion
119
120     param = classmethod(param)
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def next_id(self, table_name, primary_key):
138         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
139         sql = "SELECT nextval('%(sequence)s')" % locals()
140         rows = self.selectall(sql, hashref = False)
141         if rows: 
142             return rows[0][0]
143                 
144         return None 
145
146     def last_insert_id(self, table_name, primary_key):
147         if isinstance(self.lastrowid, int):
148             sql = "SELECT %s FROM %s WHERE oid = %d" % \
149                   (primary_key, table_name, self.lastrowid)
150             rows = self.selectall(sql, hashref = False)
151             if rows:
152                 return rows[0][0]
153
154         return None
155
156     def execute(self, query, params = None):
157         return self.execute_array(query, (params,))
158
159     def execute_array(self, query, param_seq):
160         cursor = self.cursor()
161         try:
162             if self.debug:
163                 for params in param_seq:
164                     if params:
165                         print >> log, query % params
166                     else:
167                         print >> log, query
168
169             # psycopg2 requires %()s format for all parameters,
170             # regardless of type.
171             if psycopg2:
172                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
173
174             cursor.executemany(query, param_seq)
175             (self.rowcount, self.description, self.lastrowid) = \
176                             (cursor.rowcount, cursor.description, cursor.lastrowid)
177         except Exception, e:
178             try:
179                 self.rollback()
180             except:
181                 pass
182             uuid = commands.getoutput("uuidgen")
183             print >> log, "Database error %s:" % uuid
184             print >> log, e
185             print >> log, "Query:"
186             print >> log, query
187             print >> log, "Params:"
188             print >> log, pformat(param_seq[0])
189             raise PLCDBError("Please contact " + \
190                              self.api.config.PLC_NAME + " Support " + \
191                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
192                              " and reference " + uuid)
193
194         return cursor
195
196     def selectall(self, query, params = None, hashref = True, key_field = None):
197         """
198         Return each row as a dictionary keyed on field name (like DBI
199         selectrow_hashref()). If key_field is specified, return rows
200         as a dictionary keyed on the specified field (like DBI
201         selectall_hashref()).
202
203         If params is specified, the specified parameters will be bound
204         to the query.
205         """
206
207         cursor = self.execute(query, params)
208         rows = cursor.fetchall()
209         cursor.close()
210
211         if hashref or key_field is not None:
212             # Return each row as a dictionary keyed on field name
213             # (like DBI selectrow_hashref()).
214             labels = [column[0] for column in self.description]
215             rows = [dict(zip(labels, row)) for row in rows]
216
217         if key_field is not None and key_field in labels:
218             # Return rows as a dictionary keyed on the specified field
219             # (like DBI selectall_hashref()).
220             return dict([(row[key_field], row) for row in rows])
221         else:
222             return rows
223
224     def fields(self, table, notnull = None, hasdef = None):
225         """
226         Return the names of the fields of the specified table.
227         """
228
229         if hasattr(self, 'fields_cache'):
230             if self.fields_cache.has_key((table, notnull, hasdef)):
231                 return self.fields_cache[(table, notnull, hasdef)]
232         else:
233             self.fields_cache = {}
234
235         sql = "SELECT attname FROM pg_attribute, pg_class" \
236               " WHERE pg_class.oid = attrelid" \
237               " AND attnum > 0 AND relname = %(table)s"
238
239         if notnull is not None:
240             sql += " AND attnotnull is %(notnull)s"
241
242         if hasdef is not None:
243             sql += " AND atthasdef is %(hasdef)s"
244
245         rows = self.selectall(sql, locals(), hashref = False)
246
247         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
248
249         return self.fields_cache[(table, notnull, hasdef)]