- fix typo
[plcapi.git] / PLC / Table.py
1 from PLC.Faults import *
2 from PLC.Parameter import Parameter
3
4 class Row(dict):
5     """
6     Representation of a row in a database table. To use, optionally
7     instantiate with a dict of values. Update as you would a
8     dict. Commit to the database with sync().
9     """
10
11     # Set this to the name of the table that stores the row.
12     table_name = None
13
14     # Set this to the name of the primary key of the table. It is
15     # assumed that the this key is a sequence if it is not set when
16     # sync() is called.
17     primary_key = None
18
19     # Set this to the names of tables that reference this table's
20     # primary key.
21     join_tables = []
22
23     # Set this to a dict of the valid fields of this object. Not all
24     # fields (e.g., joined fields) may be updated via sync().
25     fields = {}
26
27     def __init__(self, api, fields = {}):
28         dict.__init__(self, fields)
29         self.api = api
30
31     def validate(self):
32         """
33         Validates values. Will validate a value with a custom function
34         if a function named 'validate_[key]' exists.
35         """
36
37         # Warn about mandatory fields
38         mandatory_fields = self.api.db.fields(self.table_name, notnull = True, hasdef = False)
39         for field in mandatory_fields:
40             if not self.has_key(field) or self[field] is None:
41                 raise PLCInvalidArgument, field + " must be specified and cannot be unset"
42
43         # Validate values before committing
44         for key, value in self.iteritems():
45             if value is not None and hasattr(self, 'validate_' + key):
46                 validate = getattr(self, 'validate_' + key)
47                 self[key] = validate(value)
48
49     def sync(self, commit = True, insert = None):
50         """
51         Flush changes back to the database.
52         """
53
54         # Validate all specified fields
55         self.validate()
56
57         # Filter out fields that cannot be set or updated directly
58         all_fields = self.api.db.fields(self.table_name)
59         fields = dict(filter(lambda (key, value): \
60                              key in all_fields and \
61                              (key not in self.fields or \
62                               not isinstance(self.fields[key], Parameter) or \
63                               not self.fields[key].ro),
64                              self.items()))
65
66         # Parameterize for safety
67         keys = fields.keys()
68         values = [self.api.db.param(key, value) for (key, value) in fields.items()]
69
70         # If the primary key (usually an auto-incrementing serial
71         # identifier) has not been specified, or the primary key is the
72         # only field in the table, or insert has been forced.
73         if not self.has_key(self.primary_key) or \
74            all_fields == [self.primary_key] or \
75            insert is True:
76             # Insert new row
77             sql = "INSERT INTO %s (%s) VALUES (%s);" % \
78                   (self.table_name, ", ".join(keys), ", ".join(values))
79         else:
80             # Update existing row
81             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
82             sql = "UPDATE %s SET " % self.table_name + \
83                   ", ".join(columns) + \
84                   " WHERE %s = %s" % \
85                   (self.primary_key,
86                    self.api.db.param(self.primary_key, self[self.primary_key]))
87
88         self.api.db.do(sql, fields)
89
90         if not self.has_key(self.primary_key):
91             self[self.primary_key] = self.api.db.last_insert_id(self.table_name, self.primary_key)
92
93         if commit:
94             self.api.db.commit()
95
96     def delete(self, commit = True):
97         """
98         Delete row from its primary table, and from any tables that
99         reference it.
100         """
101
102         assert self.primary_key in self
103
104         for table in self.join_tables + [self.table_name]:
105             if isinstance(table, tuple):
106                 key = table[1]
107                 table = table[0]
108             else:
109                 key = self.primary_key
110
111             sql = "DELETE FROM %s WHERE %s = %s" % \
112                   (table, key,
113                    self.api.db.param(self.primary_key, self[self.primary_key]))
114
115             self.api.db.do(sql, self)
116
117         if commit:
118             self.api.db.commit()
119
120 class Table(dict):
121     """
122     Representation of row(s) in a database table.
123     """
124
125     def sync(self, commit = True):
126         """
127         Flush changes back to the database.
128         """
129
130         for row in self.values():
131             row.sync(commit)