added a function to return the public key string from a keypair object
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 from OpenSSL import crypto
22 import M2Crypto
23 from M2Crypto import X509
24 from M2Crypto import EVP
25
26 from sfa.util.faults import *
27
28 def convert_public_key(key):
29     keyconvert_path = "/usr/bin/keyconvert"
30     if not os.path.isfile(keyconvert_path):
31         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
32
33     # we can only convert rsa keys 
34     if "ssh-dss" in key:
35         print "XXX: DSA key encountered, ignoring"
36         return None
37     
38     (ssh_f, ssh_fn) = tempfile.mkstemp()
39     ssl_fn = tempfile.mktemp()
40     os.write(ssh_f, key)
41     os.close(ssh_f)
42
43     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
44     os.system(cmd)
45
46     # this check leaves the temporary file containing the public key so
47     # that it can be expected to see why it failed.
48     # TODO: for production, cleanup the temporary files
49     if not os.path.exists(ssl_fn):
50         report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
51         return None
52
53     k = Keypair()
54     try:
55         k.load_pubkey_from_file(ssl_fn)
56     except:
57         print "XXX: Error while converting key: ", key_str
58         k = None
59
60     # remove the temporary files
61     os.remove(ssh_fn)
62     os.remove(ssl_fn)
63
64     return k
65
66 ##
67 # Public-private key pairs are implemented by the Keypair class.
68 # A Keypair object may represent both a public and private key pair, or it
69 # may represent only a public key (this usage is consistent with OpenSSL).
70
71 class Keypair:
72    key = None       # public/private keypair
73    m2key = None     # public key (m2crypto format)
74
75    ##
76    # Creates a Keypair object
77    # @param create If create==True, creates a new public/private key and
78    #     stores it in the object
79    # @param string If string!=None, load the keypair from the string (PEM)
80    # @param filename If filename!=None, load the keypair from the file
81
82    def __init__(self, create=False, string=None, filename=None):
83       if create:
84          self.create()
85       if string:
86          self.load_from_string(string)
87       if filename:
88          self.load_from_file(filename)
89
90    ##
91    # Create a RSA public/private key pair and store it inside the keypair object
92
93    def create(self):
94       self.key = crypto.PKey()
95       self.key.generate_key(crypto.TYPE_RSA, 1024)
96
97    ##
98    # Save the private key to a file
99    # @param filename name of file to store the keypair in
100
101    def save_to_file(self, filename):
102       open(filename, 'w').write(self.as_pem())
103
104    ##
105    # Load the private key from a file. Implicity the private key includes the public key.
106
107    def load_from_file(self, filename):
108       buffer = open(filename, 'r').read()
109       self.load_from_string(buffer)
110
111    ##
112    # Load the private key from a string. Implicitly the private key includes the public key.
113
114    def load_from_string(self, string):
115       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
116       self.m2key = M2Crypto.EVP.load_key_string(string)
117
118    ##
119    #  Load the public key from a string. No private key is loaded. 
120
121    def load_pubkey_from_file(self, filename):
122       # load the m2 public key
123       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
124       self.m2key = M2Crypto.EVP.PKey()
125       self.m2key.assign_rsa(m2rsakey)
126
127       # create an m2 x509 cert
128       m2name = M2Crypto.X509.X509_Name()
129       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
130       m2x509 = M2Crypto.X509.X509()
131       m2x509.set_pubkey(self.m2key)
132       m2x509.set_serial_number(0)
133       m2x509.set_issuer_name(m2name)
134       m2x509.set_subject_name(m2name)
135       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
136       ASN1.set_time(500)
137       m2x509.set_not_before(ASN1)
138       m2x509.set_not_after(ASN1)
139       junk_key = Keypair(create=True)
140       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
141
142       # convert the m2 x509 cert to a pyopenssl x509
143       m2pem = m2x509.as_pem()
144       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
145
146       # get the pyopenssl pkey from the pyopenssl x509
147       self.key = pyx509.get_pubkey()
148
149    ##
150    # Load the public key from a string. No private key is loaded.
151
152    def load_pubkey_from_string(self, string):
153       (f, fn) = tempfile.mkstemp()
154       os.write(f, string)
155       os.close(f)
156       self.load_pubkey_from_file(fn)
157       os.remove(fn)
158
159    ##
160    # Return the private key in PEM format.
161
162    def as_pem(self):
163       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
164
165    ##
166    # Return an M2Crypto key object
167
168    def get_m2_pkey(self):
169       if not self.m2key:
170          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
171       return self.m2key
172
173    ##
174    # Returns a string containing the public key represented by this object.
175
176    def get_pubkey_string(self):
177       m2pkey = self.get_m2_pkey()
178       return base64.b64encode(m2pkey.as_der())
179
180    ##
181    # Return an OpenSSL pkey object
182
183    def get_openssl_pkey(self):
184       return self.key
185
186    ##
187    # Given another Keypair object, return TRUE if the two keys are the same.
188
189    def is_same(self, pkey):
190       return self.as_pem() == pkey.as_pem()
191
192    def sign_string(self, data):
193       k = self.get_m2_pkey()
194       k.sign_init()
195       k.sign_update(data)
196       return base64.b64encode(k.sign_final())
197
198    def verify_string(self, data, sig):
199       k = self.get_m2_pkey()
200       k.verify_init()
201       k.verify_update(data)
202       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
203
204 ##
205 # The certificate class implements a general purpose X509 certificate, making
206 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
207 # several addition features, such as the ability to maintain a chain of
208 # parent certificates, and storage of application-specific data.
209 #
210 # Certificates include the ability to maintain a chain of parents. Each
211 # certificate includes a pointer to it's parent certificate. When loaded
212 # from a file or a string, the parent chain will be automatically loaded.
213 # When saving a certificate to a file or a string, the caller can choose
214 # whether to save the parent certificates as well.
215
216 class Certificate:
217    digest = "md5"
218
219    data = None
220    cert = None
221    issuerKey = None
222    issuerSubject = None
223    parent = None
224
225    separator="-----parent-----"
226
227    ##
228    # Create a certificate object.
229    #
230    # @param create If create==True, then also create a blank X509 certificate.
231    # @param subject If subject!=None, then create a blank certificate and set
232    #     it's subject name.
233    # @param string If string!=None, load the certficate from the string.
234    # @param filename If filename!=None, load the certficiate from the file.
235
236    def __init__(self, create=False, subject=None, string=None, filename=None):
237        if create or subject:
238            self.create()
239        if subject:
240            self.set_subject(subject)
241        if string:
242            self.load_from_string(string)
243        if filename:
244            self.load_from_file(filename)
245
246    ##
247    # Create a blank X509 certificate and store it in this object.
248
249    def create(self):
250        self.cert = crypto.X509()
251        self.cert.set_serial_number(1)
252        self.cert.gmtime_adj_notBefore(0)
253        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
254
255    ##
256    # Given a pyOpenSSL X509 object, store that object inside of this
257    # certificate object.
258
259    def load_from_pyopenssl_x509(self, x509):
260        self.cert = x509
261
262    ##
263    # Load the certificate from a string
264
265    def load_from_string(self, string):
266        # if it is a chain of multiple certs, then split off the first one and
267        # load it
268        parts = string.split(Certificate.separator, 1)
269        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
270
271        # if there are more certs, then create a parent and let the parent load
272        # itself from the remainder of the string
273        if len(parts) > 1:
274            self.parent = self.__class__()
275            self.parent.load_from_string(parts[1])
276
277    ##
278    # Load the certificate from a file
279
280    def load_from_file(self, filename):
281        file = open(filename)
282        string = file.read()
283        self.load_from_string(string)
284
285    ##
286    # Save the certificate to a string.
287    #
288    # @param save_parents If save_parents==True, then also save the parent certificates.
289
290    def save_to_string(self, save_parents=False):
291        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
292        if save_parents and self.parent:
293           string = string + Certificate.separator + self.parent.save_to_string(save_parents)
294        return string
295
296    ##
297    # Save the certificate to a file.
298    # @param save_parents If save_parents==True, then also save the parent certificates.
299
300    def save_to_file(self, filename, save_parents=False):
301        string = self.save_to_string(save_parents=save_parents)
302        open(filename, 'w').write(string)
303
304    ##
305    # Sets the issuer private key and name
306    # @param key Keypair object containing the private key of the issuer
307    # @param subject String containing the name of the issuer
308    # @param cert (optional) Certificate object containing the name of the issuer
309
310    def set_issuer(self, key, subject=None, cert=None):
311        self.issuerKey = key
312        if subject:
313           # it's a mistake to use subject and cert params at the same time
314           assert(not cert)
315           if isinstance(subject, dict) or isinstance(subject, str):
316              req = crypto.X509Req()
317              reqSubject = req.get_subject()
318              if (isinstance(subject, dict)):
319                 for key in reqSubject.keys():
320                     setattr(reqSubject, key, name[key])
321              else:
322                 setattr(reqSubject, "CN", subject)
323              subject = reqSubject
324              # subject is not valid once req is out of scope, so save req
325              self.issuerReq = req
326        if cert:
327           # if a cert was supplied, then get the subject from the cert
328           subject = cert.cert.get_issuer()
329        assert(subject)
330        self.issuerSubject = subject
331
332    ##
333    # Get the issuer name
334
335    def get_issuer(self, which="CN"):
336        x = self.cert.get_issuer()
337        return getattr(x, which)
338
339    ##
340    # Set the subject name of the certificate
341
342    def set_subject(self, name):
343        req = crypto.X509Req()
344        subj = req.get_subject()
345        if (isinstance(name, dict)):
346            for key in name.keys():
347                setattr(subj, key, name[key])
348        else:
349            setattr(subj, "CN", name)
350        self.cert.set_subject(subj)
351    ##
352    # Get the subject name of the certificate
353
354    def get_subject(self, which="CN"):
355        x = self.cert.get_subject()
356        return getattr(x, which)
357
358    ##
359    # Get the public key of the certificate.
360    #
361    # @param key Keypair object containing the public key
362
363    def set_pubkey(self, key):
364        assert(isinstance(key, Keypair))
365        self.cert.set_pubkey(key.get_openssl_pkey())
366
367    ##
368    # Get the public key of the certificate.
369    # It is returned in the form of a Keypair object.
370
371    def get_pubkey(self):
372        m2x509 = X509.load_cert_string(self.save_to_string())
373        pkey = Keypair()
374        pkey.key = self.cert.get_pubkey()
375        pkey.m2key = m2x509.get_pubkey()
376        return pkey
377
378    ##
379    # Add an X509 extension to the certificate. Add_extension can only be called
380    # once for a particular extension name, due to limitations in the underlying
381    # library.
382    #
383    # @param name string containing name of extension
384    # @param value string containing value of the extension
385
386    def add_extension(self, name, critical, value):
387        ext = crypto.X509Extension (name, critical, value)
388        self.cert.add_extensions([ext])
389
390    ##
391    # Get an X509 extension from the certificate
392
393    def get_extension(self, name):
394        # pyOpenSSL does not have a way to get extensions
395        m2x509 = X509.load_cert_string(self.save_to_string())
396        value = m2x509.get_ext(name).get_value()
397        return value
398
399    ##
400    # Set_data is a wrapper around add_extension. It stores the parameter str in
401    # the X509 subject_alt_name extension. Set_data can only be called once, due
402    # to limitations in the underlying library.
403
404    def set_data(self, str):
405        # pyOpenSSL only allows us to add extensions, so if we try to set the
406        # same extension more than once, it will not work
407        if self.data != None:
408           raise "cannot set subjectAltName more than once"
409        self.data = str
410        self.add_extension("subjectAltName", 0, "URI:http://" + str)
411
412    ##
413    # Return the data string that was previously set with set_data
414
415    def get_data(self):
416        if self.data:
417            return self.data
418
419        try:
420            uri = self.get_extension("subjectAltName")
421        except LookupError:
422            self.data = None
423            return self.data
424
425        if not uri.startswith("URI:http://"):
426            raise "bad encoding in subjectAltName"
427        self.data = uri[11:]
428        return self.data
429
430    ##
431    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
432
433    def sign(self):
434        assert self.cert != None
435        assert self.issuerSubject != None
436        assert self.issuerKey != None
437        self.cert.set_issuer(self.issuerSubject)
438        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
439
440     ##
441     # Verify the authenticity of a certificate.
442     # @param pkey is a Keypair object representing a public key. If Pkey
443     #     did not sign the certificate, then an exception will be thrown.
444
445    def verify(self, pkey):
446        # pyOpenSSL does not have a way to verify signatures
447        m2x509 = X509.load_cert_string(self.save_to_string())
448        m2pkey = pkey.get_m2_pkey()
449        # verify it
450        return m2x509.verify(m2pkey)
451
452        # XXX alternatively, if openssl has been patched, do the much simpler:
453        # try:
454        #   self.cert.verify(pkey.get_openssl_key())
455        #   return 1
456        # except:
457        #   return 0
458
459    ##
460    # Return True if pkey is identical to the public key that is contained in the certificate.
461    # @param pkey Keypair object
462
463    def is_pubkey(self, pkey):
464        return self.get_pubkey().is_same(pkey)
465
466    ##
467    # Given a certificate cert, verify that this certificate was signed by the
468    # public key contained in cert. Throw an exception otherwise.
469    #
470    # @param cert certificate object
471
472    def is_signed_by_cert(self, cert):
473        k = cert.get_pubkey()
474        result = self.verify(k)
475        return result
476
477    ##
478    # Set the parent certficiate.
479    #
480    # @param p certificate object.
481
482    def set_parent(self, p):
483         self.parent = p
484
485    ##
486    # Return the certificate object of the parent of this certificate.
487
488    def get_parent(self):
489         return self.parent
490
491    ##
492    # Verification examines a chain of certificates to ensure that each parent
493    # signs the child, and that some certificate in the chain is signed by a
494    # trusted certificate.
495    #
496    # Verification is a basic recursion: <pre>
497    #     if this_certificate was signed by trusted_certs:
498    #         return
499    #     else
500    #         return verify_chain(parent, trusted_certs)
501    # </pre>
502    #
503    # At each recursion, the parent is tested to ensure that it did sign the
504    # child. If a parent did not sign a child, then an exception is thrown. If
505    # the bottom of the recursion is reached and the certificate does not match
506    # a trusted root, then an exception is thrown.
507    #
508    # @param Trusted_certs is a list of certificates that are trusted.
509    #
510
511    def verify_chain(self, trusted_certs = None):
512         # Verify a chain of certificates. Each certificate must be signed by
513         # the public key contained in it's parent. The chain is recursed
514         # until a certificate is found that is signed by a trusted root.
515
516         # TODO: verify expiration time
517
518         # if this cert is signed by a trusted_cert, then we are set
519         for trusted_cert in trusted_certs:
520             # TODO: verify expiration of trusted_cert ?
521             if self.is_signed_by_cert(trusted_cert):
522                 #print self.get_subject(), "is signed by a root"
523                 return
524
525         # if there is no parent, then no way to verify the chain
526         if not self.parent:
527             #print self.get_subject(), "has no parent"
528             raise CertMissingParent(self.get_subject())
529
530         # if it wasn't signed by the parent...
531         if not self.is_signed_by_cert(self.parent):
532             #print self.get_subject(), "is not signed by parent"
533             return CertNotSignedByParent(self.get_subject())
534
535         # if the parent isn't verified...
536         self.parent.verify_chain(trusted_certs)
537
538         return