1934316f8e741af94fad2e58b2110c880b694eff
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py,v 1.7 2006/10/24 13:47:05 mlhuang Exp $
9 #
10
11 import pgdb
12 from types import StringTypes, NoneType
13 import traceback
14 import commands
15 import re
16 from pprint import pformat
17
18 from PLC.Debug import profile, log
19 from PLC.Faults import *
20
21 is8bit = re.compile("[\x80-\xff]").search
22
23 def unicast(typecast):
24     """
25     pgdb returns raw UTF-8 strings. This function casts strings that
26     appear to contain non-ASCII characters to unicode objects.
27     """
28     
29     def wrapper(*args, **kwds):
30         value = typecast(*args, **kwds)
31
32         # pgdb always encodes unicode objects as UTF-8 regardless of
33         # the DB encoding (and gives you no option for overriding
34         # the encoding), so always decode 8-bit objects as UTF-8.
35         if isinstance(value, str) and is8bit(value):
36             value = unicode(value, "utf-8")
37
38         return value
39
40     return wrapper
41
42 pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
43
44 class PostgreSQL:
45     def __init__(self, api):
46         self.api = api
47
48         # Initialize database connection
49         self.db = pgdb.connect(user = api.config.PLC_DB_USER,
50                                password = api.config.PLC_DB_PASSWORD,
51                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
52                                database = api.config.PLC_DB_NAME)
53         self.cursor = self.db.cursor()
54
55         (self.rowcount, self.description, self.lastrowid) = \
56                         (None, None, None)
57
58     def quote(self, params):
59         """
60         Returns quoted version(s) of the specified parameter(s).
61         """
62
63         # pgdb._quote functions are good enough for general SQL quoting
64         if hasattr(params, 'has_key'):
65             params = pgdb._quoteitem(params)
66         elif isinstance(params, list) or isinstance(params, tuple) or isinstance(params, set):
67             params = map(pgdb._quote, params)
68         else:
69             params = pgdb._quote(params)
70
71         return params
72
73     quote = classmethod(quote)
74
75     def param(self, name, value):
76         # None is converted to the unquoted string NULL
77         if isinstance(value, NoneType):
78             conversion = "s"
79         # True and False are also converted to unquoted strings
80         elif isinstance(value, bool):
81             conversion = "s"
82         elif isinstance(value, float):
83             conversion = "f"
84         elif not isinstance(value, StringTypes):
85             conversion = "d"
86         else:
87             conversion = "s"
88
89         return '%(' + name + ')' + conversion
90
91     param = classmethod(param)
92
93     def begin_work(self):
94         # Implicit in pgdb.connect()
95         pass
96
97     def commit(self):
98         self.db.commit()
99
100     def rollback(self):
101         self.db.rollback()
102
103     def do(self, query, params = None):
104         self.execute(query, params)
105         return self.rowcount
106
107     def last_insert_id(self, table_name, primary_key):
108         if isinstance(self.lastrowid, int):
109             sql = "SELECT %s FROM %s WHERE oid = %d" % \
110                   (primary_key, table_name, self.lastrowid)
111             rows = self.selectall(sql, hashref = False)
112             if rows:
113                 return rows[0][0]
114
115         return None
116
117     def execute(self, query, params = None):
118         self.execute_array(query, (params,))
119
120     def execute_array(self, query, param_seq):
121         cursor = self.cursor
122         try:
123             cursor.executemany(query, param_seq)
124             (self.rowcount, self.description, self.lastrowid) = \
125                             (cursor.rowcount, cursor.description, cursor.lastrowid)
126         except pgdb.DatabaseError, e:
127             try:
128                 self.rollback()
129             except:
130                 pass
131             uuid = commands.getoutput("uuidgen")
132             print >> log, "Database error %s:" % uuid
133             print >> log, e
134             print >> log, "Query:"
135             print >> log, query
136             print >> log, "Params:"
137             print >> log, pformat(param_seq[0])
138             raise PLCDBError("Please contact " + \
139                              self.api.config.PLC_NAME + " Support " + \
140                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
141                              " and reference " + uuid)
142
143     def selectall(self, query, params = None, hashref = True, key_field = None):
144         """
145         Return each row as a dictionary keyed on field name (like DBI
146         selectrow_hashref()). If key_field is specified, return rows
147         as a dictionary keyed on the specified field (like DBI
148         selectall_hashref()).
149
150         If params is specified, the specified parameters will be bound
151         to the query (see PLC.DB.parameterize() and
152         pgdb.cursor.execute()).
153         """
154
155         self.execute(query, params)
156         rows = self.cursor.fetchall()
157
158         if hashref:
159             # Return each row as a dictionary keyed on field name
160             # (like DBI selectrow_hashref()).
161             labels = [column[0] for column in self.description]
162             rows = [dict(zip(labels, row)) for row in rows]
163
164         if key_field is not None and key_field in labels:
165             # Return rows as a dictionary keyed on the specified field
166             # (like DBI selectall_hashref()).
167             return dict([(row[key_field], row) for row in rows])
168         else:
169             return rows
170
171     def fields(self, table, notnull = None, hasdef = None):
172         """
173         Return the names of the fields of the specified table.
174         """
175
176         sql = "SELECT attname FROM pg_attribute, pg_class" \
177               " WHERE pg_class.oid = attrelid" \
178               " AND attnum > 0 AND relname = %(table)s"
179
180         if notnull is not None:
181             sql += " AND attnotnull is %(notnull)s"
182
183         if hasdef is not None:
184             sql += " AND atthasdef is %(hasdef)s"
185
186         rows = self.selectall(sql, locals(), hashref = False)
187
188         return [row[0] for row in rows]