25bb99a753194b1fac69f4a775a45c873b4dd267
[sfa.git] / sfa / trust / certificate.py
1 #----------------------------------------------------------------------
2 # Copyright (c) 2008 Board of Trustees, Princeton University
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining
5 # a copy of this software and/or hardware specification (the "Work") to
6 # deal in the Work without restriction, including without limitation the
7 # rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Work, and to permit persons to whom the Work
9 # is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice shall be
12 # included in all copies or substantial portions of the Work.
13 #
14 # THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
15 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
16 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
17 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
19 # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
20 # OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
21 # IN THE WORK.
22 #----------------------------------------------------------------------
23
24 ##
25 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
26 # the necessary crypto functionality. Ideally just one of these libraries
27 # would be used, but unfortunately each of these libraries is independently
28 # lacking. The pyOpenSSL library is missing many necessary functions, and
29 # the M2Crypto library has crashed inside of some of the functions. The
30 # design decision is to use pyOpenSSL whenever possible as it seems more
31 # stable, and only use M2Crypto for those functions that are not possible
32 # in pyOpenSSL.
33 #
34 # This module exports two classes: Keypair and Certificate.
35 ##
36 #
37
38 import os
39 import tempfile
40 import base64
41 import traceback
42 from tempfile import mkstemp
43
44 from OpenSSL import crypto
45 import M2Crypto
46 from M2Crypto import X509
47
48 from sfa.util.sfalogging import sfa_logger
49 from sfa.util.namespace import urn_to_hrn
50 from sfa.util.faults import *
51
52 def convert_public_key(key):
53     keyconvert_path = "/usr/bin/keyconvert.py"
54     if not os.path.isfile(keyconvert_path):
55         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
56
57     # we can only convert rsa keys
58     if "ssh-dss" in key:
59         return None
60
61     (ssh_f, ssh_fn) = tempfile.mkstemp()
62     ssl_fn = tempfile.mktemp()
63     os.write(ssh_f, key)
64     os.close(ssh_f)
65
66     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
67     os.system(cmd)
68
69     # this check leaves the temporary file containing the public key so
70     # that it can be expected to see why it failed.
71     # TODO: for production, cleanup the temporary files
72     if not os.path.exists(ssl_fn):
73         return None
74
75     k = Keypair()
76     try:
77         k.load_pubkey_from_file(ssl_fn)
78     except:
79         sfa_logger().log_exc("convert_public_key caught exception")
80         k = None
81
82     # remove the temporary files
83     os.remove(ssh_fn)
84     os.remove(ssl_fn)
85
86     return k
87
88 ##
89 # Public-private key pairs are implemented by the Keypair class.
90 # A Keypair object may represent both a public and private key pair, or it
91 # may represent only a public key (this usage is consistent with OpenSSL).
92
93 class Keypair:
94     key = None       # public/private keypair
95     m2key = None     # public key (m2crypto format)
96
97     ##
98     # Creates a Keypair object
99     # @param create If create==True, creates a new public/private key and
100     #     stores it in the object
101     # @param string If string!=None, load the keypair from the string (PEM)
102     # @param filename If filename!=None, load the keypair from the file
103
104     def __init__(self, create=False, string=None, filename=None):
105         if create:
106             self.create()
107         if string:
108             self.load_from_string(string)
109         if filename:
110             self.load_from_file(filename)
111
112     ##
113     # Create a RSA public/private key pair and store it inside the keypair object
114
115     def create(self):
116         self.key = crypto.PKey()
117         self.key.generate_key(crypto.TYPE_RSA, 1024)
118
119     ##
120     # Save the private key to a file
121     # @param filename name of file to store the keypair in
122
123     def save_to_file(self, filename):
124         open(filename, 'w').write(self.as_pem())
125         self.filename=filename
126
127     ##
128     # Load the private key from a file. Implicity the private key includes the public key.
129
130     def load_from_file(self, filename):
131         buffer = open(filename, 'r').read()
132         self.load_from_string(buffer)
133         self.filename=filename
134
135     ##
136     # Load the private key from a string. Implicitly the private key includes the public key.
137
138     def load_from_string(self, string):
139         self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
140         self.m2key = M2Crypto.EVP.load_key_string(string)
141
142     ##
143     #  Load the public key from a string. No private key is loaded.
144
145     def load_pubkey_from_file(self, filename):
146         # load the m2 public key
147         m2rsakey = M2Crypto.RSA.load_pub_key(filename)
148         self.m2key = M2Crypto.EVP.PKey()
149         self.m2key.assign_rsa(m2rsakey)
150
151         # create an m2 x509 cert
152         m2name = M2Crypto.X509.X509_Name()
153         m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
154         m2x509 = M2Crypto.X509.X509()
155         m2x509.set_pubkey(self.m2key)
156         m2x509.set_serial_number(0)
157         m2x509.set_issuer_name(m2name)
158         m2x509.set_subject_name(m2name)
159         ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
160         ASN1.set_time(500)
161         m2x509.set_not_before(ASN1)
162         m2x509.set_not_after(ASN1)
163         junk_key = Keypair(create=True)
164         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
165
166         # convert the m2 x509 cert to a pyopenssl x509
167         m2pem = m2x509.as_pem()
168         pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
169
170         # get the pyopenssl pkey from the pyopenssl x509
171         self.key = pyx509.get_pubkey()
172         self.filename=filename
173
174     ##
175     # Load the public key from a string. No private key is loaded.
176
177     def load_pubkey_from_string(self, string):
178         (f, fn) = tempfile.mkstemp()
179         os.write(f, string)
180         os.close(f)
181         self.load_pubkey_from_file(fn)
182         os.remove(fn)
183
184     ##
185     # Return the private key in PEM format.
186
187     def as_pem(self):
188         return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
189
190     ##
191     # Return an M2Crypto key object
192
193     def get_m2_pkey(self):
194         if not self.m2key:
195             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
196         return self.m2key
197
198     ##
199     # Returns a string containing the public key represented by this object.
200
201     def get_pubkey_string(self):
202         m2pkey = self.get_m2_pkey()
203         return base64.b64encode(m2pkey.as_der())
204
205     ##
206     # Return an OpenSSL pkey object
207
208     def get_openssl_pkey(self):
209         return self.key
210
211     ##
212     # Given another Keypair object, return TRUE if the two keys are the same.
213
214     def is_same(self, pkey):
215         return self.as_pem() == pkey.as_pem()
216
217     def sign_string(self, data):
218         k = self.get_m2_pkey()
219         k.sign_init()
220         k.sign_update(data)
221         return base64.b64encode(k.sign_final())
222
223     def verify_string(self, data, sig):
224         k = self.get_m2_pkey()
225         k.verify_init()
226         k.verify_update(data)
227         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
228
229     def compute_hash(self, value):
230         return self.sign_string(str(value))
231
232     # only informative
233     def get_filename(self):
234         return getattr(self,'filename',None)
235
236     def dump (self, *args, **kwargs):
237         print self.dump_string(*args, **kwargs)
238
239     def dump_string (self):
240         result=""
241         result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
242         filename=self.get_filename()
243         if filename: result += "Filename %s\n"%filename
244         return result
245     
246 ##
247 # The certificate class implements a general purpose X509 certificate, making
248 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
249 # several addition features, such as the ability to maintain a chain of
250 # parent certificates, and storage of application-specific data.
251 #
252 # Certificates include the ability to maintain a chain of parents. Each
253 # certificate includes a pointer to it's parent certificate. When loaded
254 # from a file or a string, the parent chain will be automatically loaded.
255 # When saving a certificate to a file or a string, the caller can choose
256 # whether to save the parent certificates as well.
257
258 class Certificate:
259     digest = "md5"
260
261     cert = None
262     issuerKey = None
263     issuerSubject = None
264     parent = None
265
266     separator="-----parent-----"
267
268     ##
269     # Create a certificate object.
270     #
271     # @param create If create==True, then also create a blank X509 certificate.
272     # @param subject If subject!=None, then create a blank certificate and set
273     #     it's subject name.
274     # @param string If string!=None, load the certficate from the string.
275     # @param filename If filename!=None, load the certficiate from the file.
276
277     def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
278         self.data = {}
279         if create or subject:
280             self.create()
281         if subject:
282             self.set_subject(subject)
283         if string:
284             self.load_from_string(string)
285         if filename:
286             self.load_from_file(filename)
287
288         if intermediate:
289             self.set_intermediate_ca(intermediate)
290
291     ##
292     # Create a blank X509 certificate and store it in this object.
293
294     def create(self):
295         self.cert = crypto.X509()
296         self.cert.set_serial_number(3)
297         self.cert.gmtime_adj_notBefore(0)
298         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
299
300     ##
301     # Given a pyOpenSSL X509 object, store that object inside of this
302     # certificate object.
303
304     def load_from_pyopenssl_x509(self, x509):
305         self.cert = x509
306
307     ##
308     # Load the certificate from a string
309
310     def load_from_string(self, string):
311         # if it is a chain of multiple certs, then split off the first one and
312         # load it (support for the ---parent--- tag as well as normal chained certs)
313
314         string = string.strip()
315         
316         # If it's not in proper PEM format, wrap it
317         if string.count('-----BEGIN CERTIFICATE') == 0:
318             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
319
320         # If there is a PEM cert in there, but there is some other text first
321         # such as the text of the certificate, skip the text
322         beg = string.find('-----BEGIN CERTIFICATE')
323         if beg > 0:
324             # skipping over non cert beginning                                                                                                              
325             string = string[beg:]
326
327         parts = []
328
329         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
330                string.count(Certificate.separator) == 0:
331             parts = string.split('-----END CERTIFICATE-----',1)
332             parts[0] += '-----END CERTIFICATE-----'
333         else:
334             parts = string.split(Certificate.separator, 1)
335
336         self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
337
338         # if there are more certs, then create a parent and let the parent load
339         # itself from the remainder of the string
340         if len(parts) > 1 and parts[1] != '':
341             self.parent = self.__class__()
342             self.parent.load_from_string(parts[1])
343
344     ##
345     # Load the certificate from a file
346
347     def load_from_file(self, filename):
348         file = open(filename)
349         string = file.read()
350         self.load_from_string(string)
351         self.filename=filename
352
353     ##
354     # Save the certificate to a string.
355     #
356     # @param save_parents If save_parents==True, then also save the parent certificates.
357
358     def save_to_string(self, save_parents=True):
359         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
360         if save_parents and self.parent:
361             string = string + self.parent.save_to_string(save_parents)
362         return string
363
364     ##
365     # Save the certificate to a file.
366     # @param save_parents If save_parents==True, then also save the parent certificates.
367
368     def save_to_file(self, filename, save_parents=True, filep=None):
369         string = self.save_to_string(save_parents=save_parents)
370         if filep:
371             f = filep
372         else:
373             f = open(filename, 'w')
374         f.write(string)
375         f.close()
376         self.filename=filename
377
378     ##
379     # Save the certificate to a random file in /tmp/
380     # @param save_parents If save_parents==True, then also save the parent certificates.
381     def save_to_random_tmp_file(self, save_parents=True):
382         fp, filename = mkstemp(suffix='cert', text=True)
383         fp = os.fdopen(fp, "w")
384         self.save_to_file(filename, save_parents=True, filep=fp)
385         return filename
386
387     ##
388     # Sets the issuer private key and name
389     # @param key Keypair object containing the private key of the issuer
390     # @param subject String containing the name of the issuer
391     # @param cert (optional) Certificate object containing the name of the issuer
392
393     def set_issuer(self, key, subject=None, cert=None):
394         self.issuerKey = key
395         if subject:
396             # it's a mistake to use subject and cert params at the same time
397             assert(not cert)
398             if isinstance(subject, dict) or isinstance(subject, str):
399                 req = crypto.X509Req()
400                 reqSubject = req.get_subject()
401                 if (isinstance(subject, dict)):
402                     for key in reqSubject.keys():
403                         setattr(reqSubject, key, subject[key])
404                 else:
405                     setattr(reqSubject, "CN", subject)
406                 subject = reqSubject
407                 # subject is not valid once req is out of scope, so save req
408                 self.issuerReq = req
409         if cert:
410             # if a cert was supplied, then get the subject from the cert
411             subject = cert.cert.get_subject()
412         assert(subject)
413         self.issuerSubject = subject
414
415     ##
416     # Get the issuer name
417
418     def get_issuer(self, which="CN"):
419         x = self.cert.get_issuer()
420         return getattr(x, which)
421
422     ##
423     # Set the subject name of the certificate
424
425     def set_subject(self, name):
426         req = crypto.X509Req()
427         subj = req.get_subject()
428         if (isinstance(name, dict)):
429             for key in name.keys():
430                 setattr(subj, key, name[key])
431         else:
432             setattr(subj, "CN", name)
433         self.cert.set_subject(subj)
434     ##
435     # Get the subject name of the certificate
436
437     def get_subject(self, which="CN"):
438         x = self.cert.get_subject()
439         return getattr(x, which)
440
441     ##
442     # Get the public key of the certificate.
443     #
444     # @param key Keypair object containing the public key
445
446     def set_pubkey(self, key):
447         assert(isinstance(key, Keypair))
448         self.cert.set_pubkey(key.get_openssl_pkey())
449
450     ##
451     # Get the public key of the certificate.
452     # It is returned in the form of a Keypair object.
453
454     def get_pubkey(self):
455         m2x509 = X509.load_cert_string(self.save_to_string())
456         pkey = Keypair()
457         pkey.key = self.cert.get_pubkey()
458         pkey.m2key = m2x509.get_pubkey()
459         return pkey
460
461     def set_intermediate_ca(self, val):
462         self.intermediate = val
463         if val:
464             self.add_extension('basicConstraints', 1, 'CA:TRUE')
465
466
467
468     ##
469     # Add an X509 extension to the certificate. Add_extension can only be called
470     # once for a particular extension name, due to limitations in the underlying
471     # library.
472     #
473     # @param name string containing name of extension
474     # @param value string containing value of the extension
475
476     def add_extension(self, name, critical, value):
477         ext = crypto.X509Extension (name, critical, value)
478         self.cert.add_extensions([ext])
479
480     ##
481     # Get an X509 extension from the certificate
482
483     def get_extension(self, name):
484
485         # pyOpenSSL does not have a way to get extensions
486         m2x509 = X509.load_cert_string(self.save_to_string())
487         value = m2x509.get_ext(name).get_value()
488         
489         return value
490
491     ##
492     # Set_data is a wrapper around add_extension. It stores the parameter str in
493     # the X509 subject_alt_name extension. Set_data can only be called once, due
494     # to limitations in the underlying library.
495
496     def set_data(self, str, field='subjectAltName'):
497         # pyOpenSSL only allows us to add extensions, so if we try to set the
498         # same extension more than once, it will not work
499         if self.data.has_key(field):
500             raise "Cannot set ", field, " more than once"
501         self.data[field] = str
502         self.add_extension(field, 0, str)
503
504     ##
505     # Return the data string that was previously set with set_data
506
507     def get_data(self, field='subjectAltName'):
508         if self.data.has_key(field):
509             return self.data[field]
510
511         try:
512             uri = self.get_extension(field)
513             self.data[field] = uri
514         except LookupError:
515             return None
516
517         return self.data[field]
518
519     ##
520     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
521
522     def sign(self):
523         sfa_logger().debug('certificate.sign')
524         assert self.cert != None
525         assert self.issuerSubject != None
526         assert self.issuerKey != None
527         self.cert.set_issuer(self.issuerSubject)
528         self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
529
530     ##
531     # Verify the authenticity of a certificate.
532     # @param pkey is a Keypair object representing a public key. If Pkey
533     #     did not sign the certificate, then an exception will be thrown.
534
535     def verify(self, pkey):
536         # pyOpenSSL does not have a way to verify signatures
537         m2x509 = X509.load_cert_string(self.save_to_string())
538         m2pkey = pkey.get_m2_pkey()
539         # verify it
540         return m2x509.verify(m2pkey)
541
542         # XXX alternatively, if openssl has been patched, do the much simpler:
543         # try:
544         #   self.cert.verify(pkey.get_openssl_key())
545         #   return 1
546         # except:
547         #   return 0
548
549     ##
550     # Return True if pkey is identical to the public key that is contained in the certificate.
551     # @param pkey Keypair object
552
553     def is_pubkey(self, pkey):
554         return self.get_pubkey().is_same(pkey)
555
556     ##
557     # Given a certificate cert, verify that this certificate was signed by the
558     # public key contained in cert. Throw an exception otherwise.
559     #
560     # @param cert certificate object
561
562     def is_signed_by_cert(self, cert):
563         print 'is_signed_by_cert'
564         k = cert.get_pubkey()
565         result = self.verify(k)
566         return result
567
568     ##
569     # Set the parent certficiate.
570     #
571     # @param p certificate object.
572
573     def set_parent(self, p):
574         self.parent = p
575
576     ##
577     # Return the certificate object of the parent of this certificate.
578
579     def get_parent(self):
580         return self.parent
581
582     ##
583     # Verification examines a chain of certificates to ensure that each parent
584     # signs the child, and that some certificate in the chain is signed by a
585     # trusted certificate.
586     #
587     # Verification is a basic recursion: <pre>
588     #     if this_certificate was signed by trusted_certs:
589     #         return
590     #     else
591     #         return verify_chain(parent, trusted_certs)
592     # </pre>
593     #
594     # At each recursion, the parent is tested to ensure that it did sign the
595     # child. If a parent did not sign a child, then an exception is thrown. If
596     # the bottom of the recursion is reached and the certificate does not match
597     # a trusted root, then an exception is thrown.
598     #
599     # @param Trusted_certs is a list of certificates that are trusted.
600     #
601
602     def verify_chain(self, trusted_certs = None):
603         # Verify a chain of certificates. Each certificate must be signed by
604         # the public key contained in it's parent. The chain is recursed
605         # until a certificate is found that is signed by a trusted root.
606
607         # verify expiration time
608         if self.cert.has_expired():
609             sfa_logger().debug("verify_chain: NO our certificate has expired")
610             raise CertExpired(self.get_subject(), "client cert")   
611         
612         # if this cert is signed by a trusted_cert, then we are set
613         for trusted_cert in trusted_certs:
614             if self.is_signed_by_cert(trusted_cert):
615                 # verify expiration of trusted_cert ?
616                 if not trusted_cert.cert.has_expired():
617                     sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
618                             self.get_subject(), trusted_cert.get_subject()))
619                     return trusted_cert
620                 else:
621                     sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
622                             self.get_subject(),trusted_cert.get_subject()))
623                     raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
624
625         # if there is no parent, then no way to verify the chain
626         if not self.parent:
627             sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
628             raise CertMissingParent(self.get_subject())
629
630         # if it wasn't signed by the parent...
631         if not self.is_signed_by_cert(self.parent):
632             sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
633             return CertNotSignedByParent(self.get_subject())
634
635         # if the parent isn't verified...
636         sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s",self.get_subject(),self.parent.get_subject())
637         self.parent.verify_chain(trusted_certs)
638
639         return
640
641     ### more introspection
642     def get_extensions(self):
643         # pyOpenSSL does not have a way to get extensions
644         triples=[]
645         m2x509 = X509.load_cert_string(self.save_to_string())
646         nb_extensions=m2x509.get_ext_count()
647         sfa_logger().debug("X509 had %d extensions"%nb_extensions)
648         for i in range(nb_extensions):
649             ext=m2x509.get_ext_at(i)
650             triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
651         return triples
652
653     def get_data_names(self):
654         return self.data.keys()
655
656     def get_all_datas (self):
657         triples=self.get_extensions()
658         for name in self.get_data_names(): 
659             triples.append( (name,self.get_data(name),'data',) )
660         return triples
661
662     # only informative
663     def get_filename(self):
664         return getattr(self,'filename',None)
665
666     def dump (self, *args, **kwargs):
667         print self.dump_string(*args, **kwargs)
668
669     def dump_string (self,show_extensions=False):
670         result = ""
671         result += "CERTIFICATE for %s\n"%self.get_subject()
672         result += "Issued by %s\n"%self.get_issuer()
673         filename=self.get_filename()
674         if filename: result += "Filename %s\n"%filename
675         if show_extensions:
676             all_datas=self.get_all_datas()
677             result += " has %d extensions/data attached"%len(all_datas)
678             for (n,v,c) in all_datas:
679                 if c=='data':
680                     result += "   data: %s=%s\n"%(n,v)
681                 else:
682                     result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
683         return result