- explicitly set new person['enabled'] field to False
[plcapi.git] / PLC / Keys.py
index 88c7120..b81442d 100644 (file)
@@ -2,6 +2,7 @@ import re
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
@@ -15,24 +16,28 @@ class Key(Row):
 
     table_name = 'keys'
     primary_key = 'key_id'
 
     table_name = 'keys'
     primary_key = 'key_id'
+    join_tables = ['person_key', 'peer_key']
     fields = {
         'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
         'key': Parameter(str, "Key value", max = 4096),
     fields = {
         'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
         'key': Parameter(str, "Key value", max = 4096),
+        'person_id': Parameter(int, "User to which this key belongs", nullok = True),
+        'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True),
+        'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True),
         }
 
         }
 
-    def __init__(self, api, fields = {}):
-        Row.__init__(self, fields)
-       self.api = api
+    # for Cache
+    class_key= 'key'
+    foreign_fields = ['key_type']
+    foreign_xrefs = []
 
     def validate_key_type(self, key_type):
 
     def validate_key_type(self, key_type):
-        if key_type not in KeyTypes(self.api):
+        key_types = [row['key_type'] for row in KeyTypes(self.api)]
+        if key_type not in key_types:
             raise PLCInvalidArgument, "Invalid key type"
        return key_type
 
     def validate_key(self, key):
             raise PLCInvalidArgument, "Invalid key type"
        return key_type
 
     def validate_key(self, key):
-        key = key.strip()
-
        # Key must not be blacklisted
        rows = self.api.db.selectall("SELECT 1 from keys" \
                                     " WHERE key = %(key)s" \
        # Key must not be blacklisted
        rows = self.api.db.selectall("SELECT 1 from keys" \
                                     " WHERE key = %(key)s" \
@@ -88,46 +93,30 @@ class Key(Row):
                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
 
        # But disassociate them from all join tables
                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
 
        # But disassociate them from all join tables
-        for table in ['person_key']:
+        for table in self.join_tables:
             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
                            (table, ", ".join(map(str, key_ids))))
 
         if commit:
             self.api.db.commit()
 
             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
                            (table, ", ".join(map(str, key_ids))))
 
         if commit:
             self.api.db.commit()
 
-    def delete(self, commit = True):
-        """
-        Delete key from the database.
-        """
-
-       assert 'key_id' in self
-       
-       for table in ['person_key', 'keys']:
-            self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \
-                           (table, self['key_id']))
-        
-        if commit:
-            self.api.db.commit()
-
 class Keys(Table):
     """
     Representation of row(s) from the keys table in the
     database.
     """
 
 class Keys(Table):
     """
     Representation of row(s) from the keys table in the
     database.
     """
 
-    def __init__(self, api, key_id_list = None, is_blacklisted = False):
-        self.api = api
+    def __init__(self, api, key_filter = None, columns = None):
+        Table.__init__(self, api, Key, columns)
        
        
-       sql = "SELECT %s FROM keys WHERE True" % \
-              ", ".join(Key.fields)
+       sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \
+              ", ".join(self.columns)
 
 
-        if is_blacklisted is not None:
-            sql += " AND is_blacklisted IS %(is_blacklisted)s"            
+        if key_filter is not None:
+            if isinstance(key_filter, (list, tuple, set)):
+                key_filter = Filter(Key.fields, {'key_id': key_filter})
+            elif isinstance(key_filter, dict):
+                key_filter = Filter(Key.fields, key_filter)
+            sql += " AND (%s)" % key_filter.sql(api)
 
 
-       if key_id_list:
-            sql += " AND key_id IN (%s)" %  ", ".join(map(str, key_id_list))
-
-       rows = self.api.db.selectall(sql, locals())
-       
-       for row in rows:        
-            self[row['key_id']] = Key(api, row)
+       self.selectall(sql)