- rename fill to selectall
[plcapi.git] / PLC / Table.py
1 from PLC.Faults import *
2 from PLC.Parameter import Parameter
3
4 class Row(dict):
5     """
6     Representation of a row in a database table. To use, optionally
7     instantiate with a dict of values. Update as you would a
8     dict. Commit to the database with sync().
9     """
10
11     # Set this to the name of the table that stores the row.
12     table_name = None
13
14     # Set this to the name of the primary key of the table. It is
15     # assumed that the this key is a sequence if it is not set when
16     # sync() is called.
17     primary_key = None
18
19     # Set this to the names of tables that reference this table's
20     # primary key.
21     join_tables = []
22
23     # Set this to a dict of the valid fields of this object and their
24     # types. Not all fields (e.g., joined fields) may be updated via
25     # sync().
26     fields = {}
27
28     def __init__(self, api, fields = {}):
29         dict.__init__(self, fields)
30         self.api = api
31
32     def validate(self):
33         """
34         Validates values. Will validate a value with a custom function
35         if a function named 'validate_[key]' exists.
36         """
37
38         # Warn about mandatory fields
39         mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False)
40         for field in mandatory_fields:
41             if not self.has_key(field) or self[field] is None:
42                 raise PLCInvalidArgument, field + " must be specified and cannot be unset"
43
44         # Validate values before committing
45         for key, value in self.iteritems():
46             if value is not None and hasattr(self, 'validate_' + key):
47                 validate = getattr(self, 'validate_' + key)
48                 self[key] = validate(value)
49
50     def sync(self, commit = True, insert = None):
51         """
52         Flush changes back to the database.
53         """
54
55         # Validate all specified fields
56         self.validate()
57
58         # Filter out fields that cannot be set or updated directly
59         all_fields = self.api.db.fields(self.table_name)
60         fields = dict(filter(lambda (key, value): \
61                              key in all_fields and \
62                              (key not in self.fields or \
63                               not isinstance(self.fields[key], Parameter) or \
64                               not self.fields[key].ro),
65                              self.items()))
66
67         # Parameterize for safety
68         keys = fields.keys()
69         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
70
71         # If the primary key (usually an auto-incrementing serial
72         # identifier) has not been specified, or the primary key is the
73         # only field in the table, or insert has been forced.
74         if not self.has_key(self.primary_key) or \
75            all_fields == [self.primary_key] or \
76            insert is True:
77             # Insert new row
78             sql = "INSERT INTO %s (%s) VALUES (%s)" % \
79                   (self.table_name, ", ".join(keys), ", ".join(values))
80         else:
81             # Update existing row
82             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
83             sql = "UPDATE %s SET " % self.table_name + \
84                   ", ".join(columns) + \
85                   " WHERE %s = %s" % \
86                   (self.primary_key,
87                    self.api.db.param(self.primary_key, self[self.primary_key]))
88
89         self.api.db.do(sql, fields)
90
91         if not self.has_key(self.primary_key):
92             self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
93
94         if commit:
95             self.api.db.commit()
96
97     def delete(self, commit = True):
98         """
99         Delete row from its primary table, and from any tables that
100         reference it.
101         """
102
103         assert self.primary_key in self
104
105         for table in self.join_tables + [self.table_name]:
106             if isinstance(table, tuple):
107                 key = table[1]
108                 table = table[0]
109             else:
110                 key = self.primary_key
111
112             sql = "DELETE FROM %s WHERE %s = %s" % \
113                   (table, key,
114                    self.api.db.param(self.primary_key, self[self.primary_key]))
115
116             self.api.db.do(sql, self)
117
118         if commit:
119             self.api.db.commit()
120
121 class Table(dict):
122     """
123     Representation of row(s) in a database table.
124     """
125
126     def __init__(self, api, row):
127         self.api = api
128         self.row = row
129
130     def sync(self, commit = True):
131         """
132         Flush changes back to the database.
133         """
134
135         for row in self.values():
136             row.sync(commit)
137
138     def selectall(self, sql, params = None):
139         """
140         Given a list of rows from the database, fill ourselves with
141         Row objects keyed on the primary key defined by the Row class
142         we were initialized with.
143         """
144
145         for row in self.api.db.selectall(sql, params):
146             self[row[self.row.primary_key]] = self.row(self.api, row)