- re-enable return_fields specification
[plcapi.git] / PLC / Table.py
1 from PLC.Faults import *
2 from PLC.Parameter import Parameter
3
4 class Row(dict):
5     """
6     Representation of a row in a database table. To use, optionally
7     instantiate with a dict of values. Update as you would a
8     dict. Commit to the database with sync().
9     """
10
11     # Set this to the name of the table that stores the row.
12     table_name = None
13
14     # Set this to the name of the primary key of the table. It is
15     # assumed that the this key is a sequence if it is not set when
16     # sync() is called.
17     primary_key = None
18
19     # Set this to the names of tables that reference this table's
20     # primary key.
21     join_tables = []
22
23     # Set this to a dict of the valid fields of this object and their
24     # types. Not all fields (e.g., joined fields) may be updated via
25     # sync().
26     fields = {}
27
28     def __init__(self, api, fields = {}):
29         dict.__init__(self, fields)
30         self.api = api
31
32     def validate(self):
33         """
34         Validates values. Will validate a value with a custom function
35         if a function named 'validate_[key]' exists.
36         """
37
38         # Warn about mandatory fields
39         mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False)
40         for field in mandatory_fields:
41             if not self.has_key(field) or self[field] is None:
42                 raise PLCInvalidArgument, field + " must be specified and cannot be unset"
43
44         # Validate values before committing
45         for key, value in self.iteritems():
46             if value is not None and hasattr(self, 'validate_' + key):
47                 validate = getattr(self, 'validate_' + key)
48                 self[key] = validate(value)
49
50     def sync(self, commit = True, insert = None):
51         """
52         Flush changes back to the database.
53         """
54
55         # Validate all specified fields
56         self.validate()
57
58         # Filter out fields that cannot be set or updated directly
59         all_fields = self.api.db.fields(self.table_name)
60         fields = dict(filter(lambda (key, value): \
61                              key in all_fields and \
62                              (key not in self.fields or \
63                               not isinstance(self.fields[key], Parameter) or \
64                               not self.fields[key].ro),
65                              self.items()))
66
67         # Parameterize for safety
68         keys = fields.keys()
69         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
70
71         # If the primary key (usually an auto-incrementing serial
72         # identifier) has not been specified, or the primary key is the
73         # only field in the table, or insert has been forced.
74         if not self.has_key(self.primary_key) or \
75            all_fields == [self.primary_key] or \
76            insert is True:
77             # Insert new row
78             sql = "INSERT INTO %s (%s) VALUES (%s)" % \
79                   (self.table_name, ", ".join(keys), ", ".join(values))
80         else:
81             # Update existing row
82             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
83             sql = "UPDATE %s SET " % self.table_name + \
84                   ", ".join(columns) + \
85                   " WHERE %s = %s" % \
86                   (self.primary_key,
87                    self.api.db.param(self.primary_key, self[self.primary_key]))
88
89         self.api.db.do(sql, fields)
90
91         if not self.has_key(self.primary_key):
92             self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
93
94         if commit:
95             self.api.db.commit()
96
97     def delete(self, commit = True):
98         """
99         Delete row from its primary table, and from any tables that
100         reference it.
101         """
102
103         assert self.primary_key in self
104
105         for table in self.join_tables + [self.table_name]:
106             if isinstance(table, tuple):
107                 key = table[1]
108                 table = table[0]
109             else:
110                 key = self.primary_key
111
112             sql = "DELETE FROM %s WHERE %s = %s" % \
113                   (table, key,
114                    self.api.db.param(self.primary_key, self[self.primary_key]))
115
116             self.api.db.do(sql, self)
117
118         if commit:
119             self.api.db.commit()
120
121 class Table(list):
122     """
123     Representation of row(s) in a database table.
124     """
125
126     def __init__(self, api, row, columns = None):
127         self.api = api
128         self.row = row
129
130         if columns is None:
131             columns = row.fields
132         else:
133             columns = filter(lambda x: x in row.fields, columns)
134             if not columns:
135                 raise PLCInvalidArgument, "No valid return fields specified"
136
137         self.columns = columns
138
139     def sync(self, commit = True):
140         """
141         Flush changes back to the database.
142         """
143
144         for row in self:
145             row.sync(commit)
146
147     def selectall(self, sql, params = None):
148         """
149         Given a list of rows from the database, fill ourselves with
150         Row objects.
151         """
152
153         for row in self.api.db.selectall(sql, params):
154             self.append(self.row(self.api, row))