X-Git-Url: http://git.onelab.eu/?p=myslice.git;a=blobdiff_plain;f=manifold%2Fcore%2Fquery.py;h=764336fa77518a1fcf46db67d351682c549124c3;hp=e45844c9d8b9d63ee85a7acf03b8dd4b3b9e15aa;hb=d68dcefd28c832608cdb359a07a8b871cbe612ae;hpb=ac2dda758798f7e44de062d370763c639cc6a375 diff --git a/manifold/core/query.py b/manifold/core/query.py index e45844c9..764336fa 100644 --- a/manifold/core/query.py +++ b/manifold/core/query.py @@ -10,9 +10,10 @@ # Thierry Parmentelat from types import StringTypes -from manifold.core.filter import Filter, Predicate -from manifold.util.frozendict import frozendict -from manifold.util.type import returns, accepts +from manifold.core.filter import Filter, Predicate +from manifold.util.frozendict import frozendict +from manifold.util.type import returns, accepts +from manifold.util.clause import Clause import copy import json @@ -116,10 +117,13 @@ class Query(object): #else: # raise ParameterError, "No valid constructor found for %s : args = %r" % (self.__class__.__name__, args) - if not self.filters: self.filters = Filter() - if not self.params: self.params = {} - if not self.fields: self.fields = set() - if not self.timestamp: self.timestamp = "now" + self.sanitize() + + def sanitize(self): + if not self.filters: self.filters = Filter() + if not self.params: self.params = {} + if not self.fields: self.fields = set() + if not self.timestamp: self.timestamp = "now" if isinstance(self.filters, list): f = self.filters @@ -127,6 +131,8 @@ class Query(object): for x in f: pred = Predicate(x) self.filters.add(pred) + elif isinstance(self.filters, Clause): + self.filters = Filter.from_clause(self.filters) if isinstance(self.fields, list): self.fields = set(self.fields) @@ -148,10 +154,8 @@ class Query(object): self.filters = Filter() self.params = {} self.fields = set() - self.timestamp = "now" self.timestamp = 'now' # ignored for now - def to_sql(self, platform='', multiline=False): get_params_str = lambda : ', '.join(['%s = %r' % (k, v) for k, v in self.get_params().items()]) get_select_str = lambda : ', '.join(self.get_select()) @@ -159,13 +163,13 @@ class Query(object): table = self.get_from() select = 'SELECT %s' % (get_select_str() if self.get_select() else '*') where = 'WHERE %s' % self.get_where() if self.get_where() else '' - at = 'AT %s ' % self.get_timestamp() if self.get_timestamp() else '' + at = 'AT %s' % self.get_timestamp() if self.get_timestamp() else '' params = 'SET %s' % get_params_str() if self.get_params() else '' sep = ' ' if not multiline else '\n ' if platform: platform = "%s:" % platform strmap = { - 'get' : '%(select)s%(sep)s%(at)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s', + 'get' : '%(select)s%(sep)s%(at)s%(sep)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s', 'update': 'UPDATE %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(where)s%(sep)s%(select)s', 'create': 'INSERT INTO %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(select)s', 'delete': 'DELETE FROM %(platform)s%(table)s%(sep)s%(where)s' @@ -196,7 +200,7 @@ class Query(object): 'action': self.action, 'object': self.object, 'timestamp': self.timestamp, - 'filters': self.filters, + 'filters': self.filters.to_list(), 'params': self.params, 'fields': list(self.fields) } @@ -237,6 +241,7 @@ class Query(object): if (debug): import traceback traceback.print_exc() + self.sanitize() #--------------------------------------------------------------------------- # Accessors @@ -371,6 +376,14 @@ class Query(object): return self def filter_by(self, *args): + """ + Args: + args: It may be: + - the parts of a Predicate (key, op, value) + - None + - a Filter instance + - a set/list/tuple of Predicate instances + """ if len(args) == 1: filters = args[0] if filters == None: @@ -388,17 +401,20 @@ class Query(object): return self def select(self, *fields): - if not fields: - # Delete all fields - self.fields = set() - return self # Accept passing iterables if len(fields) == 1: tmp, = fields - if isinstance(tmp, (list, tuple, set, frozenset)): + if not tmp: + fields = None + elif isinstance(tmp, (list, tuple, set, frozenset)): fields = tuple(tmp) + if not fields: + # Delete all fields + self.fields = set() + return self + for field in fields: self.fields.add(field) return self