86700e839e6b752d6dcd12185b6772eaf1d9434c
[sfa.git] / util / cert.py
1 from OpenSSL import crypto
2 import M2Crypto
3 from M2Crypto import X509
4 from M2Crypto import EVP
5
6 class Keypair:
7    key = None       # public/private keypair
8    m2key = None     # public key (m2crypto format)
9
10    def __init__(self, create=False):
11       if create:
12          self.create()
13       pass
14
15    def create(self):
16       self.key = crypto.PKey()
17       self.key.generate_key(crypto.TYPE_RSA, 1024)
18
19    def save_to_file(self, filename):
20       open(filename, 'w').write(self.as_pem())
21
22    def load_from_file(self, filename):
23       buffer = open(filename, 'r').read()
24       self.load_from_string(buffer)
25
26    def load_from_string(self, string):
27       self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
28       self.m2key = M2Crypto.EVP.load_key_string(string)
29
30    def as_pem(self):
31       return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
32
33    def get_m2_pkey(self):
34       if not self.m2key:
35          self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
36       return self.m2key
37
38    def get_openssl_pkey(self):
39       return self.key
40
41 class Certificate:
42    digest = "md5"
43
44    data = None
45    cert = None
46    issuerKey = None
47    issuerSubject = None
48    parent = None
49
50    def __init__(self, create=False, subject=None, string=None, filename=None):
51        if create or subject:
52            self.create()
53        if subject:
54            self.set_subject(subject)
55        if string:
56            self.load_from_string(string)
57        if filename:
58            self.load_from_file(filename)
59
60    def create(self):
61        self.cert = crypto.X509()
62        self.cert.set_serial_number(1)
63        self.cert.gmtime_adj_notBefore(0)
64        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
65
66    def load_from_string(self, string):
67        # if it is a chain of multiple certs, then split off the first one and
68        # load it
69        parts = string.split("-----parent-----", 1)
70        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
71        
72        # if there are more certs, then create a parent and let the parent load
73        # itself from the remainder of the string
74        if len(parts) > 1:
75            self.parent = Certificate()
76            self.parent.load_from_string(parts[1])
77
78
79    def load_from_file(self, filename):
80        file = open(filename)
81        string = file.read()
82        self.load_from_string(string)
83
84    def save_to_string(self, save_parents=False):
85        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
86        if save_parents and self.parent:
87           string = string + "-----parent-----" + self.parent.save_to_string(save_parents)
88        return string
89
90    def save_to_file(self, filename, save_parents=False):
91        string = self.save_to_string(save_parents=save_parents)
92        open(filename, 'w').write(string)
93
94    def set_issuer(self, key, subject=None, cert=None):
95        self.issuerKey = key
96        if subject:
97           # it's a mistake to use subject and cert params at the same time
98           assert(not cert)
99           if isinstance(subject, dict) or isinstance(subject, str):
100              req = crypto.X509Req()
101              reqSubject = req.get_subject()
102              if (isinstance(subject, dict)):
103                 for key in reqSubject.keys():
104                     setattr(reqSubject, key, name[key])
105              else:
106                 setattr(reqSubject, "CN", subject)
107              subject = reqSubject
108              # subject is not valid once req is out of scope, so save req
109              self.issuerReq = req
110        if cert:
111           # if a cert was supplied, then get the subject from the cert
112           subject = cert.cert.get_issuer()
113        assert(subject)
114        self.issuerSubject = subject
115
116    def get_issuer(self, which="CN"):
117        x = self.cert.get_issuer()
118        return getattr(x, which)
119
120    def set_subject(self, name):
121        req = crypto.X509Req()
122        subj = req.get_subject()
123        if (isinstance(name, dict)):
124            for key in name.keys():
125                setattr(subj, key, name[key])
126        else:
127            setattr(subj, "CN", name)
128        self.cert.set_subject(subj)
129
130    def get_subject(self, which="CN"):
131        x = self.cert.get_subject()
132        return getattr(x, which)
133
134    def set_pubkey(self, key):
135        assert(isinstance(key, Keypair))
136        self.cert.set_pubkey(key.get_openssl_pkey())
137
138    def get_pubkey(self):
139        m2x509 = X509.load_cert_string(self.save_to_string())
140        pkey = Keypair()
141        pkey.key = self.cert.get_pubkey()
142        pkey.m2key = m2x509.get_pubkey()
143        return pkey
144
145    def add_extension(self, name, critical, value):
146        ext = crypto.X509Extension (name, critical, value)
147        self.cert.add_extensions([ext])
148
149    def get_extension(self, name):
150        # pyOpenSSL does not have a way to get certificates
151        m2x509 = X509.load_cert_string(self.save_to_string())
152        value = m2x509.get_ext(name).get_value()
153        return value
154
155    def set_data(self, str):
156        # pyOpenSSL only allows us to add extensions, so if we try to set the
157        # same extension more than once, it will not work
158        if self.data != None:
159           raise "cannot set subjectAltName more than once"
160        self.data = str
161        self.add_extension("subjectAltName", 0, "URI:http://" + str)
162
163    def get_data(self):
164        if self.data:
165            return self.data
166
167        try:
168            uri = self.get_extension("subjectAltName")
169        except LookupError:
170            self.data = None
171            return self.data
172
173        if not uri.startswith("URI:http://"):
174            raise "bad encoding in subjectAltName"
175        self.data = uri[11:]
176        return self.data
177
178    def sign(self):
179        assert self.cert != None
180        assert self.issuerSubject != None
181        assert self.issuerKey != None
182        self.cert.set_issuer(self.issuerSubject)
183        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
184
185    def verify(self, pkey):
186        # pyOpenSSL does not have a way to verify signatures
187        m2x509 = X509.load_cert_string(self.save_to_string())
188        m2pkey = pkey.get_m2_pkey()
189        # verify it
190        return m2x509.verify(m2pkey)
191
192        # XXX alternatively, if openssl has been patched, do the much simpler:
193        # try:
194        #   self.cert.verify(pkey.get_openssl_key())
195        #   return 1
196        # except:
197        #   return 0
198
199    def is_signed_by_cert(self, cert):
200        k = cert.get_pubkey()
201        result = self.verify(k)
202        return result
203
204    def set_parent(self, p):
205         self.parent = p
206
207    def get_parent(self):
208         return self.parent
209
210    def verify_chain(self, trusted_certs = None):
211         # if this cert is signed by a trusted_cert, then we are set
212         for trusted_cert in trusted_certs:
213             if self.is_signed_by_cert(trusted_cert):
214                 #print self.get_subject(), "is signed by a root"
215                 return True
216
217         # if there is no parent, then no way to verify the chain
218         if not self.parent:
219             #print self.get_subject(), "has no parent"
220             return False
221
222         # if it wasn't signed by the parent...
223         if not self.is_signed_by_cert(self.parent):
224             #print self.get_subject(), "is not signed by parent"
225             return False
226
227         # if the parent isn't verified...
228         if not self.parent.verify_chain(trusted_certs):
229             #print self.get_subject(), "parent does not verify"
230             return False
231
232         return True