diff --git a/net/netfilter/nfnetlink_cttimeout.c b/net/netfilter/nfnetlink_cttimeout.c index 4cdcd969b64c..68216cdc7083 100644 --- a/net/netfilter/nfnetlink_cttimeout.c +++ b/net/netfilter/nfnetlink_cttimeout.c @@ -330,16 +330,16 @@ static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout) { int ret = 0; - /* we want to avoid races with nf_ct_timeout_find_get. */ - if (atomic_dec_and_test(&timeout->refcnt)) { + /* We want to avoid races with ctnl_timeout_put. So only when the + * current refcnt is 1, we decrease it to 0. + */ + if (atomic_cmpxchg(&timeout->refcnt, 1, 0) == 1) { /* We are protected by nfnl mutex. */ list_del_rcu(&timeout->head); nf_ct_l4proto_put(timeout->l4proto); ctnl_untimeout(net, timeout); kfree_rcu(timeout, rcu_head); } else { - /* still in use, restore reference counter. */ - atomic_inc(&timeout->refcnt); ret = -EBUSY; } return ret; @@ -543,7 +543,9 @@ ctnl_timeout_find_get(struct net *net, const char *name) static void ctnl_timeout_put(struct ctnl_timeout *timeout) { - atomic_dec(&timeout->refcnt); + if (atomic_dec_and_test(&timeout->refcnt)) + kfree_rcu(timeout, rcu_head); + module_put(THIS_MODULE); } #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ @@ -591,7 +593,9 @@ static void __net_exit cttimeout_net_exit(struct net *net) list_for_each_entry_safe(cur, tmp, &net->nfct_timeout_list, head) { list_del_rcu(&cur->head); nf_ct_l4proto_put(cur->l4proto); - kfree_rcu(cur, rcu_head); + + if (atomic_dec_and_test(&cur->refcnt)) + kfree_rcu(cur, rcu_head); } }