Merge remote-tracking branch 'origin/pycurl' into planetlab-4_0-branch
[plcapi.git] / PLC / PostgreSQL.py
1 #
2 # PostgreSQL database interface. Sort of like DBI(3) (Database
3 # independent interface for Perl).
4 #
5 # Mark Huang <mlhuang@cs.princeton.edu>
6 # Copyright (C) 2006 The Trustees of Princeton University
7 #
8 # $Id: PostgreSQL.py 5574 2007-10-25 20:33:17Z thierry $
9 #
10
11 import psycopg2
12 import psycopg2.extensions
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14 # UNICODEARRAY not exported yet
15 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
16
17 import pgdb
18 from types import StringTypes, NoneType
19 import traceback
20 import commands
21 import re
22 from pprint import pformat
23
24 from PLC.Debug import profile, log
25 from PLC.Faults import *
26
27 if not psycopg2:
28     is8bit = re.compile("[\x80-\xff]").search
29
30     def unicast(typecast):
31         """
32         pgdb returns raw UTF-8 strings. This function casts strings that
33         appear to contain non-ASCII characters to unicode objects.
34         """
35     
36         def wrapper(*args, **kwds):
37             value = typecast(*args, **kwds)
38
39             # pgdb always encodes unicode objects as UTF-8 regardless of
40             # the DB encoding (and gives you no option for overriding
41             # the encoding), so always decode 8-bit objects as UTF-8.
42             if isinstance(value, str) and is8bit(value):
43                 value = unicode(value, "utf-8")
44
45             return value
46
47         return wrapper
48
49     pgdb.pgdbTypeCache.typecast = unicast(pgdb.pgdbTypeCache.typecast)
50
51 class PostgreSQL:
52     def __init__(self, api):
53         self.api = api
54         self.debug = False
55         self.connection = None
56
57     def cursor(self):
58         if self.connection is None:
59             # (Re)initialize database connection
60             if psycopg2:
61                 try:
62                     # Try UNIX socket first
63                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
64                                                        password = self.api.config.PLC_DB_PASSWORD,
65                                                        database = self.api.config.PLC_DB_NAME)
66                 except psycopg2.OperationalError:
67                     # Fall back on TCP
68                     self.connection = psycopg2.connect(user = self.api.config.PLC_DB_USER,
69                                                        password = self.api.config.PLC_DB_PASSWORD,
70                                                        database = self.api.config.PLC_DB_NAME,
71                                                        host = self.api.config.PLC_DB_HOST,
72                                                        port = self.api.config.PLC_DB_PORT)
73                 self.connection.set_client_encoding("UNICODE")
74             else:
75                 self.connection = pgdb.connect(user = self.api.config.PLC_DB_USER,
76                                                password = self.api.config.PLC_DB_PASSWORD,
77                                                host = "%s:%d" % (api.config.PLC_DB_HOST, api.config.PLC_DB_PORT),
78                                                database = self.api.config.PLC_DB_NAME)
79
80         (self.rowcount, self.description, self.lastrowid) = \
81                         (None, None, None)
82
83         return self.connection.cursor()
84
85     def close(self):
86         if self.connection is not None:
87             self.connection.close()
88             self.connection = None
89
90     def quote(self, value):
91         """
92         Returns quoted version of the specified value.
93         """
94
95         # The pgdb._quote function is good enough for general SQL
96         # quoting, except for array types.
97         if isinstance(value, (list, tuple, set)):
98             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
99         else:
100             return pgdb._quote(value)
101
102     quote = classmethod(quote)
103
104     def param(self, name, value):
105         # None is converted to the unquoted string NULL
106         if isinstance(value, NoneType):
107             conversion = "s"
108         # True and False are also converted to unquoted strings
109         elif isinstance(value, bool):
110             conversion = "s"
111         elif isinstance(value, float):
112             conversion = "f"
113         elif not isinstance(value, StringTypes):
114             conversion = "d"
115         else:
116             conversion = "s"
117
118         return '%(' + name + ')' + conversion
119
120     param = classmethod(param)
121
122     def begin_work(self):
123         # Implicit in pgdb.connect()
124         pass
125
126     def commit(self):
127         self.connection.commit()
128
129     def rollback(self):
130         self.connection.rollback()
131
132     def do(self, query, params = None):
133         cursor = self.execute(query, params)
134         cursor.close()
135         return self.rowcount
136
137     def last_insert_id(self, table_name, primary_key):
138         if isinstance(self.lastrowid, int):
139             sql = "SELECT %s FROM %s WHERE oid = %d" % \
140                   (primary_key, table_name, self.lastrowid)
141             rows = self.selectall(sql, hashref = False)
142             if rows:
143                 return rows[0][0]
144
145         return None
146
147     def execute(self, query, params = None):
148         return self.execute_array(query, (params,))
149
150     def execute_array(self, query, param_seq):
151         cursor = self.cursor()
152         try:
153             if self.debug:
154                 for params in param_seq:
155                     if params:
156                         print >> log, query % params
157                     else:
158                         print >> log, query
159
160             # psycopg2 requires %()s format for all parameters,
161             # regardless of type.
162             if psycopg2:
163                 query = re.sub(r'(%\([^)]*\)|%)[df]', r'\1s', query)
164
165             cursor.executemany(query, param_seq)
166             (self.rowcount, self.description, self.lastrowid) = \
167                             (cursor.rowcount, cursor.description, cursor.lastrowid)
168         except Exception, e:
169             try:
170                 self.rollback()
171             except:
172                 pass
173             uuid = commands.getoutput("uuidgen")
174             print >> log, "Database error %s:" % uuid
175             print >> log, e
176             print >> log, "Query:"
177             print >> log, query
178             print >> log, "Params:"
179             print >> log, pformat(param_seq[0])
180             raise PLCDBError("Please contact " + \
181                              self.api.config.PLC_NAME + " Support " + \
182                              "<" + self.api.config.PLC_MAIL_SUPPORT_ADDRESS + ">" + \
183                              " and reference " + uuid)
184
185         return cursor
186
187     def selectall(self, query, params = None, hashref = True, key_field = None):
188         """
189         Return each row as a dictionary keyed on field name (like DBI
190         selectrow_hashref()). If key_field is specified, return rows
191         as a dictionary keyed on the specified field (like DBI
192         selectall_hashref()).
193
194         If params is specified, the specified parameters will be bound
195         to the query.
196         """
197
198         cursor = self.execute(query, params)
199         rows = cursor.fetchall()
200         cursor.close()
201
202         if hashref or key_field is not None:
203             # Return each row as a dictionary keyed on field name
204             # (like DBI selectrow_hashref()).
205             labels = [column[0] for column in self.description]
206             rows = [dict(zip(labels, row)) for row in rows]
207
208         if key_field is not None and key_field in labels:
209             # Return rows as a dictionary keyed on the specified field
210             # (like DBI selectall_hashref()).
211             return dict([(row[key_field], row) for row in rows])
212         else:
213             return rows
214
215     def fields(self, table, notnull = None, hasdef = None):
216         """
217         Return the names of the fields of the specified table.
218         """
219
220         if hasattr(self, 'fields_cache'):
221             if self.fields_cache.has_key((table, notnull, hasdef)):
222                 return self.fields_cache[(table, notnull, hasdef)]
223         else:
224             self.fields_cache = {}
225
226         sql = "SELECT attname FROM pg_attribute, pg_class" \
227               " WHERE pg_class.oid = attrelid" \
228               " AND attnum > 0 AND relname = %(table)s"
229
230         if notnull is not None:
231             sql += " AND attnotnull is %(notnull)s"
232
233         if hasdef is not None:
234             sql += " AND atthasdef is %(hasdef)s"
235
236         rows = self.selectall(sql, locals(), hashref = False)
237
238         self.fields_cache[(table, notnull, hasdef)] = [row[0] for row in rows]
239
240         return self.fields_cache[(table, notnull, hasdef)]