From 73c618456dc5cf2acb597256d633060cf75de8d6 Mon Sep 17 00:00:00 2001 From: Jeremy Kerr Date: Wed, 29 Sep 2021 15:26:07 +0800 Subject: mctp: locking, lifetime and validity changes for sk_keys We will want to invalidate sk_keys in a future change, which will require a boolean flag to mark invalidated items in the socket & net namespace lists. We'll also need to take a reference to keys, held over non-atomic contexts, so we need a refcount on keys also. This change adds a validity flag (currently always true) and refcount to struct mctp_sk_key. With a refcount on the keys, using RCU no longer makes much sense; we have exact indications on the lifetime of keys. So, we also change the RCU list traversal to a locked implementation. Signed-off-by: Jeremy Kerr Signed-off-by: David S. Miller --- net/mctp/af_mctp.c | 14 +++---- net/mctp/route.c | 118 +++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 94 insertions(+), 38 deletions(-) (limited to 'net') diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index a9526ac29dff..2767d548736b 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -263,21 +263,21 @@ static void mctp_sk_unhash(struct sock *sk) /* remove tag allocations */ spin_lock_irqsave(&net->mctp.keys_lock, flags); hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { - hlist_del_rcu(&key->sklist); - hlist_del_rcu(&key->hlist); + hlist_del(&key->sklist); + hlist_del(&key->hlist); - spin_lock(&key->reasm_lock); + spin_lock(&key->lock); if (key->reasm_head) kfree_skb(key->reasm_head); key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock(&key->reasm_lock); + key->valid = false; + spin_unlock(&key->lock); - kfree_rcu(key, rcu); + /* key is no longer on the lookup lists, unref */ + mctp_key_unref(key); } spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - - synchronize_rcu(); } static struct proto mctp_proto = { diff --git a/net/mctp/route.c b/net/mctp/route.c index 224fd25b3678..b2243b150e71 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -83,25 +83,43 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local, return true; } +/* returns a key (with key->lock held, and refcounted), or NULL if no such + * key exists. + */ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, - mctp_eid_t peer) + mctp_eid_t peer, + unsigned long *irqflags) + __acquires(&key->lock) { struct mctp_sk_key *key, *ret; + unsigned long flags; struct mctp_hdr *mh; u8 tag; - WARN_ON(!rcu_read_lock_held()); - mh = mctp_hdr(skb); tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); ret = NULL; + spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) { - if (mctp_key_match(key, mh->dest, peer, tag)) { + hlist_for_each_entry(key, &net->mctp.keys, hlist) { + if (!mctp_key_match(key, mh->dest, peer, tag)) + continue; + + spin_lock(&key->lock); + if (key->valid) { + refcount_inc(&key->refs); ret = key; break; } + spin_unlock(&key->lock); + } + + if (ret) { + spin_unlock(&net->mctp.keys_lock); + *irqflags = flags; + } else { + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); } return ret; @@ -121,11 +139,19 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, key->local_addr = local; key->tag = tag; key->sk = &msk->sk; - spin_lock_init(&key->reasm_lock); + key->valid = true; + spin_lock_init(&key->lock); + refcount_set(&key->refs, 1); return key; } +void mctp_key_unref(struct mctp_sk_key *key) +{ + if (refcount_dec_and_test(&key->refs)) + kfree(key); +} + static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) { struct net *net = sock_net(&msk->sk); @@ -138,12 +164,17 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { if (mctp_key_match(tmp, key->local_addr, key->peer_addr, key->tag)) { - rc = -EEXIST; - break; + spin_lock(&tmp->lock); + if (tmp->valid) + rc = -EEXIST; + spin_unlock(&tmp->lock); + if (rc) + break; } } if (!rc) { + refcount_inc(&key->refs); hlist_add_head(&key->hlist, &net->mctp.keys); hlist_add_head(&key->sklist, &msk->keys); } @@ -153,28 +184,35 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) return rc; } -/* Must be called with key->reasm_lock, which it will release. Will schedule - * the key for an RCU free. +/* We're done with the key; unset valid and remove from lists. There may still + * be outstanding refs on the key though... */ static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net, unsigned long flags) - __releases(&key->reasm_lock) + __releases(&key->lock) { struct sk_buff *skb; skb = key->reasm_head; key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock_irqrestore(&key->reasm_lock, flags); + key->valid = false; + spin_unlock_irqrestore(&key->lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_del_rcu(&key->hlist); - hlist_del_rcu(&key->sklist); + hlist_del(&key->hlist); + hlist_del(&key->sklist); spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - kfree_rcu(key, rcu); + + /* one unref for the lists */ + mctp_key_unref(key); + + /* and one for the local reference */ + mctp_key_unref(key); if (skb) kfree_skb(skb); + } static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) @@ -248,8 +286,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) rcu_read_lock(); - /* lookup socket / reasm context, exactly matching (src,dest,tag) */ - key = mctp_lookup_key(net, skb, mh->src); + /* lookup socket / reasm context, exactly matching (src,dest,tag). + * we hold a ref on the key, and key->lock held. + */ + key = mctp_lookup_key(net, skb, mh->src, &f); if (flags & MCTP_HDR_FLAG_SOM) { if (key) { @@ -260,10 +300,12 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * key for reassembly - we'll create a more specific * one for future packets if required (ie, !EOM). */ - key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY); + key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); if (key) { msk = container_of(key->sk, struct mctp_sock, sk); + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); key = NULL; } } @@ -282,11 +324,11 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (flags & MCTP_HDR_FLAG_EOM) { sock_queue_rcv_skb(&msk->sk, skb); if (key) { - spin_lock_irqsave(&key->reasm_lock, f); /* we've hit a pending reassembly; not much we * can do but drop it */ __mctp_key_unlock_drop(key, net, f); + key = NULL; } rc = 0; goto out_unlock; @@ -303,7 +345,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) goto out_unlock; } - /* we can queue without the reasm lock here, as the + /* we can queue without the key lock here, as the * key isn't observable yet */ mctp_frag_queue(key, skb); @@ -318,17 +360,17 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (rc) kfree(key); - } else { - /* existing key: start reassembly */ - spin_lock_irqsave(&key->reasm_lock, f); + /* we don't need to release key->lock on exit */ + key = NULL; + } else { if (key->reasm_head || key->reasm_dead) { /* duplicate start? drop everything */ __mctp_key_unlock_drop(key, net, f); rc = -EEXIST; + key = NULL; } else { rc = mctp_frag_queue(key, skb); - spin_unlock_irqrestore(&key->reasm_lock, f); } } @@ -337,8 +379,6 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * using the message-specific key */ - spin_lock_irqsave(&key->reasm_lock, f); - /* we need to be continuing an existing reassembly... */ if (!key->reasm_head) rc = -EINVAL; @@ -352,8 +392,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) sock_queue_rcv_skb(key->sk, key->reasm_head); key->reasm_head = NULL; __mctp_key_unlock_drop(key, net, f); - } else { - spin_unlock_irqrestore(&key->reasm_lock, f); + key = NULL; } } else { @@ -363,6 +402,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) out_unlock: rcu_read_unlock(); + if (key) { + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); + } out: if (rc) kfree_skb(skb); @@ -459,6 +502,7 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, */ hlist_add_head_rcu(&key->hlist, &mns->keys); hlist_add_head_rcu(&key->sklist, &msk->keys); + refcount_inc(&key->refs); } /* Allocate a locally-owned tag value for (saddr, daddr), and reserve @@ -492,14 +536,26 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, * tags. If we find a conflict, clear that bit from tagbits */ hlist_for_each_entry(tmp, &mns->keys, hlist) { + /* We can check the lookup fields (*_addr, tag) without the + * lock held, they don't change over the lifetime of the key. + */ + /* if we don't own the tag, it can't conflict */ if (tmp->tag & MCTP_HDR_FLAG_TO) continue; - if ((tmp->peer_addr == daddr || - tmp->peer_addr == MCTP_ADDR_ANY) && - tmp->local_addr == saddr) + if (!((tmp->peer_addr == daddr || + tmp->peer_addr == MCTP_ADDR_ANY) && + tmp->local_addr == saddr)) + continue; + + spin_lock(&tmp->lock); + /* key must still be valid. If we find a match, clear the + * potential tag value + */ + if (tmp->valid) tagbits &= ~(1 << tmp->tag); + spin_unlock(&tmp->lock); if (!tagbits) break; -- cgit v1.2.3