From 6a15f46cfce003d21ee8ba0bee213df66250e60e Mon Sep 17 00:00:00 2001 From: Mark Huang Date: Wed, 11 Oct 2006 15:39:58 +0000 Subject: [PATCH] - remove is_primary - set max on key length - validate key_type against KeyTypes - fix blacklisted key check (check for is_blacklisted IS True) - accounts can have multiple keys but not vice-versa; move add_person/remove_person to Persons.add_key/remove_key - validate SSH version 2 public keys - add blacklist() function to permanently blacklist key (disassociate it and all other identical keys, and set is_blacklisted to True) - whitespace nits - Keys: only return non-blacklisted keys --- PLC/Keys.py | 163 ++++++++++++++++++++++++++-------------------------- 1 file changed, 82 insertions(+), 81 deletions(-) diff --git a/PLC/Keys.py b/PLC/Keys.py index 5a151fd9..88c71205 100644 --- a/PLC/Keys.py +++ b/PLC/Keys.py @@ -1,113 +1,113 @@ -from types import StringTypes +import re from PLC.Faults import * from PLC.Parameter import Parameter from PLC.Debug import profile from PLC.Table import Row, Table -import PLC +from PLC.KeyTypes import KeyType, KeyTypes + class Key(Row): """ Representation of a row in the keys table. To use, instantiate with a dict of values. Update as you would a dict. Commit to the database with sync(). """ + table_name = 'keys' primary_key = 'key_id' fields = { - 'key_id': Parameter(int, "Key type"), + 'key_id': Parameter(int, "Key identifier"), 'key_type': Parameter(str, "Key type"), - 'key': Parameter(str, "Key value"), - 'is_blacklisted': Parameter(str, "Key has been blacklisted and is forever unusable"), - 'person_id': Parameter(int, "Identifier of the account that created this key"), - 'is_primary': Parameter(bool, "Is the primary key for this account") + 'key': Parameter(str, "Key value", max = 4096), } - def __init__(self, api, fields): + def __init__(self, api, fields = {}): Row.__init__(self, fields) self.api = api - - - def validate_key_type(self, key_type): - # 1. ssh is the only supported key type - if not key_type or not key_type in ['ssh']: - raise PLCInvalidArgument, "Invalid key type" + def validate_key_type(self, key_type): + if key_type not in KeyTypes(self.api): + raise PLCInvalidArgument, "Invalid key type" return key_type def validate_key(self, key): - # 1. key must not be blacklisted - - # Remove leading and trailing spaces - key = key.strip() - # Make sure key is not blank - if not len(key) > 0: - raise PLCInvalidArgument, "Invalid key" + key = key.strip() - rows = self.api.db.selectall("SELECT is_blacklisted from keys" \ - " WHERE key = '%s'" % key) + # Key must not be blacklisted + rows = self.api.db.selectall("SELECT 1 from keys" \ + " WHERE key = %(key)s" \ + " AND is_blacklisted IS True", + locals()) if rows: - raise PLCInvalidArgument, "Key is blacklisted" + raise PLCInvalidArgument, "Key is blacklisted and cannot be used" + return key - - def add_person(self, person, commit = True): - """ - Associate key with person - """ - - assert 'key_id' in self - assert isinstance(person, PLC.Persons.Person) - assert 'person_id' in person - person_id = person['person_id'] - key_id = self['key_id'] - - if not 'person_id' in self: - assert key_id not in person['key_ids'] - - self.api.db.do("INSERT INTO person_key (person_id, key_id)" \ - " VALUES (%d, %d)" % (person_id, key_id) ) - if commit: - self.api.db.commit() - - self['person_id'] = person_id - person['key_id'] = key_id - - def set_primary_key(self, person, commit = True): - """ - Set the primary key for a person + def validate(self): + # Basic validation + Row.validate(self) + + assert 'key' in self + key = self['key'] + + if self['key_type'] == 'ssh': + # Accept only SSH version 2 keys without options. From + # sshd(8): + # + # Each protocol version 2 public key consists of: options, + # keytype, base64 encoded key, comment. The options field + # is optional...The comment field is not used for anything + # (but may be convenient for the user to identify the + # key). For protocol version 2 the keytype is ``ssh-dss'' + # or ``ssh-rsa''. + + good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$' + if not re.match(good_ssh_key, key, re.IGNORECASE): + raise PLCInvalidArgument, "Invalid SSH version 2 public key" + + def blacklist(self, commit = True): + """ + Permanently blacklist key (and all other identical keys), + preventing it from ever being added again. Because this could + affect multiple keys associated with multiple accounts, it + should be admin only. """ assert 'key_id' in self - assert isinstance(person, PLC.Persons.Person) - assert 'person_id' in person + assert 'key' in self - person_id = person['person_id'] - key_id = self['key_id'] - assert person_id in [self['person_id']] + # Get all matching keys + rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s", + self) + key_ids = [row['key_id'] for row in rows] + assert key_ids + assert self['key_id'] in key_ids - self.api.db.do("UPDATE person_key SET is_primary = False" \ - " WHERE person_id = %d " % person_id) - self.api.db.do("UPDATE person_key SET is_primary = True" \ - " WHERE person_id = %d AND key_id = %d" \ - % (person_id, key_id) ) + # Keep the keys in the table + self.api.db.do("UPDATE keys SET is_blacklisted = True" \ + " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids))) + + # But disassociate them from all join tables + for table in ['person_key']: + self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \ + (table, ", ".join(map(str, key_ids)))) + + if commit: + self.api.db.commit() - if commit: - self.api.db.commit() - - self['is_primary'] = True - def delete(self, commit = True): """ - Delete key from the database - """ + Delete key from the database. + """ + assert 'key_id' in self for table in ['person_key', 'keys']: - self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \ - (table, self['key_id']), self) - - if commit: - self.api.db.commit() + self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \ + (table, self['key_id'])) + + if commit: + self.api.db.commit() class Keys(Table): """ @@ -115,18 +115,19 @@ class Keys(Table): database. """ - def __init__(self, api, key_id_list = None): + def __init__(self, api, key_id_list = None, is_blacklisted = False): self.api = api - sql = "SELECT %s FROM keys LEFT JOIN person_key USING (%s) " % \ - (", ".join(Key.fields), Key.primary_key) - + sql = "SELECT %s FROM keys WHERE True" % \ + ", ".join(Key.fields) + + if is_blacklisted is not None: + sql += " AND is_blacklisted IS %(is_blacklisted)s" + if key_id_list: - sql += " WHERE key_id IN (%s)" % ", ".join(map(str, key_id_list)) + sql += " AND key_id IN (%s)" % ", ".join(map(str, key_id_list)) - rows = self.api.db.selectall(sql) + rows = self.api.db.selectall(sql, locals()) for row in rows: - self[row['key_id']] = Key(api, row) - - + self[row['key_id']] = Key(api, row) -- 2.47.0