Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / trust / certificate.py
1 #----------------------------------------------------------------------
2 # Copyright (c) 2008 Board of Trustees, Princeton University
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining
5 # a copy of this software and/or hardware specification (the "Work") to
6 # deal in the Work without restriction, including without limitation the
7 # rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Work, and to permit persons to whom the Work
9 # is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice shall be
12 # included in all copies or substantial portions of the Work.
13 #
14 # THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
15 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
16 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
17 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
19 # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
20 # OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
21 # IN THE WORK.
22 #----------------------------------------------------------------------
23
24 ##
25 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
26 # the necessary crypto functionality. Ideally just one of these libraries
27 # would be used, but unfortunately each of these libraries is independently
28 # lacking. The pyOpenSSL library is missing many necessary functions, and
29 # the M2Crypto library has crashed inside of some of the functions. The
30 # design decision is to use pyOpenSSL whenever possible as it seems more
31 # stable, and only use M2Crypto for those functions that are not possible
32 # in pyOpenSSL.
33 #
34 # This module exports two classes: Keypair and Certificate.
35 ##
36 #
37
38 import os
39 import tempfile
40 import base64
41 import traceback
42 from tempfile import mkstemp
43
44 from OpenSSL import crypto
45 import M2Crypto
46 from M2Crypto import X509
47
48 from sfa.util.sfalogging import sfa_logger
49 from sfa.util.xrn import urn_to_hrn
50 from sfa.util.faults import *
51
52 def convert_public_key(key):
53     keyconvert_path = "/usr/bin/keyconvert.py"
54     if not os.path.isfile(keyconvert_path):
55         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
56
57     # we can only convert rsa keys
58     if "ssh-dss" in key:
59         return None
60
61     (ssh_f, ssh_fn) = tempfile.mkstemp()
62     ssl_fn = tempfile.mktemp()
63     os.write(ssh_f, key)
64     os.close(ssh_f)
65
66     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
67     os.system(cmd)
68
69     # this check leaves the temporary file containing the public key so
70     # that it can be expected to see why it failed.
71     # TODO: for production, cleanup the temporary files
72     if not os.path.exists(ssl_fn):
73         return None
74
75     k = Keypair()
76     try:
77         k.load_pubkey_from_file(ssl_fn)
78     except:
79         sfa_logger().log_exc("convert_public_key caught exception")
80         k = None
81
82     # remove the temporary files
83     os.remove(ssh_fn)
84     os.remove(ssl_fn)
85
86     return k
87
88 ##
89 # Public-private key pairs are implemented by the Keypair class.
90 # A Keypair object may represent both a public and private key pair, or it
91 # may represent only a public key (this usage is consistent with OpenSSL).
92
93 class Keypair:
94     key = None       # public/private keypair
95     m2key = None     # public key (m2crypto format)
96
97     ##
98     # Creates a Keypair object
99     # @param create If create==True, creates a new public/private key and
100     #     stores it in the object
101     # @param string If string!=None, load the keypair from the string (PEM)
102     # @param filename If filename!=None, load the keypair from the file
103
104     def __init__(self, create=False, string=None, filename=None):
105         if create:
106             self.create()
107         if string:
108             self.load_from_string(string)
109         if filename:
110             self.load_from_file(filename)
111
112     ##
113     # Create a RSA public/private key pair and store it inside the keypair object
114
115     def create(self):
116         self.key = crypto.PKey()
117         self.key.generate_key(crypto.TYPE_RSA, 1024)
118
119     ##
120     # Save the private key to a file
121     # @param filename name of file to store the keypair in
122
123     def save_to_file(self, filename):
124         open(filename, 'w').write(self.as_pem())
125         self.filename=filename
126
127     ##
128     # Load the private key from a file. Implicity the private key includes the public key.
129
130     def load_from_file(self, filename):
131         buffer = open(filename, 'r').read()
132         self.load_from_string(buffer)
133         self.filename=filename
134
135     ##
136     # Load the private key from a string. Implicitly the private key includes the public key.
137
138     def load_from_string(self, string):
139         self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
140         self.m2key = M2Crypto.EVP.load_key_string(string)
141
142     ##
143     #  Load the public key from a string. No private key is loaded.
144
145     def load_pubkey_from_file(self, filename):
146         # load the m2 public key
147         m2rsakey = M2Crypto.RSA.load_pub_key(filename)
148         self.m2key = M2Crypto.EVP.PKey()
149         self.m2key.assign_rsa(m2rsakey)
150
151         # create an m2 x509 cert
152         m2name = M2Crypto.X509.X509_Name()
153         m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
154         m2x509 = M2Crypto.X509.X509()
155         m2x509.set_pubkey(self.m2key)
156         m2x509.set_serial_number(0)
157         m2x509.set_issuer_name(m2name)
158         m2x509.set_subject_name(m2name)
159         ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
160         ASN1.set_time(500)
161         m2x509.set_not_before(ASN1)
162         m2x509.set_not_after(ASN1)
163         # x509v3 so it can have extensions
164         # prob not necc since this cert itself is junk but still...
165         m2x509.set_version(2)
166         junk_key = Keypair(create=True)
167         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
168
169         # convert the m2 x509 cert to a pyopenssl x509
170         m2pem = m2x509.as_pem()
171         pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
172
173         # get the pyopenssl pkey from the pyopenssl x509
174         self.key = pyx509.get_pubkey()
175         self.filename=filename
176
177     ##
178     # Load the public key from a string. No private key is loaded.
179
180     def load_pubkey_from_string(self, string):
181         (f, fn) = tempfile.mkstemp()
182         os.write(f, string)
183         os.close(f)
184         self.load_pubkey_from_file(fn)
185         os.remove(fn)
186
187     ##
188     # Return the private key in PEM format.
189
190     def as_pem(self):
191         return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
192
193     ##
194     # Return an M2Crypto key object
195
196     def get_m2_pkey(self):
197         if not self.m2key:
198             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
199         return self.m2key
200
201     ##
202     # Returns a string containing the public key represented by this object.
203
204     def get_pubkey_string(self):
205         m2pkey = self.get_m2_pkey()
206         return base64.b64encode(m2pkey.as_der())
207
208     ##
209     # Return an OpenSSL pkey object
210
211     def get_openssl_pkey(self):
212         return self.key
213
214     ##
215     # Given another Keypair object, return TRUE if the two keys are the same.
216
217     def is_same(self, pkey):
218         return self.as_pem() == pkey.as_pem()
219
220     def sign_string(self, data):
221         k = self.get_m2_pkey()
222         k.sign_init()
223         k.sign_update(data)
224         return base64.b64encode(k.sign_final())
225
226     def verify_string(self, data, sig):
227         k = self.get_m2_pkey()
228         k.verify_init()
229         k.verify_update(data)
230         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
231
232     def compute_hash(self, value):
233         return self.sign_string(str(value))
234
235     # only informative
236     def get_filename(self):
237         return getattr(self,'filename',None)
238
239     def dump (self, *args, **kwargs):
240         print self.dump_string(*args, **kwargs)
241
242     def dump_string (self):
243         result=""
244         result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
245         filename=self.get_filename()
246         if filename: result += "Filename %s\n"%filename
247         return result
248     
249 ##
250 # The certificate class implements a general purpose X509 certificate, making
251 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
252 # several addition features, such as the ability to maintain a chain of
253 # parent certificates, and storage of application-specific data.
254 #
255 # Certificates include the ability to maintain a chain of parents. Each
256 # certificate includes a pointer to it's parent certificate. When loaded
257 # from a file or a string, the parent chain will be automatically loaded.
258 # When saving a certificate to a file or a string, the caller can choose
259 # whether to save the parent certificates as well.
260
261 class Certificate:
262     digest = "md5"
263
264     cert = None
265     issuerKey = None
266     issuerSubject = None
267     parent = None
268
269     separator="-----parent-----"
270
271     ##
272     # Create a certificate object.
273     #
274     # @param create If create==True, then also create a blank X509 certificate.
275     # @param subject If subject!=None, then create a blank certificate and set
276     #     it's subject name.
277     # @param string If string!=None, load the certficate from the string.
278     # @param filename If filename!=None, load the certficiate from the file.
279
280     def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
281         self.data = {}
282         if create or subject:
283             self.create()
284         if subject:
285             self.set_subject(subject)
286         if string:
287             self.load_from_string(string)
288         if filename:
289             self.load_from_file(filename)
290
291         if intermediate:
292             self.set_intermediate_ca(intermediate)
293
294     ##
295     # Create a blank X509 certificate and store it in this object.
296
297     def create(self):
298         self.cert = crypto.X509()
299         self.cert.set_serial_number(3)
300         self.cert.gmtime_adj_notBefore(0)
301         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
302         self.cert.set_version(2) # x509v3 so it can have extensions        
303
304
305     ##
306     # Given a pyOpenSSL X509 object, store that object inside of this
307     # certificate object.
308
309     def load_from_pyopenssl_x509(self, x509):
310         self.cert = x509
311
312     ##
313     # Load the certificate from a string
314
315     def load_from_string(self, string):
316         # if it is a chain of multiple certs, then split off the first one and
317         # load it (support for the ---parent--- tag as well as normal chained certs)
318
319         string = string.strip()
320         
321         # If it's not in proper PEM format, wrap it
322         if string.count('-----BEGIN CERTIFICATE') == 0:
323             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
324
325         # If there is a PEM cert in there, but there is some other text first
326         # such as the text of the certificate, skip the text
327         beg = string.find('-----BEGIN CERTIFICATE')
328         if beg > 0:
329             # skipping over non cert beginning                                                                                                              
330             string = string[beg:]
331
332         parts = []
333
334         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
335                string.count(Certificate.separator) == 0:
336             parts = string.split('-----END CERTIFICATE-----',1)
337             parts[0] += '-----END CERTIFICATE-----'
338         else:
339             parts = string.split(Certificate.separator, 1)
340
341         self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
342
343         # if there are more certs, then create a parent and let the parent load
344         # itself from the remainder of the string
345         if len(parts) > 1 and parts[1] != '':
346             self.parent = self.__class__()
347             self.parent.load_from_string(parts[1])
348
349     ##
350     # Load the certificate from a file
351
352     def load_from_file(self, filename):
353         file = open(filename)
354         string = file.read()
355         self.load_from_string(string)
356         self.filename=filename
357
358     ##
359     # Save the certificate to a string.
360     #
361     # @param save_parents If save_parents==True, then also save the parent certificates.
362
363     def save_to_string(self, save_parents=True):
364         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
365         if save_parents and self.parent:
366             string = string + self.parent.save_to_string(save_parents)
367         return string
368
369     ##
370     # Save the certificate to a file.
371     # @param save_parents If save_parents==True, then also save the parent certificates.
372
373     def save_to_file(self, filename, save_parents=True, filep=None):
374         string = self.save_to_string(save_parents=save_parents)
375         if filep:
376             f = filep
377         else:
378             f = open(filename, 'w')
379         f.write(string)
380         f.close()
381         self.filename=filename
382
383     ##
384     # Save the certificate to a random file in /tmp/
385     # @param save_parents If save_parents==True, then also save the parent certificates.
386     def save_to_random_tmp_file(self, save_parents=True):
387         fp, filename = mkstemp(suffix='cert', text=True)
388         fp = os.fdopen(fp, "w")
389         self.save_to_file(filename, save_parents=True, filep=fp)
390         return filename
391
392     ##
393     # Sets the issuer private key and name
394     # @param key Keypair object containing the private key of the issuer
395     # @param subject String containing the name of the issuer
396     # @param cert (optional) Certificate object containing the name of the issuer
397
398     def set_issuer(self, key, subject=None, cert=None):
399         self.issuerKey = key
400         if subject:
401             # it's a mistake to use subject and cert params at the same time
402             assert(not cert)
403             if isinstance(subject, dict) or isinstance(subject, str):
404                 req = crypto.X509Req()
405                 reqSubject = req.get_subject()
406                 if (isinstance(subject, dict)):
407                     for key in reqSubject.keys():
408                         setattr(reqSubject, key, subject[key])
409                 else:
410                     setattr(reqSubject, "CN", subject)
411                 subject = reqSubject
412                 # subject is not valid once req is out of scope, so save req
413                 self.issuerReq = req
414         if cert:
415             # if a cert was supplied, then get the subject from the cert
416             subject = cert.cert.get_subject()
417         assert(subject)
418         self.issuerSubject = subject
419
420     ##
421     # Get the issuer name
422
423     def get_issuer(self, which="CN"):
424         x = self.cert.get_issuer()
425         return getattr(x, which)
426
427     ##
428     # Set the subject name of the certificate
429
430     def set_subject(self, name):
431         req = crypto.X509Req()
432         subj = req.get_subject()
433         if (isinstance(name, dict)):
434             for key in name.keys():
435                 setattr(subj, key, name[key])
436         else:
437             setattr(subj, "CN", name)
438         self.cert.set_subject(subj)
439     ##
440     # Get the subject name of the certificate
441
442     def get_subject(self, which="CN"):
443         x = self.cert.get_subject()
444         return getattr(x, which)
445
446     ##
447     # Get the public key of the certificate.
448     #
449     # @param key Keypair object containing the public key
450
451     def set_pubkey(self, key):
452         assert(isinstance(key, Keypair))
453         self.cert.set_pubkey(key.get_openssl_pkey())
454
455     ##
456     # Get the public key of the certificate.
457     # It is returned in the form of a Keypair object.
458
459     def get_pubkey(self):
460         m2x509 = X509.load_cert_string(self.save_to_string())
461         pkey = Keypair()
462         pkey.key = self.cert.get_pubkey()
463         pkey.m2key = m2x509.get_pubkey()
464         return pkey
465
466     def set_intermediate_ca(self, val):
467         self.intermediate = val
468         if val:
469             self.add_extension('basicConstraints', 1, 'CA:TRUE')
470
471
472
473     ##
474     # Add an X509 extension to the certificate. Add_extension can only be called
475     # once for a particular extension name, due to limitations in the underlying
476     # library.
477     #
478     # @param name string containing name of extension
479     # @param value string containing value of the extension
480
481     def add_extension(self, name, critical, value):
482         ext = crypto.X509Extension (name, critical, value)
483         self.cert.add_extensions([ext])
484
485     ##
486     # Get an X509 extension from the certificate
487
488     def get_extension(self, name):
489
490         # pyOpenSSL does not have a way to get extensions
491         m2x509 = X509.load_cert_string(self.save_to_string())
492         value = m2x509.get_ext(name).get_value()
493         
494         return value
495
496     ##
497     # Set_data is a wrapper around add_extension. It stores the parameter str in
498     # the X509 subject_alt_name extension. Set_data can only be called once, due
499     # to limitations in the underlying library.
500
501     def set_data(self, str, field='subjectAltName'):
502         # pyOpenSSL only allows us to add extensions, so if we try to set the
503         # same extension more than once, it will not work
504         if self.data.has_key(field):
505             raise "Cannot set ", field, " more than once"
506         self.data[field] = str
507         self.add_extension(field, 0, str)
508
509     ##
510     # Return the data string that was previously set with set_data
511
512     def get_data(self, field='subjectAltName'):
513         if self.data.has_key(field):
514             return self.data[field]
515
516         try:
517             uri = self.get_extension(field)
518             self.data[field] = uri
519         except LookupError:
520             return None
521
522         return self.data[field]
523
524     ##
525     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
526
527     def sign(self):
528         sfa_logger().debug('certificate.sign')
529         assert self.cert != None
530         assert self.issuerSubject != None
531         assert self.issuerKey != None
532         self.cert.set_issuer(self.issuerSubject)
533         self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
534
535     ##
536     # Verify the authenticity of a certificate.
537     # @param pkey is a Keypair object representing a public key. If Pkey
538     #     did not sign the certificate, then an exception will be thrown.
539
540     def verify(self, pkey):
541         # pyOpenSSL does not have a way to verify signatures
542         m2x509 = X509.load_cert_string(self.save_to_string())
543         m2pkey = pkey.get_m2_pkey()
544         # verify it
545         return m2x509.verify(m2pkey)
546
547         # XXX alternatively, if openssl has been patched, do the much simpler:
548         # try:
549         #   self.cert.verify(pkey.get_openssl_key())
550         #   return 1
551         # except:
552         #   return 0
553
554     ##
555     # Return True if pkey is identical to the public key that is contained in the certificate.
556     # @param pkey Keypair object
557
558     def is_pubkey(self, pkey):
559         return self.get_pubkey().is_same(pkey)
560
561     ##
562     # Given a certificate cert, verify that this certificate was signed by the
563     # public key contained in cert. Throw an exception otherwise.
564     #
565     # @param cert certificate object
566
567     def is_signed_by_cert(self, cert):
568         k = cert.get_pubkey()
569         result = self.verify(k)
570         return result
571
572     ##
573     # Set the parent certficiate.
574     #
575     # @param p certificate object.
576
577     def set_parent(self, p):
578         self.parent = p
579
580     ##
581     # Return the certificate object of the parent of this certificate.
582
583     def get_parent(self):
584         return self.parent
585
586     ##
587     # Verification examines a chain of certificates to ensure that each parent
588     # signs the child, and that some certificate in the chain is signed by a
589     # trusted certificate.
590     #
591     # Verification is a basic recursion: <pre>
592     #     if this_certificate was signed by trusted_certs:
593     #         return
594     #     else
595     #         return verify_chain(parent, trusted_certs)
596     # </pre>
597     #
598     # At each recursion, the parent is tested to ensure that it did sign the
599     # child. If a parent did not sign a child, then an exception is thrown. If
600     # the bottom of the recursion is reached and the certificate does not match
601     # a trusted root, then an exception is thrown.
602     #
603     # @param Trusted_certs is a list of certificates that are trusted.
604     #
605
606     def verify_chain(self, trusted_certs = None):
607         # Verify a chain of certificates. Each certificate must be signed by
608         # the public key contained in it's parent. The chain is recursed
609         # until a certificate is found that is signed by a trusted root.
610
611         # verify expiration time
612         if self.cert.has_expired():
613             sfa_logger().debug("verify_chain: NO our certificate has expired")
614             raise CertExpired(self.get_subject(), "client cert")   
615         
616         # if this cert is signed by a trusted_cert, then we are set
617         for trusted_cert in trusted_certs:
618             if self.is_signed_by_cert(trusted_cert):
619                 # verify expiration of trusted_cert ?
620                 if not trusted_cert.cert.has_expired():
621                     sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
622                             self.get_subject(), trusted_cert.get_subject()))
623                     return trusted_cert
624                 else:
625                     sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
626                             self.get_subject(),trusted_cert.get_subject()))
627                     raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
628
629         # if there is no parent, then no way to verify the chain
630         if not self.parent:
631             sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
632             raise CertMissingParent(self.get_subject())
633
634         # if it wasn't signed by the parent...
635         if not self.is_signed_by_cert(self.parent):
636             sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
637             return CertNotSignedByParent(self.get_subject())
638
639         # if the parent isn't verified...
640         sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
641         self.parent.verify_chain(trusted_certs)
642
643         return
644
645     ### more introspection
646     def get_extensions(self):
647         # pyOpenSSL does not have a way to get extensions
648         triples=[]
649         m2x509 = X509.load_cert_string(self.save_to_string())
650         nb_extensions=m2x509.get_ext_count()
651         sfa_logger().debug("X509 had %d extensions"%nb_extensions)
652         for i in range(nb_extensions):
653             ext=m2x509.get_ext_at(i)
654             triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
655         return triples
656
657     def get_data_names(self):
658         return self.data.keys()
659
660     def get_all_datas (self):
661         triples=self.get_extensions()
662         for name in self.get_data_names(): 
663             triples.append( (name,self.get_data(name),'data',) )
664         return triples
665
666     # only informative
667     def get_filename(self):
668         return getattr(self,'filename',None)
669
670     def dump (self, *args, **kwargs):
671         print self.dump_string(*args, **kwargs)
672
673     def dump_string (self,show_extensions=False):
674         result = ""
675         result += "CERTIFICATE for %s\n"%self.get_subject()
676         result += "Issued by %s\n"%self.get_issuer()
677         filename=self.get_filename()
678         if filename: result += "Filename %s\n"%filename
679         if show_extensions:
680             all_datas=self.get_all_datas()
681             result += " has %d extensions/data attached"%len(all_datas)
682             for (n,v,c) in all_datas:
683                 if c=='data':
684                     result += "   data: %s=%s\n"%(n,v)
685                 else:
686                     result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
687         return result