regex fix: escape the first '-' too, add start/end of the string.
[plcapi.git] / PLC / Keys.py
1 # $Id$
2 # $URL$
3 import re
4
5 from PLC.Faults import *
6 from PLC.Parameter import Parameter
7 from PLC.Filter import Filter
8 from PLC.Debug import profile
9 from PLC.Table import Row, Table
10 from PLC.KeyTypes import KeyType, KeyTypes
11
12 class Key(Row):
13     """
14     Representation of a row in the keys table. To use, instantiate with a 
15     dict of values. Update as you would a dict. Commit to the database 
16     with sync().
17     """
18
19     table_name = 'keys'
20     primary_key = 'key_id'
21     join_tables = ['person_key', 'peer_key']
22     fields = {
23         'key_id': Parameter(int, "Key identifier"),
24         'key_type': Parameter(str, "Key type"),
25         'key': Parameter(str, "Key value", max = 4096),
26         'person_id': Parameter(int, "User to which this key belongs", nullok = True),
27         'peer_id': Parameter(int, "Peer to which this key belongs", nullok = True),
28         'peer_key_id': Parameter(int, "Foreign key identifier at peer", nullok = True),
29         }
30
31     def validate_key_type(self, key_type):
32         key_types = [row['key_type'] for row in KeyTypes(self.api)]
33         if key_type not in key_types:
34             raise PLCInvalidArgument, "Invalid key type"
35         return key_type
36
37     def validate_key(self, key):
38         # Key must not be blacklisted
39         rows = self.api.db.selectall("SELECT 1 from keys" \
40                                      " WHERE key = %(key)s" \
41                                      " AND is_blacklisted IS True",
42                                      locals())
43         if rows:
44             raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
45
46         return key
47
48     def validate(self):
49         # Basic validation
50         Row.validate(self)
51
52         assert 'key' in self
53         key = self['key']
54
55         if self['key_type'] == 'ssh':
56             # Accept only SSH version 2 keys without options. From
57             # sshd(8):
58             #
59             # Each protocol version 2 public key consists of: options,
60             # keytype, base64 encoded key, comment.  The options field
61             # is optional...The comment field is not used for anything
62             # (but may be convenient for the user to identify the
63             # key). For protocol version 2 the keytype is ``ssh-dss''
64             # or ``ssh-rsa''.
65
66             good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
67             if not re.match(good_ssh_key, key, re.IGNORECASE):
68                 raise PLCInvalidArgument, "Invalid SSH version 2 public key"
69
70     def blacklist(self, commit = True):
71         """
72         Permanently blacklist key (and all other identical keys),
73         preventing it from ever being added again. Because this could
74         affect multiple keys associated with multiple accounts, it
75         should be admin only.        
76         """
77
78         assert 'key_id' in self
79         assert 'key' in self
80
81         # Get all matching keys
82         rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s",
83                                      self)
84         key_ids = [row['key_id'] for row in rows]
85         assert key_ids
86         assert self['key_id'] in key_ids
87
88         # Keep the keys in the table
89         self.api.db.do("UPDATE keys SET is_blacklisted = True" \
90                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
91
92         # But disassociate them from all join tables
93         for table in self.join_tables:
94             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
95                            (table, ", ".join(map(str, key_ids))))
96
97         if commit:
98             self.api.db.commit()
99
100 class Keys(Table):
101     """
102     Representation of row(s) from the keys table in the
103     database.
104     """
105
106     def __init__(self, api, key_filter = None, columns = None):
107         Table.__init__(self, api, Key, columns)
108         
109         sql = "SELECT %s FROM view_keys WHERE is_blacklisted IS False" % \
110               ", ".join(self.columns)
111
112         if key_filter is not None:
113             if isinstance(key_filter, (list, tuple, set)):
114                 key_filter = Filter(Key.fields, {'key_id': key_filter})
115             elif isinstance(key_filter, dict):
116                 key_filter = Filter(Key.fields, key_filter)
117             sql += " AND (%s) %s" % key_filter.sql(api)
118
119         self.selectall(sql)