preun scriplets more robust, otherwise may prevent uninstall
[sfa.git] / sfa / trust / certificate.py
1 #----------------------------------------------------------------------
2 # Copyright (c) 2008 Board of Trustees, Princeton University
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining
5 # a copy of this software and/or hardware specification (the "Work") to
6 # deal in the Work without restriction, including without limitation the
7 # rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Work, and to permit persons to whom the Work
9 # is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice shall be
12 # included in all copies or substantial portions of the Work.
13 #
14 # THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
15 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
16 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
17 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
19 # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
20 # OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
21 # IN THE WORK.
22 #----------------------------------------------------------------------
23
24 ##
25 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
26 # the necessary crypto functionality. Ideally just one of these libraries
27 # would be used, but unfortunately each of these libraries is independently
28 # lacking. The pyOpenSSL library is missing many necessary functions, and
29 # the M2Crypto library has crashed inside of some of the functions. The
30 # design decision is to use pyOpenSSL whenever possible as it seems more
31 # stable, and only use M2Crypto for those functions that are not possible
32 # in pyOpenSSL.
33 #
34 # This module exports two classes: Keypair and Certificate.
35 ##
36 #
37 ### $Id$
38 ### $URL$
39 #
40
41 import os
42 import tempfile
43 import base64
44 import traceback
45 from OpenSSL import crypto
46 import M2Crypto
47 from M2Crypto import X509
48 from tempfile import mkstemp
49 from sfa.util.sfalogging import logger
50 from sfa.util.namespace import urn_to_hrn
51 from sfa.util.faults import *
52
53 def convert_public_key(key):
54     keyconvert_path = "/usr/bin/keyconvert.py"
55     if not os.path.isfile(keyconvert_path):
56         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
57
58     # we can only convert rsa keys
59     if "ssh-dss" in key:
60         return None
61
62     (ssh_f, ssh_fn) = tempfile.mkstemp()
63     ssl_fn = tempfile.mktemp()
64     os.write(ssh_f, key)
65     os.close(ssh_f)
66
67     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
68     os.system(cmd)
69
70     # this check leaves the temporary file containing the public key so
71     # that it can be expected to see why it failed.
72     # TODO: for production, cleanup the temporary files
73     if not os.path.exists(ssl_fn):
74         return None
75
76     k = Keypair()
77     try:
78         k.load_pubkey_from_file(ssl_fn)
79     except:
80         traceback.print_exc()
81         k = None
82
83     # remove the temporary files
84     os.remove(ssh_fn)
85     os.remove(ssl_fn)
86
87     return k
88
89 ##
90 # Public-private key pairs are implemented by the Keypair class.
91 # A Keypair object may represent both a public and private key pair, or it
92 # may represent only a public key (this usage is consistent with OpenSSL).
93
94 class Keypair:
95     key = None       # public/private keypair
96     m2key = None     # public key (m2crypto format)
97
98     ##
99     # Creates a Keypair object
100     # @param create If create==True, creates a new public/private key and
101     #     stores it in the object
102     # @param string If string!=None, load the keypair from the string (PEM)
103     # @param filename If filename!=None, load the keypair from the file
104
105     def __init__(self, create=False, string=None, filename=None):
106         if create:
107             self.create()
108         if string:
109             self.load_from_string(string)
110         if filename:
111             self.load_from_file(filename)
112
113     ##
114     # Create a RSA public/private key pair and store it inside the keypair object
115
116     def create(self):
117         self.key = crypto.PKey()
118         self.key.generate_key(crypto.TYPE_RSA, 1024)
119
120     ##
121     # Save the private key to a file
122     # @param filename name of file to store the keypair in
123
124     def save_to_file(self, filename):
125         open(filename, 'w').write(self.as_pem())
126
127     ##
128     # Load the private key from a file. Implicity the private key includes the public key.
129
130     def load_from_file(self, filename):
131         buffer = open(filename, 'r').read()
132         self.load_from_string(buffer)
133
134     ##
135     # Load the private key from a string. Implicitly the private key includes the public key.
136
137     def load_from_string(self, string):
138         self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
139         self.m2key = M2Crypto.EVP.load_key_string(string)
140
141     ##
142     #  Load the public key from a string. No private key is loaded.
143
144     def load_pubkey_from_file(self, filename):
145         # load the m2 public key
146         m2rsakey = M2Crypto.RSA.load_pub_key(filename)
147         self.m2key = M2Crypto.EVP.PKey()
148         self.m2key.assign_rsa(m2rsakey)
149
150         # create an m2 x509 cert
151         m2name = M2Crypto.X509.X509_Name()
152         m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
153         m2x509 = M2Crypto.X509.X509()
154         m2x509.set_pubkey(self.m2key)
155         m2x509.set_serial_number(0)
156         m2x509.set_issuer_name(m2name)
157         m2x509.set_subject_name(m2name)
158         ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
159         ASN1.set_time(500)
160         m2x509.set_not_before(ASN1)
161         m2x509.set_not_after(ASN1)
162         junk_key = Keypair(create=True)
163         m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
164
165         # convert the m2 x509 cert to a pyopenssl x509
166         m2pem = m2x509.as_pem()
167         pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
168
169         # get the pyopenssl pkey from the pyopenssl x509
170         self.key = pyx509.get_pubkey()
171
172     ##
173     # Load the public key from a string. No private key is loaded.
174
175     def load_pubkey_from_string(self, string):
176         (f, fn) = tempfile.mkstemp()
177         os.write(f, string)
178         os.close(f)
179         self.load_pubkey_from_file(fn)
180         os.remove(fn)
181
182     ##
183     # Return the private key in PEM format.
184
185     def as_pem(self):
186         return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
187
188     ##
189     # Return an M2Crypto key object
190
191     def get_m2_pkey(self):
192         if not self.m2key:
193             self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
194         return self.m2key
195
196     ##
197     # Returns a string containing the public key represented by this object.
198
199     def get_pubkey_string(self):
200         m2pkey = self.get_m2_pkey()
201         return base64.b64encode(m2pkey.as_der())
202
203     ##
204     # Return an OpenSSL pkey object
205
206     def get_openssl_pkey(self):
207         return self.key
208
209
210     ##
211     # Given another Keypair object, return TRUE if the two keys are the same.
212
213     def is_same(self, pkey):
214         return self.as_pem() == pkey.as_pem()
215
216     def sign_string(self, data):
217         k = self.get_m2_pkey()
218         k.sign_init()
219         k.sign_update(data)
220         return base64.b64encode(k.sign_final())
221
222     def verify_string(self, data, sig):
223         k = self.get_m2_pkey()
224         k.verify_init()
225         k.verify_update(data)
226         return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
227
228     def compute_hash(self, value):
229         return self.sign_string(str(value))
230
231 ##
232 # The certificate class implements a general purpose X509 certificate, making
233 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
234 # several addition features, such as the ability to maintain a chain of
235 # parent certificates, and storage of application-specific data.
236 #
237 # Certificates include the ability to maintain a chain of parents. Each
238 # certificate includes a pointer to it's parent certificate. When loaded
239 # from a file or a string, the parent chain will be automatically loaded.
240 # When saving a certificate to a file or a string, the caller can choose
241 # whether to save the parent certificates as well.
242
243 class Certificate:
244     digest = "md5"
245
246     cert = None
247     issuerKey = None
248     issuerSubject = None
249     parent = None
250
251     separator="-----parent-----"
252
253     ##
254     # Create a certificate object.
255     #
256     # @param create If create==True, then also create a blank X509 certificate.
257     # @param subject If subject!=None, then create a blank certificate and set
258     #     it's subject name.
259     # @param string If string!=None, load the certficate from the string.
260     # @param filename If filename!=None, load the certficiate from the file.
261
262     def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
263         self.data = {}
264         if create or subject:
265             self.create()
266         if subject:
267             self.set_subject(subject)
268         if string:
269             self.load_from_string(string)
270         if filename:
271             self.load_from_file(filename)
272
273         if intermediate:
274             self.set_intermediate_ca(intermediate)
275
276     ##
277     # Create a blank X509 certificate and store it in this object.
278
279     def create(self):
280         self.cert = crypto.X509()
281         self.cert.set_serial_number(3)
282         self.cert.gmtime_adj_notBefore(0)
283         self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
284
285     ##
286     # Given a pyOpenSSL X509 object, store that object inside of this
287     # certificate object.
288
289     def load_from_pyopenssl_x509(self, x509):
290         self.cert = x509
291
292     ##
293     # Load the certificate from a string
294
295     def load_from_string(self, string):
296         # if it is a chain of multiple certs, then split off the first one and
297         # load it (support for the ---parent--- tag as well as normal chained certs)
298
299         string = string.strip()
300
301
302         if not string.startswith('-----'):
303             string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
304
305         parts = []
306
307         if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
308                string.count(Certificate.separator) == 0:
309             parts = string.split('-----END CERTIFICATE-----',1)
310             parts[0] += '-----END CERTIFICATE-----'
311         else:
312             parts = string.split(Certificate.separator, 1)
313
314         self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
315
316         # if there are more certs, then create a parent and let the parent load
317         # itself from the remainder of the string
318         if len(parts) > 1 and parts[1] != '':
319             self.parent = self.__class__()
320             self.parent.load_from_string(parts[1])
321
322     ##
323     # Load the certificate from a file
324
325     def load_from_file(self, filename):
326         file = open(filename)
327         string = file.read()
328         self.load_from_string(string)
329
330     ##
331     # Save the certificate to a string.
332     #
333     # @param save_parents If save_parents==True, then also save the parent certificates.
334
335     def save_to_string(self, save_parents=True):
336         string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
337         if save_parents and self.parent:
338             string = string + self.parent.save_to_string(save_parents)
339         return string
340
341     ##
342     # Save the certificate to a file.
343     # @param save_parents If save_parents==True, then also save the parent certificates.
344
345     def save_to_file(self, filename, save_parents=True, filep=None):
346         string = self.save_to_string(save_parents=save_parents)
347         if filep:
348             f = filep
349         else:
350             f = open(filename, 'w')
351         f.write(string)
352         f.close()
353
354     ##
355     # Save the certificate to a random file in /tmp/
356     # @param save_parents If save_parents==True, then also save the parent certificates.
357     def save_to_random_tmp_file(self, save_parents=True):
358         fp, filename = mkstemp(suffix='cert', text=True)
359         fp = os.fdopen(fp, "w")
360         self.save_to_file(filename, save_parents=True, filep=fp)
361         return filename
362
363     ##
364     # Sets the issuer private key and name
365     # @param key Keypair object containing the private key of the issuer
366     # @param subject String containing the name of the issuer
367     # @param cert (optional) Certificate object containing the name of the issuer
368
369     def set_issuer(self, key, subject=None, cert=None):
370         self.issuerKey = key
371         if subject:
372             # it's a mistake to use subject and cert params at the same time
373             assert(not cert)
374             if isinstance(subject, dict) or isinstance(subject, str):
375                 req = crypto.X509Req()
376                 reqSubject = req.get_subject()
377                 if (isinstance(subject, dict)):
378                     for key in reqSubject.keys():
379                         setattr(reqSubject, key, subject[key])
380                 else:
381                     setattr(reqSubject, "CN", subject)
382                 subject = reqSubject
383                 # subject is not valid once req is out of scope, so save req
384                 self.issuerReq = req
385         if cert:
386             # if a cert was supplied, then get the subject from the cert
387             subject = cert.cert.get_subject()
388         assert(subject)
389         self.issuerSubject = subject
390
391     ##
392     # Get the issuer name
393
394     def get_issuer(self, which="CN"):
395         x = self.cert.get_issuer()
396         return getattr(x, which)
397
398     ##
399     # Set the subject name of the certificate
400
401     def set_subject(self, name):
402         req = crypto.X509Req()
403         subj = req.get_subject()
404         if (isinstance(name, dict)):
405             for key in name.keys():
406                 setattr(subj, key, name[key])
407         else:
408             setattr(subj, "CN", name)
409         self.cert.set_subject(subj)
410     ##
411     # Get the subject name of the certificate
412
413     def get_subject(self, which="CN"):
414         x = self.cert.get_subject()
415         return getattr(x, which)
416
417     ##
418     # Get the public key of the certificate.
419     #
420     # @param key Keypair object containing the public key
421
422     def set_pubkey(self, key):
423         assert(isinstance(key, Keypair))
424         self.cert.set_pubkey(key.get_openssl_pkey())
425
426     ##
427     # Get the public key of the certificate.
428     # It is returned in the form of a Keypair object.
429
430     def get_pubkey(self):
431         m2x509 = X509.load_cert_string(self.save_to_string())
432         pkey = Keypair()
433         pkey.key = self.cert.get_pubkey()
434         pkey.m2key = m2x509.get_pubkey()
435         return pkey
436
437     def set_intermediate_ca(self, val):
438         self.intermediate = val
439         if val:
440             self.add_extension('basicConstraints', 1, 'CA:TRUE')
441
442
443
444     ##
445     # Add an X509 extension to the certificate. Add_extension can only be called
446     # once for a particular extension name, due to limitations in the underlying
447     # library.
448     #
449     # @param name string containing name of extension
450     # @param value string containing value of the extension
451
452     def add_extension(self, name, critical, value):
453         ext = crypto.X509Extension (name, critical, value)
454         self.cert.add_extensions([ext])
455
456     ##
457     # Get an X509 extension from the certificate
458
459     def get_extension(self, name):
460         # pyOpenSSL does not have a way to get extensions
461         m2x509 = X509.load_cert_string(self.save_to_string())
462         value = m2x509.get_ext(name).get_value()
463         return value
464
465     ##
466     # Set_data is a wrapper around add_extension. It stores the parameter str in
467     # the X509 subject_alt_name extension. Set_data can only be called once, due
468     # to limitations in the underlying library.
469
470     def set_data(self, str, field='subjectAltName'):
471         # pyOpenSSL only allows us to add extensions, so if we try to set the
472         # same extension more than once, it will not work
473         if self.data.has_key(field):
474             raise "Cannot set ", field, " more than once"
475         self.data[field] = str
476         self.add_extension(field, 0, str)
477
478     ##
479     # Return the data string that was previously set with set_data
480
481     def get_data(self, field='subjectAltName'):
482         if self.data.has_key(field):
483             return self.data[field]
484
485         try:
486             uri = self.get_extension(field)
487             self.data[field] = uri
488         except LookupError:
489             return None
490
491         return self.data[field]
492
493     ##
494     # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
495
496     def sign(self):
497         assert self.cert != None
498         assert self.issuerSubject != None
499         assert self.issuerKey != None
500         self.cert.set_issuer(self.issuerSubject)
501         self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
502
503     ##
504     # Verify the authenticity of a certificate.
505     # @param pkey is a Keypair object representing a public key. If Pkey
506     #     did not sign the certificate, then an exception will be thrown.
507
508     def verify(self, pkey):
509         # pyOpenSSL does not have a way to verify signatures
510         m2x509 = X509.load_cert_string(self.save_to_string())
511         m2pkey = pkey.get_m2_pkey()
512         # verify it
513         return m2x509.verify(m2pkey)
514
515         # XXX alternatively, if openssl has been patched, do the much simpler:
516         # try:
517         #   self.cert.verify(pkey.get_openssl_key())
518         #   return 1
519         # except:
520         #   return 0
521
522     ##
523     # Return True if pkey is identical to the public key that is contained in the certificate.
524     # @param pkey Keypair object
525
526     def is_pubkey(self, pkey):
527         return self.get_pubkey().is_same(pkey)
528
529     ##
530     # Given a certificate cert, verify that this certificate was signed by the
531     # public key contained in cert. Throw an exception otherwise.
532     #
533     # @param cert certificate object
534
535     def is_signed_by_cert(self, cert):
536         k = cert.get_pubkey()
537         result = self.verify(k)
538         return result
539
540     ##
541     # Set the parent certficiate.
542     #
543     # @param p certificate object.
544
545     def set_parent(self, p):
546         self.parent = p
547
548     ##
549     # Return the certificate object of the parent of this certificate.
550
551     def get_parent(self):
552         return self.parent
553
554     ##
555     # Verification examines a chain of certificates to ensure that each parent
556     # signs the child, and that some certificate in the chain is signed by a
557     # trusted certificate.
558     #
559     # Verification is a basic recursion: <pre>
560     #     if this_certificate was signed by trusted_certs:
561     #         return
562     #     else
563     #         return verify_chain(parent, trusted_certs)
564     # </pre>
565     #
566     # At each recursion, the parent is tested to ensure that it did sign the
567     # child. If a parent did not sign a child, then an exception is thrown. If
568     # the bottom of the recursion is reached and the certificate does not match
569     # a trusted root, then an exception is thrown.
570     #
571     # @param Trusted_certs is a list of certificates that are trusted.
572     #
573
574     def verify_chain(self, trusted_certs = None):
575         # Verify a chain of certificates. Each certificate must be signed by
576         # the public key contained in it's parent. The chain is recursed
577         # until a certificate is found that is signed by a trusted root.
578
579         # verify expiration time
580         if self.cert.has_expired():
581             raise CertExpired(self.get_subject(), "client cert")   
582         
583         # if this cert is signed by a trusted_cert, then we are set
584         for trusted_cert in trusted_certs:
585             if self.is_signed_by_cert(trusted_cert):
586                 logger.debug("Cert %s signed by trusted cert %s", self.get_subject(), trusted_cert.get_subject())
587                 # verify expiration of trusted_cert ?
588                 if not trusted_cert.cert.has_expired():
589                     return trusted_cert
590                 else:
591                     logger.debug("Trusted cert %s is expired", trusted_cert.get_subject())       
592
593         # if there is no parent, then no way to verify the chain
594         if not self.parent:
595             #print self.get_subject(), "has no parent"
596             raise CertMissingParent(self.get_subject())
597
598         # if it wasn't signed by the parent...
599         if not self.is_signed_by_cert(self.parent):
600             #print self.get_subject(), "is not signed by parent"
601             return CertNotSignedByParent(self.get_subject())
602
603         # if the parent isn't verified...
604         self.parent.verify_chain(trusted_certs)
605
606         return