dfff8b0b56dfae05a8a0438655da24070044c54b
[linux-drm-fsl-dcu.git] / crypto / algif_skcipher.c
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/mm.h>
22 #include <linux/module.h>
23 #include <linux/net.h>
24 #include <net/sock.h>
25
26 struct skcipher_sg_list {
27         struct list_head list;
28
29         int cur;
30
31         struct scatterlist sg[0];
32 };
33
34 struct skcipher_tfm {
35         struct crypto_skcipher *skcipher;
36         bool has_key;
37 };
38
39 struct skcipher_ctx {
40         struct list_head tsgl;
41         struct af_alg_sgl rsgl;
42
43         void *iv;
44
45         struct af_alg_completion completion;
46
47         atomic_t inflight;
48         size_t used;
49
50         unsigned int len;
51         bool more;
52         bool merge;
53         bool enc;
54
55         struct skcipher_request req;
56 };
57
58 struct skcipher_async_rsgl {
59         struct af_alg_sgl sgl;
60         struct list_head list;
61 };
62
63 struct skcipher_async_req {
64         struct kiocb *iocb;
65         struct skcipher_async_rsgl first_sgl;
66         struct list_head list;
67         struct scatterlist *tsg;
68         char iv[];
69 };
70
71 #define GET_SREQ(areq, ctx) (struct skcipher_async_req *)((char *)areq + \
72         crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req)))
73
74 #define GET_REQ_SIZE(ctx) \
75         crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req))
76
77 #define GET_IV_SIZE(ctx) \
78         crypto_skcipher_ivsize(crypto_skcipher_reqtfm(&ctx->req))
79
80 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
81                       sizeof(struct scatterlist) - 1)
82
83 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
84 {
85         struct skcipher_async_rsgl *rsgl, *tmp;
86         struct scatterlist *sgl;
87         struct scatterlist *sg;
88         int i, n;
89
90         list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
91                 af_alg_free_sg(&rsgl->sgl);
92                 if (rsgl != &sreq->first_sgl)
93                         kfree(rsgl);
94         }
95         sgl = sreq->tsg;
96         n = sg_nents(sgl);
97         for_each_sg(sgl, sg, n, i)
98                 put_page(sg_page(sg));
99
100         kfree(sreq->tsg);
101 }
102
103 static void skcipher_async_cb(struct crypto_async_request *req, int err)
104 {
105         struct sock *sk = req->data;
106         struct alg_sock *ask = alg_sk(sk);
107         struct skcipher_ctx *ctx = ask->private;
108         struct skcipher_async_req *sreq = GET_SREQ(req, ctx);
109         struct kiocb *iocb = sreq->iocb;
110
111         atomic_dec(&ctx->inflight);
112         skcipher_free_async_sgls(sreq);
113         kfree(req);
114         iocb->ki_complete(iocb, err, err);
115 }
116
117 static inline int skcipher_sndbuf(struct sock *sk)
118 {
119         struct alg_sock *ask = alg_sk(sk);
120         struct skcipher_ctx *ctx = ask->private;
121
122         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
123                           ctx->used, 0);
124 }
125
126 static inline bool skcipher_writable(struct sock *sk)
127 {
128         return PAGE_SIZE <= skcipher_sndbuf(sk);
129 }
130
131 static int skcipher_alloc_sgl(struct sock *sk)
132 {
133         struct alg_sock *ask = alg_sk(sk);
134         struct skcipher_ctx *ctx = ask->private;
135         struct skcipher_sg_list *sgl;
136         struct scatterlist *sg = NULL;
137
138         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
139         if (!list_empty(&ctx->tsgl))
140                 sg = sgl->sg;
141
142         if (!sg || sgl->cur >= MAX_SGL_ENTS) {
143                 sgl = sock_kmalloc(sk, sizeof(*sgl) +
144                                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
145                                    GFP_KERNEL);
146                 if (!sgl)
147                         return -ENOMEM;
148
149                 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
150                 sgl->cur = 0;
151
152                 if (sg)
153                         sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
154
155                 list_add_tail(&sgl->list, &ctx->tsgl);
156         }
157
158         return 0;
159 }
160
161 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
162 {
163         struct alg_sock *ask = alg_sk(sk);
164         struct skcipher_ctx *ctx = ask->private;
165         struct skcipher_sg_list *sgl;
166         struct scatterlist *sg;
167         int i;
168
169         while (!list_empty(&ctx->tsgl)) {
170                 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
171                                        list);
172                 sg = sgl->sg;
173
174                 for (i = 0; i < sgl->cur; i++) {
175                         size_t plen = min_t(size_t, used, sg[i].length);
176
177                         if (!sg_page(sg + i))
178                                 continue;
179
180                         sg[i].length -= plen;
181                         sg[i].offset += plen;
182
183                         used -= plen;
184                         ctx->used -= plen;
185
186                         if (sg[i].length)
187                                 return;
188                         if (put)
189                                 put_page(sg_page(sg + i));
190                         sg_assign_page(sg + i, NULL);
191                 }
192
193                 list_del(&sgl->list);
194                 sock_kfree_s(sk, sgl,
195                              sizeof(*sgl) + sizeof(sgl->sg[0]) *
196                                             (MAX_SGL_ENTS + 1));
197         }
198
199         if (!ctx->used)
200                 ctx->merge = 0;
201 }
202
203 static void skcipher_free_sgl(struct sock *sk)
204 {
205         struct alg_sock *ask = alg_sk(sk);
206         struct skcipher_ctx *ctx = ask->private;
207
208         skcipher_pull_sgl(sk, ctx->used, 1);
209 }
210
211 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
212 {
213         long timeout;
214         DEFINE_WAIT(wait);
215         int err = -ERESTARTSYS;
216
217         if (flags & MSG_DONTWAIT)
218                 return -EAGAIN;
219
220         sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
221
222         for (;;) {
223                 if (signal_pending(current))
224                         break;
225                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
226                 timeout = MAX_SCHEDULE_TIMEOUT;
227                 if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
228                         err = 0;
229                         break;
230                 }
231         }
232         finish_wait(sk_sleep(sk), &wait);
233
234         return err;
235 }
236
237 static void skcipher_wmem_wakeup(struct sock *sk)
238 {
239         struct socket_wq *wq;
240
241         if (!skcipher_writable(sk))
242                 return;
243
244         rcu_read_lock();
245         wq = rcu_dereference(sk->sk_wq);
246         if (wq_has_sleeper(wq))
247                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
248                                                            POLLRDNORM |
249                                                            POLLRDBAND);
250         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
251         rcu_read_unlock();
252 }
253
254 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
255 {
256         struct alg_sock *ask = alg_sk(sk);
257         struct skcipher_ctx *ctx = ask->private;
258         long timeout;
259         DEFINE_WAIT(wait);
260         int err = -ERESTARTSYS;
261
262         if (flags & MSG_DONTWAIT) {
263                 return -EAGAIN;
264         }
265
266         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
267
268         for (;;) {
269                 if (signal_pending(current))
270                         break;
271                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
272                 timeout = MAX_SCHEDULE_TIMEOUT;
273                 if (sk_wait_event(sk, &timeout, ctx->used)) {
274                         err = 0;
275                         break;
276                 }
277         }
278         finish_wait(sk_sleep(sk), &wait);
279
280         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
281
282         return err;
283 }
284
285 static void skcipher_data_wakeup(struct sock *sk)
286 {
287         struct alg_sock *ask = alg_sk(sk);
288         struct skcipher_ctx *ctx = ask->private;
289         struct socket_wq *wq;
290
291         if (!ctx->used)
292                 return;
293
294         rcu_read_lock();
295         wq = rcu_dereference(sk->sk_wq);
296         if (wq_has_sleeper(wq))
297                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
298                                                            POLLRDNORM |
299                                                            POLLRDBAND);
300         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
301         rcu_read_unlock();
302 }
303
304 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
305                             size_t size)
306 {
307         struct sock *sk = sock->sk;
308         struct alg_sock *ask = alg_sk(sk);
309         struct skcipher_ctx *ctx = ask->private;
310         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
311         unsigned ivsize = crypto_skcipher_ivsize(tfm);
312         struct skcipher_sg_list *sgl;
313         struct af_alg_control con = {};
314         long copied = 0;
315         bool enc = 0;
316         bool init = 0;
317         int err;
318         int i;
319
320         if (msg->msg_controllen) {
321                 err = af_alg_cmsg_send(msg, &con);
322                 if (err)
323                         return err;
324
325                 init = 1;
326                 switch (con.op) {
327                 case ALG_OP_ENCRYPT:
328                         enc = 1;
329                         break;
330                 case ALG_OP_DECRYPT:
331                         enc = 0;
332                         break;
333                 default:
334                         return -EINVAL;
335                 }
336
337                 if (con.iv && con.iv->ivlen != ivsize)
338                         return -EINVAL;
339         }
340
341         err = -EINVAL;
342
343         lock_sock(sk);
344         if (!ctx->more && ctx->used)
345                 goto unlock;
346
347         if (init) {
348                 ctx->enc = enc;
349                 if (con.iv)
350                         memcpy(ctx->iv, con.iv->iv, ivsize);
351         }
352
353         while (size) {
354                 struct scatterlist *sg;
355                 unsigned long len = size;
356                 size_t plen;
357
358                 if (ctx->merge) {
359                         sgl = list_entry(ctx->tsgl.prev,
360                                          struct skcipher_sg_list, list);
361                         sg = sgl->sg + sgl->cur - 1;
362                         len = min_t(unsigned long, len,
363                                     PAGE_SIZE - sg->offset - sg->length);
364
365                         err = memcpy_from_msg(page_address(sg_page(sg)) +
366                                               sg->offset + sg->length,
367                                               msg, len);
368                         if (err)
369                                 goto unlock;
370
371                         sg->length += len;
372                         ctx->merge = (sg->offset + sg->length) &
373                                      (PAGE_SIZE - 1);
374
375                         ctx->used += len;
376                         copied += len;
377                         size -= len;
378                         continue;
379                 }
380
381                 if (!skcipher_writable(sk)) {
382                         err = skcipher_wait_for_wmem(sk, msg->msg_flags);
383                         if (err)
384                                 goto unlock;
385                 }
386
387                 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
388
389                 err = skcipher_alloc_sgl(sk);
390                 if (err)
391                         goto unlock;
392
393                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
394                 sg = sgl->sg;
395                 sg_unmark_end(sg + sgl->cur);
396                 do {
397                         i = sgl->cur;
398                         plen = min_t(size_t, len, PAGE_SIZE);
399
400                         sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
401                         err = -ENOMEM;
402                         if (!sg_page(sg + i))
403                                 goto unlock;
404
405                         err = memcpy_from_msg(page_address(sg_page(sg + i)),
406                                               msg, plen);
407                         if (err) {
408                                 __free_page(sg_page(sg + i));
409                                 sg_assign_page(sg + i, NULL);
410                                 goto unlock;
411                         }
412
413                         sg[i].length = plen;
414                         len -= plen;
415                         ctx->used += plen;
416                         copied += plen;
417                         size -= plen;
418                         sgl->cur++;
419                 } while (len && sgl->cur < MAX_SGL_ENTS);
420
421                 if (!size)
422                         sg_mark_end(sg + sgl->cur - 1);
423
424                 ctx->merge = plen & (PAGE_SIZE - 1);
425         }
426
427         err = 0;
428
429         ctx->more = msg->msg_flags & MSG_MORE;
430
431 unlock:
432         skcipher_data_wakeup(sk);
433         release_sock(sk);
434
435         return copied ?: err;
436 }
437
438 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
439                                  int offset, size_t size, int flags)
440 {
441         struct sock *sk = sock->sk;
442         struct alg_sock *ask = alg_sk(sk);
443         struct skcipher_ctx *ctx = ask->private;
444         struct skcipher_sg_list *sgl;
445         int err = -EINVAL;
446
447         if (flags & MSG_SENDPAGE_NOTLAST)
448                 flags |= MSG_MORE;
449
450         lock_sock(sk);
451         if (!ctx->more && ctx->used)
452                 goto unlock;
453
454         if (!size)
455                 goto done;
456
457         if (!skcipher_writable(sk)) {
458                 err = skcipher_wait_for_wmem(sk, flags);
459                 if (err)
460                         goto unlock;
461         }
462
463         err = skcipher_alloc_sgl(sk);
464         if (err)
465                 goto unlock;
466
467         ctx->merge = 0;
468         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
469
470         if (sgl->cur)
471                 sg_unmark_end(sgl->sg + sgl->cur - 1);
472
473         sg_mark_end(sgl->sg + sgl->cur);
474         get_page(page);
475         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
476         sgl->cur++;
477         ctx->used += size;
478
479 done:
480         ctx->more = flags & MSG_MORE;
481
482 unlock:
483         skcipher_data_wakeup(sk);
484         release_sock(sk);
485
486         return err ?: size;
487 }
488
489 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
490 {
491         struct skcipher_sg_list *sgl;
492         struct scatterlist *sg;
493         int nents = 0;
494
495         list_for_each_entry(sgl, &ctx->tsgl, list) {
496                 sg = sgl->sg;
497
498                 while (!sg->length)
499                         sg++;
500
501                 nents += sg_nents(sg);
502         }
503         return nents;
504 }
505
506 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
507                                   int flags)
508 {
509         struct sock *sk = sock->sk;
510         struct alg_sock *ask = alg_sk(sk);
511         struct skcipher_ctx *ctx = ask->private;
512         struct skcipher_sg_list *sgl;
513         struct scatterlist *sg;
514         struct skcipher_async_req *sreq;
515         struct skcipher_request *req;
516         struct skcipher_async_rsgl *last_rsgl = NULL;
517         unsigned int txbufs = 0, len = 0, tx_nents = skcipher_all_sg_nents(ctx);
518         unsigned int reqlen = sizeof(struct skcipher_async_req) +
519                                 GET_REQ_SIZE(ctx) + GET_IV_SIZE(ctx);
520         int err = -ENOMEM;
521         bool mark = false;
522
523         lock_sock(sk);
524         req = kmalloc(reqlen, GFP_KERNEL);
525         if (unlikely(!req))
526                 goto unlock;
527
528         sreq = GET_SREQ(req, ctx);
529         sreq->iocb = msg->msg_iocb;
530         memset(&sreq->first_sgl, '\0', sizeof(struct skcipher_async_rsgl));
531         INIT_LIST_HEAD(&sreq->list);
532         sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
533         if (unlikely(!sreq->tsg)) {
534                 kfree(req);
535                 goto unlock;
536         }
537         sg_init_table(sreq->tsg, tx_nents);
538         memcpy(sreq->iv, ctx->iv, GET_IV_SIZE(ctx));
539         skcipher_request_set_tfm(req, crypto_skcipher_reqtfm(&ctx->req));
540         skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
541                                       skcipher_async_cb, sk);
542
543         while (iov_iter_count(&msg->msg_iter)) {
544                 struct skcipher_async_rsgl *rsgl;
545                 int used;
546
547                 if (!ctx->used) {
548                         err = skcipher_wait_for_data(sk, flags);
549                         if (err)
550                                 goto free;
551                 }
552                 sgl = list_first_entry(&ctx->tsgl,
553                                        struct skcipher_sg_list, list);
554                 sg = sgl->sg;
555
556                 while (!sg->length)
557                         sg++;
558
559                 used = min_t(unsigned long, ctx->used,
560                              iov_iter_count(&msg->msg_iter));
561                 used = min_t(unsigned long, used, sg->length);
562
563                 if (txbufs == tx_nents) {
564                         struct scatterlist *tmp;
565                         int x;
566                         /* Ran out of tx slots in async request
567                          * need to expand */
568                         tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
569                                       GFP_KERNEL);
570                         if (!tmp)
571                                 goto free;
572
573                         sg_init_table(tmp, tx_nents * 2);
574                         for (x = 0; x < tx_nents; x++)
575                                 sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
576                                             sreq->tsg[x].length,
577                                             sreq->tsg[x].offset);
578                         kfree(sreq->tsg);
579                         sreq->tsg = tmp;
580                         tx_nents *= 2;
581                         mark = true;
582                 }
583                 /* Need to take over the tx sgl from ctx
584                  * to the asynch req - these sgls will be freed later */
585                 sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
586                             sg->offset);
587
588                 if (list_empty(&sreq->list)) {
589                         rsgl = &sreq->first_sgl;
590                         list_add_tail(&rsgl->list, &sreq->list);
591                 } else {
592                         rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
593                         if (!rsgl) {
594                                 err = -ENOMEM;
595                                 goto free;
596                         }
597                         list_add_tail(&rsgl->list, &sreq->list);
598                 }
599
600                 used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
601                 err = used;
602                 if (used < 0)
603                         goto free;
604                 if (last_rsgl)
605                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
606
607                 last_rsgl = rsgl;
608                 len += used;
609                 skcipher_pull_sgl(sk, used, 0);
610                 iov_iter_advance(&msg->msg_iter, used);
611         }
612
613         if (mark)
614                 sg_mark_end(sreq->tsg + txbufs - 1);
615
616         skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
617                                    len, sreq->iv);
618         err = ctx->enc ? crypto_skcipher_encrypt(req) :
619                          crypto_skcipher_decrypt(req);
620         if (err == -EINPROGRESS) {
621                 atomic_inc(&ctx->inflight);
622                 err = -EIOCBQUEUED;
623                 goto unlock;
624         }
625 free:
626         skcipher_free_async_sgls(sreq);
627         kfree(req);
628 unlock:
629         skcipher_wmem_wakeup(sk);
630         release_sock(sk);
631         return err;
632 }
633
634 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
635                                  int flags)
636 {
637         struct sock *sk = sock->sk;
638         struct alg_sock *ask = alg_sk(sk);
639         struct skcipher_ctx *ctx = ask->private;
640         unsigned bs = crypto_skcipher_blocksize(crypto_skcipher_reqtfm(
641                 &ctx->req));
642         struct skcipher_sg_list *sgl;
643         struct scatterlist *sg;
644         int err = -EAGAIN;
645         int used;
646         long copied = 0;
647
648         lock_sock(sk);
649         while (msg_data_left(msg)) {
650                 sgl = list_first_entry(&ctx->tsgl,
651                                        struct skcipher_sg_list, list);
652                 sg = sgl->sg;
653
654                 while (!sg->length)
655                         sg++;
656
657                 if (!ctx->used) {
658                         err = skcipher_wait_for_data(sk, flags);
659                         if (err)
660                                 goto unlock;
661                 }
662
663                 used = min_t(unsigned long, ctx->used, msg_data_left(msg));
664
665                 used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
666                 err = used;
667                 if (err < 0)
668                         goto unlock;
669
670                 if (ctx->more || used < ctx->used)
671                         used -= used % bs;
672
673                 err = -EINVAL;
674                 if (!used)
675                         goto free;
676
677                 skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
678                                            ctx->iv);
679
680                 err = af_alg_wait_for_completion(
681                                 ctx->enc ?
682                                         crypto_skcipher_encrypt(&ctx->req) :
683                                         crypto_skcipher_decrypt(&ctx->req),
684                                 &ctx->completion);
685
686 free:
687                 af_alg_free_sg(&ctx->rsgl);
688
689                 if (err)
690                         goto unlock;
691
692                 copied += used;
693                 skcipher_pull_sgl(sk, used, 1);
694                 iov_iter_advance(&msg->msg_iter, used);
695         }
696
697         err = 0;
698
699 unlock:
700         skcipher_wmem_wakeup(sk);
701         release_sock(sk);
702
703         return copied ?: err;
704 }
705
706 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
707                             size_t ignored, int flags)
708 {
709         return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
710                 skcipher_recvmsg_async(sock, msg, flags) :
711                 skcipher_recvmsg_sync(sock, msg, flags);
712 }
713
714 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
715                                   poll_table *wait)
716 {
717         struct sock *sk = sock->sk;
718         struct alg_sock *ask = alg_sk(sk);
719         struct skcipher_ctx *ctx = ask->private;
720         unsigned int mask;
721
722         sock_poll_wait(file, sk_sleep(sk), wait);
723         mask = 0;
724
725         if (ctx->used)
726                 mask |= POLLIN | POLLRDNORM;
727
728         if (skcipher_writable(sk))
729                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
730
731         return mask;
732 }
733
734 static struct proto_ops algif_skcipher_ops = {
735         .family         =       PF_ALG,
736
737         .connect        =       sock_no_connect,
738         .socketpair     =       sock_no_socketpair,
739         .getname        =       sock_no_getname,
740         .ioctl          =       sock_no_ioctl,
741         .listen         =       sock_no_listen,
742         .shutdown       =       sock_no_shutdown,
743         .getsockopt     =       sock_no_getsockopt,
744         .mmap           =       sock_no_mmap,
745         .bind           =       sock_no_bind,
746         .accept         =       sock_no_accept,
747         .setsockopt     =       sock_no_setsockopt,
748
749         .release        =       af_alg_release,
750         .sendmsg        =       skcipher_sendmsg,
751         .sendpage       =       skcipher_sendpage,
752         .recvmsg        =       skcipher_recvmsg,
753         .poll           =       skcipher_poll,
754 };
755
756 static int skcipher_check_key(struct socket *sock)
757 {
758         int err = 0;
759         struct sock *psk;
760         struct alg_sock *pask;
761         struct skcipher_tfm *tfm;
762         struct sock *sk = sock->sk;
763         struct alg_sock *ask = alg_sk(sk);
764
765         lock_sock(sk);
766         if (ask->refcnt)
767                 goto unlock_child;
768
769         psk = ask->parent;
770         pask = alg_sk(ask->parent);
771         tfm = pask->private;
772
773         err = -ENOKEY;
774         lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
775         if (!tfm->has_key)
776                 goto unlock;
777
778         if (!pask->refcnt++)
779                 sock_hold(psk);
780
781         ask->refcnt = 1;
782         sock_put(psk);
783
784         err = 0;
785
786 unlock:
787         release_sock(psk);
788 unlock_child:
789         release_sock(sk);
790
791         return err;
792 }
793
794 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
795                                   size_t size)
796 {
797         int err;
798
799         err = skcipher_check_key(sock);
800         if (err)
801                 return err;
802
803         return skcipher_sendmsg(sock, msg, size);
804 }
805
806 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
807                                        int offset, size_t size, int flags)
808 {
809         int err;
810
811         err = skcipher_check_key(sock);
812         if (err)
813                 return err;
814
815         return skcipher_sendpage(sock, page, offset, size, flags);
816 }
817
818 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
819                                   size_t ignored, int flags)
820 {
821         int err;
822
823         err = skcipher_check_key(sock);
824         if (err)
825                 return err;
826
827         return skcipher_recvmsg(sock, msg, ignored, flags);
828 }
829
830 static struct proto_ops algif_skcipher_ops_nokey = {
831         .family         =       PF_ALG,
832
833         .connect        =       sock_no_connect,
834         .socketpair     =       sock_no_socketpair,
835         .getname        =       sock_no_getname,
836         .ioctl          =       sock_no_ioctl,
837         .listen         =       sock_no_listen,
838         .shutdown       =       sock_no_shutdown,
839         .getsockopt     =       sock_no_getsockopt,
840         .mmap           =       sock_no_mmap,
841         .bind           =       sock_no_bind,
842         .accept         =       sock_no_accept,
843         .setsockopt     =       sock_no_setsockopt,
844
845         .release        =       af_alg_release,
846         .sendmsg        =       skcipher_sendmsg_nokey,
847         .sendpage       =       skcipher_sendpage_nokey,
848         .recvmsg        =       skcipher_recvmsg_nokey,
849         .poll           =       skcipher_poll,
850 };
851
852 static void *skcipher_bind(const char *name, u32 type, u32 mask)
853 {
854         struct skcipher_tfm *tfm;
855         struct crypto_skcipher *skcipher;
856
857         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
858         if (!tfm)
859                 return ERR_PTR(-ENOMEM);
860
861         skcipher = crypto_alloc_skcipher(name, type, mask);
862         if (IS_ERR(skcipher)) {
863                 kfree(tfm);
864                 return ERR_CAST(skcipher);
865         }
866
867         tfm->skcipher = skcipher;
868
869         return tfm;
870 }
871
872 static void skcipher_release(void *private)
873 {
874         struct skcipher_tfm *tfm = private;
875
876         crypto_free_skcipher(tfm->skcipher);
877         kfree(tfm);
878 }
879
880 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
881 {
882         struct skcipher_tfm *tfm = private;
883         int err;
884
885         err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
886         tfm->has_key = !err;
887
888         return err;
889 }
890
891 static void skcipher_wait(struct sock *sk)
892 {
893         struct alg_sock *ask = alg_sk(sk);
894         struct skcipher_ctx *ctx = ask->private;
895         int ctr = 0;
896
897         while (atomic_read(&ctx->inflight) && ctr++ < 100)
898                 msleep(100);
899 }
900
901 static void skcipher_sock_destruct(struct sock *sk)
902 {
903         struct alg_sock *ask = alg_sk(sk);
904         struct skcipher_ctx *ctx = ask->private;
905         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
906
907         if (atomic_read(&ctx->inflight))
908                 skcipher_wait(sk);
909
910         skcipher_free_sgl(sk);
911         sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
912         sock_kfree_s(sk, ctx, ctx->len);
913         af_alg_release_parent(sk);
914 }
915
916 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
917 {
918         struct skcipher_ctx *ctx;
919         struct alg_sock *ask = alg_sk(sk);
920         struct skcipher_tfm *tfm = private;
921         struct crypto_skcipher *skcipher = tfm->skcipher;
922         unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
923
924         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
925         if (!ctx)
926                 return -ENOMEM;
927
928         ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
929                                GFP_KERNEL);
930         if (!ctx->iv) {
931                 sock_kfree_s(sk, ctx, len);
932                 return -ENOMEM;
933         }
934
935         memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
936
937         INIT_LIST_HEAD(&ctx->tsgl);
938         ctx->len = len;
939         ctx->used = 0;
940         ctx->more = 0;
941         ctx->merge = 0;
942         ctx->enc = 0;
943         atomic_set(&ctx->inflight, 0);
944         af_alg_init_completion(&ctx->completion);
945
946         ask->private = ctx;
947
948         skcipher_request_set_tfm(&ctx->req, skcipher);
949         skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
950                                       af_alg_complete, &ctx->completion);
951
952         sk->sk_destruct = skcipher_sock_destruct;
953
954         return 0;
955 }
956
957 static int skcipher_accept_parent(void *private, struct sock *sk)
958 {
959         struct skcipher_tfm *tfm = private;
960
961         if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
962                 return -ENOKEY;
963
964         return skcipher_accept_parent_nokey(private, sk);
965 }
966
967 static const struct af_alg_type algif_type_skcipher = {
968         .bind           =       skcipher_bind,
969         .release        =       skcipher_release,
970         .setkey         =       skcipher_setkey,
971         .accept         =       skcipher_accept_parent,
972         .accept_nokey   =       skcipher_accept_parent_nokey,
973         .ops            =       &algif_skcipher_ops,
974         .ops_nokey      =       &algif_skcipher_ops_nokey,
975         .name           =       "skcipher",
976         .owner          =       THIS_MODULE
977 };
978
979 static int __init algif_skcipher_init(void)
980 {
981         return af_alg_register_type(&algif_type_skcipher);
982 }
983
984 static void __exit algif_skcipher_exit(void)
985 {
986         int err = af_alg_unregister_type(&algif_type_skcipher);
987         BUG_ON(err);
988 }
989
990 module_init(algif_skcipher_init);
991 module_exit(algif_skcipher_exit);
992 MODULE_LICENSE("GPL");