checking in some debugging statements, will remove them later
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 import traceback
22 from OpenSSL import crypto
23 import M2Crypto
24 from M2Crypto import X509
25 from M2Crypto import EVP
26
27 from sfa.util.faults import *
28
29 def convert_public_key(key):
30     keyconvert_path = "/usr/bin/keyconvert"
31     if not os.path.isfile(keyconvert_path):
32         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
33
34     # we can only convert rsa keys 
35     if "ssh-dss" in key:
36         print "XXX: DSA key encountered, ignoring"
37         return None
38     
39     (ssh_f, ssh_fn) = tempfile.mkstemp()
40     ssl_fn = tempfile.mktemp()
41     os.write(ssh_f, key)
42     os.close(ssh_f)
43
44     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
45     os.system(cmd)
46
47     # this check leaves the temporary file containing the public key so
48     # that it can be expected to see why it failed.
49     # TODO: for production, cleanup the temporary files
50     if not os.path.exists(ssl_fn):
51         report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
52         return None
53
54     k = Keypair()
55     try:
56         k.load_pubkey_from_file(ssl_fn)
57     except:
58         print "XXX: Error while converting key: ", key
59         traceback.print_exc()
60         k = None
61
62     # remove the temporary files
63     os.remove(ssh_fn)
64     os.remove(ssl_fn)
65
66     return k
67
68 ##
69 # Public-private key pairs are implemented by the Keypair class.
70 # A Keypair object may represent both a public and private key pair, or it
71 # may represent only a public key (this usage is consistent with OpenSSL).
72
73 class Keypair:
74    key = None       # public/private keypair
75    m2key = None     # public key (m2crypto format)
76
77    ##
78    # Creates a Keypair object
79    # @param create If create==True, creates a new public/private key and
80    #     stores it in the object
81    # @param string If string!=None, load the keypair from the string (PEM)
82    # @param filename If filename!=None, load the keypair from the file
83
84    def __init__(self, create=False, string=None, filename=None):
85       if create:
86          self.create()
87       if string:
88          self.load_from_string(string)
89       if filename:
90          self.load_from_file(filename)
91
92    ##
93    # Create a RSA public/private key pair and store it inside the keypair object
94
95    def create(self):
96       self.key = crypto.PKey()
97       self.key.generate_key(crypto.TYPE_RSA, 1024)
98
99    ##
100    # Save the private key to a file
101    # @param filename name of file to store the keypair in
102
103    def save_to_file(self, filename):
104       open(filename, 'w').write(self.as_pem())
105
106    ##
107    # Load the private key from a file. Implicity the private key includes the public key.
108
109    def load_from_file(self, filename):
110       buffer = open(filename, 'r').read()
111       self.load_from_string(buffer)
112
113    ##
114    # Load the private key from a string. Implicitly the private key includes the public key.
115
116    def load_from_string(self, string):
117       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
118       self.m2key = M2Crypto.EVP.load_key_string(string)
119
120    ##
121    #  Load the public key from a string. No private key is loaded. 
122
123    def load_pubkey_from_file(self, filename):
124       # load the m2 public key
125       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
126       self.m2key = M2Crypto.EVP.PKey()
127       self.m2key.assign_rsa(m2rsakey)
128
129       # create an m2 x509 cert
130       m2name = M2Crypto.X509.X509_Name()
131       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
132       m2x509 = M2Crypto.X509.X509()
133       m2x509.set_pubkey(self.m2key)
134       m2x509.set_serial_number(0)
135       m2x509.set_issuer_name(m2name)
136       m2x509.set_subject_name(m2name)
137       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
138       ASN1.set_time(500)
139       m2x509.set_not_before(ASN1)
140       m2x509.set_not_after(ASN1)
141       junk_key = Keypair(create=True)
142       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
143
144       # convert the m2 x509 cert to a pyopenssl x509
145       m2pem = m2x509.as_pem()
146       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
147
148       # get the pyopenssl pkey from the pyopenssl x509
149       self.key = pyx509.get_pubkey()
150
151    ##
152    # Load the public key from a string. No private key is loaded.
153
154    def load_pubkey_from_string(self, string):
155       (f, fn) = tempfile.mkstemp()
156       os.write(f, string)
157       os.close(f)
158       self.load_pubkey_from_file(fn)
159       os.remove(fn)
160
161    ##
162    # Return the private key in PEM format.
163
164    def as_pem(self):
165       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
166
167    ##
168    # Return an M2Crypto key object
169
170    def get_m2_pkey(self):
171       if not self.m2key:
172          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
173       return self.m2key
174
175    ##
176    # Returns a string containing the public key represented by this object.
177
178    def get_pubkey_string(self):
179       m2pkey = self.get_m2_pkey()
180       return base64.b64encode(m2pkey.as_der())
181
182    ##
183    # Return an OpenSSL pkey object
184
185    def get_openssl_pkey(self):
186       return self.key
187
188    ##
189    # Given another Keypair object, return TRUE if the two keys are the same.
190
191    def is_same(self, pkey):
192       return self.as_pem() == pkey.as_pem()
193
194    def sign_string(self, data):
195       k = self.get_m2_pkey()
196       k.sign_init()
197       k.sign_update(data)
198       return base64.b64encode(k.sign_final())
199
200    def verify_string(self, data, sig):
201       k = self.get_m2_pkey()
202       k.verify_init()
203       k.verify_update(data)
204       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
205
206 ##
207 # The certificate class implements a general purpose X509 certificate, making
208 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
209 # several addition features, such as the ability to maintain a chain of
210 # parent certificates, and storage of application-specific data.
211 #
212 # Certificates include the ability to maintain a chain of parents. Each
213 # certificate includes a pointer to it's parent certificate. When loaded
214 # from a file or a string, the parent chain will be automatically loaded.
215 # When saving a certificate to a file or a string, the caller can choose
216 # whether to save the parent certificates as well.
217
218 class Certificate:
219    digest = "md5"
220
221    data = None
222    cert = None
223    issuerKey = None
224    issuerSubject = None
225    parent = None
226
227    separator="-----parent-----"
228
229    ##
230    # Create a certificate object.
231    #
232    # @param create If create==True, then also create a blank X509 certificate.
233    # @param subject If subject!=None, then create a blank certificate and set
234    #     it's subject name.
235    # @param string If string!=None, load the certficate from the string.
236    # @param filename If filename!=None, load the certficiate from the file.
237
238    def __init__(self, create=False, subject=None, string=None, filename=None):
239        if create or subject:
240            self.create()
241        if subject:
242            self.set_subject(subject)
243        if string:
244            self.load_from_string(string)
245        if filename:
246            self.load_from_file(filename)
247
248    ##
249    # Create a blank X509 certificate and store it in this object.
250
251    def create(self):
252        self.cert = crypto.X509()
253        self.cert.set_serial_number(1)
254        self.cert.gmtime_adj_notBefore(0)
255        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
256
257    ##
258    # Given a pyOpenSSL X509 object, store that object inside of this
259    # certificate object.
260
261    def load_from_pyopenssl_x509(self, x509):
262        self.cert = x509
263
264    ##
265    # Load the certificate from a string
266
267    def load_from_string(self, string):
268        # if it is a chain of multiple certs, then split off the first one and
269        # load it
270        parts = string.split(Certificate.separator, 1)
271        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
272
273        # if there are more certs, then create a parent and let the parent load
274        # itself from the remainder of the string
275        if len(parts) > 1:
276            self.parent = self.__class__()
277            self.parent.load_from_string(parts[1])
278
279    ##
280    # Load the certificate from a file
281
282    def load_from_file(self, filename):
283        file = open(filename)
284        string = file.read()
285        self.load_from_string(string)
286
287    ##
288    # Save the certificate to a string.
289    #
290    # @param save_parents If save_parents==True, then also save the parent certificates.
291
292    def save_to_string(self, save_parents=False):
293        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
294        if save_parents and self.parent:
295           string = string + Certificate.separator + self.parent.save_to_string(save_parents)
296        return string
297
298    ##
299    # Save the certificate to a file.
300    # @param save_parents If save_parents==True, then also save the parent certificates.
301
302    def save_to_file(self, filename, save_parents=False):
303        string = self.save_to_string(save_parents=save_parents)
304        open(filename, 'w').write(string)
305
306    ##
307    # Sets the issuer private key and name
308    # @param key Keypair object containing the private key of the issuer
309    # @param subject String containing the name of the issuer
310    # @param cert (optional) Certificate object containing the name of the issuer
311
312    def set_issuer(self, key, subject=None, cert=None):
313        self.issuerKey = key
314        if subject:
315           # it's a mistake to use subject and cert params at the same time
316           assert(not cert)
317           if isinstance(subject, dict) or isinstance(subject, str):
318              req = crypto.X509Req()
319              reqSubject = req.get_subject()
320              if (isinstance(subject, dict)):
321                 for key in reqSubject.keys():
322                     setattr(reqSubject, key, name[key])
323              else:
324                 setattr(reqSubject, "CN", subject)
325              subject = reqSubject
326              # subject is not valid once req is out of scope, so save req
327              self.issuerReq = req
328        if cert:
329           # if a cert was supplied, then get the subject from the cert
330           subject = cert.cert.get_issuer()
331        assert(subject)
332        self.issuerSubject = subject
333
334    ##
335    # Get the issuer name
336
337    def get_issuer(self, which="CN"):
338        x = self.cert.get_issuer()
339        return getattr(x, which)
340
341    ##
342    # Set the subject name of the certificate
343
344    def set_subject(self, name):
345        req = crypto.X509Req()
346        subj = req.get_subject()
347        if (isinstance(name, dict)):
348            for key in name.keys():
349                setattr(subj, key, name[key])
350        else:
351            setattr(subj, "CN", name)
352        self.cert.set_subject(subj)
353    ##
354    # Get the subject name of the certificate
355
356    def get_subject(self, which="CN"):
357        x = self.cert.get_subject()
358        return getattr(x, which)
359
360    ##
361    # Get the public key of the certificate.
362    #
363    # @param key Keypair object containing the public key
364
365    def set_pubkey(self, key):
366        assert(isinstance(key, Keypair))
367        self.cert.set_pubkey(key.get_openssl_pkey())
368
369    ##
370    # Get the public key of the certificate.
371    # It is returned in the form of a Keypair object.
372
373    def get_pubkey(self):
374        m2x509 = X509.load_cert_string(self.save_to_string())
375        pkey = Keypair()
376        pkey.key = self.cert.get_pubkey()
377        pkey.m2key = m2x509.get_pubkey()
378        return pkey
379
380    ##
381    # Add an X509 extension to the certificate. Add_extension can only be called
382    # once for a particular extension name, due to limitations in the underlying
383    # library.
384    #
385    # @param name string containing name of extension
386    # @param value string containing value of the extension
387
388    def add_extension(self, name, critical, value):
389        ext = crypto.X509Extension (name, critical, value)
390        self.cert.add_extensions([ext])
391
392    ##
393    # Get an X509 extension from the certificate
394
395    def get_extension(self, name):
396        # pyOpenSSL does not have a way to get extensions
397        m2x509 = X509.load_cert_string(self.save_to_string())
398        value = m2x509.get_ext(name).get_value()
399        return value
400
401    ##
402    # Set_data is a wrapper around add_extension. It stores the parameter str in
403    # the X509 subject_alt_name extension. Set_data can only be called once, due
404    # to limitations in the underlying library.
405
406    def set_data(self, str):
407        # pyOpenSSL only allows us to add extensions, so if we try to set the
408        # same extension more than once, it will not work
409        if self.data != None:
410           raise "cannot set subjectAltName more than once"
411        self.data = str
412        self.add_extension("subjectAltName", 0, "URI:http://" + str)
413
414    ##
415    # Return the data string that was previously set with set_data
416
417    def get_data(self):
418        if self.data:
419            return self.data
420
421        try:
422            uri = self.get_extension("subjectAltName")
423        except LookupError:
424            self.data = None
425            return self.data
426
427        if not uri.startswith("URI:http://"):
428            raise "bad encoding in subjectAltName"
429        self.data = uri[11:]
430        return self.data
431
432    ##
433    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
434
435    def sign(self):
436        assert self.cert != None
437        assert self.issuerSubject != None
438        assert self.issuerKey != None
439        self.cert.set_issuer(self.issuerSubject)
440        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
441
442     ##
443     # Verify the authenticity of a certificate.
444     # @param pkey is a Keypair object representing a public key. If Pkey
445     #     did not sign the certificate, then an exception will be thrown.
446
447    def verify(self, pkey):
448        # pyOpenSSL does not have a way to verify signatures
449        m2x509 = X509.load_cert_string(self.save_to_string())
450        m2pkey = pkey.get_m2_pkey()
451        # verify it
452        return m2x509.verify(m2pkey)
453
454        # XXX alternatively, if openssl has been patched, do the much simpler:
455        # try:
456        #   self.cert.verify(pkey.get_openssl_key())
457        #   return 1
458        # except:
459        #   return 0
460
461    ##
462    # Return True if pkey is identical to the public key that is contained in the certificate.
463    # @param pkey Keypair object
464
465    def is_pubkey(self, pkey):
466        return self.get_pubkey().is_same(pkey)
467
468    ##
469    # Given a certificate cert, verify that this certificate was signed by the
470    # public key contained in cert. Throw an exception otherwise.
471    #
472    # @param cert certificate object
473
474    def is_signed_by_cert(self, cert):
475        k = cert.get_pubkey()
476        result = self.verify(k)
477        return result
478
479    ##
480    # Set the parent certficiate.
481    #
482    # @param p certificate object.
483
484    def set_parent(self, p):
485         self.parent = p
486
487    ##
488    # Return the certificate object of the parent of this certificate.
489
490    def get_parent(self):
491         return self.parent
492
493    ##
494    # Verification examines a chain of certificates to ensure that each parent
495    # signs the child, and that some certificate in the chain is signed by a
496    # trusted certificate.
497    #
498    # Verification is a basic recursion: <pre>
499    #     if this_certificate was signed by trusted_certs:
500    #         return
501    #     else
502    #         return verify_chain(parent, trusted_certs)
503    # </pre>
504    #
505    # At each recursion, the parent is tested to ensure that it did sign the
506    # child. If a parent did not sign a child, then an exception is thrown. If
507    # the bottom of the recursion is reached and the certificate does not match
508    # a trusted root, then an exception is thrown.
509    #
510    # @param Trusted_certs is a list of certificates that are trusted.
511    #
512
513    def verify_chain(self, trusted_certs = None):
514         # Verify a chain of certificates. Each certificate must be signed by
515         # the public key contained in it's parent. The chain is recursed
516         # until a certificate is found that is signed by a trusted root.
517
518         # TODO: verify expiration time
519         print "====Verify Chain====="
520         # if this cert is signed by a trusted_cert, then we are set
521         for trusted_cert in trusted_certs:
522             print "***************"
523             # TODO: verify expiration of trusted_cert ?
524             print "CLIENT CERT", self.dump()
525             print "TRUSTED CERT", trusted_cert.dump()
526             print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
527             if self.is_signed_by_cert(trusted_cert):
528                 #print self.get_subject(), "is signed by a root"
529                 return
530
531         # if there is no parent, then no way to verify the chain
532         if not self.parent:
533             #print self.get_subject(), "has no parent"
534             raise CertMissingParent(self.get_subject())
535
536         # if it wasn't signed by the parent...
537         if not self.is_signed_by_cert(self.parent):
538             #print self.get_subject(), "is not signed by parent"
539             return CertNotSignedByParent(self.get_subject())
540
541         # if the parent isn't verified...
542         self.parent.verify_chain(trusted_certs)
543
544         return