0cf1c384f68cd481bfb11559f5e93b8f990ba00a
[myslice.git] / manifold / core / filter.py
1 from types import StringTypes
2 try:
3     set
4 except NameError:
5     from sets import Set
6     set = Set
7
8 import time
9 import datetime # Jordan
10 #from manifold.util.parameter import Parameter, Mixed, python_type
11 from manifold.util.predicate import Predicate, eq
12 from itertools                  import ifilter
13
14 class Filter(set):
15     """
16     A filter is a set of predicates
17     """
18
19     @staticmethod
20     def from_list(l):
21         f = Filter()
22         try:
23             for element in l:
24                 f.add(Predicate(*element))
25         except Exception, e:
26             print "Error in setting Filter from list", e
27             return None
28         return f
29         
30     @staticmethod
31     def from_dict(d):
32         f = Filter()
33         for key, value in d.items():
34             if key[0] in Predicate.operators.keys():
35                 f.add(Predicate(key[1:], key[0], value))
36             else:
37                 f.add(Predicate(key, '=', value))
38         return f
39
40     def filter_by(self, predicate):
41         self.add(predicate)
42         return self
43
44     def __str__(self):
45         return '<Filter: %s>' % ' AND '.join([str(pred) for pred in self])
46
47     def __repr__(self):
48         return self.__str__()
49
50     def __key(self):
51         return tuple([hash(pred) for pred in self])
52
53     def __hash__(self):
54         return hash(self.__key())
55
56     def __additem__(self, value):
57         if value.__class__ != Predicate:
58             raise TypeError("Element of class Predicate expected, received %s" % value.__class__.__name__)
59         set.__additem__(self, value)
60
61     def keys(self):
62         return set([x.key for x in self])
63
64     def has(self, key):
65         for x in self:
66             if x.key == key:
67                 return True
68         return False
69
70     def has_op(self, key, op):
71         for x in self:
72             if x.key == key and x.op == op:
73                 return True
74         return False
75
76     def has_eq(self, key):
77         return self.has_op(key, eq)
78
79     def get(self, key):
80         ret = []
81         for x in self:
82             if x.key == key:
83                 ret.append(x)
84         return ret
85
86     def delete(self, key):
87         to_del = []
88         for x in self:
89             if x.key == key:
90                 to_del.append(x)
91         for x in to_del:
92             self.remove(x)
93             
94         #self = filter(lambda x: x.key != key, self)
95
96     def get_op(self, key, op):
97         for x in self:
98             if x.key == key and x.op == op:
99                 return x.value
100         raise KeyError, key
101
102     def get_eq(self, key):
103         return self.get_op(key, eq)
104
105     def set_op(self, key, op, value):
106         for x in self:
107             if x.key == key and x.op == op:
108                 x.value = value
109                 return
110         raise KeyError, key
111
112     def set_eq(self, key, value):
113         return self.set_op(key, eq, value)
114
115     def get_predicates(self, key):
116         # XXX Would deserve returning a filter (cf usage in SFA gateway)
117         ret = []
118         for x in self:
119             if x.key == key:
120                 ret.append(x)
121         return ret
122
123 #    def filter(self, dic):
124 #        # We go through every filter sequentially
125 #        for predicate in self:
126 #            print "predicate", predicate
127 #            dic = predicate.filter(dic)
128 #        return dic
129
130     def match(self, dic):
131         for predicate in self:
132             if not predicate.match(dic, ignore_missing=True):
133                 return False
134         return True
135
136     def filter(self, l):
137         output = []
138         for x in l:
139             if self.match(x):
140                 output.append(x)
141         return output
142
143     def to_list(self):
144         return [list(pred.get_str_tuple()) for pred in self]
145
146 #class OldFilter(Parameter, dict):
147 #    """
148 #    A type of parameter that represents a filter on one or more
149 #    columns of a database table.
150 #    Special features provide support for negation, upper and lower bounds, 
151 #    as well as sorting and clipping.
152 #
153 #
154 #    fields should be a dictionary of field names and types.
155 #    As of PLCAPI-4.3-26, we provide support for filtering on
156 #    sequence types as well, with the special '&' and '|' modifiers.
157 #    example : fields = {'node_id': Parameter(int, "Node identifier"),
158 #                        'hostname': Parameter(int, "Fully qualified hostname", max = 255),
159 #                        ...}
160 #
161 #
162 #    filter should be a dictionary of field names and values
163 #    representing  the criteria for filtering. 
164 #    example : filter = { 'hostname' : '*.edu' , site_id : [34,54] }
165 #    Whether the filter represents an intersection (AND) or a union (OR) 
166 #    of these criteria is determined by the join_with argument 
167 #    provided to the sql method below
168 #
169 #    Special features:
170 #
171 #    * a field starting with '&' or '|' should refer to a sequence type
172 #      the semantic is then that the object value (expected to be a list)
173 #      should contain all (&) or any (|) value specified in the corresponding
174 #      filter value. See other examples below.
175 #    example : filter = { '|role_ids' : [ 20, 40 ] }
176 #    example : filter = { '|roles' : ['tech', 'pi'] }
177 #    example : filter = { '&roles' : ['admin', 'tech'] }
178 #    example : filter = { '&roles' : 'tech' }
179 #
180 #    * a field starting with the ~ character means negation.
181 #    example :  filter = { '~peer_id' : None }
182 #
183 #    * a field starting with < [  ] or > means lower than or greater than
184 #      < > uses strict comparison
185 #      [ ] is for using <= or >= instead
186 #    example :  filter = { ']event_id' : 2305 }
187 #    example :  filter = { '>time' : 1178531418 }
188 #      in this example the integer value denotes a unix timestamp
189 #
190 #    * if a value is a sequence type, then it should represent 
191 #      a list of possible values for that field
192 #    example : filter = { 'node_id' : [12,34,56] }
193 #
194 #    * a (string) value containing either a * or a % character is
195 #      treated as a (sql) pattern; * are replaced with % that is the
196 #      SQL wildcard character.
197 #    example :  filter = { 'hostname' : '*.jp' } 
198 #
199 #    * the filter's keys starting with '-' are special and relate to sorting and clipping
200 #    * '-SORT' : a field name, or an ordered list of field names that are used for sorting
201 #      these fields may start with + (default) or - for denoting increasing or decreasing order
202 #    example : filter = { '-SORT' : [ '+node_id', '-hostname' ] }
203 #    * '-OFFSET' : the number of first rows to be ommitted
204 #    * '-LIMIT' : the amount of rows to be returned 
205 #    example : filter = { '-OFFSET' : 100, '-LIMIT':25}
206 #
207 #    Here are a few realistic examples
208 #
209 #    GetNodes ( { 'node_type' : 'regular' , 'hostname' : '*.edu' , '-SORT' : 'hostname' , '-OFFSET' : 30 , '-LIMIT' : 25 } )
210 #      would return regular (usual) nodes matching '*.edu' in alphabetical order from 31th to 55th
211 #
212 #    GetPersons ( { '|role_ids' : [ 20 , 40] } )
213 #      would return all persons that have either pi (20) or tech (40) roles
214 #
215 #    GetPersons ( { '&role_ids' : 10 } )
216 #    GetPersons ( { '&role_ids' : 10 } )
217 #    GetPersons ( { '|role_ids' : [ 10 ] } )
218 #    GetPersons ( { '|role_ids' : [ 10 ] } )
219 #      all 4 forms are equivalent and would return all admin users in the system
220 #    """
221 #
222 #    def __init__(self, fields = {}, filter = {}, doc = "Attribute filter"):
223 #        # Store the filter in our dict instance
224 #        dict.__init__(self, filter)
225 #
226 #        # Declare ourselves as a type of parameter that can take
227 #        # either a value or a list of values for each of the specified
228 #        # fields.
229 #        self.fields = dict ( [ ( field, Mixed (expected, [expected])) 
230 #                                 for (field,expected) in fields.iteritems() ] )
231 #
232 #        # Null filter means no filter
233 #        Parameter.__init__(self, self.fields, doc = doc, nullok = True)
234 #
235 #    def sql(self, api, join_with = "AND"):
236 #        """
237 #        Returns a SQL conditional that represents this filter.
238 #        """
239 #
240 #        # So that we always return something
241 #        if join_with == "AND":
242 #            conditionals = ["True"]
243 #        elif join_with == "OR":
244 #            conditionals = ["False"]
245 #        else:
246 #            assert join_with in ("AND", "OR")
247 #
248 #        # init 
249 #        sorts = []
250 #        clips = []
251 #
252 #        for field, value in self.iteritems():
253 #            # handle negation, numeric comparisons
254 #            # simple, 1-depth only mechanism
255 #
256 #            modifiers={'~' : False, 
257 #                       '<' : False, '>' : False,
258 #                       '[' : False, ']' : False,
259 #                       '-' : False,
260 #                       '&' : False, '|' : False,
261 #                       '{': False ,
262 #                       }
263 #            def check_modifiers(field):
264 #                if field[0] in modifiers.keys():
265 #                    modifiers[field[0]] = True
266 #                    field = field[1:]
267 #                    return check_modifiers(field)
268 #                return field
269 #            field = check_modifiers(field)
270 #
271 #            # filter on fields
272 #            if not modifiers['-']:
273 #                if field not in self.fields:
274 #                    raise PLCInvalidArgument, "Invalid filter field '%s'" % field
275 #
276 #                # handling array fileds always as compound values
277 #                if modifiers['&'] or modifiers['|']:
278 #                    if not isinstance(value, (list, tuple, set)):
279 #                        value = [value,]
280 #
281 #                if isinstance(value, (list, tuple, set)):
282 #                    # handling filters like '~slice_id':[]
283 #                    # this should return true, as it's the opposite of 'slice_id':[] which is false
284 #                    # prior to this fix, 'slice_id':[] would have returned ``slice_id IN (NULL) '' which is unknown 
285 #                    # so it worked by coincidence, but the negation '~slice_ids':[] would return false too
286 #                    if not value:
287 #                        if modifiers['&'] or modifiers['|']:
288 #                            operator = "="
289 #                            value = "'{}'"
290 #                        else:
291 #                            field=""
292 #                            operator=""
293 #                            value = "FALSE"
294 #                    else:
295 #                        value = map(str, map(api.db.quote, value))
296 #                        if modifiers['&']:
297 #                            operator = "@>"
298 #                            value = "ARRAY[%s]" % ", ".join(value)
299 #                        elif modifiers['|']:
300 #                            operator = "&&"
301 #                            value = "ARRAY[%s]" % ", ".join(value)
302 #                        else:
303 #                            operator = "IN"
304 #                            value = "(%s)" % ", ".join(value)
305 #                else:
306 #                    if value is None:
307 #                        operator = "IS"
308 #                        value = "NULL"
309 #                    elif isinstance(value, StringTypes) and \
310 #                            (value.find("*") > -1 or value.find("%") > -1):
311 #                        operator = "LIKE"
312 #                        # insert *** in pattern instead of either * or %
313 #                        # we dont use % as requests are likely to %-expansion later on
314 #                        # actual replacement to % done in PostgreSQL.py
315 #                        value = value.replace ('*','***')
316 #                        value = value.replace ('%','***')
317 #                        value = str(api.db.quote(value))
318 #                    else:
319 #                        operator = "="
320 #                        if modifiers['<']:
321 #                            operator='<'
322 #                        if modifiers['>']:
323 #                            operator='>'
324 #                        if modifiers['[']:
325 #                            operator='<='
326 #                        if modifiers[']']:
327 #                            operator='>='
328 #                        #else:
329 #                        #    value = str(api.db.quote(value))
330 #                        # jordan
331 #                        if isinstance(value, StringTypes) and value[-2:] != "()": # XXX
332 #                            value = str(api.db.quote(value))
333 #                        if isinstance(value, datetime.datetime):
334 #                            value = str(api.db.quote(str(value)))
335
336 #                #if prefix: 
337 #                #    field = "%s.%s" % (prefix,field)
338 #                if field:
339 #                    clause = "\"%s\" %s %s" % (field, operator, value)
340 #                else:
341 #                    clause = "%s %s %s" % (field, operator, value)
342 #
343 #                if modifiers['~']:
344 #                    clause = " ( NOT %s ) " % (clause)
345 #
346 #                conditionals.append(clause)
347 #            # sorting and clipping
348 #            else:
349 #                if field not in ('SORT','OFFSET','LIMIT'):
350 #                    raise PLCInvalidArgument, "Invalid filter, unknown sort and clip field %r"%field
351 #                # sorting
352 #                if field == 'SORT':
353 #                    if not isinstance(value,(list,tuple,set)):
354 #                        value=[value]
355 #                    for field in value:
356 #                        order = 'ASC'
357 #                        if field[0] == '+':
358 #                            field = field[1:]
359 #                        elif field[0] == '-':
360 #                            field = field[1:]
361 #                            order = 'DESC'
362 #                        if field not in self.fields:
363 #                            raise PLCInvalidArgument, "Invalid field %r in SORT filter"%field
364 #                        sorts.append("%s %s"%(field,order))
365 #                # clipping
366 #                elif field == 'OFFSET':
367 #                    clips.append("OFFSET %d"%value)
368 #                # clipping continued
369 #                elif field == 'LIMIT' :
370 #                    clips.append("LIMIT %d"%value)
371 #
372 #        where_part = (" %s " % join_with).join(conditionals)
373 #        clip_part = ""
374 #        if sorts:
375 #            clip_part += " ORDER BY " + ",".join(sorts)
376 #        if clips:
377 #            clip_part += " " + " ".join(clips)
378 ##       print 'where_part=',where_part,'clip_part',clip_part
379 #        return (where_part,clip_part)
380 #