fixed hazelnut checkbox management
[myslice.git] / manifold / core / query.py
index 4aba14b..e45844c 100644 (file)
@@ -90,13 +90,13 @@ class Query(object):
                 self.filters = kwargs["filters"]
                 del kwargs["filters"]
             else:
-                self.filters = Filter([])
+                self.filters = Filter()
 
             if "fields" in kwargs:
                 self.fields = set(kwargs["fields"])
                 del kwargs["fields"]
             else:
-                self.fields = set([])
+                self.fields = set()
 
             # "update table set x = 3" => params == set
             if "params" in kwargs:
@@ -116,14 +116,14 @@ class Query(object):
         #else:
         #        raise ParameterError, "No valid constructor found for %s : args = %r" % (self.__class__.__name__, args)
 
-        if not self.filters: self.filters = Filter([])
+        if not self.filters: self.filters = Filter()
         if not self.params:  self.params  = {}
-        if not self.fields:  self.fields  = set([])
+        if not self.fields:  self.fields  = set()
         if not self.timestamp:      self.timestamp      = "now" 
 
         if isinstance(self.filters, list):
             f = self.filters
-            self.filters = Filter([])
+            self.filters = Filter()
             for x in f:
                 pred = Predicate(x)
                 self.filters.add(pred)
@@ -145,28 +145,41 @@ class Query(object):
     def clear(self):
         self.action = 'get'
         self.object = None
-        self.filters = Filter([])
+        self.filters = Filter()
         self.params  = {}
-        self.fields  = set([])
+        self.fields  = set()
         self.timestamp  = "now" 
         self.timestamp  = 'now' # ignored for now
 
+
+    def to_sql(self, platform='', multiline=False):
+        get_params_str = lambda : ', '.join(['%s = %r' % (k, v) for k, v in self.get_params().items()])
+        get_select_str = lambda : ', '.join(self.get_select()) 
+
+        table  = self.get_from()
+        select = 'SELECT %s' % (get_select_str()    if self.get_select()    else '*')
+        where  = 'WHERE %s'  % self.get_where()     if self.get_where()     else ''
+        at     = 'AT %s '    % self.get_timestamp() if self.get_timestamp() else ''
+        params = 'SET %s'    % get_params_str()     if self.get_params()    else ''
+
+        sep = ' ' if not multiline else '\n  '
+        if platform: platform = "%s:" % platform
+        strmap = {
+            'get'   : '%(select)s%(sep)s%(at)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s',                                           
+            'update': 'UPDATE %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(where)s%(sep)s%(select)s',       
+            'create': 'INSERT INTO %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(select)s',
+            'delete': 'DELETE FROM %(platform)s%(table)s%(sep)s%(where)s'
+        }
+
+        return strmap[self.action] % locals()
+
     @returns(StringTypes)
     def __str__(self):
-        return "SELECT %(select)s%(from)s%(where)s%(at)s" % {
-            "select": ", ".join(self.get_select())          if self.get_select()    else "*",
-            "from"  : "\n  FROM  %s" % self.get_from(),
-            "where" : "\n  WHERE %s" % self.get_where()     if self.get_where()     else "",
-            "at"    : "\n  AT    %s" % self.get_timestamp() if self.get_timestamp() else ""
-        }
+        return self.to_sql(multiline=True)
 
     @returns(StringTypes)
     def __repr__(self):
-        return "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(self.get_select()) if self.get_select() else '*',
-            self.get_from(),
-            self.get_where()
-        )
+        return self.to_sql()
 
     def __key(self):
         return (self.action, self.object, self.filters, frozendict(self.params), frozenset(self.fields))
@@ -268,34 +281,101 @@ class Query(object):
     #--------------------------------------------------------------------------- 
 
     @classmethod
+    #@returns(Query)
     def action(self, action, object):
+        """
+        (Internal usage). Craft a Query according to an action name 
+        See methods: get, update, delete, execute.
+        Args:
+            action: A String among {"get", "update", "delete", "execute"}
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
         query = Query()
         query.action = action
         query.object = object
         return query
 
     @classmethod
-    def get(self, object): return self.action('get', object)
+    #@returns(Query)
+    def get(self, object):
+        """
+        Craft the Query which fetches the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("get", object)
 
     @classmethod
-    def update(self, object): return self.action('update', object)
+    #@returns(Query)
+    def update(self, object):
+        """
+        Craft the Query which updates the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("update", object)
     
     @classmethod
-    def create(self, object): return self.action('create', object)
+    #@returns(Query)
+    def create(self, object):
+        """
+        Craft the Query which create the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("create", object)
     
     @classmethod
-    def delete(self, object): return self.action('delete', object)
+    #@returns(Query)
+    def delete(self, object):
+        """
+        Craft the Query which delete the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("delete", object)
     
     @classmethod
-    def execute(self, object): return self.action('execute', object)
-
+    #@returns(Query)
+    def execute(self, object):
+        """
+        Craft the Query which execute a processing related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("execute", object)
+
+    #@returns(Query)
     def at(self, timestamp):
+        """
+        Set the timestamp carried by the query
+        Args:
+            timestamp: The timestamp (it may be a python timestamp, a string
+                respecting the "%Y-%m-%d %H:%M:%S" python format, or "now")
+        Returns:
+            The self Query instance
+        """
         self.timestamp = timestamp
         return self
 
     def filter_by(self, *args):
         if len(args) == 1:
             filters = args[0]
+            if filters == None:
+                self.filters = Filter()
+                return self
             if not isinstance(filters, (set, list, tuple, Filter)):
                 filters = [filters]
             for predicate in filters:
@@ -307,13 +387,18 @@ class Query(object):
             raise Exception, 'Invalid expression for filter'
         return self
             
-    def select(self, fields=None):
+    def select(self, *fields):
         if not fields:
             # Delete all fields
             self.fields = set()
             return self
-        if not isinstance(fields, (set, list, tuple)):
-            fields = [fields]
+
+        # Accept passing iterables
+        if len(fields) == 1:
+            tmp, = fields
+            if isinstance(tmp, (list, tuple, set, frozenset)):
+                fields = tuple(tmp)
+
         for field in fields:
             self.fields.add(field)
         return self
@@ -322,6 +407,31 @@ class Query(object):
         self.params.update(params)
         return self
 
+    def __or__(self, query):
+        assert self.action == query.action
+        assert self.object == query.object
+        assert self.timestamp == query.timestamp # XXX
+        filter = self.filters | query.filters
+        # fast dict union
+        # http://my.safaribooksonline.com/book/programming/python/0596007973/python-shortcuts/pythoncook2-chp-4-sect-17
+        params = dict(self.params, **query.params)
+        fields = self.fields | query.fields
+        return Query.action(self.action, self.object).filter_by(filter).select(fields)
+
+    def __and__(self, query):
+        assert self.action == query.action
+        assert self.object == query.object
+        assert self.timestamp == query.timestamp # XXX
+        filter = self.filters & query.filters
+        # fast dict intersection
+        # http://my.safaribooksonline.com/book/programming/python/0596007973/python-shortcuts/pythoncook2-chp-4-sect-17
+        params =  dict.fromkeys([x for x in self.params if x in query.params])
+        fields = self.fields & query.fields
+        return Query.action(self.action, self.object).filter_by(filter).select(fields)
+
+    def __le__(self, query):
+        return ( self == self & query ) or ( query == self | query )
+
 class AnalyzedQuery(Query):
 
     # XXX we might need to propagate special parameters sur as DEBUG, etc.
@@ -361,12 +471,9 @@ class AnalyzedQuery(Query):
         if not method in self._subqueries:
             analyzed_query = AnalyzedQuery(metadata=self.metadata)
             analyzed_query.action = self.action
-            if self.metadata:
-                try:
-                    type = self.metadata.get_field_type(self.object, method)
-                except ValueError ,e: # backwards 1..N
-                    type = method
-            else:
+            try:
+                type = self.metadata.get_field_type(self.object, method)
+            except ValueError ,e: # backwards 1..N
                 type = method
             analyzed_query.object = type
             self._subqueries[method] = analyzed_query
@@ -381,30 +488,39 @@ class AnalyzedQuery(Query):
     def get_subquery_names(self):
         return set(self._subqueries.keys())
 
+    def get_subqueries(self):
+        return self._subqueries
+
     def subqueries(self):
         for method, subquery in self._subqueries.iteritems():
             yield (method, subquery)
 
     def filter_by(self, filters):
-        if not filters: return self
         if not isinstance(filters, (set, list, tuple, Filter)):
             filters = [filters]
         for predicate in filters:
-            if '.' in predicate.key:
-                method, subkey = pred.key.split('.', 1)
+            if predicate and '.' in predicate.key:
+                method, subkey = predicate.key.split('.', 1)
                 # Method contains the name of the subquery, we need the type
                 # XXX type = self.metadata.get_field_type(self.object, method)
-                sub_pred = Predicate(subkey, pred.op, pred.value)
+                sub_pred = Predicate(subkey, predicate.op, predicate.value)
                 self.subquery(method).filter_by(sub_pred)
             else:
                 super(AnalyzedQuery, self).filter_by(predicate)
         return self
 
-    def select(self, fields):
-        if not isinstance(fields, (set, list, tuple)):
-            fields = [fields]
+    def select(self, *fields):
+
+        # XXX passing None should reset fields in all subqueries
+
+        # Accept passing iterables
+        if len(fields) == 1:
+            tmp, = fields
+            if isinstance(tmp, (list, tuple, set, frozenset)):
+                fields = tuple(tmp)
+
         for field in fields:
-            if '.' in field:
+            if field and '.' in field:
                 method, subfield = field.split('.', 1)
                 # Method contains the name of the subquery, we need the type
                 # XXX type = self.metadata.get_field_type(self.object, method)