Merge remote-tracking branch 'origin/pycurl' into planetlab-4_0-branch
[plcapi.git] / trunk / psycopg2 / lib / extras.py
1 """Miscellaneous goodies for psycopg2
2
3 This module is a generic place used to hold little helper functions
4 and classes untill a better place in the distribution is found.
5 """
6 # psycopg/extras.py - miscellaneous extra goodies for psycopg
7 #
8 # Copyright (C) 2003-2004 Federico Di Gregorio  <fog@debian.org>
9 #
10 # This program is free software; you can redistribute it and/or modify
11 # it under the terms of the GNU General Public License as published by the
12 # Free Software Foundation; either version 2, or (at your option) any later
13 # version.
14 #
15 # This program is distributed in the hope that it will be useful, but
16 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTIBILITY
17 # or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
18 # for more details.
19
20 import os
21 import time
22
23 try:
24     import logging
25 except:
26     logging = None
27     
28 from psycopg2.extensions import cursor as _cursor
29 from psycopg2.extensions import connection as _connection
30 from psycopg2.extensions import register_adapter as _RA
31 from psycopg2.extensions import adapt as _A
32
33
34 class DictConnection(_connection):
35     """A connection that uses DictCursor automatically."""
36     def cursor(self):
37         return _connection.cursor(self, cursor_factory=DictCursor)
38
39 class DictCursor(_cursor):
40     """A cursor that keeps a list of column name -> index mappings."""
41
42     __query_executed = 0
43     
44     def execute(self, query, vars=None, async=0):
45         self.row_factory = DictRow
46         self.index = {}
47         self.__query_executed = 1
48         return _cursor.execute(self, query, vars, async)
49     
50     def callproc(self, procname, vars=None):
51         self.row_factory = DictRow
52         self.index = {}
53         self.__query_executed = 1
54         return _cursor.callproc(self, procname, vars)   
55
56     def _build_index(self):
57         if self.__query_executed == 1 and self.description:
58             for i in range(len(self.description)):
59                 self.index[self.description[i][0]] = i
60             self.__query_executed = 0
61             
62     def fetchone(self):
63         res = _cursor.fetchone(self)
64         if self.__query_executed:
65             self._build_index()
66         return res
67
68     def fetchmany(self, size=None):
69         res = _cursor.fetchmany(self, size)
70         if self.__query_executed:
71             self._build_index()
72         return res
73
74     def fetchall(self):
75         res = _cursor.fetchall(self)
76         if self.__query_executed:
77             self._build_index()
78         return res
79     
80     def next(self):
81         res = _cursor.fetchone(self)
82         if res is None:
83             raise StopIteration()
84         if self.__query_executed:
85             self._build_index()
86         return res
87
88 class DictRow(list):
89     """A row object that allow by-colun-name access to data."""
90
91     def __init__(self, cursor):
92         self._index = cursor.index
93         self[:] = [None] * len(cursor.description)
94
95     def __getitem__(self, x):
96         if type(x) != int:
97             x = self._index[x]
98         return list.__getitem__(self, x)
99
100     def items(self):
101         res = []
102         for n, v in self._index.items():
103             res.append((n, list.__getitem__(self, v)))
104         return res
105     
106     def keys(self):
107         return self._index.keys()
108
109     def values(self):
110         return tuple(self[:])
111
112     def has_key(self, x):
113         return self._index.has_key(x)
114
115     def get(self, x, default=None):
116         try:
117             return self[x]
118         except:
119             return default
120
121
122 class SQL_IN(object):
123     """Adapt any iterable to an SQL quotable object."""
124     
125     def __init__(self, seq):
126         self._seq = seq
127
128     def prepare(self, conn):
129         self._conn = conn
130     
131     def getquoted(self):
132         # this is the important line: note how every object in the
133         # list is adapted and then how getquoted() is called on it
134         pobjs = [_A(o) for o in self._seq]
135         for obj in pobjs:
136             if hasattr(obj, 'prepare'):
137                 obj.prepare(self._conn)
138         qobjs = [str(o.getquoted()) for o in pobjs]
139         return '(' + ', '.join(qobjs) + ')'
140
141     __str__ = getquoted
142     
143 _RA(tuple, SQL_IN)
144
145     
146 class LoggingConnection(_connection):
147     """A connection that logs all queries to a file or logger object."""
148
149     def initialize(self, logobj):
150         """Initialize the connection to log to `logobj`.
151         
152         The `logobj` parameter can be an open file object or a Logger instance
153         from the standard logging module.
154         """
155         self._logobj = logobj
156         if logging and isinstance(logobj, logging.Logger):
157             self.log = self._logtologger
158         else:
159             self.log = self._logtofile
160     
161     def filter(self, msg, curs):
162         """Filter the query before logging it.
163         
164         This is the method to overwrite to filter unwanted queries out of the
165         log or to add some extra data to the output. The default implementation
166         just does nothing.
167         """
168         return msg
169     
170     def _logtofile(self, msg, curs):
171         msg = self.filter(msg, curs)
172         if msg: self._logobj.write(msg + os.linesep)
173         
174     def _logtologger(self, msg, curs):
175         msg = self.filter(msg, curs)
176         if msg: self._logobj.debug(msg)
177     
178     def _check(self):
179         if not hasattr(self, '_logobj'):
180             raise self.ProgrammingError(
181                 "LoggingConnection object has not been initialize()d")
182             
183     def cursor(self):
184         self._check()
185         return _connection.cursor(self, cursor_factory=LoggingCursor)
186     
187 class LoggingCursor(_cursor):
188     """A cursor that logs queries using its connection logging facilities."""
189
190     def execute(self, query, vars=None, async=0):
191         try:
192             return _cursor.execute(self, query, vars, async)
193         finally:
194             self.connection.log(self.query, self)
195
196     def callproc(self, procname, vars=None):
197         try:
198             return _cursor.callproc(self, procname, vars)  
199         finally:
200             self.connection.log(self.query, self)
201
202             
203 class MinTimeLoggingConnection(LoggingConnection):
204     """A connection that logs queries based on execution time.
205     
206     This is just an example of how to sub-class LoggingConnection to provide
207     some extra filtering for the logged queries. Both the `.inizialize()` and
208     `.filter()` methods are overwritten to make sure that only queries
209     executing for more than `mintime` ms are logged.
210     
211     Note that this connection uses the specialized cursor MinTimeLoggingCursor.
212     """
213     def initialize(self, logobj, mintime=0):
214         LoggingConnection.initialize(self, logobj)
215         self._mintime = mintime
216         
217     def filter(self, msg, curs):
218         t = (time.time() - curs.timestamp) * 1000
219         if t > self._mintime:
220             return msg + os.linesep + "  (execution time: %d ms)" % t
221
222     def cursor(self):
223         self._check()
224         return _connection.cursor(self, cursor_factory=MinTimeLoggingCursor)
225     
226 class MinTimeLoggingCursor(LoggingCursor):
227     """The cursor sub-class companion to MinTimeLoggingConnection."""
228
229     def execute(self, query, vars=None, async=0):
230         self.timestamp = time.time()
231         return LoggingCursor.execute(self, query, vars, async)
232     
233     def callproc(self, procname, vars=None):
234         self.timestamp = time.time()
235         return LoggingCursor.execute(self, procname, var)