more detailed info passed when raising an exception
[plcapi.git] / PLC / Table.py
1 from PLC.Faults import *
2 from PLC.Parameter import Parameter
3
4 class Row(dict):
5     """
6     Representation of a row in a database table. To use, optionally
7     instantiate with a dict of values. Update as you would a
8     dict. Commit to the database with sync().
9     """
10
11     # Set this to the name of the table that stores the row.
12     table_name = None
13
14     # Set this to the name of the primary key of the table. It is
15     # assumed that the this key is a sequence if it is not set when
16     # sync() is called.
17     primary_key = None
18
19     # Set this to the names of tables that reference this table's
20     # primary key.
21     join_tables = []
22
23     # Set this to a dict of the valid fields of this object and their
24     # types. Not all fields (e.g., joined fields) may be updated via
25     # sync().
26     fields = {}
27
28     def __init__(self, api, fields = {}):
29         dict.__init__(self, fields)
30         self.api = api
31
32     def validate(self):
33         """
34         Validates values. Will validate a value with a custom function
35         if a function named 'validate_[key]' exists.
36         """
37
38         # Warn about mandatory fields
39         mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False)
40         for field in mandatory_fields:
41             if not self.has_key(field) or self[field] is None:
42                 raise PLCInvalidArgument, field + " must be specified and cannot be unset in class %s"%self.__class__.__name__
43
44         # Validate values before committing
45         for key, value in self.iteritems():
46             if value is not None and hasattr(self, 'validate_' + key):
47                 validate = getattr(self, 'validate_' + key)
48                 self[key] = validate(value)
49
50     def sync(self, commit = True, insert = None):
51         """
52         Flush changes back to the database.
53         """
54
55         # Validate all specified fields
56         self.validate()
57
58         # Filter out fields that cannot be set or updated directly
59         all_fields = self.api.db.fields(self.table_name)
60         fields = dict(filter(lambda (key, value): \
61                              key in all_fields and \
62                              (key not in self.fields or \
63                               not isinstance(self.fields[key], Parameter) or \
64                               not self.fields[key].ro),
65                              self.items()))
66
67         # Parameterize for safety
68         keys = fields.keys()
69         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
70
71         # If the primary key (usually an auto-incrementing serial
72         # identifier) has not been specified, or the primary key is the
73         # only field in the table, or insert has been forced.
74         if not self.has_key(self.primary_key) or \
75            all_fields == [self.primary_key] or \
76            insert is True:
77             # Insert new row
78             sql = "INSERT INTO %s (%s) VALUES (%s)" % \
79                   (self.table_name, ", ".join(keys), ", ".join(values))
80         else:
81             # Update existing row
82             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
83             sql = "UPDATE %s SET " % self.table_name + \
84                   ", ".join(columns) + \
85                   " WHERE %s = %s" % \
86                   (self.primary_key,
87                    self.api.db.param(self.primary_key, self[self.primary_key]))
88
89         self.api.db.do(sql, fields)
90
91         if not self.has_key(self.primary_key):
92             self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
93
94         if commit:
95             self.api.db.commit()
96
97     def delete(self, commit = True):
98         """
99         Delete row from its primary table, and from any tables that
100         reference it.
101         """
102
103         assert self.primary_key in self
104
105         for table in self.join_tables + [self.table_name]:
106             if isinstance(table, tuple):
107                 key = table[1]
108                 table = table[0]
109             else:
110                 key = self.primary_key
111
112             sql = "DELETE FROM %s WHERE %s = %s" % \
113                   (table, key,
114                    self.api.db.param(self.primary_key, self[self.primary_key]))
115
116             self.api.db.do(sql, self)
117
118         if commit:
119             self.api.db.commit()
120
121 class Table(list):
122     """
123     Representation of row(s) in a database table.
124     """
125
126     def __init__(self, api, classobj, columns = None):
127         self.api = api
128         self.classobj = classobj
129         self.rows = {}
130
131         if columns is None:
132             columns = classobj.fields
133         else:
134             columns = filter(lambda x: x in classobj.fields, columns)
135             if not columns:
136                 raise PLCInvalidArgument, "No valid return fields specified"
137
138         self.columns = columns
139
140     def sync(self, commit = True):
141         """
142         Flush changes back to the database.
143         """
144
145         for row in self:
146             row.sync(commit)
147
148     def selectall(self, sql, params = None):
149         """
150         Given a list of rows from the database, fill ourselves with
151         Row objects.
152         """
153
154         for row in self.api.db.selectall(sql, params):
155             obj = self.classobj(self.api, row)
156             self.append(obj)
157
158     def dict(self, key_field = None):
159         """
160         Return ourself as a dict keyed on key_field.
161         """
162
163         if key_field is None:
164             key_field = self.classobj.primary_key
165
166         return dict([(obj[key_field], obj) for obj in self])