1dfced50951822a30a87c17fb6a4903240c55fdf
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. 
3 # Sort of like DBI(3) (Database independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8
9 import psycopg2
10 import psycopg2.extensions
11 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
12 # UNICODEARRAY not exported yet
13 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
14
15 import types
16 from types import StringTypes, NoneType
17 import traceback
18 import commands
19 import re
20 from pprint import pformat
21
22 from PLC.Debug import profile, log
23 from PLC.Faults import *
24 from datetime import datetime as DateTimeType
25
26 class PostgreSQL:
27     def __init__(self, api):
28         self.api = api
29         self.debug = False
30 #        self.debug = True
31         self.connection = None
32
33     def cursor(self):
34         if self.connection is None:
35             # (Re)initialize database connection
36             try:
37                 # Try UNIX socket first
38                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
39                                                    password = self.api.config.PLC_DB_PASSWORD,
40                                                    database = self.api.config.PLC_DB_NAME)
41             except psycopg2.OperationalError:
42                 # Fall back on TCP
43                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
44                                                    password = self.api.config.PLC_DB_PASSWORD,
45                                                    database = self.api.config.PLC_DB_NAME,
46                                                    host = self.api.config.PLC_DB_HOST,
47                                                    port = self.api.config.PLC_DB_PORT)
48             self.connection.set_client_encoding("UNICODE")
49
50         (self.rowcount, self.description, self.lastrowid) = \
51                         (None, None, None)
52
53         return self.connection.cursor()
54
55     def close(self):
56         if self.connection is not None:
57             self.connection.close()
58             self.connection = None
59
60     @staticmethod
61     # From pgdb, and simplify code
62     def _quote(x):
63         if isinstance(x, DateTimeType):
64             x = str(x)
65         elif isinstance(x, unicode):
66             x = x.encode( 'utf-8' )
67     
68         if isinstance(x, types.StringType):
69             x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
70         elif isinstance(x, (types.IntType, types.LongType, types.FloatType)):
71             pass
72         elif x is None:
73             x = 'NULL'
74         elif isinstance(x, (types.ListType, types.TupleType, set)):
75             x = 'ARRAY[%s]' % ', '.join(map(lambda x: str(_quote(x)), x))
76         elif hasattr(x, '__pg_repr__'):
77             x = x.__pg_repr__()
78         else:
79             raise PLCDBError, 'Cannot quote type %s' % type(x)
80         return x
81
82
83     def quote(self, value):
84         """
85         Returns quoted version of the specified value.
86         """
87         return PostgreSQL._quote (value)
88
89 # following is an unsuccessful attempt to re-use lib code as much as possible
90 #    def quote(self, value):
91 #        # The pgdb._quote function is good enough for general SQL
92 #        # quoting, except for array types.
93 #        if isinstance (value, (types.ListType, types.TupleType, set)):
94 #            'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
95 #        else:
96 #            try:
97 #                # up to PyGreSQL-3.x, function was pgdb._quote
98 #                import pgdb
99 #                return pgdb._quote(value)
100 #            except:
101 #                # with PyGreSQL-4.x, use psycopg2's adapt
102 #                from psycopg2.extensions import adapt
103 #                return adapt (value)
104
105     @classmethod
106     def param(self, name, value):
107         # None is converted to the unquoted string NULL
108         if isinstance(value, NoneType):
109             conversion = "s"
110         # True and False are also converted to unquoted strings
111         elif isinstance(value, bool):
112             conversion = "s"
113         elif isinstance(value, float):
114             conversion = "f"
115         elif not isinstance(value, StringTypes):
116             conversion = "d"
117         else:
118             conversion = "s"
119
120         return '%(' + name + ')' + conversion
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def next_id(self, table_name, primary_key):
138         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
139         sql = "SELECT nextval('%(sequence)s')" % locals()
140         rows = self.selectall(sql, hashref = False)
141         if rows:
142             return rows[0][0]
143
144         return None
145
146     def last_insert_id(self, table_name, primary_key):
147         if isinstance(self.lastrowid, int):
148             sql = "SELECT %s FROM %s WHERE oid = %d" % \
149                   (primary_key, table_name, self.lastrowid)
150             rows = self.selectall(sql, hashref = False)
151             if rows:
152                 return rows[0][0]
153
154         return None
155
156     # modified for psycopg2-2.0.7
157     # executemany is undefined for SELECT's
158     # see http://www.python.org/dev/peps/pep-0249/
159     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
160     # or a tuple of several dicts, in which case it executemany's
161     def execute(self, query, params = None):
162
163         cursor = self.cursor()
164         try:
165
166             # psycopg2 requires %()s format for all parameters,
167             # regardless of type.
168             # this needs to be done carefully though as with pattern-based filters
169             # we might have percents embedded in the query
170             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
171             if psycopg2:
172                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
173             # rewrite wildcards set by Filter.py as '***' into '%'
174             query = query.replace ('***','%')
175
176             if not params:
177                 if self.debug:
178                     print >> log,'execute0',query
179                 cursor.execute(query)
180             elif isinstance(params,dict):
181                 if self.debug:
182                     print >> log,'execute-dict: params',params,'query',query%params
183                 cursor.execute(query,params)
184             elif isinstance(params,tuple) and len(params)==1:
185                 if self.debug:
186                     print >> log,'execute-tuple',query%params[0]
187                 cursor.execute(query,params[0])
188             else:
189                 param_seq=(params,)
190                 if self.debug:
191                     for params in param_seq:
192                         print >> log,'executemany',query%params
193                 cursor.executemany(query, param_seq)
194             (self.rowcount, self.description, self.lastrowid) = \
195                             (cursor.rowcount, cursor.description, cursor.lastrowid)
196         except Exception, e:
197             try:
198                 self.rollback()
199             except:
200                 pass
201             uuid = commands.getoutput("uuidgen")
202             print >> log, "Database error %s:" % uuid
203             print >> log, e
204             print >> log, "Query:"
205             print >> log, query
206             print >> log, "Params:"
207             print >> log, pformat(params)
208             raise PLCDBError("Please contact " + \
209                              self.api.config.PLC_NAME + " Support " + \
210                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
211                              " and reference " + uuid)
212
213         return cursor
214
215     def selectall(self, query, params = None, hashref = True, key_field = None):
216         """
217         Return each row as a dictionary keyed on field name (like DBI
218         selectrow_hashref()). If key_field is specified, return rows
219         as a dictionary keyed on the specified field (like DBI
220         selectall_hashref()).
221
222         If params is specified, the specified parameters will be bound
223         to the query.
224         """
225
226         cursor = self.execute(query, params)
227         rows = cursor.fetchall()
228         cursor.close()
229         self.commit()
230         if hashref or key_field is not None:
231             # Return each row as a dictionary keyed on field name
232             # (like DBI selectrow_hashref()).
233             labels = [column[0] for column in self.description]
234             rows = [dict(zip(labels, row)) for row in rows]
235
236         if key_field is not None and key_field in labels:
237             # Return rows as a dictionary keyed on the specified field
238             # (like DBI selectall_hashref()).
239             return dict([(row[key_field], row) for row in rows])
240         else:
241             return rows
242
243     def fields(self, table, notnull = None, hasdef = None):
244         """
245         Return the names of the fields of the specified table.
246         """
247
248         if hasattr(self, 'fields_cache'):
249             if self.fields_cache.has_key((table, notnull, hasdef)):
250                 return self.fields_cache[(table, notnull, hasdef)]
251         else:
252             self.fields_cache = {}
253
254         sql = "SELECT attname FROM pg_attribute, pg_class" \
255               " WHERE pg_class.oid = attrelid" \
256               " AND attnum > 0 AND relname = %(table)s"
257
258         if notnull is not None:
259             sql += " AND attnotnull is %(notnull)s"
260
261         if hasdef is not None:
262             sql += " AND atthasdef is %(hasdef)s"
263
264         rows = self.selectall(sql, locals(), hashref = False)
265
266         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
267
268         return self.fields_cache[(table, notnull, hasdef)]