1 from types import StringTypes
3 from PLC.Faults import *
4 from PLC.Parameter import Parameter, Mixed, python_type
6 class Filter(Parameter, dict):
8 A type of parameter that represents a filter on one or more
9 columns of a database table. fields should be a dictionary of
10 field names and types, e.g.
12 {'node_id': Parameter(int, "Node identifier"),
13 'hostname': Parameter(int, "Fully qualified hostname", max = 255),
16 Only filters on non-sequence type fields are supported.
18 filter should be a dictionary of field names and values
19 representing an intersection (if join_with is AND) or union (if
20 join_with is OR) filter. If a value is a sequence type, then it
21 should represent a list of possible values for that field.
24 def __init__(self, fields = {}, filter = {}, doc = "Attribute filter"):
25 # Store the filter in our dict instance
26 dict.__init__(self, filter)
28 # Declare ourselves as a type of parameter that can take
29 # either a value or a list of values for each of the specified
33 for field, expected in fields.iteritems():
34 # Cannot filter on sequences
35 if python_type(expected) in (list, tuple, set):
38 # Accept either a value or a list of values of the specified type
39 self.fields[field] = Mixed(expected, [expected])
41 # Null filter means no filter
42 Parameter.__init__(self, self.fields, doc = doc, nullok = True)
44 def sql(self, api, join_with = "AND"):
46 Returns a SQL conditional that represents this filter.
49 # So that we always return something
50 if join_with == "AND":
51 conditionals = ["True"]
52 elif join_with == "OR":
53 conditionals = ["False"]
55 assert join_with in ("AND", "OR")
57 for field, value in self.iteritems():
58 # provide for negation with a field starting with ~
64 if field not in self.fields:
65 raise PLCInvalidArgument, "Invalid filter field '%s'" % field
67 if isinstance(value, (list, tuple, set)):
68 # Turn empty list into (NULL) instead of invalid ()
73 value = map(str, map(api.db.quote, value))
74 value = "(%s)" % ", ".join(value)
79 elif isinstance(value, StringTypes) and \
80 (value.find("*") > -1 or value.find("%") > -1):
82 value = str(api.db.quote(value.replace("*", "%")))
85 value = str(api.db.quote(value))
87 clause = "%s %s %s" % (field, operator, value)
89 clause = " ( NOT %s ) "%clause
91 conditionals.append(clause)
93 return (" %s " % join_with).join(conditionals)