X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=PLC%2FKeys.py;h=b81442d54c39870ae18585a05b177e3fec4718c1;hb=c575b7396f84f929087440a6b1c5c08b59bf723e;hp=5a151fd977d271e24c2ea45a0e43684110fa822f;hpb=8154cdf6e85073026b86a0e5e27228ff6de48ae6;p=plcapi.git diff --git a/PLC/Keys.py b/PLC/Keys.py index 5a151fd..b81442d 100644 --- a/PLC/Keys.py +++ b/PLC/Keys.py @@ -1,113 +1,104 @@ -from types import StringTypes +import re from PLC.Faults import * from PLC.Parameter import Parameter +from PLC.Filter import Filter from PLC.Debug import profile from PLC.Table import Row, Table -import PLC +from PLC.KeyTypes import KeyType, KeyTypes + class Key(Row): """ Representation of a row in the keys table. To use, instantiate with a dict of values. Update as you would a dict. Commit to the database with sync(). """ + table_name = 'keys' primary_key = 'key_id' + join_tables = ['person_key', 'peer_key'] fields = { - 'key_id': Parameter(int, "Key type"), + 'key_id': Parameter(int, "Key identifier"), 'key_type': Parameter(str, "Key type"), - 'key': Parameter(str, "Key value"), - 'is_blacklisted': Parameter(str, "Key has been blacklisted and is forever unusable"), - 'person_id': Parameter(int, "Identifier of the account that created this key"), - 'is_primary': Parameter(bool, "Is the primary key for this account") + 'key': Parameter(str, "Key value", max = 4096), + 'person_id': Parameter(int, "User to which this key belongs", nullok = True), + 'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True), + 'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True), } - def __init__(self, api, fields): - Row.__init__(self, fields) - self.api = api - - - def validate_key_type(self, key_type): - # 1. ssh is the only supported key type - if not key_type or not key_type in ['ssh']: - raise PLCInvalidArgument, "Invalid key type" + # for Cache + class_key= 'key' + foreign_fields = ['key_type'] + foreign_xrefs = [] + def validate_key_type(self, key_type): + key_types = [row['key_type'] for row in KeyTypes(self.api)] + if key_type not in key_types: + raise PLCInvalidArgument, "Invalid key type" return key_type def validate_key(self, key): - # 1. key must not be blacklisted - - # Remove leading and trailing spaces - key = key.strip() - # Make sure key is not blank - if not len(key) > 0: - raise PLCInvalidArgument, "Invalid key" - - rows = self.api.db.selectall("SELECT is_blacklisted from keys" \ - " WHERE key = '%s'" % key) + # Key must not be blacklisted + rows = self.api.db.selectall("SELECT 1 from keys" \ + " WHERE key = %(key)s" \ + " AND is_blacklisted IS True", + locals()) if rows: - raise PLCInvalidArgument, "Key is blacklisted" + raise PLCInvalidArgument, "Key is blacklisted and cannot be used" + return key - - def add_person(self, person, commit = True): - """ - Associate key with person - """ - - assert 'key_id' in self - assert isinstance(person, PLC.Persons.Person) - assert 'person_id' in person - person_id = person['person_id'] - key_id = self['key_id'] - - if not 'person_id' in self: - assert key_id not in person['key_ids'] - - self.api.db.do("INSERT INTO person_key (person_id, key_id)" \ - " VALUES (%d, %d)" % (person_id, key_id) ) - if commit: - self.api.db.commit() - - self['person_id'] = person_id - person['key_id'] = key_id - - def set_primary_key(self, person, commit = True): - """ - Set the primary key for a person + def validate(self): + # Basic validation + Row.validate(self) + + assert 'key' in self + key = self['key'] + + if self['key_type'] == 'ssh': + # Accept only SSH version 2 keys without options. From + # sshd(8): + # + # Each protocol version 2 public key consists of: options, + # keytype, base64 encoded key, comment. The options field + # is optional...The comment field is not used for anything + # (but may be convenient for the user to identify the + # key). For protocol version 2 the keytype is ``ssh-dss'' + # or ``ssh-rsa''. + + good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$' + if not re.match(good_ssh_key, key, re.IGNORECASE): + raise PLCInvalidArgument, "Invalid SSH version 2 public key" + + def blacklist(self, commit = True): + """ + Permanently blacklist key (and all other identical keys), + preventing it from ever being added again. Because this could + affect multiple keys associated with multiple accounts, it + should be admin only. """ assert 'key_id' in self - assert isinstance(person, PLC.Persons.Person) - assert 'person_id' in person + assert 'key' in self - person_id = person['person_id'] - key_id = self['key_id'] - assert person_id in [self['person_id']] + # Get all matching keys + rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s", + self) + key_ids = [row['key_id'] for row in rows] + assert key_ids + assert self['key_id'] in key_ids - self.api.db.do("UPDATE person_key SET is_primary = False" \ - " WHERE person_id = %d " % person_id) - self.api.db.do("UPDATE person_key SET is_primary = True" \ - " WHERE person_id = %d AND key_id = %d" \ - % (person_id, key_id) ) + # Keep the keys in the table + self.api.db.do("UPDATE keys SET is_blacklisted = True" \ + " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids))) - if commit: - self.api.db.commit() - - self['is_primary'] = True - - def delete(self, commit = True): - """ - Delete key from the database - """ - assert 'key_id' in self - - for table in ['person_key', 'keys']: - self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \ - (table, self['key_id']), self) + # But disassociate them from all join tables + for table in self.join_tables: + self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \ + (table, ", ".join(map(str, key_ids)))) - if commit: - self.api.db.commit() + if commit: + self.api.db.commit() class Keys(Table): """ @@ -115,18 +106,17 @@ class Keys(Table): database. """ - def __init__(self, api, key_id_list = None): - self.api = api - - sql = "SELECT %s FROM keys LEFT JOIN person_key USING (%s) " % \ - (", ".join(Key.fields), Key.primary_key) + def __init__(self, api, key_filter = None, columns = None): + Table.__init__(self, api, Key, columns) - if key_id_list: - sql += " WHERE key_id IN (%s)" % ", ".join(map(str, key_id_list)) + sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \ + ", ".join(self.columns) - rows = self.api.db.selectall(sql) - - for row in rows: - self[row['key_id']] = Key(api, row) - - + if key_filter is not None: + if isinstance(key_filter, (list, tuple, set)): + key_filter = Filter(Key.fields, {'key_id': key_filter}) + elif isinstance(key_filter, dict): + key_filter = Filter(Key.fields, key_filter) + sql += " AND (%s)" % key_filter.sql(api) + + self.selectall(sql)