assert primary key exists in record before attempting to delete it
[plcapi.git] / PLC / Keys.py
index ebabd19..07a05e4 100644 (file)
@@ -6,49 +6,52 @@ from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.KeyTypes import KeyType, KeyTypes
+from PLC.NovaTable import NovaObject, NovaTable
 
-class Key(Row):
+class Key(NovaObject):
     """
     Representation of a row in the keys table. To use, instantiate with a
     dict of values. Update as you would a dict. Commit to the database
     with sync().
     """
 
-    table_name = 'keys'
-    primary_key = 'key_id'
+    tablename = 'keys'
     join_tables = ['person_key', 'peer_key']
     fields = {
-        'key_id': Parameter(int, "Key identifier"),
-        'key_type': Parameter(str, "Key type"),
-        'key': Parameter(str, "Key value", max = 4096),
-        'person_id': Parameter(int, "User to which this key belongs", nullok = True),
-        'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True),
-        'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True),
+        'id': Parameter(str, "Key identifier", primary_key=True),
+        #'key_type': Parameter(str, "Key type"),
+        'public_key': Parameter(str, "Key string", max = 4096),
+        'name': Parameter(str, "Key name",)
         }
 
-    def validate_key_type(self, key_type):
-        key_types = [row['key_type'] for row in KeyTypes(self.api)]
-        if key_type not in key_types:
-            raise PLCInvalidArgument, "Invalid key type"
-        return key_type
+    def sync(self, insert = False, validate = True):
+        NovaObject.sync(self, insert, validate)
+        if insert == True or 'id' not in self:
+            self.object = self.api.client_shell.nova.keypairs.create(self.id,
+                                                                     self.key)
+        else:
+            self.object = self.api.client_shell.nova.keypairs.update(self.id, dict(self))
 
-    def validate_key(self, key):
-        # Key must not be blacklisted
-        rows = self.api.db.selectall("SELECT 1 from keys" \
-                                     " WHERE key = %(key)s" \
-                                     " AND is_blacklisted IS True",
-                                     locals())
-        if rows:
-            raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
+    def delete(self):
+        assert 'id' in self
+        self.api.client_shell.nova.keypairs.delete(self.id)
 
+    def validate_public_key(self, key):
+        # Key must not be blacklisted
+        pass
         return key
 
+    def validate_name(self, name)
+        keys = Keys(self.api, name)
+        if keys:
+            raise PLCInvalidArgument, "Key name alredy in use"
+
     def validate(self):
         # Basic validation
-        Row.validate(self)
+        NovaObject.validate(self)
 
-        assert 'key' in self
-        key = self['key']
+        assert 'public_key' in self
+        key = self['public_key']
 
         if self['key_type'] == 'ssh':
             # Accept only SSH version 2 keys without options. From
@@ -95,25 +98,27 @@ class Key(Row):
         if commit:
             self.api.db.commit()
 
-class Keys(Table):
+class Keys(NovaTable):
     """
     Representation of row(s) from the keys table in the
     database.
     """
 
     def __init__(self, api, key_filter = None, columns = None):
-        Table.__init__(self, api, Key, columns)
-
-        sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \
-              ", ".join(self.columns)
+        self.api = api
+        keysManager = self.api.client_shell.nova.keypairs
+        keyPairs = []
 
         if key_filter is not None:
             if isinstance(key_filter, (list, tuple, set, int, long)):
-                key_filter = Filter(Key.fields, {'key_id': key_filter})
+                keyPairs = filter(lambda kp: kp.uuid in key_filter,
+                                  keysManager.findall())
             elif isinstance(key_filter, dict):
-                key_filter = Filter(Key.fields, key_filter)
+                keyPairs = keysManager.findall(**key_filter)
+            elif isinstnace(key_filter, StringTypes):
+                keyPairs = keyManagers.findall(uuid = key_filter)
             else:
                 raise PLCInvalidArgument, "Wrong key filter %r"%key_filter
-            sql += " AND (%s) %s" % key_filter.sql(api)
 
-        self.selectall(sql)
+        self.extend(keyPairs)
+