====
[plcapi.git] / PLC / Keys.py
index 5a151fd..ebabd19 100644 (file)
-from types import StringTypes
+import re
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
-import PLC
+from PLC.KeyTypes import KeyType, KeyTypes
+
 class Key(Row):
     """
-    Representation of a row in the keys table. To use, instantiate with a 
-    dict of values. Update as you would a dict. Commit to the database 
+    Representation of a row in the keys table. To use, instantiate with a
+    dict of values. Update as you would a dict. Commit to the database
     with sync().
     """
+
     table_name = 'keys'
     primary_key = 'key_id'
+    join_tables = ['person_key', 'peer_key']
     fields = {
-        'key_id': Parameter(int, "Key type"),
+        'key_id': Parameter(int, "Key identifier"),
         'key_type': Parameter(str, "Key type"),
-        'key': Parameter(str, "Key value"),
-        'is_blacklisted': Parameter(str, "Key has been blacklisted and is forever unusable"),
-       'person_id': Parameter(int, "Identifier of the account that created this key"),
-       'is_primary': Parameter(bool, "Is the primary key for this account")
+        'key': Parameter(str, "Key value", max = 4096),
+        'person_id': Parameter(int, "User to which this key belongs", nullok = True),
+        'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True),
+        'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True),
         }
 
-    def __init__(self, api, fields):
-        Row.__init__(self, fields)
-       self.api = api
-        
-   
     def validate_key_type(self, key_type):
-       # 1. ssh is the only supported key type
-       if not key_type or not key_type in ['ssh']:
-               raise PLCInvalidArgument, "Invalid key type"
-
-       return key_type
+        key_types = [row['key_type'] for row in KeyTypes(self.api)]
+        if key_type not in key_types:
+            raise PLCInvalidArgument, "Invalid key type"
+        return key_type
 
     def validate_key(self, key):
-       # 1. key must not be blacklisted
-
-       # Remove leading and trailing spaces
-       key = key.strip()
-       # Make sure key is not blank 
-       if not len(key) > 0:
-                raise PLCInvalidArgument, "Invalid key"
-
-       rows = self.api.db.selectall("SELECT is_blacklisted from keys" \
-                                    " WHERE key = '%s'" % key)
-       if rows:
-               raise PLCInvalidArgument, "Key is blacklisted"  
-       return key
-    
-    def add_person(self, person, commit = True):
-       """
-       Associate key with person
-       """
-       
-       assert 'key_id' in self
-       assert isinstance(person, PLC.Persons.Person)
-       assert 'person_id' in person
-
-       person_id = person['person_id']
-       key_id = self['key_id']
-       
-       if not 'person_id' in self:
-               assert key_id not in person['key_ids']
-               
-               self.api.db.do("INSERT INTO person_key (person_id, key_id)" \
-                              " VALUES (%d, %d)" % (person_id, key_id) )
-               if commit:
-                       self.api.db.commit()
-
-               self['person_id'] = person_id
-               person['key_id'] = key_id 
-
-    def set_primary_key(self, person, commit = True):
-       """
-       Set the primary key for a person
-       """
-
-       assert 'key_id' in self
-        assert isinstance(person, PLC.Persons.Person)
-        assert 'person_id' in person
-
-       person_id = person['person_id']
-        key_id = self['key_id']
-       assert person_id in [self['person_id']]
-
-       self.api.db.do("UPDATE person_key SET is_primary = False" \
-                      " WHERE person_id = %d " % person_id)
-       self.api.db.do("UPDATE person_key SET is_primary = True" \
-                      " WHERE person_id = %d AND key_id = %d" \
-                      % (person_id, key_id) )
-
-       if commit:
-               self.api.db.commit()
-       
-       self['is_primary'] = True
-       
-    def delete(self, commit = True):
+        # Key must not be blacklisted
+        rows = self.api.db.selectall("SELECT 1 from keys" \
+                                     " WHERE key = %(key)s" \
+                                     " AND is_blacklisted IS True",
+                                     locals())
+        if rows:
+            raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
+
+        return key
+
+    def validate(self):
+        # Basic validation
+        Row.validate(self)
+
+        assert 'key' in self
+        key = self['key']
+
+        if self['key_type'] == 'ssh':
+            # Accept only SSH version 2 keys without options. From
+            # sshd(8):
+            #
+            # Each protocol version 2 public key consists of: options,
+            # keytype, base64 encoded key, comment.  The options field
+            # is optional...The comment field is not used for anything
+            # (but may be convenient for the user to identify the
+            # key). For protocol version 2 the keytype is ``ssh-dss''
+            # or ``ssh-rsa''.
+
+            good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
+            if not re.match(good_ssh_key, key, re.IGNORECASE):
+                raise PLCInvalidArgument, "Invalid SSH version 2 public key"
+
+    def blacklist(self, commit = True):
         """
-       Delete key from the database
-       """
-       assert 'key_id' in self
-       
-       for table in ['person_key', 'keys']:
-               self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \
-                (table, self['key_id']), self)
+        Permanently blacklist key (and all other identical keys),
+        preventing it from ever being added again. Because this could
+        affect multiple keys associated with multiple accounts, it
+        should be admin only.
+        """
+
+        assert 'key_id' in self
+        assert 'key' in self
+
+        # Get all matching keys
+        rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s",
+                                     self)
+        key_ids = [row['key_id'] for row in rows]
+        assert key_ids
+        assert self['key_id'] in key_ids
 
-       if commit:
-                       self.api.db.commit()
+        # Keep the keys in the table
+        self.api.db.do("UPDATE keys SET is_blacklisted = True" \
+                       " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
+
+        # But disassociate them from all join tables
+        for table in self.join_tables:
+            self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
+                           (table, ", ".join(map(str, key_ids))))
+
+        if commit:
+            self.api.db.commit()
 
 class Keys(Table):
     """
@@ -115,18 +101,19 @@ class Keys(Table):
     database.
     """
 
-    def __init__(self, api, key_id_list = None):
-        self.api = api
-       
-       sql = "SELECT %s FROM keys LEFT JOIN person_key USING (%s) " % \
-               (", ".join(Key.fields), Key.primary_key)
-       
-       if key_id_list:
-               sql += " WHERE key_id IN (%s)" %  ", ".join(map(str, key_id_list))
-
-       rows = self.api.db.selectall(sql)
-       
-       for row in rows:        
-               self[row['key_id']] = Key(api, row)
-               
-  
+    def __init__(self, api, key_filter = None, columns = None):
+        Table.__init__(self, api, Key, columns)
+
+        sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \
+              ", ".join(self.columns)
+
+        if key_filter is not None:
+            if isinstance(key_filter, (list, tuple, set, int, long)):
+                key_filter = Filter(Key.fields, {'key_id': key_filter})
+            elif isinstance(key_filter, dict):
+                key_filter = Filter(Key.fields, key_filter)
+            else:
+                raise PLCInvalidArgument, "Wrong key filter %r"%key_filter
+            sql += " AND (%s) %s" % key_filter.sql(api)
+
+        self.selectall(sql)