Interfaces can handle tags through Add/Get/Update
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id$
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55 #        self.debug = True
56         self.connection = None
57
58     def cursor(self):
59         if self.connection is None:
60             # (Re)initialize database connection
61             if psycopg2:
62                 try:
63                     # Try UNIX socket first
64                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
65                                                        password = self.api.config.PLC_DB_PASSWORD,
66                                                        database = self.api.config.PLC_DB_NAME)
67                 except psycopg2.OperationalError:
68                     # Fall back on TCP
69                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
70                                                        password = self.api.config.PLC_DB_PASSWORD,
71                                                        database = self.api.config.PLC_DB_NAME,
72                                                        host = self.api.config.PLC_DB_HOST,
73                                                        port = self.api.config.PLC_DB_PORT)
74                 self.connection.set_client_encoding("UNICODE")
75             else:
76                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
77                                                password = self.api.config.PLC_DB_PASSWORD,
78                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
79                                                database = self.api.config.PLC_DB_NAME)
80
81         (self.rowcount, self.description, self.lastrowid) = \
82                         (None, None, None)
83
84         return self.connection.cursor()
85
86     def close(self):
87         if self.connection is not None:
88             self.connection.close()
89             self.connection = None
90
91     def quote(self, value):
92         """
93         Returns quoted version of the specified value.
94         """
95
96         # The pgdb._quote function is good enough for general SQL
97         # quoting, except for array types.
98         if isinstance(value, (list, tuple, set)):
99             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
100         else:
101             return pgdb._quote(value)
102
103     quote = classmethod(quote)
104
105     def param(self, name, value):
106         # None is converted to the unquoted string NULL
107         if isinstance(value, NoneType):
108             conversion = "s"
109         # True and False are also converted to unquoted strings
110         elif isinstance(value, bool):
111             conversion = "s"
112         elif isinstance(value, float):
113             conversion = "f"
114         elif not isinstance(value, StringTypes):
115             conversion = "d"
116         else:
117             conversion = "s"
118
119         return '%(' + name + ')' + conversion
120
121     param = classmethod(param)
122
123     def begin_work(self):
124         # Implicit in pgdb.connect()
125         pass
126
127     def commit(self):
128         self.connection.commit()
129
130     def rollback(self):
131         self.connection.rollback()
132
133     def do(self, query, params = None):
134         cursor = self.execute(query, params)
135         cursor.close()
136         return self.rowcount
137
138     def next_id(self, table_name, primary_key):
139         sequence = "%(table_name)s_%(primary_key)s_seq" % locals()      
140         sql = "SELECT nextval('%(sequence)s')" % locals()
141         rows = self.selectall(sql, hashref = False)
142         if rows: 
143             return rows[0][0]
144                 
145         return None 
146
147     def last_insert_id(self, table_name, primary_key):
148         if isinstance(self.lastrowid, int):
149             sql = "SELECT %s FROM %s WHERE oid = %d" % \
150                   (primary_key, table_name, self.lastrowid)
151             rows = self.selectall(sql, hashref = False)
152             if rows:
153                 return rows[0][0]
154
155         return None
156
157     # modified for psycopg2-2.0.7 
158     # executemany is undefined for SELECT's
159     # see http://www.python.org/dev/peps/pep-0249/
160     # accepts either None, a single dict, a tuple of single dict - in which case it execute's
161     # or a tuple of several dicts, in which case it executemany's
162     def execute(self, query, params = None):
163
164         cursor = self.cursor()
165         try:
166
167             # psycopg2 requires %()s format for all parameters,
168             # regardless of type.
169             if psycopg2:
170                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
171
172             if not params:
173                 if self.debug:
174                     print >> log,'execute0',query
175                 cursor.execute(query)
176             elif isinstance(params,dict):
177                 if self.debug:
178                     print >> log,'execute-dict: params',params,'query',query%params
179                 cursor.execute(query,params)
180             elif isinstance(params,tuple) and len(params)==1:
181                 if self.debug:
182                     print >> log,'execute-tuple',query%params[0]
183                 cursor.execute(query,params[0])
184             else:
185                 param_seq=(params,)
186                 if self.debug:
187                     for params in param_seq:
188                         print >> log,'executemany',query%params
189                 cursor.executemany(query, param_seq)
190             (self.rowcount, self.description, self.lastrowid) = \
191                             (cursor.rowcount, cursor.description, cursor.lastrowid)
192         except Exception, e:
193             try:
194                 self.rollback()
195             except:
196                 pass
197             uuid = commands.getoutput("uuidgen")
198             print >> log, "Database error %s:" % uuid
199             print >> log, e
200             print >> log, "Query:"
201             print >> log, query
202             print >> log, "Params:"
203             print >> log, pformat(params)
204             raise PLCDBError("Please contact " + \
205                              self.api.config.PLC_NAME + " Support " + \
206                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
207                              " and reference " + uuid)
208
209         return cursor
210
211     def selectall(self, query, params = None, hashref = True, key_field = None):
212         """
213         Return each row as a dictionary keyed on field name (like DBI
214         selectrow_hashref()). If key_field is specified, return rows
215         as a dictionary keyed on the specified field (like DBI
216         selectall_hashref()).
217
218         If params is specified, the specified parameters will be bound
219         to the query.
220         """
221
222         cursor = self.execute(query, params)
223         rows = cursor.fetchall()
224         cursor.close()
225
226         if hashref or key_field is not None:
227             # Return each row as a dictionary keyed on field name
228             # (like DBI selectrow_hashref()).
229             labels = [column[0] for column in self.description]
230             rows = [dict(zip(labels, row)) for row in rows]
231
232         if key_field is not None and key_field in labels:
233             # Return rows as a dictionary keyed on the specified field
234             # (like DBI selectall_hashref()).
235             return dict([(row[key_field], row) for row in rows])
236         else:
237             return rows
238
239     def fields(self, table, notnull = None, hasdef = None):
240         """
241         Return the names of the fields of the specified table.
242         """
243
244         if hasattr(self, 'fields_cache'):
245             if self.fields_cache.has_key((table, notnull, hasdef)):
246                 return self.fields_cache[(table, notnull, hasdef)]
247         else:
248             self.fields_cache = {}
249
250         sql = "SELECT attname FROM pg_attribute, pg_class" \
251               " WHERE pg_class.oid = attrelid" \
252               " AND attnum > 0 AND relname = %(table)s"
253
254         if notnull is not None:
255             sql += " AND attnotnull is %(notnull)s"
256
257         if hasdef is not None:
258             sql += " AND atthasdef is %(hasdef)s"
259
260         rows = self.selectall(sql, locals(), hashref = False)
261
262         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
263
264         return self.fields_cache[(table, notnull, hasdef)]