renaming the toplevel geni/ package into sfa/
[sfa.git] / sfa / trust / certificate.py
1 ##
2 # Geniwrapper uses two crypto libraries: pyOpenSSL and M2Crypto to implement
3 # the necessary crypto functionality. Ideally just one of these libraries
4 # would be used, but unfortunately each of these libraries is independently
5 # lacking. The pyOpenSSL library is missing many necessary functions, and
6 # the M2Crypto library has crashed inside of some of the functions. The
7 # design decision is to use pyOpenSSL whenever possible as it seems more
8 # stable, and only use M2Crypto for those functions that are not possible
9 # in pyOpenSSL.
10 #
11 # This module exports two classes: Keypair and Certificate.
12 ##
13 #
14 ### $Id$
15 ### $URL$
16 #
17
18 import os
19 import tempfile
20 import base64
21 from OpenSSL import crypto
22 import M2Crypto
23 from M2Crypto import X509
24 from M2Crypto import EVP
25
26 from sfa.util.faults import *
27
28 def convert_public_key(key):
29     # find the keyconvert program
30     from sfa.util.config import Config
31     config = Config()
32     keyconvert = 'keyconvert'
33     loaded = False
34     default_path = "/usr/share/keyconvert/" + keyconvert
35     cwd = os.path.dirname(os.path.abspath(__file__))
36     alt_path = os.sep.join(cwd.split(os.sep)[:-1] + ['keyconvert', 'keyconvert'])
37     geni_path = config.basepath + os.sep + "keyconvert/keyconvert"
38     files = [default_path, alt_path, geni_path]
39     for path in files:
40         if os.path.isfile(path):
41             keyconvert_fn = path
42             loaded = True
43             break
44
45     if not loaded:
46         raise Exception, "Could not find keyconvert in " + ", ".join(files)
47
48     # we can only convert rsa keys 
49     if "ssh-dss" in key:
50         print "XXX: DSA key encountered, ignoring"
51         return None
52     
53     (ssh_f, ssh_fn) = tempfile.mkstemp()
54     ssl_fn = tempfile.mktemp()
55     os.write(ssh_f, key)
56     os.close(ssh_f)
57
58     if not os.path.exists(keyconvert_fn):
59         report.trace("  keyconvet utility " + str(keyconvert_fn) + "does not exist")
60         sys.exit(-1)
61
62     cmd = keyconvert_fn + " " + ssh_fn + " " + ssl_fn
63     os.system(cmd)
64
65     # this check leaves the temporary file containing the public key so
66     # that it can be expected to see why it failed.
67     # TODO: for production, cleanup the temporary files
68     if not os.path.exists(ssl_fn):
69         report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
70         return None
71
72     k = Keypair()
73     try:
74         k.load_pubkey_from_file(ssl_fn)
75     except:
76         print "XXX: Error while converting key: ", key_str
77         k = None
78
79     # remove the temporary files
80     os.remove(ssh_fn)
81     os.remove(ssl_fn)
82
83     return k
84
85 ##
86 # Public-private key pairs are implemented by the Keypair class.
87 # A Keypair object may represent both a public and private key pair, or it
88 # may represent only a public key (this usage is consistent with OpenSSL).
89
90 class Keypair:
91    key = None       # public/private keypair
92    m2key = None     # public key (m2crypto format)
93
94    ##
95    # Creates a Keypair object
96    # @param create If create==True, creates a new public/private key and
97    #     stores it in the object
98    # @param string If string!=None, load the keypair from the string (PEM)
99    # @param filename If filename!=None, load the keypair from the file
100
101    def __init__(self, create=False, string=None, filename=None):
102       if create:
103          self.create()
104       if string:
105          self.load_from_string(string)
106       if filename:
107          self.load_from_file(filename)
108
109    ##
110    # Create a RSA public/private key pair and store it inside the keypair object
111
112    def create(self):
113       self.key = crypto.PKey()
114       self.key.generate_key(crypto.TYPE_RSA, 1024)
115
116    ##
117    # Save the private key to a file
118    # @param filename name of file to store the keypair in
119
120    def save_to_file(self, filename):
121       open(filename, 'w').write(self.as_pem())
122
123    ##
124    # Load the private key from a file. Implicity the private key includes the public key.
125
126    def load_from_file(self, filename):
127       buffer = open(filename, 'r').read()
128       self.load_from_string(buffer)
129
130    ##
131    # Load the private key from a string. Implicitly the private key includes the public key.
132
133    def load_from_string(self, string):
134       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
135       self.m2key = M2Crypto.EVP.load_key_string(string)
136
137    ##
138    #  Load the public key from a string. No private key is loaded. 
139
140    def load_pubkey_from_file(self, filename):
141       # load the m2 public key
142       m2rsakey = M2Crypto.RSA.load_pub_key(filename)
143       self.m2key = M2Crypto.EVP.PKey()
144       self.m2key.assign_rsa(m2rsakey)
145
146       # create an m2 x509 cert
147       m2name = M2Crypto.X509.X509_Name()
148       m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
149       m2x509 = M2Crypto.X509.X509()
150       m2x509.set_pubkey(self.m2key)
151       m2x509.set_serial_number(0)
152       m2x509.set_issuer_name(m2name)
153       m2x509.set_subject_name(m2name)
154       ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
155       ASN1.set_time(500)
156       m2x509.set_not_before(ASN1)
157       m2x509.set_not_after(ASN1)
158       junk_key = Keypair(create=True)
159       m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
160
161       # convert the m2 x509 cert to a pyopenssl x509
162       m2pem = m2x509.as_pem()
163       pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
164
165       # get the pyopenssl pkey from the pyopenssl x509
166       self.key = pyx509.get_pubkey()
167
168    ##
169    # Load the public key from a string. No private key is loaded.
170
171    def load_pubkey_from_string(self, string):
172       (f, fn) = tempfile.mkstemp()
173       os.write(f, string)
174       os.close(f)
175       self.load_pubkey_from_file(fn)
176       os.remove(fn)
177
178    ##
179    # Return the private key in PEM format.
180
181    def as_pem(self):
182       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
183
184    def get_m2_pkey(self):
185       if not self.m2key:
186          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
187       return self.m2key
188
189    ##
190    # Return an OpenSSL pkey object
191
192    def get_openssl_pkey(self):
193       return self.key
194
195    ##
196    # Given another Keypair object, return TRUE if the two keys are the same.
197
198    def is_same(self, pkey):
199       return self.as_pem() == pkey.as_pem()
200
201    def sign_string(self, data):
202       k = self.get_m2_pkey()
203       k.sign_init()
204       k.sign_update(data)
205       return base64.b64encode(k.sign_final())
206
207    def verify_string(self, data, sig):
208       k = self.get_m2_pkey()
209       k.verify_init()
210       k.verify_update(data)
211       return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
212
213 ##
214 # The certificate class implements a general purpose X509 certificate, making
215 # use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
216 # several addition features, such as the ability to maintain a chain of
217 # parent certificates, and storage of application-specific data.
218 #
219 # Certificates include the ability to maintain a chain of parents. Each
220 # certificate includes a pointer to it's parent certificate. When loaded
221 # from a file or a string, the parent chain will be automatically loaded.
222 # When saving a certificate to a file or a string, the caller can choose
223 # whether to save the parent certificates as well.
224
225 class Certificate:
226    digest = "md5"
227
228    data = None
229    cert = None
230    issuerKey = None
231    issuerSubject = None
232    parent = None
233
234    separator="-----parent-----"
235
236    ##
237    # Create a certificate object.
238    #
239    # @param create If create==True, then also create a blank X509 certificate.
240    # @param subject If subject!=None, then create a blank certificate and set
241    #     it's subject name.
242    # @param string If string!=None, load the certficate from the string.
243    # @param filename If filename!=None, load the certficiate from the file.
244
245    def __init__(self, create=False, subject=None, string=None, filename=None):
246        if create or subject:
247            self.create()
248        if subject:
249            self.set_subject(subject)
250        if string:
251            self.load_from_string(string)
252        if filename:
253            self.load_from_file(filename)
254
255    ##
256    # Create a blank X509 certificate and store it in this object.
257
258    def create(self):
259        self.cert = crypto.X509()
260        self.cert.set_serial_number(1)
261        self.cert.gmtime_adj_notBefore(0)
262        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
263
264    ##
265    # Given a pyOpenSSL X509 object, store that object inside of this
266    # certificate object.
267
268    def load_from_pyopenssl_x509(self, x509):
269        self.cert = x509
270
271    ##
272    # Load the certificate from a string
273
274    def load_from_string(self, string):
275        # if it is a chain of multiple certs, then split off the first one and
276        # load it
277        parts = string.split(Certificate.separator, 1)
278        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
279
280        # if there are more certs, then create a parent and let the parent load
281        # itself from the remainder of the string
282        if len(parts) > 1:
283            self.parent = self.__class__()
284            self.parent.load_from_string(parts[1])
285
286    ##
287    # Load the certificate from a file
288
289    def load_from_file(self, filename):
290        file = open(filename)
291        string = file.read()
292        self.load_from_string(string)
293
294    ##
295    # Save the certificate to a string.
296    #
297    # @param save_parents If save_parents==True, then also save the parent certificates.
298
299    def save_to_string(self, save_parents=False):
300        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
301        if save_parents and self.parent:
302           string = string + Certificate.separator + self.parent.save_to_string(save_parents)
303        return string
304
305    ##
306    # Save the certificate to a file.
307    # @param save_parents If save_parents==True, then also save the parent certificates.
308
309    def save_to_file(self, filename, save_parents=False):
310        string = self.save_to_string(save_parents=save_parents)
311        open(filename, 'w').write(string)
312
313    ##
314    # Sets the issuer private key and name
315    # @param key Keypair object containing the private key of the issuer
316    # @param subject String containing the name of the issuer
317    # @param cert (optional) Certificate object containing the name of the issuer
318
319    def set_issuer(self, key, subject=None, cert=None):
320        self.issuerKey = key
321        if subject:
322           # it's a mistake to use subject and cert params at the same time
323           assert(not cert)
324           if isinstance(subject, dict) or isinstance(subject, str):
325              req = crypto.X509Req()
326              reqSubject = req.get_subject()
327              if (isinstance(subject, dict)):
328                 for key in reqSubject.keys():
329                     setattr(reqSubject, key, name[key])
330              else:
331                 setattr(reqSubject, "CN", subject)
332              subject = reqSubject
333              # subject is not valid once req is out of scope, so save req
334              self.issuerReq = req
335        if cert:
336           # if a cert was supplied, then get the subject from the cert
337           subject = cert.cert.get_issuer()
338        assert(subject)
339        self.issuerSubject = subject
340
341    ##
342    # Get the issuer name
343
344    def get_issuer(self, which="CN"):
345        x = self.cert.get_issuer()
346        return getattr(x, which)
347
348    ##
349    # Set the subject name of the certificate
350
351    def set_subject(self, name):
352        req = crypto.X509Req()
353        subj = req.get_subject()
354        if (isinstance(name, dict)):
355            for key in name.keys():
356                setattr(subj, key, name[key])
357        else:
358            setattr(subj, "CN", name)
359        self.cert.set_subject(subj)
360    ##
361    # Get the subject name of the certificate
362
363    def get_subject(self, which="CN"):
364        x = self.cert.get_subject()
365        return getattr(x, which)
366
367    ##
368    # Get the public key of the certificate.
369    #
370    # @param key Keypair object containing the public key
371
372    def set_pubkey(self, key):
373        assert(isinstance(key, Keypair))
374        self.cert.set_pubkey(key.get_openssl_pkey())
375
376    ##
377    # Get the public key of the certificate.
378    # It is returned in the form of a Keypair object.
379
380    def get_pubkey(self):
381        m2x509 = X509.load_cert_string(self.save_to_string())
382        pkey = Keypair()
383        pkey.key = self.cert.get_pubkey()
384        pkey.m2key = m2x509.get_pubkey()
385        return pkey
386
387    ##
388    # Add an X509 extension to the certificate. Add_extension can only be called
389    # once for a particular extension name, due to limitations in the underlying
390    # library.
391    #
392    # @param name string containing name of extension
393    # @param value string containing value of the extension
394
395    def add_extension(self, name, critical, value):
396        ext = crypto.X509Extension (name, critical, value)
397        self.cert.add_extensions([ext])
398
399    ##
400    # Get an X509 extension from the certificate
401
402    def get_extension(self, name):
403        # pyOpenSSL does not have a way to get extensions
404        m2x509 = X509.load_cert_string(self.save_to_string())
405        value = m2x509.get_ext(name).get_value()
406        return value
407
408    ##
409    # Set_data is a wrapper around add_extension. It stores the parameter str in
410    # the X509 subject_alt_name extension. Set_data can only be called once, due
411    # to limitations in the underlying library.
412
413    def set_data(self, str):
414        # pyOpenSSL only allows us to add extensions, so if we try to set the
415        # same extension more than once, it will not work
416        if self.data != None:
417           raise "cannot set subjectAltName more than once"
418        self.data = str
419        self.add_extension("subjectAltName", 0, "URI:http://" + str)
420
421    ##
422    # Return the data string that was previously set with set_data
423
424    def get_data(self):
425        if self.data:
426            return self.data
427
428        try:
429            uri = self.get_extension("subjectAltName")
430        except LookupError:
431            self.data = None
432            return self.data
433
434        if not uri.startswith("URI:http://"):
435            raise "bad encoding in subjectAltName"
436        self.data = uri[11:]
437        return self.data
438
439    ##
440    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
441
442    def sign(self):
443        assert self.cert != None
444        assert self.issuerSubject != None
445        assert self.issuerKey != None
446        self.cert.set_issuer(self.issuerSubject)
447        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
448
449     ##
450     # Verify the authenticity of a certificate.
451     # @param pkey is a Keypair object representing a public key. If Pkey
452     #     did not sign the certificate, then an exception will be thrown.
453
454    def verify(self, pkey):
455        # pyOpenSSL does not have a way to verify signatures
456        m2x509 = X509.load_cert_string(self.save_to_string())
457        m2pkey = pkey.get_m2_pkey()
458        # verify it
459        return m2x509.verify(m2pkey)
460
461        # XXX alternatively, if openssl has been patched, do the much simpler:
462        # try:
463        #   self.cert.verify(pkey.get_openssl_key())
464        #   return 1
465        # except:
466        #   return 0
467
468    ##
469    # Return True if pkey is identical to the public key that is contained in the certificate.
470    # @param pkey Keypair object
471
472    def is_pubkey(self, pkey):
473        return self.get_pubkey().is_same(pkey)
474
475    ##
476    # Given a certificate cert, verify that this certificate was signed by the
477    # public key contained in cert. Throw an exception otherwise.
478    #
479    # @param cert certificate object
480
481    def is_signed_by_cert(self, cert):
482        k = cert.get_pubkey()
483        result = self.verify(k)
484        return result
485
486    ##
487    # Set the parent certficiate.
488    #
489    # @param p certificate object.
490
491    def set_parent(self, p):
492         self.parent = p
493
494    ##
495    # Return the certificate object of the parent of this certificate.
496
497    def get_parent(self):
498         return self.parent
499
500    ##
501    # Verification examines a chain of certificates to ensure that each parent
502    # signs the child, and that some certificate in the chain is signed by a
503    # trusted certificate.
504    #
505    # Verification is a basic recursion: <pre>
506    #     if this_certificate was signed by trusted_certs:
507    #         return
508    #     else
509    #         return verify_chain(parent, trusted_certs)
510    # </pre>
511    #
512    # At each recursion, the parent is tested to ensure that it did sign the
513    # child. If a parent did not sign a child, then an exception is thrown. If
514    # the bottom of the recursion is reached and the certificate does not match
515    # a trusted root, then an exception is thrown.
516    #
517    # @param Trusted_certs is a list of certificates that are trusted.
518    #
519
520    def verify_chain(self, trusted_certs = None):
521         # Verify a chain of certificates. Each certificate must be signed by
522         # the public key contained in it's parent. The chain is recursed
523         # until a certificate is found that is signed by a trusted root.
524
525         # TODO: verify expiration time
526
527         # if this cert is signed by a trusted_cert, then we are set
528         for trusted_cert in trusted_certs:
529             # TODO: verify expiration of trusted_cert ?
530             if self.is_signed_by_cert(trusted_cert):
531                 #print self.get_subject(), "is signed by a root"
532                 return
533
534         # if there is no parent, then no way to verify the chain
535         if not self.parent:
536             #print self.get_subject(), "has no parent"
537             raise CertMissingParent(self.get_subject())
538
539         # if it wasn't signed by the parent...
540         if not self.is_signed_by_cert(self.parent):
541             #print self.get_subject(), "is not signed by parent"
542             return CertNotSignedByParent(self.get_subject())
543
544         # if the parent isn't verified...
545         self.parent.verify_chain(trusted_certs)
546
547         return