diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 298e1df7132a..01b702d63457 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -98,7 +98,7 @@ static void netlink_skb_destructor(struct sk_buff *skb); /* nl_table locking explained: * Lookup and traversal are protected with an RCU read-side lock. Insertion - * and removal are protected with nl_sk_hash_lock while using RCU list + * and removal are protected with per bucket lock while using RCU list * modification primitives and may run in parallel to RCU protected lookups. * Destruction of the Netlink socket may only occur *after* nl_table_lock has * been acquired * either during or after the socket has been removed from @@ -110,10 +110,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); -/* Protects netlink socket hash table mutations */ -DEFINE_MUTEX(nl_sk_hash_lock); -EXPORT_SYMBOL_GPL(nl_sk_hash_lock); - static ATOMIC_NOTIFIER_HEAD(netlink_chain); static DEFINE_SPINLOCK(netlink_tap_lock); @@ -998,6 +994,19 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, &netlink_compare, &arg); } +static bool __netlink_insert(struct netlink_table *table, struct sock *sk, + struct net *net) +{ + struct netlink_compare_arg arg = { + .net = net, + .portid = nlk_sk(sk)->portid, + }; + + return rhashtable_lookup_compare_insert(&table->hash, + &nlk_sk(sk)->node, + &netlink_compare, &arg); +} + static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) { struct netlink_table *table = &nl_table[protocol]; @@ -1043,9 +1052,7 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) struct netlink_table *table = &nl_table[sk->sk_protocol]; int err = -EADDRINUSE; - mutex_lock(&nl_sk_hash_lock); - if (__netlink_lookup(table, portid, net)) - goto err; + lock_sock(sk); err = -EBUSY; if (nlk_sk(sk)->portid) @@ -1058,10 +1065,12 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) nlk_sk(sk)->portid = portid; sock_hold(sk); - rhashtable_insert(&table->hash, &nlk_sk(sk)->node); - err = 0; + if (__netlink_insert(table, sk, net)) + err = 0; + else + sock_put(sk); err: - mutex_unlock(&nl_sk_hash_lock); + release_sock(sk); return err; } @@ -1069,13 +1078,11 @@ static void netlink_remove(struct sock *sk) { struct netlink_table *table; - mutex_lock(&nl_sk_hash_lock); table = &nl_table[sk->sk_protocol]; if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } - mutex_unlock(&nl_sk_hash_lock); netlink_table_grab(); if (nlk_sk(sk)->subscriptions) { |