portal: added wip for PI validation page
[myslice.git] / manifold / core / query.py
index e3118e0..764336f 100644 (file)
 #   Thierry Parmentelat <thierry.parmentelat@inria.fr>
 
 from types                      import StringTypes
-from manifold.core.filter         import Filter, Predicate
-from manifold.util.frozendict     import frozendict
-from manifold.util.type           import returns, accepts
+from manifold.core.filter       import Filter, Predicate
+from manifold.util.frozendict   import frozendict
+from manifold.util.type         import returns, accepts
+from manifold.util.clause       import Clause
 import copy
 
 import json
@@ -90,13 +91,13 @@ class Query(object):
                 self.filters = kwargs["filters"]
                 del kwargs["filters"]
             else:
-                self.filters = Filter([])
+                self.filters = Filter()
 
             if "fields" in kwargs:
                 self.fields = set(kwargs["fields"])
                 del kwargs["fields"]
             else:
-                self.fields = set([])
+                self.fields = set()
 
             # "update table set x = 3" => params == set
             if "params" in kwargs:
@@ -116,17 +117,22 @@ class Query(object):
         #else:
         #        raise ParameterError, "No valid constructor found for %s : args = %r" % (self.__class__.__name__, args)
 
-        if not self.filters: self.filters = Filter([])
-        if not self.params:  self.params  = {}
-        if not self.fields:  self.fields  = set([])
-        if not self.timestamp:      self.timestamp      = "now" 
+        self.sanitize()
+
+    def sanitize(self):
+        if not self.filters:   self.filters   = Filter()
+        if not self.params:    self.params    = {}
+        if not self.fields:    self.fields    = set()
+        if not self.timestamp: self.timestamp = "now" 
 
         if isinstance(self.filters, list):
             f = self.filters
-            self.filters = Filter([])
+            self.filters = Filter()
             for x in f:
                 pred = Predicate(x)
                 self.filters.add(pred)
+        elif isinstance(self.filters, Clause):
+            self.filters = Filter.from_clause(self.filters)
 
         if isinstance(self.fields, list):
             self.fields = set(self.fields)
@@ -145,23 +151,39 @@ class Query(object):
     def clear(self):
         self.action = 'get'
         self.object = None
-        self.filters = Filter([])
+        self.filters = Filter()
         self.params  = {}
-        self.fields  = set([])
-        self.timestamp      = "now" 
+        self.fields  = set()
         self.timestamp  = 'now' # ignored for now
 
+    def to_sql(self, platform='', multiline=False):
+        get_params_str = lambda : ', '.join(['%s = %r' % (k, v) for k, v in self.get_params().items()])
+        get_select_str = lambda : ', '.join(self.get_select()) 
+
+        table  = self.get_from()
+        select = 'SELECT %s' % (get_select_str()    if self.get_select()    else '*')
+        where  = 'WHERE %s'  % self.get_where()     if self.get_where()     else ''
+        at     = 'AT %s'     % self.get_timestamp() if self.get_timestamp() else ''
+        params = 'SET %s'    % get_params_str()     if self.get_params()    else ''
+
+        sep = ' ' if not multiline else '\n  '
+        if platform: platform = "%s:" % platform
+        strmap = {
+            'get'   : '%(select)s%(sep)s%(at)s%(sep)sFROM %(platform)s%(table)s%(sep)s%(where)s%(sep)s',                                           
+            'update': 'UPDATE %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(where)s%(sep)s%(select)s',       
+            'create': 'INSERT INTO %(platform)s%(table)s%(sep)s%(params)s%(sep)s%(select)s',
+            'delete': 'DELETE FROM %(platform)s%(table)s%(sep)s%(where)s'
+        }
+
+        return strmap[self.action] % locals()
+
     @returns(StringTypes)
     def __str__(self):
-        return "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(self.get_select()) if self.get_select() else '*',
-            self.get_from(),
-            self.get_where()
-        )
+        return self.to_sql(multiline=True)
 
     @returns(StringTypes)
     def __repr__(self):
-        return self.__str__()
+        return self.to_sql()
 
     def __key(self):
         return (self.action, self.object, self.filters, frozendict(self.params), frozenset(self.fields))
@@ -178,7 +200,7 @@ class Query(object):
             'action': self.action,
             'object': self.object,
             'timestamp': self.timestamp,
-            'filters': self.filters,
+            'filters': self.filters.to_list(),
             'params': self.params,
             'fields': list(self.fields)
         }
@@ -219,6 +241,7 @@ class Query(object):
             if (debug):
                 import traceback
                 traceback.print_exc()
+        self.sanitize()
 
     #--------------------------------------------------------------------------- 
     # Accessors
@@ -263,30 +286,109 @@ class Query(object):
     #--------------------------------------------------------------------------- 
 
     @classmethod
+    #@returns(Query)
     def action(self, action, object):
+        """
+        (Internal usage). Craft a Query according to an action name 
+        See methods: get, update, delete, execute.
+        Args:
+            action: A String among {"get", "update", "delete", "execute"}
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
         query = Query()
         query.action = action
         query.object = object
         return query
 
     @classmethod
-    def get(self, object): return self.action('get', object)
+    #@returns(Query)
+    def get(self, object):
+        """
+        Craft the Query which fetches the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("get", object)
 
     @classmethod
-    def update(self, object): return self.action('update', object)
+    #@returns(Query)
+    def update(self, object):
+        """
+        Craft the Query which updates the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("update", object)
     
     @classmethod
-    def create(self, object): return self.action('create', object)
+    #@returns(Query)
+    def create(self, object):
+        """
+        Craft the Query which create the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("create", object)
     
     @classmethod
-    def delete(self, object): return self.action('delete', object)
+    #@returns(Query)
+    def delete(self, object):
+        """
+        Craft the Query which delete the records related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("delete", object)
     
     @classmethod
-    def execute(self, object): return self.action('execute', object)
+    #@returns(Query)
+    def execute(self, object):
+        """
+        Craft the Query which execute a processing related to a given object
+        Args:
+            object: The name of the queried object (String)
+        Returns:
+            The corresponding Query instance
+        """
+        return self.action("execute", object)
+
+    #@returns(Query)
+    def at(self, timestamp):
+        """
+        Set the timestamp carried by the query
+        Args:
+            timestamp: The timestamp (it may be a python timestamp, a string
+                respecting the "%Y-%m-%d %H:%M:%S" python format, or "now")
+        Returns:
+            The self Query instance
+        """
+        self.timestamp = timestamp
+        return self
 
     def filter_by(self, *args):
+        """
+        Args:
+            args: It may be:
+                - the parts of a Predicate (key, op, value)
+                - None
+                - a Filter instance
+                - a set/list/tuple of Predicate instances
+        """
         if len(args) == 1:
             filters = args[0]
+            if filters == None:
+                self.filters = Filter()
+                return self
             if not isinstance(filters, (set, list, tuple, Filter)):
                 filters = [filters]
             for predicate in filters:
@@ -298,9 +400,21 @@ class Query(object):
             raise Exception, 'Invalid expression for filter'
         return self
             
-    def select(self, fields):
-        if not isinstance(fields, (set, list, tuple)):
-            fields = [fields]
+    def select(self, *fields):
+
+        # Accept passing iterables
+        if len(fields) == 1:
+            tmp, = fields
+            if not tmp:
+                fields = None
+            elif isinstance(tmp, (list, tuple, set, frozenset)):
+                fields = tuple(tmp)
+
+        if not fields:
+            # Delete all fields
+            self.fields = set()
+            return self
+
         for field in fields:
             self.fields.add(field)
         return self
@@ -309,12 +423,38 @@ class Query(object):
         self.params.update(params)
         return self
 
+    def __or__(self, query):
+        assert self.action == query.action
+        assert self.object == query.object
+        assert self.timestamp == query.timestamp # XXX
+        filter = self.filters | query.filters
+        # fast dict union
+        # http://my.safaribooksonline.com/book/programming/python/0596007973/python-shortcuts/pythoncook2-chp-4-sect-17
+        params = dict(self.params, **query.params)
+        fields = self.fields | query.fields
+        return Query.action(self.action, self.object).filter_by(filter).select(fields)
+
+    def __and__(self, query):
+        assert self.action == query.action
+        assert self.object == query.object
+        assert self.timestamp == query.timestamp # XXX
+        filter = self.filters & query.filters
+        # fast dict intersection
+        # http://my.safaribooksonline.com/book/programming/python/0596007973/python-shortcuts/pythoncook2-chp-4-sect-17
+        params =  dict.fromkeys([x for x in self.params if x in query.params])
+        fields = self.fields & query.fields
+        return Query.action(self.action, self.object).filter_by(filter).select(fields)
+
+    def __le__(self, query):
+        return ( self == self & query ) or ( query == self | query )
+
 class AnalyzedQuery(Query):
 
     # XXX we might need to propagate special parameters sur as DEBUG, etc.
 
-    def __init__(self, query=None):
+    def __init__(self, query=None, metadata=None):
         self.clear()
+        self.metadata = metadata
         if query:
             self.query_uuid = query.query_uuid
             self.analyze(query)
@@ -324,8 +464,10 @@ class AnalyzedQuery(Query):
     @returns(StringTypes)
     def __str__(self):
         out = []
+        fields = self.get_select()
+        fields = ", ".join(fields) if fields else '*'
         out.append("SELECT %s FROM %s WHERE %s" % (
-            ", ".join(self.get_select()),
+            fields,
             self.get_from(),
             self.get_where()
         ))
@@ -343,35 +485,61 @@ class AnalyzedQuery(Query):
     def subquery(self, method):
         # Allows for the construction of a subquery
         if not method in self._subqueries:
-            analyzed_query = AnalyzedQuery()
+            analyzed_query = AnalyzedQuery(metadata=self.metadata)
             analyzed_query.action = self.action
-            analyzed_query.object = method
+            try:
+                type = self.metadata.get_field_type(self.object, method)
+            except ValueError ,e: # backwards 1..N
+                type = method
+            analyzed_query.object = type
             self._subqueries[method] = analyzed_query
         return self._subqueries[method]
 
+    def get_subquery(self, method):
+        return self._subqueries.get(method, None)
+
+    def remove_subquery(self, method):
+        del self._subqueries[method]
+
+    def get_subquery_names(self):
+        return set(self._subqueries.keys())
+
+    def get_subqueries(self):
+        return self._subqueries
+
     def subqueries(self):
         for method, subquery in self._subqueries.iteritems():
             yield (method, subquery)
 
     def filter_by(self, filters):
-        if not filters: return self
         if not isinstance(filters, (set, list, tuple, Filter)):
             filters = [filters]
         for predicate in filters:
-            if '.' in predicate.key:
-                method, subkey = pred.key.split('.', 1)
-                sub_pred = Predicate(subkey, pred.op, pred.value)
+            if predicate and '.' in predicate.key:
+                method, subkey = predicate.key.split('.', 1)
+                # Method contains the name of the subquery, we need the type
+                # XXX type = self.metadata.get_field_type(self.object, method)
+                sub_pred = Predicate(subkey, predicate.op, predicate.value)
                 self.subquery(method).filter_by(sub_pred)
             else:
                 super(AnalyzedQuery, self).filter_by(predicate)
         return self
 
-    def select(self, fields):
-        if not isinstance(fields, (set, list, tuple)):
-            fields = [fields]
+    def select(self, *fields):
+
+        # XXX passing None should reset fields in all subqueries
+
+        # Accept passing iterables
+        if len(fields) == 1:
+            tmp, = fields
+            if isinstance(tmp, (list, tuple, set, frozenset)):
+                fields = tuple(tmp)
+
         for field in fields:
-            if '.' in field:
+            if field and '.' in field:
                 method, subfield = field.split('.', 1)
+                # Method contains the name of the subquery, we need the type
+                # XXX type = self.metadata.get_field_type(self.object, method)
                 self.subquery(method).select(subfield)
             else:
                 super(AnalyzedQuery, self).select(field)
@@ -381,6 +549,8 @@ class AnalyzedQuery(Query):
         for param, value in self.params.items():
             if '.' in param:
                 method, subparam = param.split('.', 1)
+                # Method contains the name of the subquery, we need the type
+                # XXX type = self.metadata.get_field_type(self.object, method)
                 self.subquery(method).set({subparam: value})
             else:
                 super(AnalyzedQuery, self).set({param: value})