- re-enable return_fields specification
[plcapi.git] / PLC / Keys.py
index 88c7120..841e6eb 100644 (file)
@@ -2,6 +2,7 @@ import re
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
@@ -15,24 +16,20 @@ class Key(Row):
 
     table_name = 'keys'
     primary_key = 'key_id'
+    join_tables = ['person_key']
     fields = {
         'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
         'key': Parameter(str, "Key value", max = 4096),
         }
 
-    def __init__(self, api, fields = {}):
-        Row.__init__(self, fields)
-       self.api = api
-
     def validate_key_type(self, key_type):
-        if key_type not in KeyTypes(self.api):
+        key_types = [row['key_type'] for row in KeyTypes(self.api)]
+        if key_type not in key_types:
             raise PLCInvalidArgument, "Invalid key type"
        return key_type
 
     def validate_key(self, key):
-        key = key.strip()
-
        # Key must not be blacklisted
        rows = self.api.db.selectall("SELECT 1 from keys" \
                                     " WHERE key = %(key)s" \
@@ -95,39 +92,23 @@ class Key(Row):
         if commit:
             self.api.db.commit()
 
-    def delete(self, commit = True):
-        """
-        Delete key from the database.
-        """
-
-       assert 'key_id' in self
-       
-       for table in ['person_key', 'keys']:
-            self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \
-                           (table, self['key_id']))
-        
-        if commit:
-            self.api.db.commit()
-
 class Keys(Table):
     """
     Representation of row(s) from the keys table in the
     database.
     """
 
-    def __init__(self, api, key_id_list = None, is_blacklisted = False):
-        self.api = api
+    def __init__(self, api, key_filter = None, columns = None):
+        Table.__init__(self, api, Key, columns)
        
-       sql = "SELECT %s FROM keys WHERE True" % \
-              ", ".join(Key.fields)
-
-        if is_blacklisted is not None:
-            sql += " AND is_blacklisted IS %(is_blacklisted)s"            
+       sql = "SELECT %s FROM keys WHERE is_blacklisted IS False" % \
+              ", ".join(self.columns)
 
-       if key_id_list:
-            sql += " AND key_id IN (%s)" %  ", ".join(map(str, key_id_list))
+        if key_filter is not None:
+            if isinstance(key_filter, (list, tuple, set)):
+                key_filter = Filter(Key.fields, {'key_id': key_filter})
+            elif isinstance(key_filter, dict):
+                key_filter = Filter(Key.fields, key_filter)
+            sql += " AND (%s)" % key_filter.sql(api)
 
-       rows = self.api.db.selectall(sql, locals())
-       
-       for row in rows:        
-            self[row['key_id']] = Key(api, row)
+       self.selectall(sql)