1397a599597831af6136a631523429794c27c1cb
[sfa.git] / util / cert.py
1 # cert.py
2 #
3 # a general purpose class for dealing with certificates
4 #
5 # this class serves as an interface between a lower-level X.509 certificate
6 # library such as pyOpenSSL or M2Crypto. Currently both of these libraries
7 # are being used due to lack of functionality in pyOpenSSL and some apparant
8 # bugs in M2Crypto
9
10 import os
11 import tempfile
12 from OpenSSL import crypto
13 import M2Crypto
14 from M2Crypto import X509
15 from M2Crypto import EVP
16
17 # Keypair
18 #
19 # represents a private/public key pair, or a public key
20
21 class Keypair:
22    key = None       # public/private keypair
23    m2key = None     # public key (m2crypto format)
24
25    def __init__(self, create=False, string=None, filename=None):
26       if create:
27          self.create()
28       if string:
29          self.load_from_string(string)
30       if filename:
31          self.load_from_file(filename)
32
33    def create(self):
34       self.key = crypto.PKey()
35       self.key.generate_key(crypto.TYPE_RSA, 1024)
36
37    def save_to_file(self, filename):
38       open(filename, 'w').write(self.as_pem())
39
40    def load_from_file(self, filename):
41       buffer = open(filename, 'r').read()
42       self.load_from_string(buffer)
43
44    def load_from_string(self, string):
45       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
46       self.m2key = M2Crypto.EVP.load_key_string(string)
47
48    def load_pubkey_from_file(self, filename):
49       # load the m2 public key
50       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
51       self.m2key = M2Crypto.EVP.PKey()
52       self.m2key.assign_rsa(m2rsakey)
53
54       # create an m2 x509 cert
55       m2name = M2Crypto.X509.X509_Name()
56       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
57       m2x509 = M2Crypto.X509.X509()
58       m2x509.set_pubkey(self.m2key)
59       m2x509.set_serial_number(0)
60       m2x509.set_issuer_name(m2name)
61       m2x509.set_subject_name(m2name)
62       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
63       ASN1.set_time(500)
64       m2x509.set_not_before(ASN1)
65       m2x509.set_not_after(ASN1)
66       junk_key = Keypair(create=True)
67       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
68
69       # convert the m2 x509 cert to a pyopenssl x509
70       m2pem = m2x509.as_pem()
71       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
72
73       # get the pyopenssl pkey from the pyopenssl x509
74       self.key = pyx509.get_pubkey()
75
76    def load_pubkey_from_string(self, string):
77       (f, fn) = tempfile.mkstemp()
78       os.write(f, string)
79       os.close(f)
80       self.load_pubkey_from_file(fn)
81       os.remove(fn)
82
83    def as_pem(self):
84       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
85
86    def get_m2_pkey(self):
87       if not self.m2key:
88          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
89       return self.m2key
90
91    def get_openssl_pkey(self):
92       return self.key
93
94    def is_same(self, pkey):
95       return self.as_pem() == pkey.as_pem()
96
97 # Certificate
98 #
99 # Represents an X.509 certificate. Support is included for a list of
100 # certificates by use of a "parent" member. See load_from_string() and
101 # save_to_string() for insight into how a recursive chain of certs is
102 # serialized.
103 #
104 # Certificates support an application-defined "data" field, which is
105 # stored in the subjectAltName field of the X.509 certificate.
106
107 class Certificate:
108    digest = "md5"
109
110    data = None
111    cert = None
112    issuerKey = None
113    issuerSubject = None
114    parent = None
115
116    def __init__(self, create=False, subject=None, string=None, filename=None):
117        if create or subject:
118            self.create()
119        if subject:
120            self.set_subject(subject)
121        if string:
122            self.load_from_string(string)
123        if filename:
124            self.load_from_file(filename)
125
126    def create(self):
127        self.cert = crypto.X509()
128        self.cert.set_serial_number(1)
129        self.cert.gmtime_adj_notBefore(0)
130        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
131
132    def load_from_pyopenssl_x509(self, x509):
133        self.cert = x509
134
135    def load_from_string(self, string):
136        # if it is a chain of multiple certs, then split off the first one and
137        # load it
138        parts = string.split("-----parent-----", 1)
139        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
140
141        # if there are more certs, then create a parent and let the parent load
142        # itself from the remainder of the string
143        if len(parts) > 1:
144            self.parent = Certificate()
145            self.parent.load_from_string(parts[1])
146
147
148    def load_from_file(self, filename):
149        file = open(filename)
150        string = file.read()
151        self.load_from_string(string)
152
153    def save_to_string(self, save_parents=False):
154        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
155        if save_parents and self.parent:
156           string = string + "-----parent-----" + self.parent.save_to_string(save_parents)
157        return string
158
159    def save_to_file(self, filename, save_parents=False):
160        string = self.save_to_string(save_parents=save_parents)
161        open(filename, 'w').write(string)
162
163    def set_issuer(self, key, subject=None, cert=None):
164        self.issuerKey = key
165        if subject:
166           # it's a mistake to use subject and cert params at the same time
167           assert(not cert)
168           if isinstance(subject, dict) or isinstance(subject, str):
169              req = crypto.X509Req()
170              reqSubject = req.get_subject()
171              if (isinstance(subject, dict)):
172                 for key in reqSubject.keys():
173                     setattr(reqSubject, key, name[key])
174              else:
175                 setattr(reqSubject, "CN", subject)
176              subject = reqSubject
177              # subject is not valid once req is out of scope, so save req
178              self.issuerReq = req
179        if cert:
180           # if a cert was supplied, then get the subject from the cert
181           subject = cert.cert.get_issuer()
182        assert(subject)
183        self.issuerSubject = subject
184
185    def get_issuer(self, which="CN"):
186        x = self.cert.get_issuer()
187        return getattr(x, which)
188
189    def set_subject(self, name):
190        req = crypto.X509Req()
191        subj = req.get_subject()
192        if (isinstance(name, dict)):
193            for key in name.keys():
194                setattr(subj, key, name[key])
195        else:
196            setattr(subj, "CN", name)
197        self.cert.set_subject(subj)
198
199    def get_subject(self, which="CN"):
200        x = self.cert.get_subject()
201        return getattr(x, which)
202
203    def set_pubkey(self, key):
204        assert(isinstance(key, Keypair))
205        self.cert.set_pubkey(key.get_openssl_pkey())
206
207    def get_pubkey(self):
208        m2x509 = X509.load_cert_string(self.save_to_string())
209        pkey = Keypair()
210        pkey.key = self.cert.get_pubkey()
211        pkey.m2key = m2x509.get_pubkey()
212        return pkey
213
214    def add_extension(self, name, critical, value):
215        ext = crypto.X509Extension (name, critical, value)
216        self.cert.add_extensions([ext])
217
218    def get_extension(self, name):
219        # pyOpenSSL does not have a way to get extensions
220        m2x509 = X509.load_cert_string(self.save_to_string())
221        value = m2x509.get_ext(name).get_value()
222        return value
223
224    def set_data(self, str):
225        # pyOpenSSL only allows us to add extensions, so if we try to set the
226        # same extension more than once, it will not work
227        if self.data != None:
228           raise "cannot set subjectAltName more than once"
229        self.data = str
230        self.add_extension("subjectAltName", 0, "URI:http://" + str)
231
232    def get_data(self):
233        if self.data:
234            return self.data
235
236        try:
237            uri = self.get_extension("subjectAltName")
238        except LookupError:
239            self.data = None
240            return self.data
241
242        if not uri.startswith("URI:http://"):
243            raise "bad encoding in subjectAltName"
244        self.data = uri[11:]
245        return self.data
246
247    def sign(self):
248        assert self.cert != None
249        assert self.issuerSubject != None
250        assert self.issuerKey != None
251        self.cert.set_issuer(self.issuerSubject)
252        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
253
254    def verify(self, pkey):
255        # pyOpenSSL does not have a way to verify signatures
256        m2x509 = X509.load_cert_string(self.save_to_string())
257        m2pkey = pkey.get_m2_pkey()
258        # verify it
259        return m2x509.verify(m2pkey)
260
261        # XXX alternatively, if openssl has been patched, do the much simpler:
262        # try:
263        #   self.cert.verify(pkey.get_openssl_key())
264        #   return 1
265        # except:
266        #   return 0
267
268    def is_pubkey(self, pkey):
269        return self.get_pubkey().is_same(pkey)
270
271    def is_signed_by_cert(self, cert):
272        k = cert.get_pubkey()
273        result = self.verify(k)
274        return result
275
276    def set_parent(self, p):
277         self.parent = p
278
279    def get_parent(self):
280         return self.parent
281
282    def verify_chain(self, trusted_certs = None):
283         # Verify a chain of certificates. Each certificate must be signed by
284         # the public key contained in it's parent. The chain is recursed
285         # until a certificate is found that is signed by a trusted root.
286
287         # TODO: verify expiration time
288
289         # if this cert is signed by a trusted_cert, then we are set
290         for trusted_cert in trusted_certs:
291             # TODO: verify expiration of trusted_cert ?
292             if self.is_signed_by_cert(trusted_cert):
293                 #print self.get_subject(), "is signed by a root"
294                 return
295
296         # if there is no parent, then no way to verify the chain
297         if not self.parent:
298             #print self.get_subject(), "has no parent"
299             raise MissingParent(self.get_subject())
300
301         # if it wasn't signed by the parent...
302         if not self.is_signed_by_cert(self.parent):
303             #print self.get_subject(), "is not signed by parent"
304             return NotSignedByParent(self.get_subject())
305
306         # if the parent isn't verified...
307         self.parent.verify_chain(trusted_certs)
308
309         return