- use Table.selectall()
[plcapi.git] / PLC / Keys.py
1 import re
2
3 from PLC.Faults import *
4 from PLC.Parameter import Parameter
5 from PLC.Debug import profile
6 from PLC.Table import Row, Table
7 from PLC.KeyTypes import KeyType, KeyTypes
8
9 class Key(Row):
10     """
11     Representation of a row in the keys table. To use, instantiate with a 
12     dict of values. Update as you would a dict. Commit to the database 
13     with sync().
14     """
15
16     table_name = 'keys'
17     primary_key = 'key_id'
18     join_tables = ['person_key']
19     fields = {
20         'key_id': Parameter(int, "Key identifier"),
21         'key_type': Parameter(str, "Key type"),
22         'key': Parameter(str, "Key value", max = 4096),
23         }
24
25     def validate_key_type(self, key_type):
26         if key_type not in KeyTypes(self.api):
27             raise PLCInvalidArgument, "Invalid key type"
28         return key_type
29
30     def validate_key(self, key):
31         # Key must not be blacklisted
32         rows = self.api.db.selectall("SELECT 1 from keys" \
33                                      " WHERE key = %(key)s" \
34                                      " AND is_blacklisted IS True",
35                                      locals())
36         if rows:
37             raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
38
39         return key
40
41     def validate(self):
42         # Basic validation
43         Row.validate(self)
44
45         assert 'key' in self
46         key = self['key']
47
48         if self['key_type'] == 'ssh':
49             # Accept only SSH version 2 keys without options. From
50             # sshd(8):
51             #
52             # Each protocol version 2 public key consists of: options,
53             # keytype, base64 encoded key, comment.  The options field
54             # is optional...The comment field is not used for anything
55             # (but may be convenient for the user to identify the
56             # key). For protocol version 2 the keytype is ``ssh-dss''
57             # or ``ssh-rsa''.
58
59             good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
60             if not re.match(good_ssh_key, key, re.IGNORECASE):
61                 raise PLCInvalidArgument, "Invalid SSH version 2 public key"
62
63     def blacklist(self, commit = True):
64         """
65         Permanently blacklist key (and all other identical keys),
66         preventing it from ever being added again. Because this could
67         affect multiple keys associated with multiple accounts, it
68         should be admin only.        
69         """
70
71         assert 'key_id' in self
72         assert 'key' in self
73
74         # Get all matching keys
75         rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s",
76                                      self)
77         key_ids = [row['key_id'] for row in rows]
78         assert key_ids
79         assert self['key_id'] in key_ids
80
81         # Keep the keys in the table
82         self.api.db.do("UPDATE keys SET is_blacklisted = True" \
83                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
84
85         # But disassociate them from all join tables
86         for table in ['person_key']:
87             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
88                            (table, ", ".join(map(str, key_ids))))
89
90         if commit:
91             self.api.db.commit()
92
93 class Keys(Table):
94     """
95     Representation of row(s) from the keys table in the
96     database.
97     """
98
99     def __init__(self, api, key_ids = None, is_blacklisted = False):
100         Table.__init__(self, api, Key)
101         
102         sql = "SELECT %s FROM keys WHERE is_blacklisted IS False" % \
103               ", ".join(Key.fields)
104
105         if key_ids:
106             sql += " AND key_id IN (%s)" %  ", ".join(map(str, key_ids))
107
108         self.selectall(sql)