small fix
[unfold.git] / manifold / core / query.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 # Query representation
5 #
6 # Copyright (C) UPMC Paris Universitas
7 # Authors:
8 #   Jordan AugĂ©         <jordan.auge@lip6.fr>
9 #   Marc-Olivier Buob   <marc-olivier.buob@lip6.fr>
10 #   Thierry Parmentelat <thierry.parmentelat@inria.fr>
11
12 from types                      import StringTypes
13 from manifold.core.filter         import Filter, Predicate
14 from manifold.util.frozendict     import frozendict
15 from manifold.util.type           import returns, accepts
16 import copy
17
18 import json
19 import uuid
20
21 def uniqid (): 
22     return uuid.uuid4().hex
23
24 debug=False
25 debug=True
26
27 class ParameterError(StandardError): pass
28
29 class Query(object):
30     """
31     Implements a TopHat query.
32
33     We assume this is a correct DAG specification.
34
35     1/ A field designates several tables = OR specification.
36     2/ The set of fields specifies a AND between OR clauses.
37     """
38
39     #--------------------------------------------------------------------------- 
40     # Constructor
41     #--------------------------------------------------------------------------- 
42
43     def __init__(self, *args, **kwargs):
44
45         self.query_uuid = uniqid()
46
47         # Initialize optional parameters
48         self.clear()
49     
50         #l = len(kwargs.keys())
51         len_args = len(args)
52
53         if len(args) == 1:
54             if isinstance(args[0], dict):
55                 kwargs = args[0]
56                 args = []
57
58         # Initialization from a tuple
59
60         if len_args in range(2, 7) and type(args) == tuple:
61             # Note: range(x,y) <=> [x, y[
62
63             # XXX UGLY
64             if len_args == 3:
65                 self.action = 'get'
66                 self.params = {}
67                 self.timestamp     = 'now'
68                 self.object, self.filters, self.fields = args
69             elif len_args == 4:
70                 self.object, self.filters, self.params, self.fields = args
71                 self.action = 'get'
72                 self.timestamp     = 'now'
73             else:
74                 self.action, self.object, self.filters, self.params, self.fields, self.timestamp = args
75
76         # Initialization from a dict
77         elif "object" in kwargs:
78             if "action" in kwargs:
79                 self.action = kwargs["action"]
80                 del kwargs["action"]
81             else:
82                 print "W: defaulting to get action"
83                 self.action = "get"
84
85
86             self.object = kwargs["object"]
87             del kwargs["object"]
88
89             if "filters" in kwargs:
90                 self.filters = kwargs["filters"]
91                 del kwargs["filters"]
92             else:
93                 self.filters = Filter([])
94
95             if "fields" in kwargs:
96                 self.fields = set(kwargs["fields"])
97                 del kwargs["fields"]
98             else:
99                 self.fields = set([])
100
101             # "update table set x = 3" => params == set
102             if "params" in kwargs:
103                 self.params = kwargs["params"]
104                 del kwargs["params"]
105             else:
106                 self.params = {}
107
108             if "timestamp" in kwargs:
109                 self.timestamp = kwargs["timestamp"]
110                 del kwargs["timestamp"]
111             else:
112                 self.timestamp = "now" 
113
114             if kwargs:
115                 raise ParameterError, "Invalid parameter(s) : %r" % kwargs.keys()
116         #else:
117         #        raise ParameterError, "No valid constructor found for %s : args = %r" % (self.__class__.__name__, args)
118
119         if not self.filters: self.filters = Filter([])
120         if not self.params:  self.params  = {}
121         if not self.fields:  self.fields  = set([])
122         if not self.timestamp:      self.timestamp      = "now" 
123
124         if isinstance(self.filters, list):
125             f = self.filters
126             self.filters = Filter([])
127             for x in f:
128                 pred = Predicate(x)
129                 self.filters.add(pred)
130
131         if isinstance(self.fields, list):
132             self.fields = set(self.fields)
133
134         for field in self.fields:
135             if not isinstance(field, StringTypes):
136                 raise TypeError("Invalid field name %s (string expected, got %s)" % (field, type(field)))
137
138     #--------------------------------------------------------------------------- 
139     # Helpers
140     #--------------------------------------------------------------------------- 
141
142     def copy(self):
143         return copy.deepcopy(self)
144
145     def clear(self):
146         self.action = 'get'
147         self.object = None
148         self.filters = Filter([])
149         self.params  = {}
150         self.fields  = set([])
151         self.timestamp      = "now" 
152         self.timestamp  = 'now' # ignored for now
153
154     @returns(StringTypes)
155     def __str__(self):
156         return "SELECT %s FROM %s WHERE %s" % (
157             ", ".join(self.get_select()) if self.get_select() else '*',
158             self.get_from(),
159             self.get_where()
160         )
161
162     @returns(StringTypes)
163     def __repr__(self):
164         return self.__str__()
165
166     def __key(self):
167         return (self.action, self.object, self.filters, frozendict(self.params), frozenset(self.fields))
168
169     def __hash__(self):
170         return hash(self.__key())
171
172     #--------------------------------------------------------------------------- 
173     # Conversion
174     #--------------------------------------------------------------------------- 
175
176     def to_dict(self):
177         return {
178             'action': self.action,
179             'object': self.object,
180             'timestamp': self.timestamp,
181             'filters': self.filters,
182             'params': self.params,
183             'fields': list(self.fields)
184         }
185
186     def to_json (self, analyzed_query=None):
187         query_uuid=self.query_uuid
188         a=self.action
189         o=self.object
190         t=self.timestamp
191         f=json.dumps (self.filters.to_list())
192         p=json.dumps (self.params)
193         c=json.dumps (list(self.fields))
194         # xxx unique can be removed, but for now we pad the js structure
195         unique=0
196
197         if not analyzed_query:
198             aq = 'null'
199         else:
200             aq = analyzed_query.to_json()
201         sq="{}"
202         
203         result= """ new ManifoldQuery('%(a)s', '%(o)s', '%(t)s', %(f)s, %(p)s, %(c)s, %(unique)s, '%(query_uuid)s', %(aq)s, %(sq)s)"""%locals()
204         if debug: print 'ManifoldQuery.to_json:',result
205         return result
206     
207     # this builds a ManifoldQuery object from a dict as received from javascript through its ajax request 
208     # we use a json-encoded string - see manifold.js for the sender part 
209     # e.g. here's what I captured from the server's output
210     # manifoldproxy.proxy: request.POST <QueryDict: {u'json': [u'{"action":"get","object":"resource","timestamp":"latest","filters":[["slice_hrn","=","ple.inria.omftest"]],"params":[],"fields":["hrn","hostname"],"unique":0,"query_uuid":"436aae70a48141cc826f88e08fbd74b1","analyzed_query":null,"subqueries":{}}']}>
211     def fill_from_POST (self, POST_dict):
212         try:
213             json_string=POST_dict['json']
214             dict=json.loads(json_string)
215             for (k,v) in dict.iteritems(): 
216                 setattr(self,k,v)
217         except:
218             print "Could not decode incoming ajax request as a Query, POST=",POST_dict
219             if (debug):
220                 import traceback
221                 traceback.print_exc()
222
223     #--------------------------------------------------------------------------- 
224     # Accessors
225     #--------------------------------------------------------------------------- 
226
227     @returns(StringTypes)
228     def get_action(self):
229         return self.action
230
231     @returns(frozenset)
232     def get_select(self):
233         return frozenset(self.fields)
234
235     @returns(StringTypes)
236     def get_from(self):
237         return self.object
238
239     @returns(Filter)
240     def get_where(self):
241         return self.filters
242
243     @returns(dict)
244     def get_params(self):
245         return self.params
246
247     @returns(StringTypes)
248     def get_timestamp(self):
249         return self.timestamp
250
251 #DEPRECATED#
252 #DEPRECATED#    def make_filters(self, filters):
253 #DEPRECATED#        return Filter(filters)
254 #DEPRECATED#
255 #DEPRECATED#    def make_fields(self, fields):
256 #DEPRECATED#        if isinstance(fields, (list, tuple)):
257 #DEPRECATED#            return set(fields)
258 #DEPRECATED#        else:
259 #DEPRECATED#            raise Exception, "Invalid field specification"
260
261     #--------------------------------------------------------------------------- 
262     # LINQ-like syntax
263     #--------------------------------------------------------------------------- 
264
265     @classmethod
266     def action(self, action, object):
267         query = Query()
268         query.action = action
269         query.object = object
270         return query
271
272     @classmethod
273     def get(self, object): return self.action('get', object)
274
275     @classmethod
276     def update(self, object): return self.action('update', object)
277     
278     @classmethod
279     def create(self, object): return self.action('create', object)
280     
281     @classmethod
282     def delete(self, object): return self.action('delete', object)
283     
284     @classmethod
285     def execute(self, object): return self.action('execute', object)
286
287     def filter_by(self, *args):
288         if len(args) == 1:
289             filters = args[0]
290             if not isinstance(filters, (set, list, tuple, Filter)):
291                 filters = [filters]
292             for predicate in filters:
293                 self.filters.add(predicate)
294         elif len(args) == 3: 
295             predicate = Predicate(*args)
296             self.filters.add(predicate)
297         else:
298             raise Exception, 'Invalid expression for filter'
299         return self
300             
301     def select(self, fields):
302         if not isinstance(fields, (set, list, tuple)):
303             fields = [fields]
304         for field in fields:
305             self.fields.add(field)
306         return self
307
308     def set(self, params):
309         self.params.update(params)
310         return self
311
312 class AnalyzedQuery(Query):
313
314     # XXX we might need to propagate special parameters sur as DEBUG, etc.
315
316     def __init__(self, query=None):
317         self.clear()
318         if query:
319             self.query_uuid = query.query_uuid
320             self.analyze(query)
321         else:
322             self.query_uuid = uniqid()
323
324     @returns(StringTypes)
325     def __str__(self):
326         out = []
327         out.append("SELECT %s FROM %s WHERE %s" % (
328             ", ".join(self.get_select()),
329             self.get_from(),
330             self.get_where()
331         ))
332         cpt = 1
333         for method, subquery in self.subqueries():
334             out.append('  [SQ #%d : %s] %s' % (cpt, method, str(subquery)))
335             cpt += 1
336
337         return "\n".join(out)
338
339     def clear(self):
340         super(AnalyzedQuery, self).clear()
341         self._subqueries = {}
342
343     def subquery(self, method):
344         # Allows for the construction of a subquery
345         if not method in self._subqueries:
346             analyzed_query = AnalyzedQuery()
347             analyzed_query.action = self.action
348             analyzed_query.object = method
349             self._subqueries[method] = analyzed_query
350         return self._subqueries[method]
351
352     def subqueries(self):
353         for method, subquery in self._subqueries.iteritems():
354             yield (method, subquery)
355
356     def filter_by(self, filters):
357         if not filters: return self
358         if not isinstance(filters, (set, list, tuple, Filter)):
359             filters = [filters]
360         for predicate in filters:
361             if '.' in predicate.key:
362                 method, subkey = pred.key.split('.', 1)
363                 sub_pred = Predicate(subkey, pred.op, pred.value)
364                 self.subquery(method).filter_by(sub_pred)
365             else:
366                 super(AnalyzedQuery, self).filter_by(predicate)
367         return self
368
369     def select(self, fields):
370         if not isinstance(fields, (set, list, tuple)):
371             fields = [fields]
372         for field in fields:
373             if '.' in field:
374                 method, subfield = field.split('.', 1)
375                 self.subquery(method).select(subfield)
376             else:
377                 super(AnalyzedQuery, self).select(field)
378         return self
379
380     def set(self, params):
381         for param, value in self.params.items():
382             if '.' in param:
383                 method, subparam = param.split('.', 1)
384                 self.subquery(method).set({subparam: value})
385             else:
386                 super(AnalyzedQuery, self).set({param: value})
387         return self
388         
389     def analyze(self, query):
390         self.clear()
391         self.action = query.action
392         self.object = query.object
393         self.filter_by(query.filters)
394         self.set(query.params)
395         self.select(query.fields)
396
397     def to_json (self):
398         query_uuid=self.query_uuid
399         a=self.action
400         o=self.object
401         t=self.timestamp
402         f=json.dumps (self.filters.to_list())
403         p=json.dumps (self.params)
404         c=json.dumps (list(self.fields))
405         # xxx unique can be removed, but for now we pad the js structure
406         unique=0
407
408         aq = 'null'
409         sq=", ".join ( [ "'%s':%s" % (object, subquery.to_json())
410                   for (object, subquery) in self._subqueries.iteritems()])
411         sq="{%s}"%sq
412         
413         result= """ new ManifoldQuery('%(a)s', '%(o)s', '%(t)s', %(f)s, %(p)s, %(c)s, %(unique)s, '%(query_uuid)s', %(aq)s, %(sq)s)"""%locals()
414         if debug: print 'ManifoldQuery.to_json:',result
415         return result