Merge remote-tracking branches 'asoc/fix/rt298', 'asoc/fix/sx', 'asoc/fix/wm8904...
[linux-drm-fsl-dcu.git] / crypto / rsa.c
1 /* RSA asymmetric public-key algorithm [RFC3447]
2  *
3  * Copyright (c) 2015, Intel Corporation
4  * Authors: Tadeusz Struk <tadeusz.struk@intel.com>
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public Licence
8  * as published by the Free Software Foundation; either version
9  * 2 of the Licence, or (at your option) any later version.
10  */
11
12 #include <linux/module.h>
13 #include <crypto/internal/rsa.h>
14 #include <crypto/internal/akcipher.h>
15 #include <crypto/akcipher.h>
16
17 /*
18  * RSAEP function [RFC3447 sec 5.1.1]
19  * c = m^e mod n;
20  */
21 static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m)
22 {
23         /* (1) Validate 0 <= m < n */
24         if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0)
25                 return -EINVAL;
26
27         /* (2) c = m^e mod n */
28         return mpi_powm(c, m, key->e, key->n);
29 }
30
31 /*
32  * RSADP function [RFC3447 sec 5.1.2]
33  * m = c^d mod n;
34  */
35 static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c)
36 {
37         /* (1) Validate 0 <= c < n */
38         if (mpi_cmp_ui(c, 0) < 0 || mpi_cmp(c, key->n) >= 0)
39                 return -EINVAL;
40
41         /* (2) m = c^d mod n */
42         return mpi_powm(m, c, key->d, key->n);
43 }
44
45 /*
46  * RSASP1 function [RFC3447 sec 5.2.1]
47  * s = m^d mod n
48  */
49 static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m)
50 {
51         /* (1) Validate 0 <= m < n */
52         if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0)
53                 return -EINVAL;
54
55         /* (2) s = m^d mod n */
56         return mpi_powm(s, m, key->d, key->n);
57 }
58
59 /*
60  * RSAVP1 function [RFC3447 sec 5.2.2]
61  * m = s^e mod n;
62  */
63 static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s)
64 {
65         /* (1) Validate 0 <= s < n */
66         if (mpi_cmp_ui(s, 0) < 0 || mpi_cmp(s, key->n) >= 0)
67                 return -EINVAL;
68
69         /* (2) m = s^e mod n */
70         return mpi_powm(m, s, key->e, key->n);
71 }
72
73 static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm)
74 {
75         return akcipher_tfm_ctx(tfm);
76 }
77
78 static int rsa_enc(struct akcipher_request *req)
79 {
80         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
81         const struct rsa_key *pkey = rsa_get_key(tfm);
82         MPI m, c = mpi_alloc(0);
83         int ret = 0;
84         int sign;
85
86         if (!c)
87                 return -ENOMEM;
88
89         if (unlikely(!pkey->n || !pkey->e)) {
90                 ret = -EINVAL;
91                 goto err_free_c;
92         }
93
94         if (req->dst_len < mpi_get_size(pkey->n)) {
95                 req->dst_len = mpi_get_size(pkey->n);
96                 ret = -EOVERFLOW;
97                 goto err_free_c;
98         }
99
100         m = mpi_read_raw_data(req->src, req->src_len);
101         if (!m) {
102                 ret = -ENOMEM;
103                 goto err_free_c;
104         }
105
106         ret = _rsa_enc(pkey, c, m);
107         if (ret)
108                 goto err_free_m;
109
110         ret = mpi_read_buffer(c, req->dst, req->dst_len, &req->dst_len, &sign);
111         if (ret)
112                 goto err_free_m;
113
114         if (sign < 0) {
115                 ret = -EBADMSG;
116                 goto err_free_m;
117         }
118
119 err_free_m:
120         mpi_free(m);
121 err_free_c:
122         mpi_free(c);
123         return ret;
124 }
125
126 static int rsa_dec(struct akcipher_request *req)
127 {
128         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
129         const struct rsa_key *pkey = rsa_get_key(tfm);
130         MPI c, m = mpi_alloc(0);
131         int ret = 0;
132         int sign;
133
134         if (!m)
135                 return -ENOMEM;
136
137         if (unlikely(!pkey->n || !pkey->d)) {
138                 ret = -EINVAL;
139                 goto err_free_m;
140         }
141
142         if (req->dst_len < mpi_get_size(pkey->n)) {
143                 req->dst_len = mpi_get_size(pkey->n);
144                 ret = -EOVERFLOW;
145                 goto err_free_m;
146         }
147
148         c = mpi_read_raw_data(req->src, req->src_len);
149         if (!c) {
150                 ret = -ENOMEM;
151                 goto err_free_m;
152         }
153
154         ret = _rsa_dec(pkey, m, c);
155         if (ret)
156                 goto err_free_c;
157
158         ret = mpi_read_buffer(m, req->dst, req->dst_len, &req->dst_len, &sign);
159         if (ret)
160                 goto err_free_c;
161
162         if (sign < 0) {
163                 ret = -EBADMSG;
164                 goto err_free_c;
165         }
166
167 err_free_c:
168         mpi_free(c);
169 err_free_m:
170         mpi_free(m);
171         return ret;
172 }
173
174 static int rsa_sign(struct akcipher_request *req)
175 {
176         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
177         const struct rsa_key *pkey = rsa_get_key(tfm);
178         MPI m, s = mpi_alloc(0);
179         int ret = 0;
180         int sign;
181
182         if (!s)
183                 return -ENOMEM;
184
185         if (unlikely(!pkey->n || !pkey->d)) {
186                 ret = -EINVAL;
187                 goto err_free_s;
188         }
189
190         if (req->dst_len < mpi_get_size(pkey->n)) {
191                 req->dst_len = mpi_get_size(pkey->n);
192                 ret = -EOVERFLOW;
193                 goto err_free_s;
194         }
195
196         m = mpi_read_raw_data(req->src, req->src_len);
197         if (!m) {
198                 ret = -ENOMEM;
199                 goto err_free_s;
200         }
201
202         ret = _rsa_sign(pkey, s, m);
203         if (ret)
204                 goto err_free_m;
205
206         ret = mpi_read_buffer(s, req->dst, req->dst_len, &req->dst_len, &sign);
207         if (ret)
208                 goto err_free_m;
209
210         if (sign < 0) {
211                 ret = -EBADMSG;
212                 goto err_free_m;
213         }
214
215 err_free_m:
216         mpi_free(m);
217 err_free_s:
218         mpi_free(s);
219         return ret;
220 }
221
222 static int rsa_verify(struct akcipher_request *req)
223 {
224         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
225         const struct rsa_key *pkey = rsa_get_key(tfm);
226         MPI s, m = mpi_alloc(0);
227         int ret = 0;
228         int sign;
229
230         if (!m)
231                 return -ENOMEM;
232
233         if (unlikely(!pkey->n || !pkey->e)) {
234                 ret = -EINVAL;
235                 goto err_free_m;
236         }
237
238         if (req->dst_len < mpi_get_size(pkey->n)) {
239                 req->dst_len = mpi_get_size(pkey->n);
240                 ret = -EOVERFLOW;
241                 goto err_free_m;
242         }
243
244         s = mpi_read_raw_data(req->src, req->src_len);
245         if (!s) {
246                 ret = -ENOMEM;
247                 goto err_free_m;
248         }
249
250         ret = _rsa_verify(pkey, m, s);
251         if (ret)
252                 goto err_free_s;
253
254         ret = mpi_read_buffer(m, req->dst, req->dst_len, &req->dst_len, &sign);
255         if (ret)
256                 goto err_free_s;
257
258         if (sign < 0) {
259                 ret = -EBADMSG;
260                 goto err_free_s;
261         }
262
263 err_free_s:
264         mpi_free(s);
265 err_free_m:
266         mpi_free(m);
267         return ret;
268 }
269
270 static int rsa_check_key_length(unsigned int len)
271 {
272         switch (len) {
273         case 512:
274         case 1024:
275         case 1536:
276         case 2048:
277         case 3072:
278         case 4096:
279                 return 0;
280         }
281
282         return -EINVAL;
283 }
284
285 static int rsa_setkey(struct crypto_akcipher *tfm, const void *key,
286                       unsigned int keylen)
287 {
288         struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
289         int ret;
290
291         ret = rsa_parse_key(pkey, key, keylen);
292         if (ret)
293                 return ret;
294
295         if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) {
296                 rsa_free_key(pkey);
297                 ret = -EINVAL;
298         }
299         return ret;
300 }
301
302 static void rsa_exit_tfm(struct crypto_akcipher *tfm)
303 {
304         struct rsa_key *pkey = akcipher_tfm_ctx(tfm);
305
306         rsa_free_key(pkey);
307 }
308
309 static struct akcipher_alg rsa = {
310         .encrypt = rsa_enc,
311         .decrypt = rsa_dec,
312         .sign = rsa_sign,
313         .verify = rsa_verify,
314         .setkey = rsa_setkey,
315         .exit = rsa_exit_tfm,
316         .base = {
317                 .cra_name = "rsa",
318                 .cra_driver_name = "rsa-generic",
319                 .cra_priority = 100,
320                 .cra_module = THIS_MODULE,
321                 .cra_ctxsize = sizeof(struct rsa_key),
322         },
323 };
324
325 static int rsa_init(void)
326 {
327         return crypto_register_akcipher(&rsa);
328 }
329
330 static void rsa_exit(void)
331 {
332         crypto_unregister_akcipher(&rsa);
333 }
334
335 module_init(rsa_init);
336 module_exit(rsa_exit);
337 MODULE_ALIAS_CRYPTO("rsa");
338 MODULE_LICENSE("GPL");
339 MODULE_DESCRIPTION("RSA generic algorithm");