StringTypes has gone
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface.
3 # Sort of like DBI(3) (Database independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8
9 import psycopg2
10 import psycopg2.extensions
11 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
12 # UNICODEARRAY not exported yet
13 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
14
15 import types
16 import traceback
17 import subprocess
18 import re
19 from pprint import pformat
20
21 from PLC.Logger import logger
22 from PLC.Debug import profile
23 from PLC.Faults import *
24 from datetime import datetime as DateTimeType
25
26 class PostgreSQL:
27     def __init__(self, api):
28         self.api = api
29         self.debug = False
30 #        self.debug = True
31         self.connection = None
32
33     def cursor(self):
34         if self.connection is None:
35             # (Re)initialize database connection
36             try:
37                 # Try UNIX socket first
38                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
39                                                    password = self.api.config.PLC_DB_PASSWORD,
40                                                    database = self.api.config.PLC_DB_NAME)
41             except psycopg2.OperationalError:
42                 # Fall back on TCP
43                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
44                                                    password = self.api.config.PLC_DB_PASSWORD,
45                                                    database = self.api.config.PLC_DB_NAME,
46                                                    host = self.api.config.PLC_DB_HOST,
47                                                    port = self.api.config.PLC_DB_PORT)
48             self.connection.set_client_encoding("UNICODE")
49
50         (self.rowcount, self.description, self.lastrowid) = \
51                         (None, None, None)
52
53         return self.connection.cursor()
54
55     def close(self):
56         if self.connection is not None:
57             self.connection.close()
58             self.connection = None
59
60     @staticmethod
61     # From pgdb, and simplify code
62     def _quote(x):
63         if isinstance(x, DateTimeType):
64             x = str(x)
65         elif isinstance(x, str):
66             x = x.encode( 'utf-8' )
67
68         if isinstance(x, bytes):
69             x = "'%s'" % str(x).replace("\\", "\\\\").replace("'", "''")
70         elif isinstance(x, (int, float)):
71             pass
72         elif x is None:
73             x = 'NULL'
74         elif isinstance(x, (list, tuple, set)):
75             x = 'ARRAY[%s]' % ', '.join([str(_quote(x)) for x in x])
76         elif hasattr(x, '__pg_repr__'):
77             x = x.__pg_repr__()
78         else:
79             raise PLCDBError('Cannot quote type %s' % type(x))
80         return x
81
82
83     def quote(self, value):
84         """
85         Returns quoted version of the specified value.
86         """
87         return PostgreSQL._quote (value)
88
89 # following is an unsuccessful attempt to re-use lib code as much as possible
90 #    def quote(self, value):
91 #        # The pgdb._quote function is good enough for general SQL
92 #        # quoting, except for array types.
93 #        if isinstance (value, (types.ListType, types.TupleType, set)):
94 #            'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
95 #        else:
96 #            try:
97 #                # up to PyGreSQL-3.x, function was pgdb._quote
98 #                import pgdb
99 #                return pgdb._quote(value)
100 #            except:
101 #                # with PyGreSQL-4.x, use psycopg2's adapt
102 #                from psycopg2.extensions import adapt
103 #                return adapt (value)
104
105     @classmethod
106     def param(self, name, value):
107         # None is converted to the unquoted string NULL
108         if isinstance(value, type(None)):
109             conversion = "s"
110         # True and False are also converted to unquoted strings
111         elif isinstance(value, bool):
112             conversion = "s"
113         elif isinstance(value, float):
114             conversion = "f"
115         elif not isinstance(value, str):
116             conversion = "d"
117         else:
118             conversion = "s"
119
120         return '%(' + name + ')' + conversion
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def next_id(self, table_name, primary_key):
138         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
139         sql = "SELECT nextval('%(sequence)s')" % locals()
140         rows = self.selectall(sql, hashref = False)
141         if rows:
142             return rows[0][0]
143
144         return None
145
146     def last_insert_id(self, table_name, primary_key):
147         if isinstance(self.lastrowid, int):
148             sql = "SELECT %s FROM %s WHERE oid = %d" % \
149                   (primary_key, table_name, self.lastrowid)
150             rows = self.selectall(sql, hashref = False)
151             if rows:
152                 return rows[0][0]
153
154         return None
155
156     # modified for psycopg2-2.0.7
157     # executemany is undefined for SELECT's
158     # see http://www.python.org/dev/peps/pep-0249/
159     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
160     # or a tuple of several dicts, in which case it executemany's
161     def execute(self, query, params = None):
162
163         cursor = self.cursor()
164         try:
165
166             # psycopg2 requires %()s format for all parameters,
167             # regardless of type.
168             # this needs to be done carefully though as with pattern-based filters
169             # we might have percents embedded in the query
170             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
171             if psycopg2:
172                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
173             # rewrite wildcards set by Filter.py as '***' into '%'
174             query = query.replace ('***','%')
175
176             if not params:
177                 if self.debug:
178                     logger.debug('execute0: {}'.format(query))
179                 cursor.execute(query)
180             elif isinstance(params, dict):
181                 if self.debug:
182                     logger.debug('execute-dict: params {} query {}'
183                                  .format(params, query%params))
184                 cursor.execute(query, params)
185             elif isinstance(params,tuple) and len(params)==1:
186                 if self.debug:
187                     logger.debug('execute-tuple {}'.format(query%params[0]))
188                 cursor.execute(query,params[0])
189             else:
190                 param_seq=(params,)
191                 if self.debug:
192                     for params in param_seq:
193                         logger.debug('executemany {}'.format(query%params))
194                 cursor.executemany(query, param_seq)
195             (self.rowcount, self.description, self.lastrowid) = \
196                             (cursor.rowcount, cursor.description, cursor.lastrowid)
197         except Exception as e:
198             try:
199                 self.rollback()
200             except:
201                 pass
202             uuid = subprocess.getoutput("uuidgen")
203             message = "Database error {}: - Query {} - Params {}".format(uuid, query, pformat(params))
204             logger.exception(message)
205             raise PLCDBError("Please contact " + \
206                              self.api.config.PLC_NAME + " Support " + \
207                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
208                              " and reference " + uuid)
209
210         return cursor
211
212     def selectall(self, query, params = None, hashref = True, key_field = None):
213         """
214         Return each row as a dictionary keyed on field name (like DBI
215         selectrow_hashref()). If key_field is specified, return rows
216         as a dictionary keyed on the specified field (like DBI
217         selectall_hashref()).
218
219         If params is specified, the specified parameters will be bound
220         to the query.
221         """
222
223         cursor = self.execute(query, params)
224         rows = cursor.fetchall()
225         cursor.close()
226         self.commit()
227         if hashref or key_field is not None:
228             # Return each row as a dictionary keyed on field name
229             # (like DBI selectrow_hashref()).
230             labels = [column[0] for column in self.description]
231             rows = [dict(list(zip(labels, row))) for row in rows]
232
233         if key_field is not None and key_field in labels:
234             # Return rows as a dictionary keyed on the specified field
235             # (like DBI selectall_hashref()).
236             return dict([(row[key_field], row) for row in rows])
237         else:
238             return rows
239
240     def fields(self, table, notnull = None, hasdef = None):
241         """
242         Return the names of the fields of the specified table.
243         """
244
245         if hasattr(self, 'fields_cache'):
246             if (table, notnull, hasdef) in self.fields_cache:
247                 return self.fields_cache[(table, notnull, hasdef)]
248         else:
249             self.fields_cache = {}
250
251         sql = "SELECT attname FROM pg_attribute, pg_class" \
252               " WHERE pg_class.oid = attrelid" \
253               " AND attnum > 0 AND relname = %(table)s"
254
255         if notnull is not None:
256             sql += " AND attnotnull is %(notnull)s"
257
258         if hasdef is not None:
259             sql += " AND atthasdef is %(hasdef)s"
260
261         rows = self.selectall(sql, locals(), hashref = False)
262
263         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
264
265         return self.fields_cache[(table, notnull, hasdef)]