started to repair sfadump
[sfa.git] / sfa / trust / certificate.py
1 #----------------------------------------------------------------------
2 # Copyright (c) 2008 Board of Trustees, Princeton University
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining
5 # a copy of this software and/or hardware specification (the "Work") to
6 # deal in the Work without restriction, including without limitation the
7 # rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Work, and to permit persons to whom the Work
9 # is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice shall be
12 # included in all copies or substantial portions of the Work.
13 #
14 # THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
15 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
16 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
17 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
19 # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
20 # OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
21 # IN THE WORK.
22 #----------------------------------------------------------------------
23
24 ##
25 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
26 # the necessary crypto functionality. Ideally just one of these libraries
27 # would be used, but unfortunately each of these libraries is independently
28 # lacking. The pyOpenSSL library is missing many necessary functions, and
29 # the M2Crypto library has crashed inside of some of the functions. The
30 # design decision is to use pyOpenSSL whenever possible as it seems more
31 # stable, and only use M2Crypto for those functions that are not possible
32 # in pyOpenSSL.
33 #
34 # This module exports two classes: Keypair and Certificate.
35 ##
36 #
37 ### $Id$
38 ### $URL$
39 #
40
41 import os
42 import tempfile
43 import base64
44 import traceback
45 from tempfile import mkstemp
46
47 from OpenSSL import crypto
48 import M2Crypto
49 from M2Crypto import X509
50
51 from sfa.util.sfalogging import sfa_logger
52 from sfa.util.namespace import urn_to_hrn
53 from sfa.util.faults import *
54
55 def convert_public_key(key):
56     keyconvert_path = "/usr/bin/keyconvert.py"
57     if not os.path.isfile(keyconvert_path):
58         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
59
60     # we can only convert rsa keys
61     if "ssh-dss" in key:
62         return None
63
64     (ssh_f, ssh_fn) = tempfile.mkstemp()
65     ssl_fn = tempfile.mktemp()
66     os.write(ssh_f, key)
67     os.close(ssh_f)
68
69     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
70     os.system(cmd)
71
72     # this check leaves the temporary file containing the public key so
73     # that it can be expected to see why it failed.
74     # TODO: for production, cleanup the temporary files
75     if not os.path.exists(ssl_fn):
76         return None
77
78     k = Keypair()
79     try:
80         k.load_pubkey_from_file(ssl_fn)
81     except:
82         sfa_logger().log_exc("convert_public_key caught exception")
83         k = None
84
85     # remove the temporary files
86     os.remove(ssh_fn)
87     os.remove(ssl_fn)
88
89     return k
90
91 ##
92 # Public-private key pairs are implemented by the Keypair class.
93 # A Keypair object may represent both a public and private key pair, or it
94 # may represent only a public key (this usage is consistent with OpenSSL).
95
96 class Keypair:
97     key = None       # public/private keypair
98     m2key = None     # public key (m2crypto format)
99
100     ##
101     # Creates a Keypair object
102     # @param create If create==True, creates a new public/private key and
103     #     stores it in the object
104     # @param string If string!=None, load the keypair from the string (PEM)
105     # @param filename If filename!=None, load the keypair from the file
106
107     def __init__(self, create=False, string=None, filename=None):
108         if create:
109             self.create()
110         if string:
111             self.load_from_string(string)
112         if filename:
113             self.load_from_file(filename)
114
115     ##
116     # Create a RSA public/private key pair and store it inside the keypair object
117
118     def create(self):
119         self.key = crypto.PKey()
120         self.key.generate_key(crypto.TYPE_RSA, 1024)
121
122     ##
123     # Save the private key to a file
124     # @param filename name of file to store the keypair in
125
126     def save_to_file(self, filename):
127         open(filename, 'w').write(self.as_pem())
128
129     ##
130     # Load the private key from a file. Implicity the private key includes the public key.
131
132     def load_from_file(self, filename):
133         buffer = open(filename, 'r').read()
134         self.load_from_string(buffer)
135
136     ##
137     # Load the private key from a string. Implicitly the private key includes the public key.
138
139     def load_from_string(self, string):
140         self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
141         self.m2key = M2Crypto.EVP.load_key_string(string)
142
143     ##
144     #  Load the public key from a string. No private key is loaded.
145
146     def load_pubkey_from_file(self, filename):
147         # load the m2 public key
148         m2rsakey = M2Crypto.RSA.load_pub_key(filename)
149         self.m2key = M2Crypto.EVP.PKey()
150         self.m2key.assign_rsa(m2rsakey)
151
152         # create an m2 x509 cert
153         m2name = M2Crypto.X509.X509_Name()
154         m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
155         m2x509 = M2Crypto.X509.X509()
156         m2x509.set_pubkey(self.m2key)
157         m2x509.set_serial_number(0)
158         m2x509.set_issuer_name(m2name)
159         m2x509.set_subject_name(m2name)
160         ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
161         ASN1.set_time(500)
162         m2x509.set_not_before(ASN1)
163         m2x509.set_not_after(ASN1)
164         junk_key = Keypair(create=True)
165         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
166
167         # convert the m2 x509 cert to a pyopenssl x509
168         m2pem = m2x509.as_pem()
169         pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
170
171         # get the pyopenssl pkey from the pyopenssl x509
172         self.key = pyx509.get_pubkey()
173
174     ##
175     # Load the public key from a string. No private key is loaded.
176
177     def load_pubkey_from_string(self, string):
178         (f, fn) = tempfile.mkstemp()
179         os.write(f, string)
180         os.close(f)
181         self.load_pubkey_from_file(fn)
182         os.remove(fn)
183
184     ##
185     # Return the private key in PEM format.
186
187     def as_pem(self):
188         return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
189
190     ##
191     # Return an M2Crypto key object
192
193     def get_m2_pkey(self):
194         if not self.m2key:
195             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
196         return self.m2key
197
198     ##
199     # Returns a string containing the public key represented by this object.
200
201     def get_pubkey_string(self):
202         m2pkey = self.get_m2_pkey()
203         return base64.b64encode(m2pkey.as_der())
204
205     ##
206     # Return an OpenSSL pkey object
207
208     def get_openssl_pkey(self):
209         return self.key
210
211
212     ##
213     # Given another Keypair object, return TRUE if the two keys are the same.
214
215     def is_same(self, pkey):
216         return self.as_pem() == pkey.as_pem()
217
218     def sign_string(self, data):
219         k = self.get_m2_pkey()
220         k.sign_init()
221         k.sign_update(data)
222         return base64.b64encode(k.sign_final())
223
224     def verify_string(self, data, sig):
225         k = self.get_m2_pkey()
226         k.verify_init()
227         k.verify_update(data)
228         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
229
230     def compute_hash(self, value):
231         return self.sign_string(str(value))
232
233 ##
234 # The certificate class implements a general purpose X509 certificate, making
235 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
236 # several addition features, such as the ability to maintain a chain of
237 # parent certificates, and storage of application-specific data.
238 #
239 # Certificates include the ability to maintain a chain of parents. Each
240 # certificate includes a pointer to it's parent certificate. When loaded
241 # from a file or a string, the parent chain will be automatically loaded.
242 # When saving a certificate to a file or a string, the caller can choose
243 # whether to save the parent certificates as well.
244
245 class Certificate:
246     digest = "md5"
247
248     cert = None
249     issuerKey = None
250     issuerSubject = None
251     parent = None
252
253     separator="-----parent-----"
254
255     ##
256     # Create a certificate object.
257     #
258     # @param create If create==True, then also create a blank X509 certificate.
259     # @param subject If subject!=None, then create a blank certificate and set
260     #     it's subject name.
261     # @param string If string!=None, load the certficate from the string.
262     # @param filename If filename!=None, load the certficiate from the file.
263
264     def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
265         self.data = {}
266         if create or subject:
267             self.create()
268         if subject:
269             self.set_subject(subject)
270         if string:
271             self.load_from_string(string)
272         if filename:
273             self.load_from_file(filename)
274
275         if intermediate:
276             self.set_intermediate_ca(intermediate)
277
278     ##
279     # Create a blank X509 certificate and store it in this object.
280
281     def create(self):
282         self.cert = crypto.X509()
283         self.cert.set_serial_number(3)
284         self.cert.gmtime_adj_notBefore(0)
285         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
286
287     ##
288     # Given a pyOpenSSL X509 object, store that object inside of this
289     # certificate object.
290
291     def load_from_pyopenssl_x509(self, x509):
292         self.cert = x509
293
294     ##
295     # Load the certificate from a string
296
297     def load_from_string(self, string):
298         # if it is a chain of multiple certs, then split off the first one and
299         # load it (support for the ---parent--- tag as well as normal chained certs)
300
301         string = string.strip()
302
303
304         if not string.startswith('-----'):
305             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
306
307         parts = []
308
309         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
310                string.count(Certificate.separator) == 0:
311             parts = string.split('-----END CERTIFICATE-----',1)
312             parts[0] += '-----END CERTIFICATE-----'
313         else:
314             parts = string.split(Certificate.separator, 1)
315
316         self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
317
318         # if there are more certs, then create a parent and let the parent load
319         # itself from the remainder of the string
320         if len(parts) > 1 and parts[1] != '':
321             self.parent = self.__class__()
322             self.parent.load_from_string(parts[1])
323
324     ##
325     # Load the certificate from a file
326
327     def load_from_file(self, filename):
328         file = open(filename)
329         string = file.read()
330         self.load_from_string(string)
331
332     ##
333     # Save the certificate to a string.
334     #
335     # @param save_parents If save_parents==True, then also save the parent certificates.
336
337     def save_to_string(self, save_parents=True):
338         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
339         if save_parents and self.parent:
340             string = string + self.parent.save_to_string(save_parents)
341         return string
342
343     ##
344     # Save the certificate to a file.
345     # @param save_parents If save_parents==True, then also save the parent certificates.
346
347     def save_to_file(self, filename, save_parents=True, filep=None):
348         string = self.save_to_string(save_parents=save_parents)
349         if filep:
350             f = filep
351         else:
352             f = open(filename, 'w')
353         f.write(string)
354         f.close()
355
356     ##
357     # Save the certificate to a random file in /tmp/
358     # @param save_parents If save_parents==True, then also save the parent certificates.
359     def save_to_random_tmp_file(self, save_parents=True):
360         fp, filename = mkstemp(suffix='cert', text=True)
361         fp = os.fdopen(fp, "w")
362         self.save_to_file(filename, save_parents=True, filep=fp)
363         return filename
364
365     ##
366     # Sets the issuer private key and name
367     # @param key Keypair object containing the private key of the issuer
368     # @param subject String containing the name of the issuer
369     # @param cert (optional) Certificate object containing the name of the issuer
370
371     def set_issuer(self, key, subject=None, cert=None):
372         self.issuerKey = key
373         if subject:
374             # it's a mistake to use subject and cert params at the same time
375             assert(not cert)
376             if isinstance(subject, dict) or isinstance(subject, str):
377                 req = crypto.X509Req()
378                 reqSubject = req.get_subject()
379                 if (isinstance(subject, dict)):
380                     for key in reqSubject.keys():
381                         setattr(reqSubject, key, subject[key])
382                 else:
383                     setattr(reqSubject, "CN", subject)
384                 subject = reqSubject
385                 # subject is not valid once req is out of scope, so save req
386                 self.issuerReq = req
387         if cert:
388             # if a cert was supplied, then get the subject from the cert
389             subject = cert.cert.get_subject()
390         assert(subject)
391         self.issuerSubject = subject
392
393     ##
394     # Get the issuer name
395
396     def get_issuer(self, which="CN"):
397         x = self.cert.get_issuer()
398         return getattr(x, which)
399
400     ##
401     # Set the subject name of the certificate
402
403     def set_subject(self, name):
404         req = crypto.X509Req()
405         subj = req.get_subject()
406         if (isinstance(name, dict)):
407             for key in name.keys():
408                 setattr(subj, key, name[key])
409         else:
410             setattr(subj, "CN", name)
411         self.cert.set_subject(subj)
412     ##
413     # Get the subject name of the certificate
414
415     def get_subject(self, which="CN"):
416         x = self.cert.get_subject()
417         return getattr(x, which)
418
419     ##
420     # Get the public key of the certificate.
421     #
422     # @param key Keypair object containing the public key
423
424     def set_pubkey(self, key):
425         assert(isinstance(key, Keypair))
426         self.cert.set_pubkey(key.get_openssl_pkey())
427
428     ##
429     # Get the public key of the certificate.
430     # It is returned in the form of a Keypair object.
431
432     def get_pubkey(self):
433         m2x509 = X509.load_cert_string(self.save_to_string())
434         pkey = Keypair()
435         pkey.key = self.cert.get_pubkey()
436         pkey.m2key = m2x509.get_pubkey()
437         return pkey
438
439     def set_intermediate_ca(self, val):
440         self.intermediate = val
441         if val:
442             self.add_extension('basicConstraints', 1, 'CA:TRUE')
443
444
445
446     ##
447     # Add an X509 extension to the certificate. Add_extension can only be called
448     # once for a particular extension name, due to limitations in the underlying
449     # library.
450     #
451     # @param name string containing name of extension
452     # @param value string containing value of the extension
453
454     def add_extension(self, name, critical, value):
455         ext = crypto.X509Extension (name, critical, value)
456         self.cert.add_extensions([ext])
457
458     ##
459     # Get an X509 extension from the certificate
460
461     def get_extension(self, name):
462
463         # pyOpenSSL does not have a way to get extensions
464         m2x509 = X509.load_cert_string(self.save_to_string())
465         value = m2x509.get_ext(name).get_value()
466         
467         return value
468
469     ##
470     # Set_data is a wrapper around add_extension. It stores the parameter str in
471     # the X509 subject_alt_name extension. Set_data can only be called once, due
472     # to limitations in the underlying library.
473
474     def set_data(self, str, field='subjectAltName'):
475         # pyOpenSSL only allows us to add extensions, so if we try to set the
476         # same extension more than once, it will not work
477         if self.data.has_key(field):
478             raise "Cannot set ", field, " more than once"
479         self.data[field] = str
480         self.add_extension(field, 0, str)
481
482     ##
483     # Return the data string that was previously set with set_data
484
485     def get_data(self, field='subjectAltName'):
486         if self.data.has_key(field):
487             return self.data[field]
488
489         try:
490             uri = self.get_extension(field)
491             self.data[field] = uri
492         except LookupError:
493             return None
494
495         return self.data[field]
496
497     ##
498     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
499
500     def sign(self):
501         sfa_logger().debug('certificate.sign')
502         assert self.cert != None
503         assert self.issuerSubject != None
504         assert self.issuerKey != None
505         self.cert.set_issuer(self.issuerSubject)
506         self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
507
508     ##
509     # Verify the authenticity of a certificate.
510     # @param pkey is a Keypair object representing a public key. If Pkey
511     #     did not sign the certificate, then an exception will be thrown.
512
513     def verify(self, pkey):
514         # pyOpenSSL does not have a way to verify signatures
515         m2x509 = X509.load_cert_string(self.save_to_string())
516         m2pkey = pkey.get_m2_pkey()
517         # verify it
518         return m2x509.verify(m2pkey)
519
520         # XXX alternatively, if openssl has been patched, do the much simpler:
521         # try:
522         #   self.cert.verify(pkey.get_openssl_key())
523         #   return 1
524         # except:
525         #   return 0
526
527     ##
528     # Return True if pkey is identical to the public key that is contained in the certificate.
529     # @param pkey Keypair object
530
531     def is_pubkey(self, pkey):
532         return self.get_pubkey().is_same(pkey)
533
534     ##
535     # Given a certificate cert, verify that this certificate was signed by the
536     # public key contained in cert. Throw an exception otherwise.
537     #
538     # @param cert certificate object
539
540     def is_signed_by_cert(self, cert):
541         print 'is_signed_by_cert'
542         k = cert.get_pubkey()
543         result = self.verify(k)
544         return result
545
546     ##
547     # Set the parent certficiate.
548     #
549     # @param p certificate object.
550
551     def set_parent(self, p):
552         self.parent = p
553
554     ##
555     # Return the certificate object of the parent of this certificate.
556
557     def get_parent(self):
558         return self.parent
559
560     ##
561     # Verification examines a chain of certificates to ensure that each parent
562     # signs the child, and that some certificate in the chain is signed by a
563     # trusted certificate.
564     #
565     # Verification is a basic recursion: <pre>
566     #     if this_certificate was signed by trusted_certs:
567     #         return
568     #     else
569     #         return verify_chain(parent, trusted_certs)
570     # </pre>
571     #
572     # At each recursion, the parent is tested to ensure that it did sign the
573     # child. If a parent did not sign a child, then an exception is thrown. If
574     # the bottom of the recursion is reached and the certificate does not match
575     # a trusted root, then an exception is thrown.
576     #
577     # @param Trusted_certs is a list of certificates that are trusted.
578     #
579
580     def verify_chain(self, trusted_certs = None):
581         # Verify a chain of certificates. Each certificate must be signed by
582         # the public key contained in it's parent. The chain is recursed
583         # until a certificate is found that is signed by a trusted root.
584
585         # verify expiration time
586         if self.cert.has_expired():
587             sfa_logger().debug("verify_chain: NO our certificate has expired")
588             raise CertExpired(self.get_subject(), "client cert")   
589         
590         # if this cert is signed by a trusted_cert, then we are set
591         for trusted_cert in trusted_certs:
592             if self.is_signed_by_cert(trusted_cert):
593                 # verify expiration of trusted_cert ?
594                 if not trusted_cert.cert.has_expired():
595                     sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
596                             self.get_subject(), trusted_cert.get_subject()))
597                     return trusted_cert
598                 else:
599                     sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
600                             self.get_subject(),trusted_cert.get_subject()))
601                     raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
602
603         # if there is no parent, then no way to verify the chain
604         if not self.parent:
605             sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
606             raise CertMissingParent(self.get_subject())
607
608         # if it wasn't signed by the parent...
609         if not self.is_signed_by_cert(self.parent):
610             sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
611             return CertNotSignedByParent(self.get_subject())
612
613         # if the parent isn't verified...
614         sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject())
615         self.parent.verify_chain(trusted_certs)
616
617         return
618
619     ### more introspection
620     def get_extensions(self):
621         # pyOpenSSL does not have a way to get extensions
622         triples=[]
623         m2x509 = X509.load_cert_string(self.save_to_string())
624         nb_extensions=m2x509.get_ext_count()
625         sfa_logger().debug("X509 had %d extensions"%nb_extensions)
626         for i in range(nb_extensions):
627             ext=m2x509.get_ext_at(i)
628             triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
629         return triples
630
631     def get_data_names(self):
632         return self.data.keys()
633
634     def get_all_datas (self):
635         triples=self.get_extensions()
636         for name in self.get_data_names(): 
637             triples.append( (name,self.get_data(name),'data',) )
638         return triples
639
640     def dump (self, *args, **kwargs):
641         print self.dump_string(*args, **kwargs)
642
643     def dump_string (self):
644         result = ""
645         result += "Certificate for %s\n"%self.get_subject()
646         result += "Issued by %s\n"%self.get_issuer()
647         for (n,v,c) in self.get_all_datas():
648             if c=='data':
649                 result += "   data: %s=%s\n"%(n,v)
650             else:
651                 result += "    ext: %s=%s (%s)\n"%(n,v,c)
652         return result