add functions to sign and verify strings using a key
[sfa.git] / geni / util / cert.py
1 ##
2 # Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13
14 import os
15 import tempfile
16 import base64
17 from OpenSSL import crypto
18 import M2Crypto
19 from M2Crypto import X509
20 from M2Crypto import EVP
21
22 from excep import *
23
24 ##
25 # Public-private key pairs are implemented by the Keypair class.
26 # A Keypair object may represent both a public and private key pair, or it
27 # may represent only a public key (this usage is consistent with OpenSSL).
28
29 class Keypair:
30    key = None       # public/private keypair
31    m2key = None     # public key (m2crypto format)
32
33    ##
34    # Creates a Keypair object
35    # @param create If create==True, creates a new public/private key and
36    #     stores it in the object
37    # @param string If string!=None, load the keypair from the string (PEM)
38    # @param filename If filename!=None, load the keypair from the file
39
40    def __init__(self, create=False, string=None, filename=None):
41       if create:
42          self.create()
43       if string:
44          self.load_from_string(string)
45       if filename:
46          self.load_from_file(filename)
47
48    ##
49    # Create a RSA public/private key pair and store it inside the keypair object
50
51    def create(self):\r
52       self.key = crypto.PKey()
53       self.key.generate_key(crypto.TYPE_RSA, 1024)
54
55    ##
56    # Save the private key to a file
57    # @param filename name of file to store the keypair in
58
59    def save_to_file(self, filename):
60       open(filename, 'w').write(self.as_pem())
61
62    ##
63    # Load the private key from a file. Implicity the private key includes the public key.
64
65    def load_from_file(self, filename):
66       buffer = open(filename, 'r').read()
67       self.load_from_string(buffer)
68
69    ##
70    # Load the private key from a string. Implicitly the private key includes the public key.
71
72    def load_from_string(self, string):
73       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
74       self.m2key = M2Crypto.EVP.load_key_string(string)
75
76    ##
77    #  Load the public key from a string. No private key is loaded. 
78
79    def load_pubkey_from_file(self, filename):
80       # load the m2 public key
81       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
82       self.m2key = M2Crypto.EVP.PKey()
83       self.m2key.assign_rsa(m2rsakey)
84
85       # create an m2 x509 cert
86       m2name = M2Crypto.X509.X509_Name()
87       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
88       m2x509 = M2Crypto.X509.X509()
89       m2x509.set_pubkey(self.m2key)
90       m2x509.set_serial_number(0)
91       m2x509.set_issuer_name(m2name)
92       m2x509.set_subject_name(m2name)
93       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
94       ASN1.set_time(500)
95       m2x509.set_not_before(ASN1)
96       m2x509.set_not_after(ASN1)
97       junk_key = Keypair(create=True)
98       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
99
100       # convert the m2 x509 cert to a pyopenssl x509
101       m2pem = m2x509.as_pem()
102       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
103
104       # get the pyopenssl pkey from the pyopenssl x509
105       self.key = pyx509.get_pubkey()
106
107    ##
108    # Load the public key from a string. No private key is loaded.
109
110    def load_pubkey_from_string(self, string):
111       (f, fn) = tempfile.mkstemp()
112       os.write(f, string)
113       os.close(f)
114       self.load_pubkey_from_file(fn)
115       os.remove(fn)
116
117    ##
118    # Return the private key in PEM format.
119
120    def as_pem(self):
121       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
122
123    def get_m2_pkey(self):
124       if not self.m2key:
125          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
126       return self.m2key
127
128    ##
129    # Return an OpenSSL pkey object
130
131    def get_openssl_pkey(self):
132       return self.key
133
134    ##
135    # Given another Keypair object, return TRUE if the two keys are the same.
136
137    def is_same(self, pkey):
138       return self.as_pem() == pkey.as_pem()
139
140    def sign_string(self, data):
141       k = self.get_m2_pkey()
142       k.sign_init()
143       k.sign_update(data)
144       return base64.b64encode(k.sign_final())
145
146    def verify_string(self, data, sig):
147       k = self.get_m2_pkey()
148       k.verify_init()
149       k.verify_update(data)
150       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
151
152 ##
153 # The certificate class implements a general purpose X509 certificate, making
154 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
155 # several addition features, such as the ability to maintain a chain of
156 # parent certificates, and storage of application-specific data.
157 #
158 # Certificates include the ability to maintain a chain of parents. Each\r
159 # certificate includes a pointer to it's parent certificate. When loaded\r
160 # from a file or a string, the parent chain will be automatically loaded.\r
161 # When saving a certificate to a file or a string, the caller can choose\r
162 # whether to save the parent certificates as well.\r
163
164 class Certificate:
165    digest = "md5"
166
167    data = None
168    cert = None
169    issuerKey = None
170    issuerSubject = None
171    parent = None
172
173    ##
174    # Create a certificate object.
175    #\r
176    # @param create If create==True, then also create a blank X509 certificate.\r
177    # @param subject If subject!=None, then create a blank certificate and set\r
178    #     it's subject name.\r
179    # @param string If string!=None, load the certficate from the string.\r
180    # @param filename If filename!=None, load the certficiate from the file.\r
181
182    def __init__(self, create=False, subject=None, string=None, filename=None):
183        if create or subject:
184            self.create()
185        if subject:
186            self.set_subject(subject)
187        if string:
188            self.load_from_string(string)
189        if filename:
190            self.load_from_file(filename)
191
192    ##
193    # Create a blank X509 certificate and store it in this object.
194
195    def create(self):
196        self.cert = crypto.X509()
197        self.cert.set_serial_number(1)
198        self.cert.gmtime_adj_notBefore(0)
199        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
200
201    ##
202    # Given a pyOpenSSL X509 object, store that object inside of this
203    # certificate object.
204
205    def load_from_pyopenssl_x509(self, x509):
206        self.cert = x509
207
208    ##
209    # Load the certificate from a string
210
211    def load_from_string(self, string):
212        # if it is a chain of multiple certs, then split off the first one and
213        # load it
214        parts = string.split("-----parent-----", 1)
215        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
216
217        # if there are more certs, then create a parent and let the parent load
218        # itself from the remainder of the string
219        if len(parts) > 1:
220            self.parent = self.__class__()
221            self.parent.load_from_string(parts[1])
222
223    ##
224    # Load the certificate from a file
225
226    def load_from_file(self, filename):
227        file = open(filename)
228        string = file.read()
229        self.load_from_string(string)
230
231    ##
232    # Save the certificate to a string.
233    #
234    # @param save_parents If save_parents==True, then also save the parent certificates.
235
236    def save_to_string(self, save_parents=False):
237        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
238        if save_parents and self.parent:
239           string = string + "-----parent-----" + self.parent.save_to_string(save_parents)
240        return string
241
242    ##
243    # Save the certificate to a file.
244    # @param save_parents If save_parents==True, then also save the parent certificates.
245
246    def save_to_file(self, filename, save_parents=False):
247        string = self.save_to_string(save_parents=save_parents)
248        open(filename, 'w').write(string)
249
250    ##
251    # Sets the issuer private key and name
252    # @param key Keypair object containing the private key of the issuer
253    # @param subject String containing the name of the issuer
254    # @param cert (optional) Certificate object containing the name of the issuer
255
256    def set_issuer(self, key, subject=None, cert=None):
257        self.issuerKey = key
258        if subject:
259           # it's a mistake to use subject and cert params at the same time
260           assert(not cert)
261           if isinstance(subject, dict) or isinstance(subject, str):
262              req = crypto.X509Req()
263              reqSubject = req.get_subject()
264              if (isinstance(subject, dict)):
265                 for key in reqSubject.keys():
266                     setattr(reqSubject, key, name[key])
267              else:
268                 setattr(reqSubject, "CN", subject)
269              subject = reqSubject
270              # subject is not valid once req is out of scope, so save req
271              self.issuerReq = req
272        if cert:
273           # if a cert was supplied, then get the subject from the cert
274           subject = cert.cert.get_issuer()
275        assert(subject)
276        self.issuerSubject = subject
277
278    ##
279    # Get the issuer name
280
281    def get_issuer(self, which="CN"):
282        x = self.cert.get_issuer()
283        return getattr(x, which)
284
285    ##
286    # Set the subject name of the certificate
287
288    def set_subject(self, name):
289        req = crypto.X509Req()
290        subj = req.get_subject()
291        if (isinstance(name, dict)):
292            for key in name.keys():
293                setattr(subj, key, name[key])
294        else:
295            setattr(subj, "CN", name)
296        self.cert.set_subject(subj)
297    ##
298    # Get the subject name of the certificate
299
300    def get_subject(self, which="CN"):
301        x = self.cert.get_subject()
302        return getattr(x, which)
303
304    ##
305    # Get the public key of the certificate.
306    #
307    # @param key Keypair object containing the public key
308
309    def set_pubkey(self, key):
310        assert(isinstance(key, Keypair))
311        self.cert.set_pubkey(key.get_openssl_pkey())
312
313    ##
314    # Get the public key of the certificate.
315    # It is returned in the form of a Keypair object.
316
317    def get_pubkey(self):
318        m2x509 = X509.load_cert_string(self.save_to_string())
319        pkey = Keypair()
320        pkey.key = self.cert.get_pubkey()
321        pkey.m2key = m2x509.get_pubkey()
322        return pkey
323
324    ##
325    # Add an X509 extension to the certificate. Add_extension can only be called
326    # once for a particular extension name, due to limitations in the underlying
327    # library.
328    #
329    # @param name string containing name of extension
330    # @param value string containing value of the extension
331
332    def add_extension(self, name, critical, value):
333        ext = crypto.X509Extension (name, critical, value)
334        self.cert.add_extensions([ext])
335
336    ##
337    # Get an X509 extension from the certificate
338
339    def get_extension(self, name):
340        # pyOpenSSL does not have a way to get extensions
341        m2x509 = X509.load_cert_string(self.save_to_string())
342        value = m2x509.get_ext(name).get_value()
343        return value
344
345    ##
346    # Set_data is a wrapper around add_extension. It stores the parameter str in
347    # the X509 subject_alt_name extension. Set_data can only be called once, due
348    # to limitations in the underlying library.
349
350    def set_data(self, str):
351        # pyOpenSSL only allows us to add extensions, so if we try to set the
352        # same extension more than once, it will not work
353        if self.data != None:
354           raise "cannot set subjectAltName more than once"
355        self.data = str
356        self.add_extension("subjectAltName", 0, "URI:http://" + str)
357
358    ##
359    # Return the data string that was previously set with set_data
360
361    def get_data(self):
362        if self.data:
363            return self.data
364
365        try:
366            uri = self.get_extension("subjectAltName")
367        except LookupError:
368            self.data = None
369            return self.data
370
371        if not uri.startswith("URI:http://"):
372            raise "bad encoding in subjectAltName"
373        self.data = uri[11:]
374        return self.data
375
376    ##
377    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
378
379    def sign(self):
380        assert self.cert != None
381        assert self.issuerSubject != None
382        assert self.issuerKey != None
383        self.cert.set_issuer(self.issuerSubject)
384        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
385
386     ##
387     # Verify the authenticity of a certificate.
388     # @param pkey is a Keypair object representing a public key. If Pkey
389     #     did not sign the certificate, then an exception will be thrown.
390 \r
391    def verify(self, pkey):
392        # pyOpenSSL does not have a way to verify signatures
393        m2x509 = X509.load_cert_string(self.save_to_string())
394        m2pkey = pkey.get_m2_pkey()
395        # verify it
396        return m2x509.verify(m2pkey)
397
398        # XXX alternatively, if openssl has been patched, do the much simpler:
399        # try:
400        #   self.cert.verify(pkey.get_openssl_key())
401        #   return 1
402        # except:
403        #   return 0
404
405    ##
406    # Return True if pkey is identical to the public key that is contained in the certificate.
407    # @param pkey Keypair object
408
409    def is_pubkey(self, pkey):
410        return self.get_pubkey().is_same(pkey)
411
412    ##
413    # Given a certificate cert, verify that this certificate was signed by the
414    # public key contained in cert. Throw an exception otherwise.
415    #
416    # @param cert certificate object
417
418    def is_signed_by_cert(self, cert):
419        k = cert.get_pubkey()
420        result = self.verify(k)
421        return result
422
423    ##
424    # Set the parent certficiate.
425    #
426    # @param p certificate object.
427
428    def set_parent(self, p):
429         self.parent = p
430
431    ##
432    # Return the certificate object of the parent of this certificate.
433
434    def get_parent(self):
435         return self.parent
436
437    ##
438    # Verification examines a chain of certificates to ensure that each parent
439    # signs the child, and that some certificate in the chain is signed by a
440    # trusted certificate.
441    #
442    # Verification is a basic recursion: <pre>
443    #     if this_certificate was signed by trusted_certs:\r
444    #         return\r
445    #     else\r
446    #         return verify_chain(parent, trusted_certs)\r
447    # </pre>\r
448    #\r
449    # At each recursion, the parent is tested to ensure that it did sign the
450    # child. If a parent did not sign a child, then an exception is thrown. If
451    # the bottom of the recursion is reached and the certificate does not match
452    # a trusted root, then an exception is thrown.
453    #
454    # @param Trusted_certs is a list of certificates that are trusted.\r
455    #
456
457    def verify_chain(self, trusted_certs = None):
458         # Verify a chain of certificates. Each certificate must be signed by
459         # the public key contained in it's parent. The chain is recursed
460         # until a certificate is found that is signed by a trusted root.
461
462         # TODO: verify expiration time
463
464         # if this cert is signed by a trusted_cert, then we are set
465         for trusted_cert in trusted_certs:
466             # TODO: verify expiration of trusted_cert ?
467             if self.is_signed_by_cert(trusted_cert):
468                 #print self.get_subject(), "is signed by a root"
469                 return
470
471         # if there is no parent, then no way to verify the chain
472         if not self.parent:
473             #print self.get_subject(), "has no parent"
474             raise CertMissingParent(self.get_subject())
475
476         # if it wasn't signed by the parent...
477         if not self.is_signed_by_cert(self.parent):
478             #print self.get_subject(), "is not signed by parent"
479             return CertNotSignedByParent(self.get_subject())
480
481         # if the parent isn't verified...
482         self.parent.verify_chain(trusted_certs)
483
484         return