- remove is_primary
[plcapi.git] / PLC / Keys.py
1 import re
2
3 from PLC.Faults import *
4 from PLC.Parameter import Parameter
5 from PLC.Debug import profile
6 from PLC.Table import Row, Table
7 from PLC.KeyTypes import KeyType, KeyTypes
8
9 class Key(Row):
10     """
11     Representation of a row in the keys table. To use, instantiate with a 
12     dict of values. Update as you would a dict. Commit to the database 
13     with sync().
14     """
15
16     table_name = 'keys'
17     primary_key = 'key_id'
18     fields = {
19         'key_id': Parameter(int, "Key identifier"),
20         'key_type': Parameter(str, "Key type"),
21         'key': Parameter(str, "Key value", max = 4096),
22         }
23
24     def __init__(self, api, fields = {}):
25         Row.__init__(self, fields)
26         self.api = api
27
28     def validate_key_type(self, key_type):
29         if key_type not in KeyTypes(self.api):
30             raise PLCInvalidArgument, "Invalid key type"
31         return key_type
32
33     def validate_key(self, key):
34         key = key.strip()
35
36         # Key must not be blacklisted
37         rows = self.api.db.selectall("SELECT 1 from keys" \
38                                      " WHERE key = %(key)s" \
39                                      " AND is_blacklisted IS True",
40                                      locals())
41         if rows:
42             raise PLCInvalidArgument, "Key is blacklisted and cannot be used"
43
44         return key
45
46     def validate(self):
47         # Basic validation
48         Row.validate(self)
49
50         assert 'key' in self
51         key = self['key']
52
53         if self['key_type'] == 'ssh':
54             # Accept only SSH version 2 keys without options. From
55             # sshd(8):
56             #
57             # Each protocol version 2 public key consists of: options,
58             # keytype, base64 encoded key, comment.  The options field
59             # is optional...The comment field is not used for anything
60             # (but may be convenient for the user to identify the
61             # key). For protocol version 2 the keytype is ``ssh-dss''
62             # or ``ssh-rsa''.
63
64             good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
65             if not re.match(good_ssh_key, key, re.IGNORECASE):
66                 raise PLCInvalidArgument, "Invalid SSH version 2 public key"
67
68     def blacklist(self, commit = True):
69         """
70         Permanently blacklist key (and all other identical keys),
71         preventing it from ever being added again. Because this could
72         affect multiple keys associated with multiple accounts, it
73         should be admin only.        
74         """
75
76         assert 'key_id' in self
77         assert 'key' in self
78
79         # Get all matching keys
80         rows = self.api.db.selectall("SELECT key_id FROM keys WHERE key = %(key)s",
81                                      self)
82         key_ids = [row['key_id'] for row in rows]
83         assert key_ids
84         assert self['key_id'] in key_ids
85
86         # Keep the keys in the table
87         self.api.db.do("UPDATE keys SET is_blacklisted = True" \
88                        " WHERE key_id IN (%s)" % ", ".join(map(str, key_ids)))
89
90         # But disassociate them from all join tables
91         for table in ['person_key']:
92             self.api.db.do("DELETE FROM %s WHERE key_id IN (%s)" % \
93                            (table, ", ".join(map(str, key_ids))))
94
95         if commit:
96             self.api.db.commit()
97
98     def delete(self, commit = True):
99         """
100         Delete key from the database.
101         """
102
103         assert 'key_id' in self
104         
105         for table in ['person_key', 'keys']:
106             self.api.db.do("DELETE FROM %s WHERE key_id = %d" % \
107                            (table, self['key_id']))
108         
109         if commit:
110             self.api.db.commit()
111
112 class Keys(Table):
113     """
114     Representation of row(s) from the keys table in the
115     database.
116     """
117
118     def __init__(self, api, key_id_list = None, is_blacklisted = False):
119         self.api = api
120         
121         sql = "SELECT %s FROM keys WHERE True" % \
122               ", ".join(Key.fields)
123
124         if is_blacklisted is not None:
125             sql += " AND is_blacklisted IS %(is_blacklisted)s"            
126
127         if key_id_list:
128             sql += " AND key_id IN (%s)" %  ", ".join(map(str, key_id_list))
129
130         rows = self.api.db.selectall(sql, locals())
131         
132         for row in rows:        
133             self[row['key_id']] = Key(api, row)