diff --git a/include/net/mctp.h b/include/net/mctp.h index a824d47c3c6d..bf783dc3ea45 100644 --- a/include/net/mctp.h +++ b/include/net/mctp.h @@ -67,30 +67,36 @@ struct mctp_sock { /* Key for matching incoming packets to sockets or reassembly contexts. * Packets are matched on (src,dest,tag). * - * Lifetime requirements: + * Lifetime / locking requirements: * - * - keys are free()ed via RCU + * - individual key data (ie, the struct itself) is protected by key->lock; + * changes must be made with that lock held. + * + * - the lookup fields: peer_addr, local_addr and tag are set before the + * key is added to lookup lists, and never updated. + * + * - A ref to the key must be held (throuh key->refs) if a pointer to the + * key is to be accessed after key->lock is released. * * - a mctp_sk_key contains a reference to a struct sock; this is valid * for the life of the key. On sock destruction (through unhash), the key is - * removed from lists (see below), and will not be observable after a RCU - * grace period. - * - * any RX occurring within that grace period may still queue to the socket, - * but will hit the SOCK_DEAD case before the socket is freed. + * removed from lists (see below), and marked invalid. * * - these mctp_sk_keys appear on two lists: * 1) the struct mctp_sock->keys list * 2) the struct netns_mctp->keys list * - * updates to either list are performed under the netns_mctp->keys - * lock. + * presences on these lists requires a (single) refcount to be held; both + * lists are updated as a single operation. + * + * Updates and lookups in either list are performed under the + * netns_mctp->keys lock. Lookup functions will need to lock the key and + * take a reference before unlocking the keys_lock. Consequently, the list's + * keys_lock *cannot* be acquired with the individual key->lock held. * * - a key may have a sk_buff attached as part of an in-progress message - * reassembly (->reasm_head). The reassembly context is protected by - * reasm_lock, which may be acquired with the keys lock (above) held, if - * necessary. Consequently, keys lock *cannot* be acquired with the - * reasm_lock held. + * reassembly (->reasm_head). The reasm data is protected by the individual + * key->lock. * * - there are two destruction paths for a mctp_sk_key: * @@ -116,14 +122,22 @@ struct mctp_sk_key { /* per-socket list */ struct hlist_node sklist; + /* lock protects against concurrent updates to the reassembly and + * expiry data below. + */ + spinlock_t lock; + + /* Keys are referenced during the output path, which may sleep */ + refcount_t refs; + /* incoming fragment reassembly context */ - spinlock_t reasm_lock; struct sk_buff *reasm_head; struct sk_buff **reasm_tailp; bool reasm_dead; u8 last_seq; - struct rcu_head rcu; + /* key validity */ + bool valid; }; struct mctp_skb_cb { @@ -191,6 +205,8 @@ int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb); int mctp_local_output(struct sock *sk, struct mctp_route *rt, struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag); +void mctp_key_unref(struct mctp_sk_key *key); + /* routing <--> device interface */ unsigned int mctp_default_net(struct net *net); int mctp_default_net_set(struct net *net, unsigned int index); diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index a9526ac29dff..2767d548736b 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -263,21 +263,21 @@ static void mctp_sk_unhash(struct sock *sk) /* remove tag allocations */ spin_lock_irqsave(&net->mctp.keys_lock, flags); hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { - hlist_del_rcu(&key->sklist); - hlist_del_rcu(&key->hlist); + hlist_del(&key->sklist); + hlist_del(&key->hlist); - spin_lock(&key->reasm_lock); + spin_lock(&key->lock); if (key->reasm_head) kfree_skb(key->reasm_head); key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock(&key->reasm_lock); + key->valid = false; + spin_unlock(&key->lock); - kfree_rcu(key, rcu); + /* key is no longer on the lookup lists, unref */ + mctp_key_unref(key); } spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - - synchronize_rcu(); } static struct proto mctp_proto = { diff --git a/net/mctp/route.c b/net/mctp/route.c index 224fd25b3678..b2243b150e71 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -83,25 +83,43 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local, return true; } +/* returns a key (with key->lock held, and refcounted), or NULL if no such + * key exists. + */ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, - mctp_eid_t peer) + mctp_eid_t peer, + unsigned long *irqflags) + __acquires(&key->lock) { struct mctp_sk_key *key, *ret; + unsigned long flags; struct mctp_hdr *mh; u8 tag; - WARN_ON(!rcu_read_lock_held()); - mh = mctp_hdr(skb); tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); ret = NULL; + spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) { - if (mctp_key_match(key, mh->dest, peer, tag)) { + hlist_for_each_entry(key, &net->mctp.keys, hlist) { + if (!mctp_key_match(key, mh->dest, peer, tag)) + continue; + + spin_lock(&key->lock); + if (key->valid) { + refcount_inc(&key->refs); ret = key; break; } + spin_unlock(&key->lock); + } + + if (ret) { + spin_unlock(&net->mctp.keys_lock); + *irqflags = flags; + } else { + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); } return ret; @@ -121,11 +139,19 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, key->local_addr = local; key->tag = tag; key->sk = &msk->sk; - spin_lock_init(&key->reasm_lock); + key->valid = true; + spin_lock_init(&key->lock); + refcount_set(&key->refs, 1); return key; } +void mctp_key_unref(struct mctp_sk_key *key) +{ + if (refcount_dec_and_test(&key->refs)) + kfree(key); +} + static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) { struct net *net = sock_net(&msk->sk); @@ -138,12 +164,17 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { if (mctp_key_match(tmp, key->local_addr, key->peer_addr, key->tag)) { - rc = -EEXIST; - break; + spin_lock(&tmp->lock); + if (tmp->valid) + rc = -EEXIST; + spin_unlock(&tmp->lock); + if (rc) + break; } } if (!rc) { + refcount_inc(&key->refs); hlist_add_head(&key->hlist, &net->mctp.keys); hlist_add_head(&key->sklist, &msk->keys); } @@ -153,28 +184,35 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) return rc; } -/* Must be called with key->reasm_lock, which it will release. Will schedule - * the key for an RCU free. +/* We're done with the key; unset valid and remove from lists. There may still + * be outstanding refs on the key though... */ static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net, unsigned long flags) - __releases(&key->reasm_lock) + __releases(&key->lock) { struct sk_buff *skb; skb = key->reasm_head; key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock_irqrestore(&key->reasm_lock, flags); + key->valid = false; + spin_unlock_irqrestore(&key->lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_del_rcu(&key->hlist); - hlist_del_rcu(&key->sklist); + hlist_del(&key->hlist); + hlist_del(&key->sklist); spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - kfree_rcu(key, rcu); + + /* one unref for the lists */ + mctp_key_unref(key); + + /* and one for the local reference */ + mctp_key_unref(key); if (skb) kfree_skb(skb); + } static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) @@ -248,8 +286,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) rcu_read_lock(); - /* lookup socket / reasm context, exactly matching (src,dest,tag) */ - key = mctp_lookup_key(net, skb, mh->src); + /* lookup socket / reasm context, exactly matching (src,dest,tag). + * we hold a ref on the key, and key->lock held. + */ + key = mctp_lookup_key(net, skb, mh->src, &f); if (flags & MCTP_HDR_FLAG_SOM) { if (key) { @@ -260,10 +300,12 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * key for reassembly - we'll create a more specific * one for future packets if required (ie, !EOM). */ - key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY); + key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); if (key) { msk = container_of(key->sk, struct mctp_sock, sk); + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); key = NULL; } } @@ -282,11 +324,11 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (flags & MCTP_HDR_FLAG_EOM) { sock_queue_rcv_skb(&msk->sk, skb); if (key) { - spin_lock_irqsave(&key->reasm_lock, f); /* we've hit a pending reassembly; not much we * can do but drop it */ __mctp_key_unlock_drop(key, net, f); + key = NULL; } rc = 0; goto out_unlock; @@ -303,7 +345,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) goto out_unlock; } - /* we can queue without the reasm lock here, as the + /* we can queue without the key lock here, as the * key isn't observable yet */ mctp_frag_queue(key, skb); @@ -318,17 +360,17 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (rc) kfree(key); - } else { - /* existing key: start reassembly */ - spin_lock_irqsave(&key->reasm_lock, f); + /* we don't need to release key->lock on exit */ + key = NULL; + } else { if (key->reasm_head || key->reasm_dead) { /* duplicate start? drop everything */ __mctp_key_unlock_drop(key, net, f); rc = -EEXIST; + key = NULL; } else { rc = mctp_frag_queue(key, skb); - spin_unlock_irqrestore(&key->reasm_lock, f); } } @@ -337,8 +379,6 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * using the message-specific key */ - spin_lock_irqsave(&key->reasm_lock, f); - /* we need to be continuing an existing reassembly... */ if (!key->reasm_head) rc = -EINVAL; @@ -352,8 +392,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) sock_queue_rcv_skb(key->sk, key->reasm_head); key->reasm_head = NULL; __mctp_key_unlock_drop(key, net, f); - } else { - spin_unlock_irqrestore(&key->reasm_lock, f); + key = NULL; } } else { @@ -363,6 +402,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) out_unlock: rcu_read_unlock(); + if (key) { + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); + } out: if (rc) kfree_skb(skb); @@ -459,6 +502,7 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, */ hlist_add_head_rcu(&key->hlist, &mns->keys); hlist_add_head_rcu(&key->sklist, &msk->keys); + refcount_inc(&key->refs); } /* Allocate a locally-owned tag value for (saddr, daddr), and reserve @@ -492,14 +536,26 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, * tags. If we find a conflict, clear that bit from tagbits */ hlist_for_each_entry(tmp, &mns->keys, hlist) { + /* We can check the lookup fields (*_addr, tag) without the + * lock held, they don't change over the lifetime of the key. + */ + /* if we don't own the tag, it can't conflict */ if (tmp->tag & MCTP_HDR_FLAG_TO) continue; - if ((tmp->peer_addr == daddr || - tmp->peer_addr == MCTP_ADDR_ANY) && - tmp->local_addr == saddr) + if (!((tmp->peer_addr == daddr || + tmp->peer_addr == MCTP_ADDR_ANY) && + tmp->local_addr == saddr)) + continue; + + spin_lock(&tmp->lock); + /* key must still be valid. If we find a match, clear the + * potential tag value + */ + if (tmp->valid) tagbits &= ~(1 << tmp->tag); + spin_unlock(&tmp->lock); if (!tagbits) break;