crypto: af_alg - Disallow bind/setkey/... after accept(2)
[linux-drm-fsl-dcu.git] / crypto / af_alg.c
index a8e7aa3e257bbc3c6254b82d2966b9fde3031184..7b5b5926c767204cb1f851f3e4c069fb07867804 100644 (file)
@@ -125,6 +125,23 @@ int af_alg_release(struct socket *sock)
 }
 EXPORT_SYMBOL_GPL(af_alg_release);
 
+void af_alg_release_parent(struct sock *sk)
+{
+       struct alg_sock *ask = alg_sk(sk);
+       bool last;
+
+       sk = ask->parent;
+       ask = alg_sk(sk);
+
+       lock_sock(sk);
+       last = !--ask->refcnt;
+       release_sock(sk);
+
+       if (last)
+               sock_put(sk);
+}
+EXPORT_SYMBOL_GPL(af_alg_release_parent);
+
 static int alg_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
 {
        const u32 forbidden = CRYPTO_ALG_INTERNAL;
@@ -133,6 +150,7 @@ static int alg_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
        struct sockaddr_alg *sa = (void *)uaddr;
        const struct af_alg_type *type;
        void *private;
+       int err;
 
        if (sock->state == SS_CONNECTED)
                return -EINVAL;
@@ -160,16 +178,22 @@ static int alg_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
                return PTR_ERR(private);
        }
 
+       err = -EBUSY;
        lock_sock(sk);
+       if (ask->refcnt)
+               goto unlock;
 
        swap(ask->type, type);
        swap(ask->private, private);
 
+       err = 0;
+
+unlock:
        release_sock(sk);
 
        alg_do_release(type, private);
 
-       return 0;
+       return err;
 }
 
 static int alg_setkey(struct sock *sk, char __user *ukey,
@@ -202,11 +226,15 @@ static int alg_setsockopt(struct socket *sock, int level, int optname,
        struct sock *sk = sock->sk;
        struct alg_sock *ask = alg_sk(sk);
        const struct af_alg_type *type;
-       int err = -ENOPROTOOPT;
+       int err = -EBUSY;
 
        lock_sock(sk);
+       if (ask->refcnt)
+               goto unlock;
+
        type = ask->type;
 
+       err = -ENOPROTOOPT;
        if (level != SOL_ALG || !type)
                goto unlock;
 
@@ -264,7 +292,8 @@ int af_alg_accept(struct sock *sk, struct socket *newsock)
 
        sk2->sk_family = PF_ALG;
 
-       sock_hold(sk);
+       if (!ask->refcnt++)
+               sock_hold(sk);
        alg_sk(sk2)->parent = sk;
        alg_sk(sk2)->type = type;