crypto: algif_skcipher - sendmsg SG marking is off by one
[linux-drm-fsl-dcu.git] / crypto / algif_skcipher.c
index ca9efe17db1ac4e9e2806528ea28d2d87a954b8f..a81c10faf9c4ea75e5d366e0bdf878b221cdd0a0 100644 (file)
@@ -31,6 +31,11 @@ struct skcipher_sg_list {
        struct scatterlist sg[0];
 };
 
+struct skcipher_tfm {
+       struct crypto_skcipher *skcipher;
+       bool has_key;
+};
+
 struct skcipher_ctx {
        struct list_head tsgl;
        struct af_alg_sgl rsgl;
@@ -40,14 +45,14 @@ struct skcipher_ctx {
        struct af_alg_completion completion;
 
        atomic_t inflight;
-       unsigned used;
+       size_t used;
 
        unsigned int len;
        bool more;
        bool merge;
        bool enc;
 
-       struct ablkcipher_request req;
+       struct skcipher_request req;
 };
 
 struct skcipher_async_rsgl {
@@ -64,13 +69,13 @@ struct skcipher_async_req {
 };
 
 #define GET_SREQ(areq, ctx) (struct skcipher_async_req *)((char *)areq + \
-       crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req)))
+       crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req)))
 
 #define GET_REQ_SIZE(ctx) \
-       crypto_ablkcipher_reqsize(crypto_ablkcipher_reqtfm(&ctx->req))
+       crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req))
 
 #define GET_IV_SIZE(ctx) \
-       crypto_ablkcipher_ivsize(crypto_ablkcipher_reqtfm(&ctx->req))
+       crypto_skcipher_ivsize(crypto_skcipher_reqtfm(&ctx->req))
 
 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
                      sizeof(struct scatterlist) - 1)
@@ -153,7 +158,7 @@ static int skcipher_alloc_sgl(struct sock *sk)
        return 0;
 }
 
-static void skcipher_pull_sgl(struct sock *sk, int used, int put)
+static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
 {
        struct alg_sock *ask = alg_sk(sk);
        struct skcipher_ctx *ctx = ask->private;
@@ -167,7 +172,7 @@ static void skcipher_pull_sgl(struct sock *sk, int used, int put)
                sg = sgl->sg;
 
                for (i = 0; i < sgl->cur; i++) {
-                       int plen = min_t(int, used, sg[i].length);
+                       size_t plen = min_t(size_t, used, sg[i].length);
 
                        if (!sg_page(sg + i))
                                continue;
@@ -302,8 +307,8 @@ static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
        struct sock *sk = sock->sk;
        struct alg_sock *ask = alg_sk(sk);
        struct skcipher_ctx *ctx = ask->private;
-       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
-       unsigned ivsize = crypto_ablkcipher_ivsize(tfm);
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
+       unsigned ivsize = crypto_skcipher_ivsize(tfm);
        struct skcipher_sg_list *sgl;
        struct af_alg_control con = {};
        long copied = 0;
@@ -348,7 +353,7 @@ static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
        while (size) {
                struct scatterlist *sg;
                unsigned long len = size;
-               int plen;
+               size_t plen;
 
                if (ctx->merge) {
                        sgl = list_entry(ctx->tsgl.prev,
@@ -387,10 +392,11 @@ static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
 
                sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
                sg = sgl->sg;
-               sg_unmark_end(sg + sgl->cur);
+               if (sgl->cur)
+                       sg_unmark_end(sg + sgl->cur - 1);
                do {
                        i = sgl->cur;
-                       plen = min_t(int, len, PAGE_SIZE);
+                       plen = min_t(size_t, len, PAGE_SIZE);
 
                        sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
                        err = -ENOMEM;
@@ -507,7 +513,7 @@ static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
        struct skcipher_sg_list *sgl;
        struct scatterlist *sg;
        struct skcipher_async_req *sreq;
-       struct ablkcipher_request *req;
+       struct skcipher_request *req;
        struct skcipher_async_rsgl *last_rsgl = NULL;
        unsigned int txbufs = 0, len = 0, tx_nents = skcipher_all_sg_nents(ctx);
        unsigned int reqlen = sizeof(struct skcipher_async_req) +
@@ -531,9 +537,9 @@ static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
        }
        sg_init_table(sreq->tsg, tx_nents);
        memcpy(sreq->iv, ctx->iv, GET_IV_SIZE(ctx));
-       ablkcipher_request_set_tfm(req, crypto_ablkcipher_reqtfm(&ctx->req));
-       ablkcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-                                       skcipher_async_cb, sk);
+       skcipher_request_set_tfm(req, crypto_skcipher_reqtfm(&ctx->req));
+       skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+                                     skcipher_async_cb, sk);
 
        while (iov_iter_count(&msg->msg_iter)) {
                struct skcipher_async_rsgl *rsgl;
@@ -608,10 +614,10 @@ static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
        if (mark)
                sg_mark_end(sreq->tsg + txbufs - 1);
 
-       ablkcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
-                                    len, sreq->iv);
-       err = ctx->enc ? crypto_ablkcipher_encrypt(req) :
-                        crypto_ablkcipher_decrypt(req);
+       skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
+                                  len, sreq->iv);
+       err = ctx->enc ? crypto_skcipher_encrypt(req) :
+                        crypto_skcipher_decrypt(req);
        if (err == -EINPROGRESS) {
                atomic_inc(&ctx->inflight);
                err = -EIOCBQUEUED;
@@ -632,7 +638,7 @@ static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
        struct sock *sk = sock->sk;
        struct alg_sock *ask = alg_sk(sk);
        struct skcipher_ctx *ctx = ask->private;
-       unsigned bs = crypto_ablkcipher_blocksize(crypto_ablkcipher_reqtfm(
+       unsigned bs = crypto_skcipher_blocksize(crypto_skcipher_reqtfm(
                &ctx->req));
        struct skcipher_sg_list *sgl;
        struct scatterlist *sg;
@@ -642,13 +648,6 @@ static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
 
        lock_sock(sk);
        while (msg_data_left(msg)) {
-               sgl = list_first_entry(&ctx->tsgl,
-                                      struct skcipher_sg_list, list);
-               sg = sgl->sg;
-
-               while (!sg->length)
-                       sg++;
-
                if (!ctx->used) {
                        err = skcipher_wait_for_data(sk, flags);
                        if (err)
@@ -669,14 +668,20 @@ static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
                if (!used)
                        goto free;
 
-               ablkcipher_request_set_crypt(&ctx->req, sg,
-                                            ctx->rsgl.sg, used,
-                                            ctx->iv);
+               sgl = list_first_entry(&ctx->tsgl,
+                                      struct skcipher_sg_list, list);
+               sg = sgl->sg;
+
+               while (!sg->length)
+                       sg++;
+
+               skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
+                                          ctx->iv);
 
                err = af_alg_wait_for_completion(
                                ctx->enc ?
-                                       crypto_ablkcipher_encrypt(&ctx->req) :
-                                       crypto_ablkcipher_decrypt(&ctx->req),
+                                       crypto_skcipher_encrypt(&ctx->req) :
+                                       crypto_skcipher_decrypt(&ctx->req),
                                &ctx->completion);
 
 free:
@@ -749,19 +754,139 @@ static struct proto_ops algif_skcipher_ops = {
        .poll           =       skcipher_poll,
 };
 
+static int skcipher_check_key(struct socket *sock)
+{
+       int err = 0;
+       struct sock *psk;
+       struct alg_sock *pask;
+       struct skcipher_tfm *tfm;
+       struct sock *sk = sock->sk;
+       struct alg_sock *ask = alg_sk(sk);
+
+       lock_sock(sk);
+       if (ask->refcnt)
+               goto unlock_child;
+
+       psk = ask->parent;
+       pask = alg_sk(ask->parent);
+       tfm = pask->private;
+
+       err = -ENOKEY;
+       lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
+       if (!tfm->has_key)
+               goto unlock;
+
+       if (!pask->refcnt++)
+               sock_hold(psk);
+
+       ask->refcnt = 1;
+       sock_put(psk);
+
+       err = 0;
+
+unlock:
+       release_sock(psk);
+unlock_child:
+       release_sock(sk);
+
+       return err;
+}
+
+static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
+                                 size_t size)
+{
+       int err;
+
+       err = skcipher_check_key(sock);
+       if (err)
+               return err;
+
+       return skcipher_sendmsg(sock, msg, size);
+}
+
+static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
+                                      int offset, size_t size, int flags)
+{
+       int err;
+
+       err = skcipher_check_key(sock);
+       if (err)
+               return err;
+
+       return skcipher_sendpage(sock, page, offset, size, flags);
+}
+
+static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
+                                 size_t ignored, int flags)
+{
+       int err;
+
+       err = skcipher_check_key(sock);
+       if (err)
+               return err;
+
+       return skcipher_recvmsg(sock, msg, ignored, flags);
+}
+
+static struct proto_ops algif_skcipher_ops_nokey = {
+       .family         =       PF_ALG,
+
+       .connect        =       sock_no_connect,
+       .socketpair     =       sock_no_socketpair,
+       .getname        =       sock_no_getname,
+       .ioctl          =       sock_no_ioctl,
+       .listen         =       sock_no_listen,
+       .shutdown       =       sock_no_shutdown,
+       .getsockopt     =       sock_no_getsockopt,
+       .mmap           =       sock_no_mmap,
+       .bind           =       sock_no_bind,
+       .accept         =       sock_no_accept,
+       .setsockopt     =       sock_no_setsockopt,
+
+       .release        =       af_alg_release,
+       .sendmsg        =       skcipher_sendmsg_nokey,
+       .sendpage       =       skcipher_sendpage_nokey,
+       .recvmsg        =       skcipher_recvmsg_nokey,
+       .poll           =       skcipher_poll,
+};
+
 static void *skcipher_bind(const char *name, u32 type, u32 mask)
 {
-       return crypto_alloc_ablkcipher(name, type, mask);
+       struct skcipher_tfm *tfm;
+       struct crypto_skcipher *skcipher;
+
+       tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
+       if (!tfm)
+               return ERR_PTR(-ENOMEM);
+
+       skcipher = crypto_alloc_skcipher(name, type, mask);
+       if (IS_ERR(skcipher)) {
+               kfree(tfm);
+               return ERR_CAST(skcipher);
+       }
+
+       tfm->skcipher = skcipher;
+
+       return tfm;
 }
 
 static void skcipher_release(void *private)
 {
-       crypto_free_ablkcipher(private);
+       struct skcipher_tfm *tfm = private;
+
+       crypto_free_skcipher(tfm->skcipher);
+       kfree(tfm);
 }
 
 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
 {
-       return crypto_ablkcipher_setkey(private, key, keylen);
+       struct skcipher_tfm *tfm = private;
+       int err;
+
+       err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
+       tfm->has_key = !err;
+
+       return err;
 }
 
 static void skcipher_wait(struct sock *sk)
@@ -778,35 +903,37 @@ static void skcipher_sock_destruct(struct sock *sk)
 {
        struct alg_sock *ask = alg_sk(sk);
        struct skcipher_ctx *ctx = ask->private;
-       struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
 
        if (atomic_read(&ctx->inflight))
                skcipher_wait(sk);
 
        skcipher_free_sgl(sk);
-       sock_kzfree_s(sk, ctx->iv, crypto_ablkcipher_ivsize(tfm));
+       sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
        sock_kfree_s(sk, ctx, ctx->len);
        af_alg_release_parent(sk);
 }
 
-static int skcipher_accept_parent(void *private, struct sock *sk)
+static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
 {
        struct skcipher_ctx *ctx;
        struct alg_sock *ask = alg_sk(sk);
-       unsigned int len = sizeof(*ctx) + crypto_ablkcipher_reqsize(private);
+       struct skcipher_tfm *tfm = private;
+       struct crypto_skcipher *skcipher = tfm->skcipher;
+       unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
 
        ctx = sock_kmalloc(sk, len, GFP_KERNEL);
        if (!ctx)
                return -ENOMEM;
 
-       ctx->iv = sock_kmalloc(sk, crypto_ablkcipher_ivsize(private),
+       ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
                               GFP_KERNEL);
        if (!ctx->iv) {
                sock_kfree_s(sk, ctx, len);
                return -ENOMEM;
        }
 
-       memset(ctx->iv, 0, crypto_ablkcipher_ivsize(private));
+       memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
 
        INIT_LIST_HEAD(&ctx->tsgl);
        ctx->len = len;
@@ -819,21 +946,33 @@ static int skcipher_accept_parent(void *private, struct sock *sk)
 
        ask->private = ctx;
 
-       ablkcipher_request_set_tfm(&ctx->req, private);
-       ablkcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-                                       af_alg_complete, &ctx->completion);
+       skcipher_request_set_tfm(&ctx->req, skcipher);
+       skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+                                     af_alg_complete, &ctx->completion);
 
        sk->sk_destruct = skcipher_sock_destruct;
 
        return 0;
 }
 
+static int skcipher_accept_parent(void *private, struct sock *sk)
+{
+       struct skcipher_tfm *tfm = private;
+
+       if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
+               return -ENOKEY;
+
+       return skcipher_accept_parent_nokey(private, sk);
+}
+
 static const struct af_alg_type algif_type_skcipher = {
        .bind           =       skcipher_bind,
        .release        =       skcipher_release,
        .setkey         =       skcipher_setkey,
        .accept         =       skcipher_accept_parent,
+       .accept_nokey   =       skcipher_accept_parent_nokey,
        .ops            =       &algif_skcipher_ops,
+       .ops_nokey      =       &algif_skcipher_ops_nokey,
        .name           =       "skcipher",
        .owner          =       THIS_MODULE
 };