3bdad0f449b7da1c2eaa63a9123f24ffe7108d46
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 import traceback
22 from OpenSSL import crypto
23 import M2Crypto
24 from M2Crypto import X509
25 from tempfile import mkstemp
26 from sfa.util.sfalogging import logger
27 from sfa.util.namespace import urn_to_hrn
28 from sfa.util.faults import *
29
30 def convert_public_key(key):
31     keyconvert_path = "/usr/bin/keyconvert"
32     if not os.path.isfile(keyconvert_path):
33         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
34
35     # we can only convert rsa keys 
36     if "ssh-dss" in key:
37         return None
38     
39     (ssh_f, ssh_fn) = tempfile.mkstemp()
40     ssl_fn = tempfile.mktemp()
41     os.write(ssh_f, key)
42     os.close(ssh_f)
43
44     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
45     os.system(cmd)
46     
47     # this check leaves the temporary file containing the public key so
48     # that it can be expected to see why it failed.
49     # TODO: for production, cleanup the temporary files
50     if not os.path.exists(ssl_fn):
51         return None
52     
53     k = Keypair()
54     try:
55         k.load_pubkey_from_file(ssl_fn)
56     except:
57         traceback.print_exc()
58         k = None
59
60     # remove the temporary files
61     os.remove(ssh_fn)
62     os.remove(ssl_fn)
63
64     return k
65
66 ##
67 # Public-private key pairs are implemented by the Keypair class.
68 # A Keypair object may represent both a public and private key pair, or it
69 # may represent only a public key (this usage is consistent with OpenSSL).
70
71 class Keypair:
72    key = None       # public/private keypair
73    m2key = None     # public key (m2crypto format)
74    
75    ##
76    # Creates a Keypair object
77    # @param create If create==True, creates a new public/private key and
78    #     stores it in the object
79    # @param string If string!=None, load the keypair from the string (PEM)
80    # @param filename If filename!=None, load the keypair from the file
81
82    def __init__(self, create=False, string=None, filename=None):
83       if create:
84          self.create()
85       if string:
86          self.load_from_string(string)
87       if filename:
88          self.load_from_file(filename)
89
90    ##
91    # Create a RSA public/private key pair and store it inside the keypair object
92
93    def create(self):
94       self.key = crypto.PKey()
95       self.key.generate_key(crypto.TYPE_RSA, 1024)
96
97    ##
98    # Save the private key to a file
99    # @param filename name of file to store the keypair in
100
101    def save_to_file(self, filename):
102       open(filename, 'w').write(self.as_pem())
103
104    ##
105    # Load the private key from a file. Implicity the private key includes the public key.
106
107    def load_from_file(self, filename):
108       buffer = open(filename, 'r').read()
109       self.load_from_string(buffer)
110
111    ##
112    # Load the private key from a string. Implicitly the private key includes the public key.
113
114    def load_from_string(self, string):
115       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
116       self.m2key = M2Crypto.EVP.load_key_string(string)
117
118    ##
119    #  Load the public key from a string. No private key is loaded. 
120
121    def load_pubkey_from_file(self, filename):
122       # load the m2 public key
123       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
124       self.m2key = M2Crypto.EVP.PKey()
125       self.m2key.assign_rsa(m2rsakey)
126
127       # create an m2 x509 cert
128       m2name = M2Crypto.X509.X509_Name()
129       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
130       m2x509 = M2Crypto.X509.X509()
131       m2x509.set_pubkey(self.m2key)
132       m2x509.set_serial_number(0)
133       m2x509.set_issuer_name(m2name)
134       m2x509.set_subject_name(m2name)
135       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
136       ASN1.set_time(500)
137       m2x509.set_not_before(ASN1)
138       m2x509.set_not_after(ASN1)
139       junk_key = Keypair(create=True)
140       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
141
142       # convert the m2 x509 cert to a pyopenssl x509
143       m2pem = m2x509.as_pem()
144       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
145
146       # get the pyopenssl pkey from the pyopenssl x509
147       self.key = pyx509.get_pubkey()
148
149    ##
150    # Load the public key from a string. No private key is loaded.
151
152    def load_pubkey_from_string(self, string):
153       (f, fn) = tempfile.mkstemp()
154       os.write(f, string)
155       os.close(f)
156       self.load_pubkey_from_file(fn)
157       os.remove(fn)
158
159    ##
160    # Return the private key in PEM format.
161
162    def as_pem(self):
163       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
164
165    ##
166    # Return an M2Crypto key object
167
168    def get_m2_pkey(self):
169       if not self.m2key:
170          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
171       return self.m2key
172
173    ##
174    # Returns a string containing the public key represented by this object.
175
176    def get_pubkey_string(self):
177       m2pkey = self.get_m2_pkey()
178       return base64.b64encode(m2pkey.as_der())
179
180    ##
181    # Return an OpenSSL pkey object
182
183    def get_openssl_pkey(self):
184       return self.key
185
186
187    ##
188    # Given another Keypair object, return TRUE if the two keys are the same.
189
190    def is_same(self, pkey):
191       return self.as_pem() == pkey.as_pem()
192
193    def sign_string(self, data):
194       k = self.get_m2_pkey()
195       k.sign_init()
196       k.sign_update(data)
197       return base64.b64encode(k.sign_final())
198
199    def verify_string(self, data, sig):
200       k = self.get_m2_pkey()
201       k.verify_init()
202       k.verify_update(data)
203       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
204
205    def compute_hash(self, value):
206       return self.sign_string(str(value))      
207
208 ##
209 # The certificate class implements a general purpose X509 certificate, making
210 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
211 # several addition features, such as the ability to maintain a chain of
212 # parent certificates, and storage of application-specific data.
213 #
214 # Certificates include the ability to maintain a chain of parents. Each
215 # certificate includes a pointer to it's parent certificate. When loaded
216 # from a file or a string, the parent chain will be automatically loaded.
217 # When saving a certificate to a file or a string, the caller can choose
218 # whether to save the parent certificates as well.
219
220 class Certificate:
221    digest = "md5"
222
223    cert = None
224    issuerKey = None
225    issuerSubject = None
226    parent = None
227
228    separator="-----parent-----"
229
230    ##
231    # Create a certificate object.
232    #
233    # @param create If create==True, then also create a blank X509 certificate.
234    # @param subject If subject!=None, then create a blank certificate and set
235    #     it's subject name.
236    # @param string If string!=None, load the certficate from the string.
237    # @param filename If filename!=None, load the certficiate from the file.
238
239    def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
240        self.data = {}
241        if create or subject:
242            self.create()
243        if subject:
244            self.set_subject(subject)
245        if string:
246            self.load_from_string(string)
247        if filename:
248            self.load_from_file(filename)
249
250        if intermediate:
251            self.set_intermediate_ca(intermediate)
252
253    ##
254    # Create a blank X509 certificate and store it in this object.
255
256    def create(self):
257        self.cert = crypto.X509()
258        self.cert.set_serial_number(3)
259        self.cert.gmtime_adj_notBefore(0)
260        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
261
262    ##
263    # Given a pyOpenSSL X509 object, store that object inside of this
264    # certificate object.
265
266    def load_from_pyopenssl_x509(self, x509):
267        self.cert = x509
268
269    ##
270    # Load the certificate from a string
271
272    def load_from_string(self, string):
273        # if it is a chain of multiple certs, then split off the first one and
274        # load it (support for the ---parent--- tag as well as normal chained certs)       
275
276        string = string.strip()       
277        
278        
279        if not string.startswith('-----'):
280            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
281            
282        parts = []
283        
284        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
285               string.count(Certificate.separator) == 0:
286            parts = string.split('-----END CERTIFICATE-----',1)
287            parts[0] += '-----END CERTIFICATE-----'
288        else:
289            parts = string.split(Certificate.separator, 1)
290        
291        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
292
293        # if there are more certs, then create a parent and let the parent load
294        # itself from the remainder of the string
295        if len(parts) > 1 and parts[1] != '':
296            self.parent = self.__class__()
297            self.parent.load_from_string(parts[1])
298
299    ##
300    # Load the certificate from a file
301
302    def load_from_file(self, filename):     
303        file = open(filename)
304        string = file.read()
305        self.load_from_string(string)
306
307    ##
308    # Save the certificate to a string.
309    #
310    # @param save_parents If save_parents==True, then also save the parent certificates.
311
312    def save_to_string(self, save_parents=True):
313        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
314        if save_parents and self.parent:
315           string = string + self.parent.save_to_string(save_parents)
316        return string
317
318    ##
319    # Save the certificate to a file.
320    # @param save_parents If save_parents==True, then also save the parent certificates.
321
322    def save_to_file(self, filename, save_parents=True, filep=None):
323        string = self.save_to_string(save_parents=save_parents)
324        if filep:
325            f = filep
326        else:
327            f = open(filename, 'w')
328        f.write(string)
329        f.close()
330
331    ##
332    # Save the certificate to a random file in /tmp/
333    # @param save_parents If save_parents==True, then also save the parent certificates.  
334    def save_to_random_tmp_file(self, save_parents=True):       
335        fp, filename = mkstemp(suffix='cert', text=True)
336        fp = os.fdopen(fp, "w")
337        self.save_to_file(filename, save_parents=True, filep=fp)
338        return filename   
339
340    ##
341    # Sets the issuer private key and name
342    # @param key Keypair object containing the private key of the issuer
343    # @param subject String containing the name of the issuer
344    # @param cert (optional) Certificate object containing the name of the issuer
345
346    def set_issuer(self, key, subject=None, cert=None):
347        self.issuerKey = key
348        if subject:
349           # it's a mistake to use subject and cert params at the same time
350           assert(not cert)
351           if isinstance(subject, dict) or isinstance(subject, str):
352              req = crypto.X509Req()
353              reqSubject = req.get_subject()
354              if (isinstance(subject, dict)):
355                 for key in reqSubject.keys():
356                     setattr(reqSubject, key, name[key])
357              else:
358                 setattr(reqSubject, "CN", subject)
359              subject = reqSubject
360              # subject is not valid once req is out of scope, so save req
361              self.issuerReq = req
362        if cert:
363           # if a cert was supplied, then get the subject from the cert
364           subject = cert.cert.get_subject()
365        assert(subject)
366        self.issuerSubject = subject
367
368    ##
369    # Get the issuer name
370
371    def get_issuer(self, which="CN"):
372        x = self.cert.get_issuer()
373        return getattr(x, which)
374
375    ##
376    # Set the subject name of the certificate
377
378    def set_subject(self, name):
379        req = crypto.X509Req()
380        subj = req.get_subject()
381        if (isinstance(name, dict)):
382            for key in name.keys():
383                setattr(subj, key, name[key])
384        else:
385            setattr(subj, "CN", name)
386        self.cert.set_subject(subj)
387    ##
388    # Get the subject name of the certificate
389
390    def get_subject(self, which="CN"):
391        x = self.cert.get_subject()
392        return getattr(x, which)
393
394    ##
395    # Get the public key of the certificate.
396    #
397    # @param key Keypair object containing the public key
398
399    def set_pubkey(self, key):
400        assert(isinstance(key, Keypair))
401        self.cert.set_pubkey(key.get_openssl_pkey())
402
403    ##
404    # Get the public key of the certificate.
405    # It is returned in the form of a Keypair object.
406
407    def get_pubkey(self):
408        m2x509 = X509.load_cert_string(self.save_to_string())
409        pkey = Keypair()
410        pkey.key = self.cert.get_pubkey()
411        pkey.m2key = m2x509.get_pubkey()
412        return pkey
413    
414    def set_intermediate_ca(self, val):
415        self.intermediate = val
416        if val:
417            self.add_extension('basicConstraints', 1, 'CA:TRUE')
418        
419
420
421    ##
422    # Add an X509 extension to the certificate. Add_extension can only be called
423    # once for a particular extension name, due to limitations in the underlying
424    # library.
425    #
426    # @param name string containing name of extension
427    # @param value string containing value of the extension
428
429    def add_extension(self, name, critical, value):
430        ext = crypto.X509Extension (name, critical, value)
431        self.cert.add_extensions([ext])
432
433    ##
434    # Get an X509 extension from the certificate
435
436    def get_extension(self, name):
437        # pyOpenSSL does not have a way to get extensions
438        m2x509 = X509.load_cert_string(self.save_to_string())
439        value = m2x509.get_ext(name).get_value()
440        return value
441
442    ##
443    # Set_data is a wrapper around add_extension. It stores the parameter str in
444    # the X509 subject_alt_name extension. Set_data can only be called once, due
445    # to limitations in the underlying library.
446
447    def set_data(self, str, field='subjectAltName'):
448        # pyOpenSSL only allows us to add extensions, so if we try to set the
449        # same extension more than once, it will not work
450        if self.data.has_key(field):
451           raise "cannot set ", field, " more than once"
452        self.data[field] = str
453        self.add_extension(field, 0, str)
454
455    ##
456    # Return the data string that was previously set with set_data
457
458    def get_data(self, field='subjectAltName'):
459        if self.data.has_key(field):
460            return self.data[field]
461
462        try:
463            uri = self.get_extension(field)
464            self.data[field] = uri           
465        except LookupError:
466            return None
467        
468        return self.data[field]
469
470    ##
471    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
472
473    def sign(self):
474        assert self.cert != None
475        assert self.issuerSubject != None
476        assert self.issuerKey != None
477        self.cert.set_issuer(self.issuerSubject)
478        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
479
480     ##
481     # Verify the authenticity of a certificate.
482     # @param pkey is a Keypair object representing a public key. If Pkey
483     #     did not sign the certificate, then an exception will be thrown.
484
485    def verify(self, pkey):
486        # pyOpenSSL does not have a way to verify signatures
487        m2x509 = X509.load_cert_string(self.save_to_string())
488        m2pkey = pkey.get_m2_pkey()
489        # verify it
490        return m2x509.verify(m2pkey)
491
492        # XXX alternatively, if openssl has been patched, do the much simpler:
493        # try:
494        #   self.cert.verify(pkey.get_openssl_key())
495        #   return 1
496        # except:
497        #   return 0
498
499    ##
500    # Return True if pkey is identical to the public key that is contained in the certificate.
501    # @param pkey Keypair object
502
503    def is_pubkey(self, pkey):
504        return self.get_pubkey().is_same(pkey)
505
506    ##
507    # Given a certificate cert, verify that this certificate was signed by the
508    # public key contained in cert. Throw an exception otherwise.
509    #
510    # @param cert certificate object
511
512    def is_signed_by_cert(self, cert):
513        k = cert.get_pubkey()
514        result = self.verify(k)
515        return result
516
517    ##
518    # Set the parent certficiate.
519    #
520    # @param p certificate object.
521
522    def set_parent(self, p):
523         self.parent = p
524
525    ##
526    # Return the certificate object of the parent of this certificate.
527
528    def get_parent(self):
529         return self.parent
530
531    ##
532    # Verification examines a chain of certificates to ensure that each parent
533    # signs the child, and that some certificate in the chain is signed by a
534    # trusted certificate.
535    #
536    # Verification is a basic recursion: <pre>
537    #     if this_certificate was signed by trusted_certs:
538    #         return
539    #     else
540    #         return verify_chain(parent, trusted_certs)
541    # </pre>
542    #
543    # At each recursion, the parent is tested to ensure that it did sign the
544    # child. If a parent did not sign a child, then an exception is thrown. If
545    # the bottom of the recursion is reached and the certificate does not match
546    # a trusted root, then an exception is thrown.
547    #
548    # @param Trusted_certs is a list of certificates that are trusted.
549    #
550
551    def verify_chain(self, trusted_certs = None):
552         # Verify a chain of certificates. Each certificate must be signed by
553         # the public key contained in it's parent. The chain is recursed
554         # until a certificate is found that is signed by a trusted root.
555
556         # TODO: verify expiration time
557         #print "====Verify Chain====="
558         # if this cert is signed by a trusted_cert, then we are set
559         for trusted_cert in trusted_certs:
560             #print "***************"
561             # TODO: verify expiration of trusted_cert ?
562             #print "CLIENT CERT", self.dump()
563             #print "TRUSTED CERT", trusted_cert.dump()
564             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
565             if self.is_signed_by_cert(trusted_cert):
566                 #print self.get_subject(), "is signed by a root"
567                 return trusted_cert
568
569         # if there is no parent, then no way to verify the chain
570         if not self.parent:
571             #print self.get_subject(), "has no parent"
572             raise CertMissingParent(self.get_subject())
573
574         # if it wasn't signed by the parent...
575         if not self.is_signed_by_cert(self.parent):
576             #print self.get_subject(), "is not signed by parent"
577             return CertNotSignedByParent(self.get_subject())
578
579         # if the parent isn't verified...
580         self.parent.verify_chain(trusted_certs)
581
582         return