First draft for leases
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 # $URL$
10 #
11
12 import psycopg2
13 import psycopg2.extensions
14 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
15 # UNICODEARRAY not exported yet
16 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
17
18 import pgdb
19 from types import StringTypes, NoneType
20 import traceback
21 import commands
22 import re
23 from pprint import pformat
24
25 from PLC.Debug import profile, log
26 from PLC.Faults import *
27
28 class PostgreSQL:
29     def __init__(self, api):
30         self.api = api
31         self.debug = False
32 #        self.debug = True
33         self.connection = None
34
35     def cursor(self):
36         if self.connection is None:
37             # (Re)initialize database connection
38             try:
39                 # Try UNIX socket first
40                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
41                                                    password = self.api.config.PLC_DB_PASSWORD,
42                                                    database = self.api.config.PLC_DB_NAME)
43             except psycopg2.OperationalError:
44                 # Fall back on TCP
45                 self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
46                                                    password = self.api.config.PLC_DB_PASSWORD,
47                                                    database = self.api.config.PLC_DB_NAME,
48                                                    host = self.api.config.PLC_DB_HOST,
49                                                    port = self.api.config.PLC_DB_PORT)
50             self.connection.set_client_encoding("UNICODE")
51
52         (self.rowcount, self.description, self.lastrowid) = \
53                         (None, None, None)
54
55         return self.connection.cursor()
56
57     def close(self):
58         if self.connection is not None:
59             self.connection.close()
60             self.connection = None
61
62     # join insists on getting strings
63     @classmethod
64     def quote_string(self, value):
65         return str(PostgreSQL.quote(value))
66
67     @classmethod
68     def quote(self, value):
69         """
70         Returns quoted version of the specified value.
71         """
72
73         # The pgdb._quote function is good enough for general SQL
74         # quoting, except for array types.
75         if isinstance(value, (list, tuple, set)):
76             return "ARRAY[%s]" % ", ".join(map (PostgreSQL.quote_string, value))
77         else:
78             return pgdb._quote(value)
79
80     @classmethod
81     def param(self, name, value):
82         # None is converted to the unquoted string NULL
83         if isinstance(value, NoneType):
84             conversion = "s"
85         # True and False are also converted to unquoted strings
86         elif isinstance(value, bool):
87             conversion = "s"
88         elif isinstance(value, float):
89             conversion = "f"
90         elif not isinstance(value, StringTypes):
91             conversion = "d"
92         else:
93             conversion = "s"
94
95         return '%(' + name + ')' + conversion
96
97     def begin_work(self):
98         # Implicit in pgdb.connect()
99         pass
100
101     def commit(self):
102         self.connection.commit()
103
104     def rollback(self):
105         self.connection.rollback()
106
107     def do(self, query, params = None):
108         cursor = self.execute(query, params)
109         cursor.close()
110         return self.rowcount
111
112     def next_id(self, table_name, primary_key):
113         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
114         sql = "SELECT nextval('%(sequence)s')" % locals()
115         rows = self.selectall(sql, hashref = False)
116         if rows: 
117             return rows[0][0]
118                 
119         return None 
120
121     def last_insert_id(self, table_name, primary_key):
122         if isinstance(self.lastrowid, int):
123             sql = "SELECT %s FROM %s WHERE oid = %d" % \
124                   (primary_key, table_name, self.lastrowid)
125             rows = self.selectall(sql, hashref = False)
126             if rows:
127                 return rows[0][0]
128
129         return None
130
131     # modified for psycopg2-2.0.7 
132     # executemany is undefined for SELECT's
133     # see http://www.python.org/dev/peps/pep-0249/
134     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
135     # or a tuple of several dicts, in which case it executemany's
136     def execute(self, query, params = None):
137
138         cursor = self.cursor()
139         try:
140
141             # psycopg2 requires %()s format for all parameters,
142             # regardless of type.
143             # this needs to be done carefully though as with pattern-based filters
144             # we might have percents embedded in the query
145             # so e.g. GetPersons({'email':'*fake*'}) was resulting in .. LIKE '%sake%'
146             if psycopg2:
147                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
148             # rewrite wildcards set by Filter.py as '***' into '%'
149             query = query.replace ('***','%')
150
151             if not params:
152                 if self.debug:
153                     print >> log,'execute0',query
154                 cursor.execute(query)
155             elif isinstance(params,dict):
156                 if self.debug:
157                     print >> log,'execute-dict: params',params,'query',query%params
158                 cursor.execute(query,params)
159             elif isinstance(params,tuple) and len(params)==1:
160                 if self.debug:
161                     print >> log,'execute-tuple',query%params[0]
162                 cursor.execute(query,params[0])
163             else:
164                 param_seq=(params,)
165                 if self.debug:
166                     for params in param_seq:
167                         print >> log,'executemany',query%params
168                 cursor.executemany(query, param_seq)
169             (self.rowcount, self.description, self.lastrowid) = \
170                             (cursor.rowcount, cursor.description, cursor.lastrowid)
171         except Exception, e:
172             try:
173                 self.rollback()
174             except:
175                 pass
176             uuid = commands.getoutput("uuidgen")
177             print >> log, "Database error %s:" % uuid
178             print >> log, e
179             print >> log, "Query:"
180             print >> log, query
181             print >> log, "Params:"
182             print >> log, pformat(params)
183             raise PLCDBError("Please contact " + \
184                              self.api.config.PLC_NAME + " Support " + \
185                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
186                              " and reference " + uuid)
187
188         return cursor
189
190     def selectall(self, query, params = None, hashref = True, key_field = None):
191         """
192         Return each row as a dictionary keyed on field name (like DBI
193         selectrow_hashref()). If key_field is specified, return rows
194         as a dictionary keyed on the specified field (like DBI
195         selectall_hashref()).
196
197         If params is specified, the specified parameters will be bound
198         to the query.
199         """
200
201         cursor = self.execute(query, params)
202         rows = cursor.fetchall()
203         cursor.close()
204         self.commit()
205         if hashref or key_field is not None:
206             # Return each row as a dictionary keyed on field name
207             # (like DBI selectrow_hashref()).
208             labels = [column[0] for column in self.description]
209             rows = [dict(zip(labels, row)) for row in rows]
210
211         if key_field is not None and key_field in labels:
212             # Return rows as a dictionary keyed on the specified field
213             # (like DBI selectall_hashref()).
214             return dict([(row[key_field], row) for row in rows])
215         else:
216             return rows
217
218     def fields(self, table, notnull = None, hasdef = None):
219         """
220         Return the names of the fields of the specified table.
221         """
222
223         if hasattr(self, 'fields_cache'):
224             if self.fields_cache.has_key((table, notnull, hasdef)):
225                 return self.fields_cache[(table, notnull, hasdef)]
226         else:
227             self.fields_cache = {}
228
229         sql = "SELECT attname FROM pg_attribute, pg_class" \
230               " WHERE pg_class.oid = attrelid" \
231               " AND attnum > 0 AND relname = %(table)s"
232
233         if notnull is not None:
234             sql += " AND attnotnull is %(notnull)s"
235
236         if hasdef is not None:
237             sql += " AND atthasdef is %(hasdef)s"
238
239         rows = self.selectall(sql, locals(), hashref = False)
240
241         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
242
243         return self.fields_cache[(table, notnull, hasdef)]