remove PLC.Debug.log, use PLC.Logger.logger instead
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. 
3 # Sort of like DBI(3) (Database independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8
9 import psycopg2
10 import psycopg2.extensions
11 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
12 # UNICODEARRAY not exported yet
13 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
14
15 import types
16 from types import StringTypes, NoneType
17 import traceback
18 import commands
19 import re
20 from pprint import pformat
21
22 from PLC.Logger import logger
23 from PLC.Debug import profile
24 from PLC.Faults import *
25 from datetime import datetime as DateTimeType
26
27 class PostgreSQL:
28     def __init__(self, api):
29         self.api = api
30         self.debug = False
31 #        self.debug = True
32         self.connection = None
33
34     def cursor(self):
35         if self.connection is None:
36             # (Re)initialize database connection
37             try:
38                 # Try UNIX socket first
39                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
40                                                    password = self.api.config.PLC_DB_PASSWORD,
41                                                    database = self.api.config.PLC_DB_NAME)
42             except psycopg2.OperationalError:
43                 # Fall back on TCP
44                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
45                                                    password = self.api.config.PLC_DB_PASSWORD,
46                                                    database = self.api.config.PLC_DB_NAME,
47                                                    host = self.api.config.PLC_DB_HOST,
48                                                    port = self.api.config.PLC_DB_PORT)
49             self.connection.set_client_encoding("UNICODE")
50
51         (self.rowcount, self.description, self.lastrowid) = \
52                         (None, None, None)
53
54         return self.connection.cursor()
55
56     def close(self):
57         if self.connection is not None:
58             self.connection.close()
59             self.connection = None
60
61     @staticmethod
62     # From pgdb, and simplify code
63     def _quote(x):
64         if isinstance(x, DateTimeType):
65             x = str(x)
66         elif isinstance(x, unicode):
67             x = x.encode( 'utf-8' )
68     
69         if isinstance(x, types.StringType):
70             x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
71         elif isinstance(x, (types.IntType, types.LongType, types.FloatType)):
72             pass
73         elif x is None:
74             x = 'NULL'
75         elif isinstance(x, (types.ListType, types.TupleType, set)):
76             x = 'ARRAY[%s]' % ', '.join(map(lambda x: str(_quote(x)), x))
77         elif hasattr(x, '__pg_repr__'):
78             x = x.__pg_repr__()
79         else:
80             raise PLCDBError, 'Cannot quote type %s' % type(x)
81         return x
82
83
84     def quote(self, value):
85         """
86         Returns quoted version of the specified value.
87         """
88         return PostgreSQL._quote (value)
89
90 # following is an unsuccessful attempt to re-use lib code as much as possible
91 #    def quote(self, value):
92 #        # The pgdb._quote function is good enough for general SQL
93 #        # quoting, except for array types.
94 #        if isinstance (value, (types.ListType, types.TupleType, set)):
95 #            'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
96 #        else:
97 #            try:
98 #                # up to PyGreSQL-3.x, function was pgdb._quote
99 #                import pgdb
100 #                return pgdb._quote(value)
101 #            except:
102 #                # with PyGreSQL-4.x, use psycopg2's adapt
103 #                from psycopg2.extensions import adapt
104 #                return adapt (value)
105
106     @classmethod
107     def param(self, name, value):
108         # None is converted to the unquoted string NULL
109         if isinstance(value, NoneType):
110             conversion = "s"
111         # True and False are also converted to unquoted strings
112         elif isinstance(value, bool):
113             conversion = "s"
114         elif isinstance(value, float):
115             conversion = "f"
116         elif not isinstance(value, StringTypes):
117             conversion = "d"
118         else:
119             conversion = "s"
120
121         return '%(' + name + ')' + conversion
122
123     def begin_work(self):
124         # Implicit in pgdb.connect()
125         pass
126
127     def commit(self):
128         self.connection.commit()
129
130     def rollback(self):
131         self.connection.rollback()
132
133     def do(self, query, params = None):
134         cursor = self.execute(query, params)
135         cursor.close()
136         return self.rowcount
137
138     def next_id(self, table_name, primary_key):
139         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
140         sql = "SELECT nextval('%(sequence)s')" % locals()
141         rows = self.selectall(sql, hashref = False)
142         if rows:
143             return rows[0][0]
144
145         return None
146
147     def last_insert_id(self, table_name, primary_key):
148         if isinstance(self.lastrowid, int):
149             sql = "SELECT %s FROM %s WHERE oid = %d" % \
150                   (primary_key, table_name, self.lastrowid)
151             rows = self.selectall(sql, hashref = False)
152             if rows:
153                 return rows[0][0]
154
155         return None
156
157     # modified for psycopg2-2.0.7
158     # executemany is undefined for SELECT's
159     # see http://www.python.org/dev/peps/pep-0249/
160     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
161     # or a tuple of several dicts, in which case it executemany's
162     def execute(self, query, params = None):
163
164         cursor = self.cursor()
165         try:
166
167             # psycopg2 requires %()s format for all parameters,
168             # regardless of type.
169             # this needs to be done carefully though as with pattern-based filters
170             # we might have percents embedded in the query
171             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
172             if psycopg2:
173                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
174             # rewrite wildcards set by Filter.py as '***' into '%'
175             query = query.replace ('***','%')
176
177             if not params:
178                 if self.debug:
179                     logger.debug('execute0: {}'.format(query))
180                 cursor.execute(query)
181             elif isinstance(params, dict):
182                 if self.debug:
183                     logger.debug('execute-dict: params {} query {}'
184                                  .format(params, query%params))
185                 cursor.execute(query, params)
186             elif isinstance(params,tuple) and len(params)==1:
187                 if self.debug:
188                     logger.debug('execute-tuple {}'.format(query%params[0]))
189                 cursor.execute(query,params[0])
190             else:
191                 param_seq=(params,)
192                 if self.debug:
193                     for params in param_seq:
194                         logger.debug('executemany {}'.format(query%params))
195                 cursor.executemany(query, param_seq)
196             (self.rowcount, self.description, self.lastrowid) = \
197                             (cursor.rowcount, cursor.description, cursor.lastrowid)
198         except Exception, e:
199             try:
200                 self.rollback()
201             except:
202                 pass
203             uuid = commands.getoutput("uuidgen")
204             message = "Database error {}: - Query {} - Params {}".format(uuid, query, pformat(params))
205             logger.exception(message)
206             raise PLCDBError("Please contact " + \
207                              self.api.config.PLC_NAME + " Support " + \
208                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
209                              " and reference " + uuid)
210
211         return cursor
212
213     def selectall(self, query, params = None, hashref = True, key_field = None):
214         """
215         Return each row as a dictionary keyed on field name (like DBI
216         selectrow_hashref()). If key_field is specified, return rows
217         as a dictionary keyed on the specified field (like DBI
218         selectall_hashref()).
219
220         If params is specified, the specified parameters will be bound
221         to the query.
222         """
223
224         cursor = self.execute(query, params)
225         rows = cursor.fetchall()
226         cursor.close()
227         self.commit()
228         if hashref or key_field is not None:
229             # Return each row as a dictionary keyed on field name
230             # (like DBI selectrow_hashref()).
231             labels = [column[0] for column in self.description]
232             rows = [dict(zip(labels, row)) for row in rows]
233
234         if key_field is not None and key_field in labels:
235             # Return rows as a dictionary keyed on the specified field
236             # (like DBI selectall_hashref()).
237             return dict([(row[key_field], row) for row in rows])
238         else:
239             return rows
240
241     def fields(self, table, notnull = None, hasdef = None):
242         """
243         Return the names of the fields of the specified table.
244         """
245
246         if hasattr(self, 'fields_cache'):
247             if self.fields_cache.has_key((table, notnull, hasdef)):
248                 return self.fields_cache[(table, notnull, hasdef)]
249         else:
250             self.fields_cache = {}
251
252         sql = "SELECT attname FROM pg_attribute, pg_class" \
253               " WHERE pg_class.oid = attrelid" \
254               " AND attnum > 0 AND relname = %(table)s"
255
256         if notnull is not None:
257             sql += " AND attnotnull is %(notnull)s"
258
259         if hasdef is not None:
260             sql += " AND atthasdef is %(hasdef)s"
261
262         rows = self.selectall(sql, locals(), hashref = False)
263
264         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
265
266         return self.fields_cache[(table, notnull, hasdef)]