Fixed up URNs in GID creation
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 import traceback
22 from OpenSSL import crypto
23 import M2Crypto
24 from M2Crypto import X509
25 from tempfile import mkstemp
26 from sfa.util.sfalogging import logger
27
28 from sfa.util.faults import *
29
30 def convert_public_key(key):
31     keyconvert_path = "/usr/bin/keyconvert"
32     if not os.path.isfile(keyconvert_path):
33         raise IOError, "Could not find keyconvert in %s" % keyconvert_path
34
35     # we can only convert rsa keys 
36     if "ssh-dss" in key:
37         return None
38     
39     (ssh_f, ssh_fn) = tempfile.mkstemp()
40     ssl_fn = tempfile.mktemp()
41     os.write(ssh_f, key)
42     os.close(ssh_f)
43
44     cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
45     os.system(cmd)
46     
47     # this check leaves the temporary file containing the public key so
48     # that it can be expected to see why it failed.
49     # TODO: for production, cleanup the temporary files
50     if not os.path.exists(ssl_fn):
51         return None
52     
53     k = Keypair()
54     try:
55         k.load_pubkey_from_file(ssl_fn)
56     except:
57         traceback.print_exc()
58         k = None
59
60     # remove the temporary files
61     os.remove(ssh_fn)
62     os.remove(ssl_fn)
63
64     return k
65
66 ##
67 # Public-private key pairs are implemented by the Keypair class.
68 # A Keypair object may represent both a public and private key pair, or it
69 # may represent only a public key (this usage is consistent with OpenSSL).
70
71 class Keypair:
72    key = None       # public/private keypair
73    m2key = None     # public key (m2crypto format)
74    
75    ##
76    # Creates a Keypair object
77    # @param create If create==True, creates a new public/private key and
78    #     stores it in the object
79    # @param string If string!=None, load the keypair from the string (PEM)
80    # @param filename If filename!=None, load the keypair from the file
81
82    def __init__(self, create=False, string=None, filename=None):
83       if create:
84          self.create()
85       if string:
86          self.load_from_string(string)
87       if filename:
88          self.load_from_file(filename)
89
90    ##
91    # Create a RSA public/private key pair and store it inside the keypair object
92
93    def create(self):
94       self.key = crypto.PKey()
95       self.key.generate_key(crypto.TYPE_RSA, 1024)
96
97    ##
98    # Save the private key to a file
99    # @param filename name of file to store the keypair in
100
101    def save_to_file(self, filename):
102       open(filename, 'w').write(self.as_pem())
103
104    ##
105    # Load the private key from a file. Implicity the private key includes the public key.
106
107    def load_from_file(self, filename):
108       buffer = open(filename, 'r').read()
109       self.load_from_string(buffer)
110
111    ##
112    # Load the private key from a string. Implicitly the private key includes the public key.
113
114    def load_from_string(self, string):
115       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
116       self.m2key = M2Crypto.EVP.load_key_string(string)
117
118    ##
119    #  Load the public key from a string. No private key is loaded. 
120
121    def load_pubkey_from_file(self, filename):
122       # load the m2 public key
123       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
124       self.m2key = M2Crypto.EVP.PKey()
125       self.m2key.assign_rsa(m2rsakey)
126
127       # create an m2 x509 cert
128       m2name = M2Crypto.X509.X509_Name()
129       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
130       m2x509 = M2Crypto.X509.X509()
131       m2x509.set_pubkey(self.m2key)
132       m2x509.set_serial_number(0)
133       m2x509.set_issuer_name(m2name)
134       m2x509.set_subject_name(m2name)
135       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
136       ASN1.set_time(500)
137       m2x509.set_not_before(ASN1)
138       m2x509.set_not_after(ASN1)
139       junk_key = Keypair(create=True)
140       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
141
142       # convert the m2 x509 cert to a pyopenssl x509
143       m2pem = m2x509.as_pem()
144       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
145
146       # get the pyopenssl pkey from the pyopenssl x509
147       self.key = pyx509.get_pubkey()
148
149    ##
150    # Load the public key from a string. No private key is loaded.
151
152    def load_pubkey_from_string(self, string):
153       (f, fn) = tempfile.mkstemp()
154       os.write(f, string)
155       os.close(f)
156       self.load_pubkey_from_file(fn)
157       os.remove(fn)
158
159    ##
160    # Return the private key in PEM format.
161
162    def as_pem(self):
163       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
164
165    ##
166    # Return an M2Crypto key object
167
168    def get_m2_pkey(self):
169       if not self.m2key:
170          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
171       return self.m2key
172
173    ##
174    # Returns a string containing the public key represented by this object.
175
176    def get_pubkey_string(self):
177       m2pkey = self.get_m2_pkey()
178       return base64.b64encode(m2pkey.as_der())
179
180    ##
181    # Return an OpenSSL pkey object
182
183    def get_openssl_pkey(self):
184       return self.key
185
186    ##
187    # Given another Keypair object, return TRUE if the two keys are the same.
188
189    def is_same(self, pkey):
190       return self.as_pem() == pkey.as_pem()
191
192    def sign_string(self, data):
193       k = self.get_m2_pkey()
194       k.sign_init()
195       k.sign_update(data)
196       return base64.b64encode(k.sign_final())
197
198    def verify_string(self, data, sig):
199       k = self.get_m2_pkey()
200       k.verify_init()
201       k.verify_update(data)
202       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
203
204    def compute_hash(self, value):
205       return self.sign_string(str(value))      
206
207 ##
208 # The certificate class implements a general purpose X509 certificate, making
209 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
210 # several addition features, such as the ability to maintain a chain of
211 # parent certificates, and storage of application-specific data.
212 #
213 # Certificates include the ability to maintain a chain of parents. Each
214 # certificate includes a pointer to it's parent certificate. When loaded
215 # from a file or a string, the parent chain will be automatically loaded.
216 # When saving a certificate to a file or a string, the caller can choose
217 # whether to save the parent certificates as well.
218
219 class Certificate:
220    digest = "md5"
221
222    cert = None
223    issuerKey = None
224    issuerSubject = None
225    parent = None
226
227    separator="-----parent-----"
228
229    ##
230    # Create a certificate object.
231    #
232    # @param create If create==True, then also create a blank X509 certificate.
233    # @param subject If subject!=None, then create a blank certificate and set
234    #     it's subject name.
235    # @param string If string!=None, load the certficate from the string.
236    # @param filename If filename!=None, load the certficiate from the file.
237
238    def __init__(self, create=False, subject=None, string=None, filename=None):
239        self.data = {}
240        if create or subject:
241            self.create()
242        if subject:
243            self.set_subject(subject)
244        if string:
245            self.load_from_string(string)
246        if filename:
247            self.load_from_file(filename)
248
249    ##
250    # Create a blank X509 certificate and store it in this object.
251
252    def create(self):
253        self.cert = crypto.X509()
254        self.cert.set_serial_number(1)
255        self.cert.gmtime_adj_notBefore(0)
256        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
257
258    ##
259    # Given a pyOpenSSL X509 object, store that object inside of this
260    # certificate object.
261
262    def load_from_pyopenssl_x509(self, x509):
263        self.cert = x509
264
265    ##
266    # Load the certificate from a string
267
268    def load_from_string(self, string):
269        # if it is a chain of multiple certs, then split off the first one and
270        # load it (support for the ---parent--- tag as well as normal chained certs)       
271
272        string = string.strip()       
273        
274        
275        if not string.startswith('-----'):
276            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
277            
278        parts = []
279        
280        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
281               string.count(Certificate.separator) == 0:
282            parts = string.split('-----END CERTIFICATE-----',1)
283            parts[0] += '-----END CERTIFICATE-----'
284        else:
285            parts = string.split(Certificate.separator, 1)
286        
287        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
288
289        # if there are more certs, then create a parent and let the parent load
290        # itself from the remainder of the string
291        if len(parts) > 1 and parts[1] != '':
292            self.parent = self.__class__()
293            self.parent.load_from_string(parts[1])
294
295    ##
296    # Load the certificate from a file
297
298    def load_from_file(self, filename):     
299        file = open(filename)
300        string = file.read()
301        self.load_from_string(string)
302
303    ##
304    # Save the certificate to a string.
305    #
306    # @param save_parents If save_parents==True, then also save the parent certificates.
307
308    def save_to_string(self, save_parents=True):
309        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
310        if save_parents and self.parent:
311           string = string + self.parent.save_to_string(save_parents)
312        return string
313
314    ##
315    # Save the certificate to a file.
316    # @param save_parents If save_parents==True, then also save the parent certificates.
317
318    def save_to_file(self, filename, save_parents=True, filep=None):
319        string = self.save_to_string(save_parents=save_parents)
320        if filep:
321            f = filep
322        else:
323            f = open(filename, 'w')
324        f.write(string)
325        f.close()
326
327    ##
328    # Save the certificate to a random file in /tmp/
329    # @param save_parents If save_parents==True, then also save the parent certificates.  
330    def save_to_random_tmp_file(self, save_parents=True):       
331        fp, filename = mkstemp(suffix='cert', text=True)
332        fp = os.fdopen(fp, "w")
333        self.save_to_file(filename, save_parents=True, filep=fp)
334        return filename   
335
336    ##
337    # Sets the issuer private key and name
338    # @param key Keypair object containing the private key of the issuer
339    # @param subject String containing the name of the issuer
340    # @param cert (optional) Certificate object containing the name of the issuer
341
342    def set_issuer(self, key, subject=None, cert=None):
343        self.issuerKey = key
344        if subject:
345           # it's a mistake to use subject and cert params at the same time
346           assert(not cert)
347           if isinstance(subject, dict) or isinstance(subject, str):
348              req = crypto.X509Req()
349              reqSubject = req.get_subject()
350              if (isinstance(subject, dict)):
351                 for key in reqSubject.keys():
352                     setattr(reqSubject, key, name[key])
353              else:
354                 setattr(reqSubject, "CN", subject)
355              subject = reqSubject
356              # subject is not valid once req is out of scope, so save req
357              self.issuerReq = req
358        if cert:
359           # if a cert was supplied, then get the subject from the cert
360           subject = cert.cert.get_issuer()
361        assert(subject)
362        self.issuerSubject = subject
363
364    ##
365    # Get the issuer name
366
367    def get_issuer(self, which="CN"):
368        x = self.cert.get_issuer()
369        return getattr(x, which)
370
371    ##
372    # Set the subject name of the certificate
373
374    def set_subject(self, name):
375        req = crypto.X509Req()
376        subj = req.get_subject()
377        if (isinstance(name, dict)):
378            for key in name.keys():
379                setattr(subj, key, name[key])
380        else:
381            setattr(subj, "CN", name)
382        self.cert.set_subject(subj)
383    ##
384    # Get the subject name of the certificate
385
386    def get_subject(self, which="CN"):
387        x = self.cert.get_subject()
388        return getattr(x, which)
389
390    ##
391    # Get the public key of the certificate.
392    #
393    # @param key Keypair object containing the public key
394
395    def set_pubkey(self, key):
396        assert(isinstance(key, Keypair))
397        self.cert.set_pubkey(key.get_openssl_pkey())
398
399    ##
400    # Get the public key of the certificate.
401    # It is returned in the form of a Keypair object.
402
403    def get_pubkey(self):
404        m2x509 = X509.load_cert_string(self.save_to_string())
405        pkey = Keypair()
406        pkey.key = self.cert.get_pubkey()
407        pkey.m2key = m2x509.get_pubkey()
408        return pkey
409
410    ##
411    # Add an X509 extension to the certificate. Add_extension can only be called
412    # once for a particular extension name, due to limitations in the underlying
413    # library.
414    #
415    # @param name string containing name of extension
416    # @param value string containing value of the extension
417
418    def add_extension(self, name, critical, value):
419        ext = crypto.X509Extension (name, critical, value)
420        self.cert.add_extensions([ext])
421
422    ##
423    # Get an X509 extension from the certificate
424
425    def get_extension(self, name):
426        # pyOpenSSL does not have a way to get extensions
427        m2x509 = X509.load_cert_string(self.save_to_string())
428        value = m2x509.get_ext(name).get_value()
429        return value
430
431    ##
432    # Set_data is a wrapper around add_extension. It stores the parameter str in
433    # the X509 subject_alt_name extension. Set_data can only be called once, due
434    # to limitations in the underlying library.
435
436    def set_data(self, str, field='subjectAltName'):
437        # pyOpenSSL only allows us to add extensions, so if we try to set the
438        # same extension more than once, it will not work
439        if self.data.has_key(field):
440           raise "cannot set ", field, " more than once"
441        self.data[field] = str
442        self.add_extension(field, 0, str)
443
444    ##
445    # Return the data string that was previously set with set_data
446
447    def get_data(self, field='subjectAltName'):
448        if self.data.has_key(field):
449            return self.data[field]
450
451        try:
452            uri = self.get_extension(field)
453            self.data[field] = uri           
454        except LookupError:
455            return None
456        
457        return self.data[field]
458
459    ##
460    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
461
462    def sign(self):
463        assert self.cert != None
464        assert self.issuerSubject != None
465        assert self.issuerKey != None
466        self.cert.set_issuer(self.issuerSubject)
467        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
468
469     ##
470     # Verify the authenticity of a certificate.
471     # @param pkey is a Keypair object representing a public key. If Pkey
472     #     did not sign the certificate, then an exception will be thrown.
473
474    def verify(self, pkey):
475        # pyOpenSSL does not have a way to verify signatures
476        m2x509 = X509.load_cert_string(self.save_to_string())
477        m2pkey = pkey.get_m2_pkey()
478        # verify it
479        return m2x509.verify(m2pkey)
480
481        # XXX alternatively, if openssl has been patched, do the much simpler:
482        # try:
483        #   self.cert.verify(pkey.get_openssl_key())
484        #   return 1
485        # except:
486        #   return 0
487
488    ##
489    # Return True if pkey is identical to the public key that is contained in the certificate.
490    # @param pkey Keypair object
491
492    def is_pubkey(self, pkey):
493        return self.get_pubkey().is_same(pkey)
494
495    ##
496    # Given a certificate cert, verify that this certificate was signed by the
497    # public key contained in cert. Throw an exception otherwise.
498    #
499    # @param cert certificate object
500
501    def is_signed_by_cert(self, cert):
502        k = cert.get_pubkey()
503        result = self.verify(k)
504        return result
505
506    ##
507    # Set the parent certficiate.
508    #
509    # @param p certificate object.
510
511    def set_parent(self, p):
512         self.parent = p
513
514    ##
515    # Return the certificate object of the parent of this certificate.
516
517    def get_parent(self):
518         return self.parent
519
520    ##
521    # Verification examines a chain of certificates to ensure that each parent
522    # signs the child, and that some certificate in the chain is signed by a
523    # trusted certificate.
524    #
525    # Verification is a basic recursion: <pre>
526    #     if this_certificate was signed by trusted_certs:
527    #         return
528    #     else
529    #         return verify_chain(parent, trusted_certs)
530    # </pre>
531    #
532    # At each recursion, the parent is tested to ensure that it did sign the
533    # child. If a parent did not sign a child, then an exception is thrown. If
534    # the bottom of the recursion is reached and the certificate does not match
535    # a trusted root, then an exception is thrown.
536    #
537    # @param Trusted_certs is a list of certificates that are trusted.
538    #
539
540    def verify_chain(self, trusted_certs = None):
541         # Verify a chain of certificates. Each certificate must be signed by
542         # the public key contained in it's parent. The chain is recursed
543         # until a certificate is found that is signed by a trusted root.
544
545         # TODO: verify expiration time
546         #print "====Verify Chain====="
547         # if this cert is signed by a trusted_cert, then we are set
548         for trusted_cert in trusted_certs:
549             #print "***************"
550             # TODO: verify expiration of trusted_cert ?
551             #print "CLIENT CERT", self.dump()
552             #print "TRUSTED CERT", trusted_cert.dump()
553             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
554             if self.is_signed_by_cert(trusted_cert):
555                 # make sure sure the trusted cert's hrn is a prefix of the
556                 # signed cert's hrn
557                 if not self.get_subject().startswith(trusted_cert.get_subject()):
558                     raise GidParentHrn(trusted_cert.get_subject()) 
559                 #print self.get_subject(), "is signed by a root"
560                 return
561
562         # if there is no parent, then no way to verify the chain
563         if not self.parent:
564             #print self.get_subject(), "has no parent"
565             raise CertMissingParent(self.get_subject())
566
567         # if it wasn't signed by the parent...
568         if not self.is_signed_by_cert(self.parent):
569             #print self.get_subject(), "is not signed by parent"
570             return CertNotSignedByParent(self.get_subject())
571
572         # if the parent isn't verified...
573         self.parent.verify_chain(trusted_certs)
574
575         return