- re-enable return_fields specification
[plcapi.git] / PLC / Keys.py
1 import re
2
3 from PLC.Faults import *
4 from PLC.Parameter import Parameter
5 from PLC.Filter import Filter
6 from PLC.Debug import profile
7 from PLC.Table import Row, Table
8 from PLC.KeyTypes import KeyType, KeyTypes
9
10 class Key(Row):
11     """
12     Representation of a row in the keys table. To use, instantiate with a 
13     dict of values. Update as you would a dict. Commit to the database 
14     with sync().
15     """
16
17     table_name = 'keys'
18     primary_key = 'key_id'
19     join_tables = ['person_key']
20     fields = {
21         'key_id': Parameter(int, "Key identifier"),
22         'key_type': Parameter(str, "Key type"),
23         'key': Parameter(str, "Key value", max = 4096),
24         }
25
26     def validate_key_type(self, key_type):
27         key_types = [row['key_type'] for row in KeyTypes(self.api)]
28         if key_type not in key_types:
29             raise PLCInvalidArgument, "Invalid key type"
30         return key_type
31
32     def validate_key(self, key):
33         # Key must not be blacklisted
34         rows = self.api.db.selectall("SELECT 1 from keys" \
35                                      " WHERE key = %(key)s" \
36                                      " AND is_blacklisted IS True",
37                                      locals())
38         if rows:
39             raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
40
41         return key
42
43     def validate(self):
44         # Basic validation
45         Row.validate(self)
46
47         assert 'key' in self
48         key = self['key']
49
50         if self['key_type'] == 'ssh':
51             # Accept only SSH version 2 keys without options. From
52             # sshd(8):
53             #
54             # Each protocol version 2 public key consists of: options,
55             # keytype, base64 encoded key, comment.  The options field
56             # is optional...The comment field is not used for anything
57             # (but may be convenient for the user to identify the
58             # key). For protocol version 2 the keytype is ``ssh-dss''
59             # or ``ssh-rsa''.
60
61             good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
62             if not re.match(good_ssh_key, key, re.IGNORECASE):
63                 raise PLCInvalidArgument, "Invalid SSH version 2 public key"
64
65     def blacklist(self, commit = True):
66         """
67         Permanently blacklist key (and all other identical keys),
68         preventing it from ever being added again. Because this could
69         affect multiple keys associated with multiple accounts, it
70         should be admin only.        
71         """
72
73         assert 'key_id' in self
74         assert 'key' in self
75
76         # Get all matching keys
77         rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s",
78                                      self)
79         key_ids = [row['key_id'] for row in rows]
80         assert key_ids
81         assert self['key_id'] in key_ids
82
83         # Keep the keys in the table
84         self.api.db.do("UPDATE keys SET is_blacklisted = True" \
85                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
86
87         # But disassociate them from all join tables
88         for table in ['person_key']:
89             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
90                            (table, ", ".join(map(str, key_ids))))
91
92         if commit:
93             self.api.db.commit()
94
95 class Keys(Table):
96     """
97     Representation of row(s) from the keys table in the
98     database.
99     """
100
101     def __init__(self, api, key_filter = None, columns = None):
102         Table.__init__(self, api, Key, columns)
103         
104         sql = "SELECT %s FROM keys WHERE is_blacklisted IS False" % \
105               ", ".join(self.columns)
106
107         if key_filter is not None:
108             if isinstance(key_filter, (list, tuple, set)):
109                 key_filter = Filter(Key.fields, {'key_id': key_filter})
110             elif isinstance(key_filter, dict):
111                 key_filter = Filter(Key.fields, key_filter)
112             sql += " AND (%s)" % key_filter.sql(api)
113
114         self.selectall(sql)