Merge commit 'origin/jktest' into jktest3
[sfa.git] / sfa / trust / certificate.py
1 #----------------------------------------------------------------------
2 # Copyright (c) 2008 Board of Trustees, Princeton University
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining
5 # a copy of this software and/or hardware specification (the "Work") to
6 # deal in the Work without restriction, including without limitation the
7 # rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Work, and to permit persons to whom the Work
9 # is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice shall be
12 # included in all copies or substantial portions of the Work.
13 #
14 # THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
15 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
16 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
17 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
19 # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
20 # OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
21 # IN THE WORK.
22 #----------------------------------------------------------------------
23
24 ##
25 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
26 # the necessary crypto functionality. Ideally just one of these libraries
27 # would be used, but unfortunately each of these libraries is independently
28 # lacking. The pyOpenSSL library is missing many necessary functions, and
29 # the M2Crypto library has crashed inside of some of the functions. The
30 # design decision is to use pyOpenSSL whenever possible as it seems more
31 # stable, and only use M2Crypto for those functions that are not possible
32 # in pyOpenSSL.
33 #
34 # This module exports two classes: Keypair and Certificate.
35 ##
36 #
37 ### $Id$
38 ### $URL$
39 #
40
41 import os
42 import tempfile
43 import base64
44 import traceback
45 from tempfile import mkstemp
46
47 from OpenSSL import crypto
48 import M2Crypto
49 from M2Crypto import X509
50
51 from sfa.util.sfalogging import sfa_logger
52 from sfa.util.namespace import urn_to_hrn
53 from sfa.util.faults import *
54
55 def convert_public_key(key):
56     keyconvert_path = "/usr/bin/keyconvert.py"
57     if not os.path.isfile(keyconvert_path):
58         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
59
60     # we can only convert rsa keys
61     if "ssh-dss" in key:
62         return None
63
64     (ssh_f, ssh_fn) = tempfile.mkstemp()
65     ssl_fn = tempfile.mktemp()
66     os.write(ssh_f, key)
67     os.close(ssh_f)
68
69     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
70     os.system(cmd)
71
72     # this check leaves the temporary file containing the public key so
73     # that it can be expected to see why it failed.
74     # TODO: for production, cleanup the temporary files
75     if not os.path.exists(ssl_fn):
76         return None
77
78     k = Keypair()
79     try:
80         k.load_pubkey_from_file(ssl_fn)
81     except:
82         sfa_logger.log_exc("convert_public_key caught exception")
83         k = None
84
85     # remove the temporary files
86     os.remove(ssh_fn)
87     os.remove(ssl_fn)
88
89     return k
90
91 ##
92 # Public-private key pairs are implemented by the Keypair class.
93 # A Keypair object may represent both a public and private key pair, or it
94 # may represent only a public key (this usage is consistent with OpenSSL).
95
96 class Keypair:
97     key = None       # public/private keypair
98     m2key = None     # public key (m2crypto format)
99
100     ##
101     # Creates a Keypair object
102     # @param create If create==True, creates a new public/private key and
103     #     stores it in the object
104     # @param string If string!=None, load the keypair from the string (PEM)
105     # @param filename If filename!=None, load the keypair from the file
106
107     def __init__(self, create=False, string=None, filename=None):
108         if create:
109             self.create()
110         if string:
111             self.load_from_string(string)
112         if filename:
113             self.load_from_file(filename)
114
115     ##
116     # Create a RSA public/private key pair and store it inside the keypair object
117
118     def create(self):
119         self.key = crypto.PKey()
120         self.key.generate_key(crypto.TYPE_RSA, 1024)
121
122     ##
123     # Save the private key to a file
124     # @param filename name of file to store the keypair in
125
126     def save_to_file(self, filename):
127         open(filename, 'w').write(self.as_pem())
128
129     ##
130     # Load the private key from a file. Implicity the private key includes the public key.
131
132     def load_from_file(self, filename):
133         buffer = open(filename, 'r').read()
134         self.load_from_string(buffer)
135
136     ##
137     # Load the private key from a string. Implicitly the private key includes the public key.
138
139     def load_from_string(self, string):
140         self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
141         self.m2key = M2Crypto.EVP.load_key_string(string)
142
143     ##
144     #  Load the public key from a string. No private key is loaded.
145
146     def load_pubkey_from_file(self, filename):
147         # load the m2 public key
148         m2rsakey = M2Crypto.RSA.load_pub_key(filename)
149         self.m2key = M2Crypto.EVP.PKey()
150         self.m2key.assign_rsa(m2rsakey)
151
152         # create an m2 x509 cert
153         m2name = M2Crypto.X509.X509_Name()
154         m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
155         m2x509 = M2Crypto.X509.X509()
156         m2x509.set_pubkey(self.m2key)
157         m2x509.set_serial_number(0)
158         m2x509.set_issuer_name(m2name)
159         m2x509.set_subject_name(m2name)
160         ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
161         ASN1.set_time(500)
162         m2x509.set_not_before(ASN1)
163         m2x509.set_not_after(ASN1)
164         junk_key = Keypair(create=True)
165         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
166
167         # convert the m2 x509 cert to a pyopenssl x509
168         m2pem = m2x509.as_pem()
169         pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
170
171         # get the pyopenssl pkey from the pyopenssl x509
172         self.key = pyx509.get_pubkey()
173
174     ##
175     # Load the public key from a string. No private key is loaded.
176
177     def load_pubkey_from_string(self, string):
178         (f, fn) = tempfile.mkstemp()
179         os.write(f, string)
180         os.close(f)
181         self.load_pubkey_from_file(fn)
182         os.remove(fn)
183
184     ##
185     # Return the private key in PEM format.
186
187     def as_pem(self):
188         return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
189
190     ##
191     # Return an M2Crypto key object
192
193     def get_m2_pkey(self):
194         if not self.m2key:
195             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
196         return self.m2key
197
198     ##
199     # Returns a string containing the public key represented by this object.
200
201     def get_pubkey_string(self):
202         m2pkey = self.get_m2_pkey()
203         return base64.b64encode(m2pkey.as_der())
204
205     ##
206     # Return an OpenSSL pkey object
207
208     def get_openssl_pkey(self):
209         return self.key
210
211
212     ##
213     # Given another Keypair object, return TRUE if the two keys are the same.
214
215     def is_same(self, pkey):
216         return self.as_pem() == pkey.as_pem()
217
218     def sign_string(self, data):
219         k = self.get_m2_pkey()
220         k.sign_init()
221         k.sign_update(data)
222         return base64.b64encode(k.sign_final())
223
224     def verify_string(self, data, sig):
225         k = self.get_m2_pkey()
226         k.verify_init()
227         k.verify_update(data)
228         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
229
230     def compute_hash(self, value):
231         return self.sign_string(str(value))
232
233 ##
234 # The certificate class implements a general purpose X509 certificate, making
235 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
236 # several addition features, such as the ability to maintain a chain of
237 # parent certificates, and storage of application-specific data.
238 #
239 # Certificates include the ability to maintain a chain of parents. Each
240 # certificate includes a pointer to it's parent certificate. When loaded
241 # from a file or a string, the parent chain will be automatically loaded.
242 # When saving a certificate to a file or a string, the caller can choose
243 # whether to save the parent certificates as well.
244
245 class Certificate:
246     digest = "md5"
247
248     cert = None
249     issuerKey = None
250     issuerSubject = None
251     parent = None
252
253     separator="-----parent-----"
254
255     ##
256     # Create a certificate object.
257     #
258     # @param create If create==True, then also create a blank X509 certificate.
259     # @param subject If subject!=None, then create a blank certificate and set
260     #     it's subject name.
261     # @param string If string!=None, load the certficate from the string.
262     # @param filename If filename!=None, load the certficiate from the file.
263
264     def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
265         self.data = {}
266         if create or subject:
267             self.create()
268         if subject:
269             self.set_subject(subject)
270         if string:
271             self.load_from_string(string)
272         if filename:
273             self.load_from_file(filename)
274
275         if intermediate:
276             self.set_intermediate_ca(intermediate)
277
278     ##
279     # Create a blank X509 certificate and store it in this object.
280
281     def create(self):
282         self.cert = crypto.X509()
283         self.cert.set_serial_number(3)
284         self.cert.gmtime_adj_notBefore(0)
285         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
286
287     ##
288     # Given a pyOpenSSL X509 object, store that object inside of this
289     # certificate object.
290
291     def load_from_pyopenssl_x509(self, x509):
292         self.cert = x509
293
294     ##
295     # Load the certificate from a string
296
297     def load_from_string(self, string):
298         # if it is a chain of multiple certs, then split off the first one and
299         # load it (support for the ---parent--- tag as well as normal chained certs)
300
301         string = string.strip()
302         
303         # If it's not in proper PEM format, wrap it
304         if string.count('-----BEGIN CERTIFICATE') == 0:
305             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
306
307         # If there is a PEM cert in there, but there is some other text first
308         # such as the text of the certificate, skip the text
309         beg = string.find('-----BEGIN CERTIFICATE')
310         if beg > 0:
311             # skipping over non cert beginning                                                                                                              
312             string = string[beg:]
313
314         parts = []
315
316         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
317                string.count(Certificate.separator) == 0:
318             parts = string.split('-----END CERTIFICATE-----',1)
319             parts[0] += '-----END CERTIFICATE-----'
320         else:
321             parts = string.split(Certificate.separator, 1)
322
323         self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
324
325         # if there are more certs, then create a parent and let the parent load
326         # itself from the remainder of the string
327         if len(parts) > 1 and parts[1] != '':
328             self.parent = self.__class__()
329             self.parent.load_from_string(parts[1])
330
331     ##
332     # Load the certificate from a file
333
334     def load_from_file(self, filename):
335         file = open(filename)
336         string = file.read()
337         self.load_from_string(string)
338
339     ##
340     # Save the certificate to a string.
341     #
342     # @param save_parents If save_parents==True, then also save the parent certificates.
343
344     def save_to_string(self, save_parents=True):
345         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
346         if save_parents and self.parent:
347             string = string + self.parent.save_to_string(save_parents)
348         return string
349
350     ##
351     # Save the certificate to a file.
352     # @param save_parents If save_parents==True, then also save the parent certificates.
353
354     def save_to_file(self, filename, save_parents=True, filep=None):
355         string = self.save_to_string(save_parents=save_parents)
356         if filep:
357             f = filep
358         else:
359             f = open(filename, 'w')
360         f.write(string)
361         f.close()
362
363     ##
364     # Save the certificate to a random file in /tmp/
365     # @param save_parents If save_parents==True, then also save the parent certificates.
366     def save_to_random_tmp_file(self, save_parents=True):
367         fp, filename = mkstemp(suffix='cert', text=True)
368         fp = os.fdopen(fp, "w")
369         self.save_to_file(filename, save_parents=True, filep=fp)
370         return filename
371
372     ##
373     # Sets the issuer private key and name
374     # @param key Keypair object containing the private key of the issuer
375     # @param subject String containing the name of the issuer
376     # @param cert (optional) Certificate object containing the name of the issuer
377
378     def set_issuer(self, key, subject=None, cert=None):
379         self.issuerKey = key
380         if subject:
381             # it's a mistake to use subject and cert params at the same time
382             assert(not cert)
383             if isinstance(subject, dict) or isinstance(subject, str):
384                 req = crypto.X509Req()
385                 reqSubject = req.get_subject()
386                 if (isinstance(subject, dict)):
387                     for key in reqSubject.keys():
388                         setattr(reqSubject, key, subject[key])
389                 else:
390                     setattr(reqSubject, "CN", subject)
391                 subject = reqSubject
392                 # subject is not valid once req is out of scope, so save req
393                 self.issuerReq = req
394         if cert:
395             # if a cert was supplied, then get the subject from the cert
396             subject = cert.cert.get_subject()
397         assert(subject)
398         self.issuerSubject = subject
399
400     ##
401     # Get the issuer name
402
403     def get_issuer(self, which="CN"):
404         x = self.cert.get_issuer()
405         return getattr(x, which)
406
407     ##
408     # Set the subject name of the certificate
409
410     def set_subject(self, name):
411         req = crypto.X509Req()
412         subj = req.get_subject()
413         if (isinstance(name, dict)):
414             for key in name.keys():
415                 setattr(subj, key, name[key])
416         else:
417             setattr(subj, "CN", name)
418         self.cert.set_subject(subj)
419     ##
420     # Get the subject name of the certificate
421
422     def get_subject(self, which="CN"):
423         x = self.cert.get_subject()
424         return getattr(x, which)
425
426     ##
427     # Get the public key of the certificate.
428     #
429     # @param key Keypair object containing the public key
430
431     def set_pubkey(self, key):
432         assert(isinstance(key, Keypair))
433         self.cert.set_pubkey(key.get_openssl_pkey())
434
435     ##
436     # Get the public key of the certificate.
437     # It is returned in the form of a Keypair object.
438
439     def get_pubkey(self):
440         m2x509 = X509.load_cert_string(self.save_to_string())
441         pkey = Keypair()
442         pkey.key = self.cert.get_pubkey()
443         pkey.m2key = m2x509.get_pubkey()
444         return pkey
445
446     def set_intermediate_ca(self, val):
447         self.intermediate = val
448         if val:
449             self.add_extension('basicConstraints', 1, 'CA:TRUE')
450
451
452
453     ##
454     # Add an X509 extension to the certificate. Add_extension can only be called
455     # once for a particular extension name, due to limitations in the underlying
456     # library.
457     #
458     # @param name string containing name of extension
459     # @param value string containing value of the extension
460
461     def add_extension(self, name, critical, value):
462         ext = crypto.X509Extension (name, critical, value)
463         self.cert.add_extensions([ext])
464
465     ##
466     # Get an X509 extension from the certificate
467
468     def get_extension(self, name):
469         # pyOpenSSL does not have a way to get extensions
470         m2x509 = X509.load_cert_string(self.save_to_string())
471         value = m2x509.get_ext(name).get_value()
472         return value
473
474     ##
475     # Set_data is a wrapper around add_extension. It stores the parameter str in
476     # the X509 subject_alt_name extension. Set_data can only be called once, due
477     # to limitations in the underlying library.
478
479     def set_data(self, str, field='subjectAltName'):
480         # pyOpenSSL only allows us to add extensions, so if we try to set the
481         # same extension more than once, it will not work
482         if self.data.has_key(field):
483             raise "Cannot set ", field, " more than once"
484         self.data[field] = str
485         self.add_extension(field, 0, str)
486
487     ##
488     # Return the data string that was previously set with set_data
489
490     def get_data(self, field='subjectAltName'):
491         if self.data.has_key(field):
492             return self.data[field]
493
494         try:
495             uri = self.get_extension(field)
496             self.data[field] = uri
497         except LookupError:
498             return None
499
500         return self.data[field]
501
502     ##
503     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
504
505     def sign(self):
506         assert self.cert != None
507         assert self.issuerSubject != None
508         assert self.issuerKey != None
509         self.cert.set_issuer(self.issuerSubject)
510         self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
511
512     ##
513     # Verify the authenticity of a certificate.
514     # @param pkey is a Keypair object representing a public key. If Pkey
515     #     did not sign the certificate, then an exception will be thrown.
516
517     def verify(self, pkey):
518         # pyOpenSSL does not have a way to verify signatures
519         m2x509 = X509.load_cert_string(self.save_to_string())
520         m2pkey = pkey.get_m2_pkey()
521         # verify it
522         return m2x509.verify(m2pkey)
523
524         # XXX alternatively, if openssl has been patched, do the much simpler:
525         # try:
526         #   self.cert.verify(pkey.get_openssl_key())
527         #   return 1
528         # except:
529         #   return 0
530
531     ##
532     # Return True if pkey is identical to the public key that is contained in the certificate.
533     # @param pkey Keypair object
534
535     def is_pubkey(self, pkey):
536         return self.get_pubkey().is_same(pkey)
537
538     ##
539     # Given a certificate cert, verify that this certificate was signed by the
540     # public key contained in cert. Throw an exception otherwise.
541     #
542     # @param cert certificate object
543
544     def is_signed_by_cert(self, cert):
545         k = cert.get_pubkey()
546         result = self.verify(k)
547         return result
548
549     ##
550     # Set the parent certficiate.
551     #
552     # @param p certificate object.
553
554     def set_parent(self, p):
555         self.parent = p
556
557     ##
558     # Return the certificate object of the parent of this certificate.
559
560     def get_parent(self):
561         return self.parent
562
563     ##
564     # Verification examines a chain of certificates to ensure that each parent
565     # signs the child, and that some certificate in the chain is signed by a
566     # trusted certificate.
567     #
568     # Verification is a basic recursion: <pre>
569     #     if this_certificate was signed by trusted_certs:
570     #         return
571     #     else
572     #         return verify_chain(parent, trusted_certs)
573     # </pre>
574     #
575     # At each recursion, the parent is tested to ensure that it did sign the
576     # child. If a parent did not sign a child, then an exception is thrown. If
577     # the bottom of the recursion is reached and the certificate does not match
578     # a trusted root, then an exception is thrown.
579     #
580     # @param Trusted_certs is a list of certificates that are trusted.
581     #
582
583     def verify_chain(self, trusted_certs = None):
584         # Verify a chain of certificates. Each certificate must be signed by
585         # the public key contained in it's parent. The chain is recursed
586         # until a certificate is found that is signed by a trusted root.
587
588         # verify expiration time
589         if self.cert.has_expired():
590             raise CertExpired(self.get_subject(), "client cert")   
591         
592         # if this cert is signed by a trusted_cert, then we are set
593         for trusted_cert in trusted_certs:
594             if self.is_signed_by_cert(trusted_cert):
595                 sfa_logger.debug("Cert %s signed by trusted cert %s", self.get_subject(), trusted_cert.get_subject())
596                 # verify expiration of trusted_cert ?
597                 if not trusted_cert.cert.has_expired():
598                     return trusted_cert
599                 else:
600                     sfa_logger.debug("Trusted cert %s is expired", trusted_cert.get_subject())       
601
602         # if there is no parent, then no way to verify the chain
603         if not self.parent:
604             sfa_logger.debug("%r has no parent"%self.get_subject())
605             raise CertMissingParent(self.get_subject())
606
607         # if it wasn't signed by the parent...
608         if not self.is_signed_by_cert(self.parent):
609             sfa_logger.debug("%r is not signed by parent"%self.get_subject())
610             return CertNotSignedByParent(self.get_subject())
611
612         # if the parent isn't verified...
613         self.parent.verify_chain(trusted_certs)
614
615         return