portal: added wip for PI validation page
[myslice.git] / manifold / core / query.py
index e45844c..764336f 100644 (file)
 #   Thierry Parmentelat <thierry.parmentelat@inria.fr>
 
 from types                      import StringTypes
-from manifold.core.filter         import Filter, Predicate
-from manifold.util.frozendict     import frozendict
-from manifold.util.type           import returns, accepts
+from manifold.core.filter       import Filter, Predicate
+from manifold.util.frozendict   import frozendict
+from manifold.util.type         import returns, accepts
+from manifold.util.clause       import Clause
 import copy
 
 import json
@@ -116,10 +117,13 @@ class Query(object):
         #else:
         #        raise ParameterError, "No valid constructor found for %s : args = %r" % (self.__class__.__name__, args)
 
-        if not self.filters: self.filters = Filter()
-        if not self.params:  self.params  = {}
-        if not self.fields:  self.fields  = set()
-        if not self.timestamp:      self.timestamp      = "now" 
+        self.sanitize()
+
+    def sanitize(self):
+        if not self.filters:   self.filters   = Filter()
+        if not self.params:    self.params    = {}
+        if not self.fields:    self.fields    = set()
+        if not self.timestamp: self.timestamp = "now" 
 
         if isinstance(self.filters, list):
             f = self.filters
@@ -127,6 +131,8 @@ class Query(object):
             for x in f:
                 pred = Predicate(x)
                 self.filters.add(pred)
+        elif isinstance(self.filters, Clause):
+            self.filters = Filter.from_clause(self.filters)
 
         if isinstance(self.fields, list):
             self.fields = set(self.fields)
@@ -148,10 +154,8 @@ class Query(object):
         self.filters = Filter()
         self.params  = {}
         self.fields  = set()
-        self.timestamp  = "now" 
         self.timestamp  = 'now' # ignored for now
 
-
     def to_sql(self, platform='', multiline=False):
         get_params_str = lambda : ', '.join(['%s = %r' % (k, v) for k, v in self.get_params().items()])
         get_select_str = lambda : ', '.join(self.get_select()) 
@@ -159,13 +163,13 @@ class Query(object):
         table  = self.get_from()
         select = 'SELECT %s' % (get_select_str()    if self.get_select()    else '*')
         where  = 'WHERE %s'  % self.get_where()     if self.get_where()     else ''
-        at     = 'AT %s '    % self.get_timestamp() if self.get_timestamp() else ''
+        at     = 'AT %s    % self.get_timestamp() if self.get_timestamp() else ''
         params = 'SET %s'    % get_params_str()     if self.get_params()    else ''
 
         sep = ' ' if not multiline else '\n  '
         if platform: platform = "%s:" % platform
         strmap = {
-            'get'   : '%(select)s%(sep)s%(at)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s',                                           
+            'get'   : '%(select)s%(sep)s%(at)s%(sep)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s',                                           
             'update': 'UPDATE %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(where)s%(sep)s%(select)s',       
             'create': 'INSERT INTO %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(select)s',
             'delete': 'DELETE FROM %(platform)s%(table)s%(sep)s%(where)s'
@@ -196,7 +200,7 @@ class Query(object):
             'action': self.action,
             'object': self.object,
             'timestamp': self.timestamp,
-            'filters': self.filters,
+            'filters': self.filters.to_list(),
             'params': self.params,
             'fields': list(self.fields)
         }
@@ -237,6 +241,7 @@ class Query(object):
             if (debug):
                 import traceback
                 traceback.print_exc()
+        self.sanitize()
 
     #--------------------------------------------------------------------------- 
     # Accessors
@@ -371,6 +376,14 @@ class Query(object):
         return self
 
     def filter_by(self, *args):
+        """
+        Args:
+            args: It may be:
+                - the parts of a Predicate (key, op, value)
+                - None
+                - a Filter instance
+                - a set/list/tuple of Predicate instances
+        """
         if len(args) == 1:
             filters = args[0]
             if filters == None:
@@ -388,17 +401,20 @@ class Query(object):
         return self
             
     def select(self, *fields):
-        if not fields:
-            # Delete all fields
-            self.fields = set()
-            return self
 
         # Accept passing iterables
         if len(fields) == 1:
             tmp, = fields
-            if isinstance(tmp, (list, tuple, set, frozenset)):
+            if not tmp:
+                fields = None
+            elif isinstance(tmp, (list, tuple, set, frozenset)):
                 fields = tuple(tmp)
 
+        if not fields:
+            # Delete all fields
+            self.fields = set()
+            return self
+
         for field in fields:
             self.fields.add(field)
         return self