merged in from trunk -r 17776:17849
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 import traceback
22 from OpenSSL import crypto
23 import M2Crypto
24 from M2Crypto import X509
25 from tempfile import mkstemp
26
27 from sfa.util.faults import *
28
29 def convert_public_key(key):
30     keyconvert_path = "/usr/bin/keyconvert"
31     if not os.path.isfile(keyconvert_path):
32         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
33
34     # we can only convert rsa keys 
35     if "ssh-dss" in key:
36         return None
37     
38     (ssh_f, ssh_fn) = tempfile.mkstemp()
39     ssl_fn = tempfile.mktemp()
40     os.write(ssh_f, key)
41     os.close(ssh_f)
42
43     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
44     os.system(cmd)
45     
46     # this check leaves the temporary file containing the public key so
47     # that it can be expected to see why it failed.
48     # TODO: for production, cleanup the temporary files
49     if not os.path.exists(ssl_fn):
50         return None
51     
52     k = Keypair()
53     try:
54         k.load_pubkey_from_file(ssl_fn)
55     except:
56         traceback.print_exc()
57         k = None
58
59     # remove the temporary files
60     os.remove(ssh_fn)
61     os.remove(ssl_fn)
62
63     return k
64
65 ##
66 # Public-private key pairs are implemented by the Keypair class.
67 # A Keypair object may represent both a public and private key pair, or it
68 # may represent only a public key (this usage is consistent with OpenSSL).
69
70 class Keypair:
71    key = None       # public/private keypair
72    m2key = None     # public key (m2crypto format)
73    
74    ##
75    # Creates a Keypair object
76    # @param create If create==True, creates a new public/private key and
77    #     stores it in the object
78    # @param string If string!=None, load the keypair from the string (PEM)
79    # @param filename If filename!=None, load the keypair from the file
80
81    def __init__(self, create=False, string=None, filename=None):
82       if create:
83          self.create()
84       if string:
85          self.load_from_string(string)
86       if filename:
87          self.load_from_file(filename)
88
89    ##
90    # Create a RSA public/private key pair and store it inside the keypair object
91
92    def create(self):
93       self.key = crypto.PKey()
94       self.key.generate_key(crypto.TYPE_RSA, 1024)
95
96    ##
97    # Save the private key to a file
98    # @param filename name of file to store the keypair in
99
100    def save_to_file(self, filename):
101       open(filename, 'w').write(self.as_pem())
102
103    ##
104    # Load the private key from a file. Implicity the private key includes the public key.
105
106    def load_from_file(self, filename):
107       buffer = open(filename, 'r').read()
108       self.load_from_string(buffer)
109
110    ##
111    # Load the private key from a string. Implicitly the private key includes the public key.
112
113    def load_from_string(self, string):
114       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
115       self.m2key = M2Crypto.EVP.load_key_string(string)
116
117    ##
118    #  Load the public key from a string. No private key is loaded. 
119
120    def load_pubkey_from_file(self, filename):
121       # load the m2 public key
122       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
123       self.m2key = M2Crypto.EVP.PKey()
124       self.m2key.assign_rsa(m2rsakey)
125
126       # create an m2 x509 cert
127       m2name = M2Crypto.X509.X509_Name()
128       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
129       m2x509 = M2Crypto.X509.X509()
130       m2x509.set_pubkey(self.m2key)
131       m2x509.set_serial_number(0)
132       m2x509.set_issuer_name(m2name)
133       m2x509.set_subject_name(m2name)
134       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
135       ASN1.set_time(500)
136       m2x509.set_not_before(ASN1)
137       m2x509.set_not_after(ASN1)
138       junk_key = Keypair(create=True)
139       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
140
141       # convert the m2 x509 cert to a pyopenssl x509
142       m2pem = m2x509.as_pem()
143       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
144
145       # get the pyopenssl pkey from the pyopenssl x509
146       self.key = pyx509.get_pubkey()
147
148    ##
149    # Load the public key from a string. No private key is loaded.
150
151    def load_pubkey_from_string(self, string):
152       (f, fn) = tempfile.mkstemp()
153       os.write(f, string)
154       os.close(f)
155       self.load_pubkey_from_file(fn)
156       os.remove(fn)
157
158    ##
159    # Return the private key in PEM format.
160
161    def as_pem(self):
162       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
163
164    ##
165    # Return an M2Crypto key object
166
167    def get_m2_pkey(self):
168       if not self.m2key:
169          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
170       return self.m2key
171
172    ##
173    # Returns a string containing the public key represented by this object.
174
175    def get_pubkey_string(self):
176       m2pkey = self.get_m2_pkey()
177       return base64.b64encode(m2pkey.as_der())
178
179    ##
180    # Return an OpenSSL pkey object
181
182    def get_openssl_pkey(self):
183       return self.key
184
185    ##
186    # Given another Keypair object, return TRUE if the two keys are the same.
187
188    def is_same(self, pkey):
189       return self.as_pem() == pkey.as_pem()
190
191    def sign_string(self, data):
192       k = self.get_m2_pkey()
193       k.sign_init()
194       k.sign_update(data)
195       return base64.b64encode(k.sign_final())
196
197    def verify_string(self, data, sig):
198       k = self.get_m2_pkey()
199       k.verify_init()
200       k.verify_update(data)
201       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
202
203    def compute_hash(self, value):
204       return self.sign_string(str(value))      
205
206 ##
207 # The certificate class implements a general purpose X509 certificate, making
208 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
209 # several addition features, such as the ability to maintain a chain of
210 # parent certificates, and storage of application-specific data.
211 #
212 # Certificates include the ability to maintain a chain of parents. Each
213 # certificate includes a pointer to it's parent certificate. When loaded
214 # from a file or a string, the parent chain will be automatically loaded.
215 # When saving a certificate to a file or a string, the caller can choose
216 # whether to save the parent certificates as well.
217
218 class Certificate:
219    digest = "md5"
220
221    cert = None
222    issuerKey = None
223    issuerSubject = None
224    parent = None
225
226    separator="-----parent-----"
227
228    ##
229    # Create a certificate object.
230    #
231    # @param create If create==True, then also create a blank X509 certificate.
232    # @param subject If subject!=None, then create a blank certificate and set
233    #     it's subject name.
234    # @param string If string!=None, load the certficate from the string.
235    # @param filename If filename!=None, load the certficiate from the file.
236
237    def __init__(self, create=False, subject=None, string=None, filename=None):
238        self.data = {}
239        if create or subject:
240            self.create()
241        if subject:
242            self.set_subject(subject)
243        if string:
244            self.load_from_string(string)
245        if filename:
246            self.load_from_file(filename)
247
248    ##
249    # Create a blank X509 certificate and store it in this object.
250
251    def create(self):
252        self.cert = crypto.X509()
253        self.cert.set_serial_number(1)
254        self.cert.gmtime_adj_notBefore(0)
255        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
256
257    ##
258    # Given a pyOpenSSL X509 object, store that object inside of this
259    # certificate object.
260
261    def load_from_pyopenssl_x509(self, x509):
262        self.cert = x509
263
264    ##
265    # Load the certificate from a string
266
267    def load_from_string(self, string):
268        # if it is a chain of multiple certs, then split off the first one and
269        # load it (support for the ---parent--- tag as well as normal chained certs)       
270
271        string = string.strip()       
272        
273        
274        if not string.startswith('-----'):
275            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
276            
277        parts = []
278        
279        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
280               string.count(Certificate.separator) == 0:
281            parts = string.split('-----END CERTIFICATE-----',1)
282            parts[0] += '-----END CERTIFICATE-----'
283        else:
284            parts = string.split(Certificate.separator, 1)
285        
286        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
287
288        # if there are more certs, then create a parent and let the parent load
289        # itself from the remainder of the string
290        if len(parts) > 1 and parts[1] != '':
291            self.parent = self.__class__()
292            self.parent.load_from_string(parts[1])
293
294    ##
295    # Load the certificate from a file
296
297    def load_from_file(self, filename):     
298        file = open(filename)
299        string = file.read()
300        self.load_from_string(string)
301
302    ##
303    # Save the certificate to a string.
304    #
305    # @param save_parents If save_parents==True, then also save the parent certificates.
306
307    def save_to_string(self, save_parents=True):
308        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
309        if save_parents and self.parent:
310           string = string + self.parent.save_to_string(save_parents)
311        return string
312
313    ##
314    # Save the certificate to a file.
315    # @param save_parents If save_parents==True, then also save the parent certificates.
316
317    def save_to_file(self, filename, save_parents=True, filep=None):
318        string = self.save_to_string(save_parents=save_parents)
319        if filep:
320            f = filep
321        else:
322            f = open(filename, 'w')
323        f.write(string)
324        f.close()
325
326    ##
327    # Save the certificate to a random file in /tmp/
328    # @param save_parents If save_parents==True, then also save the parent certificates.  
329    def save_to_random_tmp_file(self, save_parents=True):       
330        fp, filename = mkstemp(suffix='cert', text=True)
331        fp = os.fdopen(fp, "w")
332        self.save_to_file(filename, save_parents=True, filep=fp)
333        return filename   
334
335    ##
336    # Sets the issuer private key and name
337    # @param key Keypair object containing the private key of the issuer
338    # @param subject String containing the name of the issuer
339    # @param cert (optional) Certificate object containing the name of the issuer
340
341    def set_issuer(self, key, subject=None, cert=None):
342        self.issuerKey = key
343        if subject:
344           # it's a mistake to use subject and cert params at the same time
345           assert(not cert)
346           if isinstance(subject, dict) or isinstance(subject, str):
347              req = crypto.X509Req()
348              reqSubject = req.get_subject()
349              if (isinstance(subject, dict)):
350                 for key in reqSubject.keys():
351                     setattr(reqSubject, key, name[key])
352              else:
353                 setattr(reqSubject, "CN", subject)
354              subject = reqSubject
355              # subject is not valid once req is out of scope, so save req
356              self.issuerReq = req
357        if cert:
358           # if a cert was supplied, then get the subject from the cert
359           subject = cert.cert.get_issuer()
360        assert(subject)
361        self.issuerSubject = subject
362
363    ##
364    # Get the issuer name
365
366    def get_issuer(self, which="CN"):
367        x = self.cert.get_issuer()
368        return getattr(x, which)
369
370    ##
371    # Set the subject name of the certificate
372
373    def set_subject(self, name):
374        req = crypto.X509Req()
375        subj = req.get_subject()
376        if (isinstance(name, dict)):
377            for key in name.keys():
378                setattr(subj, key, name[key])
379        else:
380            setattr(subj, "CN", name)
381        self.cert.set_subject(subj)
382    ##
383    # Get the subject name of the certificate
384
385    def get_subject(self, which="CN"):
386        x = self.cert.get_subject()
387        return getattr(x, which)
388
389    ##
390    # Get the public key of the certificate.
391    #
392    # @param key Keypair object containing the public key
393
394    def set_pubkey(self, key):
395        assert(isinstance(key, Keypair))
396        self.cert.set_pubkey(key.get_openssl_pkey())
397
398    ##
399    # Get the public key of the certificate.
400    # It is returned in the form of a Keypair object.
401
402    def get_pubkey(self):
403        m2x509 = X509.load_cert_string(self.save_to_string())
404        pkey = Keypair()
405        pkey.key = self.cert.get_pubkey()
406        pkey.m2key = m2x509.get_pubkey()
407        return pkey
408
409    ##
410    # Add an X509 extension to the certificate. Add_extension can only be called
411    # once for a particular extension name, due to limitations in the underlying
412    # library.
413    #
414    # @param name string containing name of extension
415    # @param value string containing value of the extension
416
417    def add_extension(self, name, critical, value):
418        ext = crypto.X509Extension (name, critical, value)
419        self.cert.add_extensions([ext])
420
421    ##
422    # Get an X509 extension from the certificate
423
424    def get_extension(self, name):
425        # pyOpenSSL does not have a way to get extensions
426        m2x509 = X509.load_cert_string(self.save_to_string())
427        value = m2x509.get_ext(name).get_value()
428        return value
429
430    ##
431    # Set_data is a wrapper around add_extension. It stores the parameter str in
432    # the X509 subject_alt_name extension. Set_data can only be called once, due
433    # to limitations in the underlying library.
434
435    def set_data(self, str, field='subjectAltName'):
436        # pyOpenSSL only allows us to add extensions, so if we try to set the
437        # same extension more than once, it will not work
438        if self.data.has_key(field):
439           raise "cannot set ", field, " more than once"
440        self.data[field] = str
441        self.add_extension(field, 0, str)
442
443    ##
444    # Return the data string that was previously set with set_data
445
446    def get_data(self, field='subjectAltName'):
447        if self.data.has_key(field):
448            return self.data[field]
449
450        try:
451            uri = self.get_extension(field)
452            self.data[field] = uri           
453        except LookupError:
454            return None
455        
456        return self.data[field]
457
458    ##
459    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
460
461    def sign(self):
462        assert self.cert != None
463        assert self.issuerSubject != None
464        assert self.issuerKey != None
465        self.cert.set_issuer(self.issuerSubject)
466        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
467
468     ##
469     # Verify the authenticity of a certificate.
470     # @param pkey is a Keypair object representing a public key. If Pkey
471     #     did not sign the certificate, then an exception will be thrown.
472
473    def verify(self, pkey):
474        # pyOpenSSL does not have a way to verify signatures
475        m2x509 = X509.load_cert_string(self.save_to_string())
476        m2pkey = pkey.get_m2_pkey()
477        # verify it
478        return m2x509.verify(m2pkey)
479
480        # XXX alternatively, if openssl has been patched, do the much simpler:
481        # try:
482        #   self.cert.verify(pkey.get_openssl_key())
483        #   return 1
484        # except:
485        #   return 0
486
487    ##
488    # Return True if pkey is identical to the public key that is contained in the certificate.
489    # @param pkey Keypair object
490
491    def is_pubkey(self, pkey):
492        return self.get_pubkey().is_same(pkey)
493
494    ##
495    # Given a certificate cert, verify that this certificate was signed by the
496    # public key contained in cert. Throw an exception otherwise.
497    #
498    # @param cert certificate object
499
500    def is_signed_by_cert(self, cert):
501        k = cert.get_pubkey()
502        result = self.verify(k)
503        return result
504
505    ##
506    # Set the parent certficiate.
507    #
508    # @param p certificate object.
509
510    def set_parent(self, p):
511         self.parent = p
512
513    ##
514    # Return the certificate object of the parent of this certificate.
515
516    def get_parent(self):
517         return self.parent
518
519    ##
520    # Verification examines a chain of certificates to ensure that each parent
521    # signs the child, and that some certificate in the chain is signed by a
522    # trusted certificate.
523    #
524    # Verification is a basic recursion: <pre>
525    #     if this_certificate was signed by trusted_certs:
526    #         return
527    #     else
528    #         return verify_chain(parent, trusted_certs)
529    # </pre>
530    #
531    # At each recursion, the parent is tested to ensure that it did sign the
532    # child. If a parent did not sign a child, then an exception is thrown. If
533    # the bottom of the recursion is reached and the certificate does not match
534    # a trusted root, then an exception is thrown.
535    #
536    # @param Trusted_certs is a list of certificates that are trusted.
537    #
538
539    def verify_chain(self, trusted_certs = None):
540         # Verify a chain of certificates. Each certificate must be signed by
541         # the public key contained in it's parent. The chain is recursed
542         # until a certificate is found that is signed by a trusted root.
543
544         # TODO: verify expiration time
545         #print "====Verify Chain====="
546         # if this cert is signed by a trusted_cert, then we are set
547         for trusted_cert in trusted_certs:
548             #print "***************"
549             # TODO: verify expiration of trusted_cert ?
550             #print "CLIENT CERT", self.dump()
551             #print "TRUSTED CERT", trusted_cert.dump()
552             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
553             if self.is_signed_by_cert(trusted_cert):
554                 # make sure sure the trusted cert's hrn is a prefix of the
555                 # signed cert's hrn
556                 if not self.get_subject().startswith(trusted_cert.get_subject()):
557                     raise GidParentHrn(trusted_cert.get_subject()) 
558                 #print self.get_subject(), "is signed by a root"
559                 return
560
561         # if there is no parent, then no way to verify the chain
562         if not self.parent:
563             #print self.get_subject(), "has no parent"
564             raise CertMissingParent(self.get_subject())
565
566         # if it wasn't signed by the parent...
567         if not self.is_signed_by_cert(self.parent):
568             #print self.get_subject(), "is not signed by parent"
569             return CertNotSignedByParent(self.get_subject())
570
571         # if the parent isn't verified...
572         self.parent.verify_chain(trusted_certs)
573
574         return