merge changes from head
[plcapi.git] / PLC / Keys.py
index 542224d..b81442d 100644 (file)
@@ -2,6 +2,7 @@ import re
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
@@ -15,15 +16,24 @@ class Key(Row):
 
     table_name = 'keys'
     primary_key = 'key_id'
 
     table_name = 'keys'
     primary_key = 'key_id'
-    join_tables = ['person_key']
+    join_tables = ['person_key', 'peer_key']
     fields = {
         'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
         'key': Parameter(str, "Key value", max = 4096),
     fields = {
         'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
         'key': Parameter(str, "Key value", max = 4096),
+        'person_id': Parameter(int, "User to which this key belongs", nullok = True),
+        'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True),
+        'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True),
         }
 
         }
 
+    # for Cache
+    class_key= 'key'
+    foreign_fields = ['key_type']
+    foreign_xrefs = []
+
     def validate_key_type(self, key_type):
     def validate_key_type(self, key_type):
-        if key_type not in KeyTypes(self.api):
+        key_types = [row['key_type'] for row in KeyTypes(self.api)]
+        if key_type not in key_types:
             raise PLCInvalidArgument, "Invalid key type"
        return key_type
 
             raise PLCInvalidArgument, "Invalid key type"
        return key_type
 
@@ -83,7 +93,7 @@ class Key(Row):
                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
 
        # But disassociate them from all join tables
                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
 
        # But disassociate them from all join tables
-        for table in ['person_key']:
+        for table in self.join_tables:
             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
                            (table, ", ".join(map(str, key_ids))))
 
             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
                            (table, ", ".join(map(str, key_ids))))
 
@@ -96,19 +106,17 @@ class Keys(Table):
     database.
     """
 
     database.
     """
 
-    def __init__(self, api, key_id_list = None, is_blacklisted = False):
-        self.api = api
+    def __init__(self, api, key_filter = None, columns = None):
+        Table.__init__(self, api, Key, columns)
        
        
-       sql = "SELECT %s FROM keys WHERE True" % \
-              ", ".join(Key.fields)
-
-        if is_blacklisted is not None:
-            sql += " AND is_blacklisted IS %(is_blacklisted)s"            
+       sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \
+              ", ".join(self.columns)
 
 
-       if key_id_list:
-            sql += " AND key_id IN (%s)" %  ", ".join(map(str, key_id_list))
+        if key_filter is not None:
+            if isinstance(key_filter, (list, tuple, set)):
+                key_filter = Filter(Key.fields, {'key_id': key_filter})
+            elif isinstance(key_filter, dict):
+                key_filter = Filter(Key.fields, key_filter)
+            sql += " AND (%s)" % key_filter.sql(api)
 
 
-       rows = self.api.db.selectall(sql, locals())
-       
-       for row in rows:        
-            self[row['key_id']] = Key(api, row)
+       self.selectall(sql)