get rid of svn keywords once and for good
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. 
3 # Sort of like DBI(3) (Database independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8
9 import psycopg2
10 import psycopg2.extensions
11 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
12 # UNICODEARRAY not exported yet
13 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
14
15 import pgdb
16 import types
17 from types import StringTypes, NoneType
18 import traceback
19 import commands
20 import re
21 from pprint import pformat
22
23 from PLC.Debug import profile, log
24 from PLC.Faults import *
25 from datetime import datetime as DateTimeType
26
27 class PostgreSQL:
28     def __init__(self, api):
29         self.api = api
30         self.debug = False
31 #        self.debug = True
32         self.connection = None
33
34     def cursor(self):
35         if self.connection is None:
36             # (Re)initialize database connection
37             try:
38                 # Try UNIX socket first
39                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
40                                                    password = self.api.config.PLC_DB_PASSWORD,
41                                                    database = self.api.config.PLC_DB_NAME)
42             except psycopg2.OperationalError:
43                 # Fall back on TCP
44                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
45                                                    password = self.api.config.PLC_DB_PASSWORD,
46                                                    database = self.api.config.PLC_DB_NAME,
47                                                    host = self.api.config.PLC_DB_HOST,
48                                                    port = self.api.config.PLC_DB_PORT)
49             self.connection.set_client_encoding("UNICODE")
50
51         (self.rowcount, self.description, self.lastrowid) = \
52                         (None, None, None)
53
54         return self.connection.cursor()
55
56     def close(self):
57         if self.connection is not None:
58             self.connection.close()
59             self.connection = None
60
61     def quote(self, value):
62         """
63         Returns quoted version of the specified value.
64         """
65         # The pgdb._quote function is good enough for general SQL
66         # quoting, except for array types.
67         if isinstance (value, (types.ListType, types.TupleType, set)):
68             'ARRAY[%s]' % ', '.join( [ str(self.quote(x)) for x in value ] )
69         else:
70             try:
71                 # up to PyGreSQL-3.x, function was pgdb._quote
72                 return pgdb._quote(value)
73             except:
74                 # from PyGreSQL-4.0, it's a cursor method
75                 return self.cursor()._quote(value)
76
77     @classmethod
78     def param(self, name, value):
79         # None is converted to the unquoted string NULL
80         if isinstance(value, NoneType):
81             conversion = "s"
82         # True and False are also converted to unquoted strings
83         elif isinstance(value, bool):
84             conversion = "s"
85         elif isinstance(value, float):
86             conversion = "f"
87         elif not isinstance(value, StringTypes):
88             conversion = "d"
89         else:
90             conversion = "s"
91
92         return '%(' + name + ')' + conversion
93
94     def begin_work(self):
95         # Implicit in pgdb.connect()
96         pass
97
98     def commit(self):
99         self.connection.commit()
100
101     def rollback(self):
102         self.connection.rollback()
103
104     def do(self, query, params = None):
105         cursor = self.execute(query, params)
106         cursor.close()
107         return self.rowcount
108
109     def next_id(self, table_name, primary_key):
110         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()
111         sql = "SELECT nextval('%(sequence)s')" % locals()
112         rows = self.selectall(sql, hashref = False)
113         if rows:
114             return rows[0][0]
115
116         return None
117
118     def last_insert_id(self, table_name, primary_key):
119         if isinstance(self.lastrowid, int):
120             sql = "SELECT %s FROM %s WHERE oid = %d" % \
121                   (primary_key, table_name, self.lastrowid)
122             rows = self.selectall(sql, hashref = False)
123             if rows:
124                 return rows[0][0]
125
126         return None
127
128     # modified for psycopg2-2.0.7
129     # executemany is undefined for SELECT's
130     # see http://www.python.org/dev/peps/pep-0249/
131     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
132     # or a tuple of several dicts, in which case it executemany's
133     def execute(self, query, params = None):
134
135         cursor = self.cursor()
136         try:
137
138             # psycopg2 requires %()s format for all parameters,
139             # regardless of type.
140             # this needs to be done carefully though as with pattern-based filters
141             # we might have percents embedded in the query
142             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
143             if psycopg2:
144                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
145             # rewrite wildcards set by Filter.py as '***' into '%'
146             query = query.replace ('***','%')
147
148             if not params:
149                 if self.debug:
150                     print >> log,'execute0',query
151                 cursor.execute(query)
152             elif isinstance(params,dict):
153                 if self.debug:
154                     print >> log,'execute-dict: params',params,'query',query%params
155                 cursor.execute(query,params)
156             elif isinstance(params,tuple) and len(params)==1:
157                 if self.debug:
158                     print >> log,'execute-tuple',query%params[0]
159                 cursor.execute(query,params[0])
160             else:
161                 param_seq=(params,)
162                 if self.debug:
163                     for params in param_seq:
164                         print >> log,'executemany',query%params
165                 cursor.executemany(query, param_seq)
166             (self.rowcount, self.description, self.lastrowid) = \
167                             (cursor.rowcount, cursor.description, cursor.lastrowid)
168         except Exception, e:
169             try:
170                 self.rollback()
171             except:
172                 pass
173             uuid = commands.getoutput("uuidgen")
174             print >> log, "Database error %s:" % uuid
175             print >> log, e
176             print >> log, "Query:"
177             print >> log, query
178             print >> log, "Params:"
179             print >> log, pformat(params)
180             raise PLCDBError("Please contact " + \
181                              self.api.config.PLC_NAME + " Support " + \
182                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
183                              " and reference " + uuid)
184
185         return cursor
186
187     def selectall(self, query, params = None, hashref = True, key_field = None):
188         """
189         Return each row as a dictionary keyed on field name (like DBI
190         selectrow_hashref()). If key_field is specified, return rows
191         as a dictionary keyed on the specified field (like DBI
192         selectall_hashref()).
193
194         If params is specified, the specified parameters will be bound
195         to the query.
196         """
197
198         cursor = self.execute(query, params)
199         rows = cursor.fetchall()
200         cursor.close()
201         self.commit()
202         if hashref or key_field is not None:
203             # Return each row as a dictionary keyed on field name
204             # (like DBI selectrow_hashref()).
205             labels = [column[0] for column in self.description]
206             rows = [dict(zip(labels, row)) for row in rows]
207
208         if key_field is not None and key_field in labels:
209             # Return rows as a dictionary keyed on the specified field
210             # (like DBI selectall_hashref()).
211             return dict([(row[key_field], row) for row in rows])
212         else:
213             return rows
214
215     def fields(self, table, notnull = None, hasdef = None):
216         """
217         Return the names of the fields of the specified table.
218         """
219
220         if hasattr(self, 'fields_cache'):
221             if self.fields_cache.has_key((table, notnull, hasdef)):
222                 return self.fields_cache[(table, notnull, hasdef)]
223         else:
224             self.fields_cache = {}
225
226         sql = "SELECT attname FROM pg_attribute, pg_class" \
227               " WHERE pg_class.oid = attrelid" \
228               " AND attnum > 0 AND relname = %(table)s"
229
230         if notnull is not None:
231             sql += " AND attnotnull is %(notnull)s"
232
233         if hasdef is not None:
234             sql += " AND atthasdef is %(hasdef)s"
235
236         rows = self.selectall(sql, locals(), hashref = False)
237
238         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
239
240         return self.fields_cache[(table, notnull, hasdef)]