added compute_hash() method
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 import traceback
22 from OpenSSL import crypto
23 import M2Crypto
24 from M2Crypto import X509
25 from M2Crypto import EVP
26
27 from sfa.util.faults import *
28
29 def convert_public_key(key):
30     keyconvert_path = "/usr/bin/keyconvert"
31     if not os.path.isfile(keyconvert_path):
32         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
33
34     # we can only convert rsa keys 
35     if "ssh-dss" in key:
36         print "XXX: DSA key encountered, ignoring"
37         return None
38     
39     (ssh_f, ssh_fn) = tempfile.mkstemp()
40     ssl_fn = tempfile.mktemp()
41     os.write(ssh_f, key)
42     os.close(ssh_f)
43
44     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
45     os.system(cmd)
46
47     # this check leaves the temporary file containing the public key so
48     # that it can be expected to see why it failed.
49     # TODO: for production, cleanup the temporary files
50     if not os.path.exists(ssl_fn):
51         report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
52         return None
53
54     k = Keypair()
55     try:
56         k.load_pubkey_from_file(ssl_fn)
57     except:
58         print "XXX: Error while converting key: ", key
59         traceback.print_exc()
60         k = None
61
62     # remove the temporary files
63     os.remove(ssh_fn)
64     os.remove(ssl_fn)
65
66     return k
67
68 ##
69 # Public-private key pairs are implemented by the Keypair class.
70 # A Keypair object may represent both a public and private key pair, or it
71 # may represent only a public key (this usage is consistent with OpenSSL).
72
73 class Keypair:
74    key = None       # public/private keypair
75    m2key = None     # public key (m2crypto format)
76
77    ##
78    # Creates a Keypair object
79    # @param create If create==True, creates a new public/private key and
80    #     stores it in the object
81    # @param string If string!=None, load the keypair from the string (PEM)
82    # @param filename If filename!=None, load the keypair from the file
83
84    def __init__(self, create=False, string=None, filename=None):
85       if create:
86          self.create()
87       if string:
88          self.load_from_string(string)
89       if filename:
90          self.load_from_file(filename)
91
92    ##
93    # Create a RSA public/private key pair and store it inside the keypair object
94
95    def create(self):
96       self.key = crypto.PKey()
97       self.key.generate_key(crypto.TYPE_RSA, 1024)
98
99    ##
100    # Save the private key to a file
101    # @param filename name of file to store the keypair in
102
103    def save_to_file(self, filename):
104       open(filename, 'w').write(self.as_pem())
105
106    ##
107    # Load the private key from a file. Implicity the private key includes the public key.
108
109    def load_from_file(self, filename):
110       buffer = open(filename, 'r').read()
111       self.load_from_string(buffer)
112
113    ##
114    # Load the private key from a string. Implicitly the private key includes the public key.
115
116    def load_from_string(self, string):
117       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
118       self.m2key = M2Crypto.EVP.load_key_string(string)
119
120    ##
121    #  Load the public key from a string. No private key is loaded. 
122
123    def load_pubkey_from_file(self, filename):
124       # load the m2 public key
125       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
126       self.m2key = M2Crypto.EVP.PKey()
127       self.m2key.assign_rsa(m2rsakey)
128
129       # create an m2 x509 cert
130       m2name = M2Crypto.X509.X509_Name()
131       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
132       m2x509 = M2Crypto.X509.X509()
133       m2x509.set_pubkey(self.m2key)
134       m2x509.set_serial_number(0)
135       m2x509.set_issuer_name(m2name)
136       m2x509.set_subject_name(m2name)
137       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
138       ASN1.set_time(500)
139       m2x509.set_not_before(ASN1)
140       m2x509.set_not_after(ASN1)
141       junk_key = Keypair(create=True)
142       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
143
144       # convert the m2 x509 cert to a pyopenssl x509
145       m2pem = m2x509.as_pem()
146       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
147
148       # get the pyopenssl pkey from the pyopenssl x509
149       self.key = pyx509.get_pubkey()
150
151    ##
152    # Load the public key from a string. No private key is loaded.
153
154    def load_pubkey_from_string(self, string):
155       (f, fn) = tempfile.mkstemp()
156       os.write(f, string)
157       os.close(f)
158       self.load_pubkey_from_file(fn)
159       os.remove(fn)
160
161    ##
162    # Return the private key in PEM format.
163
164    def as_pem(self):
165       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
166
167    ##
168    # Return an M2Crypto key object
169
170    def get_m2_pkey(self):
171       if not self.m2key:
172          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
173       return self.m2key
174
175    ##
176    # Returns a string containing the public key represented by this object.
177
178    def get_pubkey_string(self):
179       m2pkey = self.get_m2_pkey()
180       return base64.b64encode(m2pkey.as_der())
181
182    ##
183    # Return an OpenSSL pkey object
184
185    def get_openssl_pkey(self):
186       return self.key
187
188    ##
189    # Given another Keypair object, return TRUE if the two keys are the same.
190
191    def is_same(self, pkey):
192       return self.as_pem() == pkey.as_pem()
193
194    def sign_string(self, data):
195       k = self.get_m2_pkey()
196       k.sign_init()
197       k.sign_update(data)
198       return base64.b64encode(k.sign_final())
199
200    def verify_string(self, data, sig):
201       k = self.get_m2_pkey()
202       k.verify_init()
203       k.verify_update(data)
204       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
205
206    def compute_hash(self, value):
207       return self.sign_string(str(value))      
208
209 ##
210 # The certificate class implements a general purpose X509 certificate, making
211 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
212 # several addition features, such as the ability to maintain a chain of
213 # parent certificates, and storage of application-specific data.
214 #
215 # Certificates include the ability to maintain a chain of parents. Each
216 # certificate includes a pointer to it's parent certificate. When loaded
217 # from a file or a string, the parent chain will be automatically loaded.
218 # When saving a certificate to a file or a string, the caller can choose
219 # whether to save the parent certificates as well.
220
221 class Certificate:
222    digest = "md5"
223
224    data = None
225    cert = None
226    issuerKey = None
227    issuerSubject = None
228    parent = None
229
230    separator="-----parent-----"
231
232    ##
233    # Create a certificate object.
234    #
235    # @param create If create==True, then also create a blank X509 certificate.
236    # @param subject If subject!=None, then create a blank certificate and set
237    #     it's subject name.
238    # @param string If string!=None, load the certficate from the string.
239    # @param filename If filename!=None, load the certficiate from the file.
240
241    def __init__(self, create=False, subject=None, string=None, filename=None):
242        if create or subject:
243            self.create()
244        if subject:
245            self.set_subject(subject)
246        if string:
247            self.load_from_string(string)
248        if filename:
249            self.load_from_file(filename)
250
251    ##
252    # Create a blank X509 certificate and store it in this object.
253
254    def create(self):
255        self.cert = crypto.X509()
256        self.cert.set_serial_number(1)
257        self.cert.gmtime_adj_notBefore(0)
258        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
259
260    ##
261    # Given a pyOpenSSL X509 object, store that object inside of this
262    # certificate object.
263
264    def load_from_pyopenssl_x509(self, x509):
265        self.cert = x509
266
267    ##
268    # Load the certificate from a string
269
270    def load_from_string(self, string):
271        # if it is a chain of multiple certs, then split off the first one and
272        # load it
273        parts = string.split(Certificate.separator, 1)
274        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
275
276        # if there are more certs, then create a parent and let the parent load
277        # itself from the remainder of the string
278        if len(parts) > 1:
279            self.parent = self.__class__()
280            self.parent.load_from_string(parts[1])
281
282    ##
283    # Load the certificate from a file
284
285    def load_from_file(self, filename):
286        file = open(filename)
287        string = file.read()
288        self.load_from_string(string)
289
290    ##
291    # Save the certificate to a string.
292    #
293    # @param save_parents If save_parents==True, then also save the parent certificates.
294
295    def save_to_string(self, save_parents=False):
296        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
297        if save_parents and self.parent:
298           string = string + Certificate.separator + self.parent.save_to_string(save_parents)
299        return string
300
301    ##
302    # Save the certificate to a file.
303    # @param save_parents If save_parents==True, then also save the parent certificates.
304
305    def save_to_file(self, filename, save_parents=False):
306        string = self.save_to_string(save_parents=save_parents)
307        open(filename, 'w').write(string)
308
309    ##
310    # Sets the issuer private key and name
311    # @param key Keypair object containing the private key of the issuer
312    # @param subject String containing the name of the issuer
313    # @param cert (optional) Certificate object containing the name of the issuer
314
315    def set_issuer(self, key, subject=None, cert=None):
316        self.issuerKey = key
317        if subject:
318           # it's a mistake to use subject and cert params at the same time
319           assert(not cert)
320           if isinstance(subject, dict) or isinstance(subject, str):
321              req = crypto.X509Req()
322              reqSubject = req.get_subject()
323              if (isinstance(subject, dict)):
324                 for key in reqSubject.keys():
325                     setattr(reqSubject, key, name[key])
326              else:
327                 setattr(reqSubject, "CN", subject)
328              subject = reqSubject
329              # subject is not valid once req is out of scope, so save req
330              self.issuerReq = req
331        if cert:
332           # if a cert was supplied, then get the subject from the cert
333           subject = cert.cert.get_issuer()
334        assert(subject)
335        self.issuerSubject = subject
336
337    ##
338    # Get the issuer name
339
340    def get_issuer(self, which="CN"):
341        x = self.cert.get_issuer()
342        return getattr(x, which)
343
344    ##
345    # Set the subject name of the certificate
346
347    def set_subject(self, name):
348        req = crypto.X509Req()
349        subj = req.get_subject()
350        if (isinstance(name, dict)):
351            for key in name.keys():
352                setattr(subj, key, name[key])
353        else:
354            setattr(subj, "CN", name)
355        self.cert.set_subject(subj)
356    ##
357    # Get the subject name of the certificate
358
359    def get_subject(self, which="CN"):
360        x = self.cert.get_subject()
361        return getattr(x, which)
362
363    ##
364    # Get the public key of the certificate.
365    #
366    # @param key Keypair object containing the public key
367
368    def set_pubkey(self, key):
369        assert(isinstance(key, Keypair))
370        self.cert.set_pubkey(key.get_openssl_pkey())
371
372    ##
373    # Get the public key of the certificate.
374    # It is returned in the form of a Keypair object.
375
376    def get_pubkey(self):
377        m2x509 = X509.load_cert_string(self.save_to_string())
378        pkey = Keypair()
379        pkey.key = self.cert.get_pubkey()
380        pkey.m2key = m2x509.get_pubkey()
381        return pkey
382
383    ##
384    # Add an X509 extension to the certificate. Add_extension can only be called
385    # once for a particular extension name, due to limitations in the underlying
386    # library.
387    #
388    # @param name string containing name of extension
389    # @param value string containing value of the extension
390
391    def add_extension(self, name, critical, value):
392        ext = crypto.X509Extension (name, critical, value)
393        self.cert.add_extensions([ext])
394
395    ##
396    # Get an X509 extension from the certificate
397
398    def get_extension(self, name):
399        # pyOpenSSL does not have a way to get extensions
400        m2x509 = X509.load_cert_string(self.save_to_string())
401        value = m2x509.get_ext(name).get_value()
402        return value
403
404    ##
405    # Set_data is a wrapper around add_extension. It stores the parameter str in
406    # the X509 subject_alt_name extension. Set_data can only be called once, due
407    # to limitations in the underlying library.
408
409    def set_data(self, str):
410        # pyOpenSSL only allows us to add extensions, so if we try to set the
411        # same extension more than once, it will not work
412        if self.data != None:
413           raise "cannot set subjectAltName more than once"
414        self.data = str
415        self.add_extension("subjectAltName", 0, "URI:http://" + str)
416
417    ##
418    # Return the data string that was previously set with set_data
419
420    def get_data(self):
421        if self.data:
422            return self.data
423
424        try:
425            uri = self.get_extension("subjectAltName")
426        except LookupError:
427            self.data = None
428            return self.data
429
430        if not uri.startswith("URI:http://"):
431            raise "bad encoding in subjectAltName"
432        self.data = uri[11:]
433        return self.data
434
435    ##
436    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
437
438    def sign(self):
439        assert self.cert != None
440        assert self.issuerSubject != None
441        assert self.issuerKey != None
442        self.cert.set_issuer(self.issuerSubject)
443        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
444
445     ##
446     # Verify the authenticity of a certificate.
447     # @param pkey is a Keypair object representing a public key. If Pkey
448     #     did not sign the certificate, then an exception will be thrown.
449
450    def verify(self, pkey):
451        # pyOpenSSL does not have a way to verify signatures
452        m2x509 = X509.load_cert_string(self.save_to_string())
453        m2pkey = pkey.get_m2_pkey()
454        # verify it
455        return m2x509.verify(m2pkey)
456
457        # XXX alternatively, if openssl has been patched, do the much simpler:
458        # try:
459        #   self.cert.verify(pkey.get_openssl_key())
460        #   return 1
461        # except:
462        #   return 0
463
464    ##
465    # Return True if pkey is identical to the public key that is contained in the certificate.
466    # @param pkey Keypair object
467
468    def is_pubkey(self, pkey):
469        return self.get_pubkey().is_same(pkey)
470
471    ##
472    # Given a certificate cert, verify that this certificate was signed by the
473    # public key contained in cert. Throw an exception otherwise.
474    #
475    # @param cert certificate object
476
477    def is_signed_by_cert(self, cert):
478        k = cert.get_pubkey()
479        result = self.verify(k)
480        return result
481
482    ##
483    # Set the parent certficiate.
484    #
485    # @param p certificate object.
486
487    def set_parent(self, p):
488         self.parent = p
489
490    ##
491    # Return the certificate object of the parent of this certificate.
492
493    def get_parent(self):
494         return self.parent
495
496    ##
497    # Verification examines a chain of certificates to ensure that each parent
498    # signs the child, and that some certificate in the chain is signed by a
499    # trusted certificate.
500    #
501    # Verification is a basic recursion: <pre>
502    #     if this_certificate was signed by trusted_certs:
503    #         return
504    #     else
505    #         return verify_chain(parent, trusted_certs)
506    # </pre>
507    #
508    # At each recursion, the parent is tested to ensure that it did sign the
509    # child. If a parent did not sign a child, then an exception is thrown. If
510    # the bottom of the recursion is reached and the certificate does not match
511    # a trusted root, then an exception is thrown.
512    #
513    # @param Trusted_certs is a list of certificates that are trusted.
514    #
515
516    def verify_chain(self, trusted_certs = None):
517         # Verify a chain of certificates. Each certificate must be signed by
518         # the public key contained in it's parent. The chain is recursed
519         # until a certificate is found that is signed by a trusted root.
520
521         # TODO: verify expiration time
522         #print "====Verify Chain====="
523         # if this cert is signed by a trusted_cert, then we are set
524         for trusted_cert in trusted_certs:
525             #print "***************"
526             # TODO: verify expiration of trusted_cert ?
527             #print "CLIENT CERT", self.dump()
528             #print "TRUSTED CERT", trusted_cert.dump()
529             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
530             if self.is_signed_by_cert(trusted_cert):
531                 #print self.get_subject(), "is signed by a root"
532                 return
533
534         # if there is no parent, then no way to verify the chain
535         if not self.parent:
536             #print self.get_subject(), "has no parent"
537             raise CertMissingParent(self.get_subject())
538
539         # if it wasn't signed by the parent...
540         if not self.is_signed_by_cert(self.parent):
541             #print self.get_subject(), "is not signed by parent"
542             return CertNotSignedByParent(self.get_subject())
543
544         # if the parent isn't verified...
545         self.parent.verify_chain(trusted_certs)
546
547         return