- fix returns documentation
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.10 2006/11/09 19:34:04 mlhuang Exp $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54
55         # Initialize database connection
56         if psycopg2:
57             try:
58                 # Try UNIX socket first
59                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
60                                            password = api.config.PLC_DB_PASSWORD,
61                                            database = api.config.PLC_DB_NAME)
62             except psycopg2.OperationalError:
63                 # Fall back on TCP
64                 self.db = psycopg2.connect(user = api.config.PLC_DB_USER,
65                                            password = api.config.PLC_DB_PASSWORD,
66                                            database = api.config.PLC_DB_NAME,
67                                            host = api.config.PLC_DB_HOST,
68                                            port = api.config.PLC_DB_PORT)
69             self.db.set_client_encoding("UNICODE")
70         else:
71             self.db = pgdb.connect(user = api.config.PLC_DB_USER,
72                                    password = api.config.PLC_DB_PASSWORD,
73                                    host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
74                                    database = api.config.PLC_DB_NAME)
75
76         self.cursor = self.db.cursor()
77
78         (self.rowcount, self.description, self.lastrowid) = \
79                         (None, None, None)
80
81     def quote(self, value):
82         """
83         Returns quoted version of the specified value.
84         """
85
86         # The pgdb._quote function is good enough for general SQL
87         # quoting, except for array types.
88         if isinstance(value, (list, tuple, set)):
89             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
90         else:
91             return pgdb._quote(value)
92
93     quote = classmethod(quote)
94
95     def param(self, name, value):
96         # None is converted to the unquoted string NULL
97         if isinstance(value, NoneType):
98             conversion = "s"
99         # True and False are also converted to unquoted strings
100         elif isinstance(value, bool):
101             conversion = "s"
102         elif isinstance(value, float):
103             conversion = "f"
104         elif not isinstance(value, StringTypes):
105             conversion = "d"
106         else:
107             conversion = "s"
108
109         return '%(' + name + ')' + conversion
110
111     param = classmethod(param)
112
113     def begin_work(self):
114         # Implicit in pgdb.connect()
115         pass
116
117     def commit(self):
118         self.db.commit()
119
120     def rollback(self):
121         self.db.rollback()
122
123     def do(self, query, params = None):
124         self.execute(query, params)
125         return self.rowcount
126
127     def last_insert_id(self, table_name, primary_key):
128         if isinstance(self.lastrowid, int):
129             sql = "SELECT %s FROM %s WHERE oid = %d" % \
130                   (primary_key, table_name, self.lastrowid)
131             rows = self.selectall(sql, hashref = False)
132             if rows:
133                 return rows[0][0]
134
135         return None
136
137     def execute(self, query, params = None):
138         self.execute_array(query, (params,))
139
140     def execute_array(self, query, param_seq):
141         cursor = self.cursor
142         try:
143             # psycopg2 requires %()s format for all parameters,
144             # regardless of type.
145             if psycopg2:
146                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
147
148             cursor.executemany(query, param_seq)
149             (self.rowcount, self.description, self.lastrowid) = \
150                             (cursor.rowcount, cursor.description, cursor.lastrowid)
151         except Exception, e:
152             try:
153                 self.rollback()
154             except:
155                 pass
156             uuid = commands.getoutput("uuidgen")
157             print >> log, "Database error %s:" % uuid
158             print >> log, e
159             print >> log, "Query:"
160             print >> log, query
161             print >> log, "Params:"
162             print >> log, pformat(param_seq[0])
163             raise PLCDBError("Please contact " + \
164                              self.api.config.PLC_NAME + " Support " + \
165                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
166                              " and reference " + uuid)
167
168     def selectall(self, query, params = None, hashref = True, key_field = None):
169         """
170         Return each row as a dictionary keyed on field name (like DBI
171         selectrow_hashref()). If key_field is specified, return rows
172         as a dictionary keyed on the specified field (like DBI
173         selectall_hashref()).
174
175         If params is specified, the specified parameters will be bound
176         to the query (see PLC.DB.parameterize() and
177         pgdb.cursor.execute()).
178         """
179
180         self.execute(query, params)
181         rows = self.cursor.fetchall()
182
183         if hashref or key_field is not None:
184             # Return each row as a dictionary keyed on field name
185             # (like DBI selectrow_hashref()).
186             labels = [column[0] for column in self.description]
187             rows = [dict(zip(labels, row)) for row in rows]
188
189         if key_field is not None and key_field in labels:
190             # Return rows as a dictionary keyed on the specified field
191             # (like DBI selectall_hashref()).
192             return dict([(row[key_field], row) for row in rows])
193         else:
194             return rows
195
196     def fields(self, table, notnull = None, hasdef = None):
197         """
198         Return the names of the fields of the specified table.
199         """
200
201         sql = "SELECT attname FROM pg_attribute, pg_class" \
202               " WHERE pg_class.oid = attrelid" \
203               " AND attnum > 0 AND relname = %(table)s"
204
205         if notnull is not None:
206             sql += " AND attnotnull is %(notnull)s"
207
208         if hasdef is not None:
209             sql += " AND atthasdef is %(hasdef)s"
210
211         rows = self.selectall(sql, locals(), hashref = False)
212
213         return [row[0] for row in rows]