- support psycopg2
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.8 2006/10/30 16:37:49 mlhuang Exp $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14
15 import pgdb
16 from types import StringTypes, NoneType
17 import traceback
18 import commands
19 import re
20 from pprint import pformat
21
22 from PLC.Debug import profile, log
23 from PLC.Faults import *
24
25 if not psycopg2:
26     is8bit = re.compile("[\x80-\xff]").search
27
28     def unicast(typecast):
29         """
30         pgdb returns raw UTF-8 strings. This function casts strings that
31         appear to contain non-ASCII characters to unicode objects.
32         """
33     
34         def wrapper(*args, **kwds):
35             value = typecast(*args, **kwds)
36
37             # pgdb always encodes unicode objects as UTF-8 regardless of
38             # the DB encoding (and gives you no option for overriding
39             # the encoding), so always decode 8-bit objects as UTF-8.
40             if isinstance(value, str) and is8bit(value):
41                 value = unicode(value, "utf-8")
42
43             return value
44
45         return wrapper
46
47     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
48
49 class PostgreSQL:
50     def __init__(self, api):
51         self.api = api
52
53         # Initialize database connection
54         if psycopg2:
55             try:
56                 # Try UNIX socket first
57                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
58                                            password = api.config.PLC_DB_PASSWORD,
59                                            database = api.config.PLC_DB_NAME)
60             except psycopg2.OperationalError:
61                 # Fall back on TCP
62                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
63                                            password = api.config.PLC_DB_PASSWORD,
64                                            database = api.config.PLC_DB_NAME,
65                                            host = api.config.PLC_DB_HOST,
66                                            port = api.config.PLC_DB_PORT)
67             self.db.set_client_encoding("UNICODE")
68         else:
69             self.db = pgdb.connect(user = api.config.PLC_DB_USER,
70                                    password = api.config.PLC_DB_PASSWORD,
71                                    host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
72                                    database = api.config.PLC_DB_NAME)
73
74         self.cursor = self.db.cursor()
75
76         (self.rowcount, self.description, self.lastrowid) = \
77                         (None, None, None)
78
79     def quote(self, value):
80         """
81         Returns quoted version of the specified value.
82         """
83
84         # The pgdb._quote function is good enough for general SQL
85         # quoting, except for array types.
86         if isinstance(value, (list, tuple, set)):
87             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
88         else:
89             return pgdb._quote(value)
90
91     quote = classmethod(quote)
92
93     def param(self, name, value):
94         # None is converted to the unquoted string NULL
95         if isinstance(value, NoneType):
96             conversion = "s"
97         # True and False are also converted to unquoted strings
98         elif isinstance(value, bool):
99             conversion = "s"
100         elif isinstance(value, float):
101             conversion = "f"
102         elif not isinstance(value, StringTypes):
103             conversion = "d"
104         else:
105             conversion = "s"
106
107         return '%(' + name + ')' + conversion
108
109     param = classmethod(param)
110
111     def begin_work(self):
112         # Implicit in pgdb.connect()
113         pass
114
115     def commit(self):
116         self.db.commit()
117
118     def rollback(self):
119         self.db.rollback()
120
121     def do(self, query, params = None):
122         self.execute(query, params)
123         return self.rowcount
124
125     def last_insert_id(self, table_name, primary_key):
126         if isinstance(self.lastrowid, int):
127             sql = "SELECT %s FROM %s WHERE oid = %d" % \
128                   (primary_key, table_name, self.lastrowid)
129             rows = self.selectall(sql, hashref = False)
130             if rows:
131                 return rows[0][0]
132
133         return None
134
135     def execute(self, query, params = None):
136         self.execute_array(query, (params,))
137
138     def execute_array(self, query, param_seq):
139         cursor = self.cursor
140         try:
141             # psycopg2 requires %()s format for all parameters,
142             # regardless of type.
143             if psycopg2:
144                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
145
146             cursor.executemany(query, param_seq)
147             (self.rowcount, self.description, self.lastrowid) = \
148                             (cursor.rowcount, cursor.description, cursor.lastrowid)
149         except Exception, e:
150             try:
151                 self.rollback()
152             except:
153                 pass
154             uuid = commands.getoutput("uuidgen")
155             print >> log, "Database error %s:" % uuid
156             print >> log, e
157             print >> log, "Query:"
158             print >> log, query
159             print >> log, "Params:"
160             print >> log, pformat(param_seq[0])
161             raise PLCDBError("Please contact " + \
162                              self.api.config.PLC_NAME + " Support " + \
163                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
164                              " and reference " + uuid)
165
166     def selectall(self, query, params = None, hashref = True, key_field = None):
167         """
168         Return each row as a dictionary keyed on field name (like DBI
169         selectrow_hashref()). If key_field is specified, return rows
170         as a dictionary keyed on the specified field (like DBI
171         selectall_hashref()).
172
173         If params is specified, the specified parameters will be bound
174         to the query (see PLC.DB.parameterize() and
175         pgdb.cursor.execute()).
176         """
177
178         self.execute(query, params)
179         rows = self.cursor.fetchall()
180
181         if hashref:
182             # Return each row as a dictionary keyed on field name
183             # (like DBI selectrow_hashref()).
184             labels = [column[0] for column in self.description]
185             rows = [dict(zip(labels, row)) for row in rows]
186
187         if key_field is not None and key_field in labels:
188             # Return rows as a dictionary keyed on the specified field
189             # (like DBI selectall_hashref()).
190             return dict([(row[key_field], row) for row in rows])
191         else:
192             return rows
193
194     def fields(self, table, notnull = None, hasdef = None):
195         """
196         Return the names of the fields of the specified table.
197         """
198
199         sql = "SELECT attname FROM pg_attribute, pg_class" \
200               " WHERE pg_class.oid = attrelid" \
201               " AND attnum > 0 AND relname = %(table)s"
202
203         if notnull is not None:
204             sql += " AND attnotnull is %(notnull)s"
205
206         if hasdef is not None:
207             sql += " AND atthasdef is %(hasdef)s"
208
209         rows = self.selectall(sql, locals(), hashref = False)
210
211         return [row[0] for row in rows]