netlink: Fix autobind race condition that leads to zero port ID
[linux-drm-fsl-dcu.git] / net / netlink / af_netlink.c
index 4cad99d6c68b8867d6662c1edcd9cb693040d5c3..9f51608b968afb4d6388e194ef6712346113aa49 100644 (file)
@@ -1031,7 +1031,7 @@ static inline int netlink_compare(struct rhashtable_compare_arg *arg,
        const struct netlink_compare_arg *x = arg->key;
        const struct netlink_sock *nlk = ptr;
 
-       return nlk->portid != x->portid ||
+       return nlk->rhash_portid != x->portid ||
               !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
 }
 
@@ -1057,7 +1057,7 @@ static int __netlink_insert(struct netlink_table *table, struct sock *sk)
 {
        struct netlink_compare_arg arg;
 
-       netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
+       netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->rhash_portid);
        return rhashtable_lookup_insert_key(&table->hash, &arg,
                                            &nlk_sk(sk)->node,
                                            netlink_rhashtable_params);
@@ -1119,7 +1119,7 @@ static int netlink_insert(struct sock *sk, u32 portid)
            unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
                goto err;
 
-       nlk_sk(sk)->portid = portid;
+       nlk_sk(sk)->rhash_portid = portid;
        sock_hold(sk);
 
        err = __netlink_insert(table, sk);
@@ -1131,10 +1131,12 @@ static int netlink_insert(struct sock *sk, u32 portid)
                        err = -EOVERFLOW;
                if (err == -EEXIST)
                        err = -EADDRINUSE;
-               nlk_sk(sk)->portid = 0;
                sock_put(sk);
+               goto err;
        }
 
+       nlk_sk(sk)->portid = portid;
+
 err:
        release_sock(sk);
        return err;
@@ -3271,7 +3273,7 @@ static inline u32 netlink_hash(const void *data, u32 len, u32 seed)
        const struct netlink_sock *nlk = data;
        struct netlink_compare_arg arg;
 
-       netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
+       netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->rhash_portid);
        return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed);
 }