[CRYPTO] ctr: Fix multi-page processing
[linux-drm-fsl-dcu.git] / crypto / ctr.c
index b816e959fa550a120e9f2c7add924788225ce5c8..57da7d0affcb3a6065f33c5f190402f4afac6df7 100644 (file)
@@ -59,6 +59,21 @@ static int crypto_ctr_setkey(struct crypto_tfm *parent, const u8 *key,
        return err;
 }
 
+static void crypto_ctr_crypt_final(struct blkcipher_walk *walk,
+                                  struct crypto_cipher *tfm, u8 *ctrblk,
+                                  unsigned int countersize)
+{
+       unsigned int bsize = crypto_cipher_blocksize(tfm);
+       u8 *keystream = ctrblk + bsize;
+       u8 *src = walk->src.virt.addr;
+       u8 *dst = walk->dst.virt.addr;
+       unsigned int nbytes = walk->nbytes;
+
+       crypto_cipher_encrypt_one(tfm, keystream, ctrblk);
+       crypto_xor(keystream, src, nbytes);
+       memcpy(dst, keystream, nbytes);
+}
+
 static int crypto_ctr_crypt_segment(struct blkcipher_walk *walk,
                                    struct crypto_cipher *tfm, u8 *ctrblk,
                                    unsigned int countersize)
@@ -66,35 +81,23 @@ static int crypto_ctr_crypt_segment(struct blkcipher_walk *walk,
        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
                   crypto_cipher_alg(tfm)->cia_encrypt;
        unsigned int bsize = crypto_cipher_blocksize(tfm);
-       unsigned long alignmask = crypto_cipher_alignmask(tfm) |
-                                 (__alignof__(u32) - 1);
-       u8 ks[bsize + alignmask];
-       u8 *keystream = (u8 *)ALIGN((unsigned long)ks, alignmask + 1);
        u8 *src = walk->src.virt.addr;
        u8 *dst = walk->dst.virt.addr;
        unsigned int nbytes = walk->nbytes;
 
        do {
                /* create keystream */
-               fn(crypto_cipher_tfm(tfm), keystream, ctrblk);
-               crypto_xor(keystream, src, min(nbytes, bsize));
-
-               /* copy result into dst */
-               memcpy(dst, keystream, min(nbytes, bsize));
+               fn(crypto_cipher_tfm(tfm), dst, ctrblk);
+               crypto_xor(dst, src, bsize);
 
                /* increment counter in counterblock */
                crypto_inc(ctrblk + bsize - countersize, countersize);
 
-               if (nbytes < bsize)
-                       break;
-
                src += bsize;
                dst += bsize;
-               nbytes -= bsize;
-
-       } while (nbytes);
+       } while ((nbytes -= bsize) >= bsize);
 
-       return 0;
+       return nbytes;
 }
 
 static int crypto_ctr_crypt_inplace(struct blkcipher_walk *walk,
@@ -104,30 +107,22 @@ static int crypto_ctr_crypt_inplace(struct blkcipher_walk *walk,
        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
                   crypto_cipher_alg(tfm)->cia_encrypt;
        unsigned int bsize = crypto_cipher_blocksize(tfm);
-       unsigned long alignmask = crypto_cipher_alignmask(tfm) |
-                                 (__alignof__(u32) - 1);
        unsigned int nbytes = walk->nbytes;
        u8 *src = walk->src.virt.addr;
-       u8 ks[bsize + alignmask];
-       u8 *keystream = (u8 *)ALIGN((unsigned long)ks, alignmask + 1);
+       u8 *keystream = ctrblk + bsize;
 
        do {
                /* create keystream */
                fn(crypto_cipher_tfm(tfm), keystream, ctrblk);
-               crypto_xor(src, keystream, min(nbytes, bsize));
+               crypto_xor(src, keystream, bsize);
 
                /* increment counter in counterblock */
                crypto_inc(ctrblk + bsize - countersize, countersize);
 
-               if (nbytes < bsize)
-                       break;
-
                src += bsize;
-               nbytes -= bsize;
+       } while ((nbytes -= bsize) >= bsize);
 
-       } while (nbytes);
-
-       return 0;
+       return nbytes;
 }
 
 static int crypto_ctr_crypt(struct blkcipher_desc *desc,
@@ -143,7 +138,7 @@ static int crypto_ctr_crypt(struct blkcipher_desc *desc,
                crypto_instance_ctx(crypto_tfm_alg_instance(&tfm->base));
        unsigned long alignmask = crypto_cipher_alignmask(child) |
                                  (__alignof__(u32) - 1);
-       u8 cblk[bsize + alignmask];
+       u8 cblk[bsize * 2 + alignmask];
        u8 *counterblk = (u8 *)ALIGN((unsigned long)cblk, alignmask + 1);
        int err;
 
@@ -158,7 +153,7 @@ static int crypto_ctr_crypt(struct blkcipher_desc *desc,
        /* initialize counter portion of counter block */
        crypto_inc(counterblk + bsize - ictx->countersize, ictx->countersize);
 
-       while (walk.nbytes) {
+       while (walk.nbytes >= bsize) {
                if (walk.src.virt.addr == walk.dst.virt.addr)
                        nbytes = crypto_ctr_crypt_inplace(&walk, child,
                                                          counterblk,
@@ -170,6 +165,13 @@ static int crypto_ctr_crypt(struct blkcipher_desc *desc,
 
                err = blkcipher_walk_done(desc, &walk, nbytes);
        }
+
+       if (walk.nbytes) {
+               crypto_ctr_crypt_final(&walk, child, counterblk,
+                                      ictx->countersize);
+               err = blkcipher_walk_done(desc, &walk, 0);
+       }
+
        return err;
 }
 
@@ -277,7 +279,7 @@ static struct crypto_instance *crypto_ctr_alloc(struct rtattr **tb)
        inst->alg.cra_flags = CRYPTO_ALG_TYPE_BLKCIPHER;
        inst->alg.cra_priority = alg->cra_priority;
        inst->alg.cra_blocksize = 1;
-       inst->alg.cra_alignmask = __alignof__(u32) - 1;
+       inst->alg.cra_alignmask = alg->cra_alignmask | (__alignof__(u32) - 1);
        inst->alg.cra_type = &crypto_blkcipher_type;
 
        inst->alg.cra_blkcipher.ivsize = ivsize;