diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h index 0e558ca7afbf..1333d0cddfbc 100644 --- a/include/net/sock_reuseport.h +++ b/include/net/sock_reuseport.h @@ -32,6 +32,7 @@ extern int reuseport_alloc(struct sock *sk, bool bind_inany); extern int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany); extern void reuseport_detach_sock(struct sock *sk); +void reuseport_stop_listen_sock(struct sock *sk); extern struct sock *reuseport_select_sock(struct sock *sk, u32 hash, struct sk_buff *skb, diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c index f478c65a281b..41fcd55ab5ae 100644 --- a/net/core/sock_reuseport.c +++ b/net/core/sock_reuseport.c @@ -17,6 +17,8 @@ DEFINE_SPINLOCK(reuseport_lock); static DEFINE_IDA(reuseport_ida); +static int reuseport_resurrect(struct sock *sk, struct sock_reuseport *old_reuse, + struct sock_reuseport *reuse, bool bind_inany); static int reuseport_sock_index(struct sock *sk, const struct sock_reuseport *reuse, @@ -61,6 +63,29 @@ static bool __reuseport_detach_sock(struct sock *sk, return true; } +static void __reuseport_add_closed_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + reuse->socks[reuse->max_socks - reuse->num_closed_socks - 1] = sk; + /* paired with READ_ONCE() in inet_csk_bind_conflict() */ + WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks + 1); +} + +static bool __reuseport_detach_closed_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + int i = reuseport_sock_index(sk, reuse, true); + + if (i == -1) + return false; + + reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks]; + /* paired with READ_ONCE() in inet_csk_bind_conflict() */ + WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks - 1); + + return true; +} + static struct sock_reuseport *__reuseport_alloc(unsigned int max_socks) { unsigned int size = sizeof(struct sock_reuseport) + @@ -92,6 +117,12 @@ int reuseport_alloc(struct sock *sk, bool bind_inany) reuse = rcu_dereference_protected(sk->sk_reuseport_cb, lockdep_is_held(&reuseport_lock)); if (reuse) { + if (reuse->num_closed_socks) { + /* sk was shutdown()ed before */ + ret = reuseport_resurrect(sk, reuse, NULL, bind_inany); + goto out; + } + /* Only set reuse->bind_inany if the bind_inany is true. * Otherwise, it will overwrite the reuse->bind_inany * which was set by the bind/hash path. @@ -133,8 +164,23 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse) u32 more_socks_size, i; more_socks_size = reuse->max_socks * 2U; - if (more_socks_size > U16_MAX) + if (more_socks_size > U16_MAX) { + if (reuse->num_closed_socks) { + /* Make room by removing a closed sk. + * The child has already been migrated. + * Only reqsk left at this point. + */ + struct sock *sk; + + sk = reuse->socks[reuse->max_socks - reuse->num_closed_socks]; + RCU_INIT_POINTER(sk->sk_reuseport_cb, NULL); + __reuseport_detach_closed_sock(sk, reuse); + + return reuse; + } + return NULL; + } more_reuse = __reuseport_alloc(more_socks_size); if (!more_reuse) @@ -200,7 +246,15 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany) reuse = rcu_dereference_protected(sk2->sk_reuseport_cb, lockdep_is_held(&reuseport_lock)); old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb, - lockdep_is_held(&reuseport_lock)); + lockdep_is_held(&reuseport_lock)); + if (old_reuse && old_reuse->num_closed_socks) { + /* sk was shutdown()ed before */ + int err = reuseport_resurrect(sk, old_reuse, reuse, reuse->bind_inany); + + spin_unlock_bh(&reuseport_lock); + return err; + } + if (old_reuse && old_reuse->num_socks != 1) { spin_unlock_bh(&reuseport_lock); return -EBUSY; @@ -225,6 +279,65 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany) } EXPORT_SYMBOL(reuseport_add_sock); +static int reuseport_resurrect(struct sock *sk, struct sock_reuseport *old_reuse, + struct sock_reuseport *reuse, bool bind_inany) +{ + if (old_reuse == reuse) { + /* If sk was in the same reuseport group, just pop sk out of + * the closed section and push sk into the listening section. + */ + __reuseport_detach_closed_sock(sk, old_reuse); + __reuseport_add_sock(sk, old_reuse); + return 0; + } + + if (!reuse) { + /* In bind()/listen() path, we cannot carry over the eBPF prog + * for the shutdown()ed socket. In setsockopt() path, we should + * not change the eBPF prog of listening sockets by attaching a + * prog to the shutdown()ed socket. Thus, we will allocate a new + * reuseport group and detach sk from the old group. + */ + int id; + + reuse = __reuseport_alloc(INIT_SOCKS); + if (!reuse) + return -ENOMEM; + + id = ida_alloc(&reuseport_ida, GFP_ATOMIC); + if (id < 0) { + kfree(reuse); + return id; + } + + reuse->reuseport_id = id; + reuse->bind_inany = bind_inany; + } else { + /* Move sk from the old group to the new one if + * - all the other listeners in the old group were close()d or + * shutdown()ed, and then sk2 has listen()ed on the same port + * OR + * - sk listen()ed without bind() (or with autobind), was + * shutdown()ed, and then listen()s on another port which + * sk2 listen()s on. + */ + if (reuse->num_socks + reuse->num_closed_socks == reuse->max_socks) { + reuse = reuseport_grow(reuse); + if (!reuse) + return -ENOMEM; + } + } + + __reuseport_detach_closed_sock(sk, old_reuse); + __reuseport_add_sock(sk, reuse); + rcu_assign_pointer(sk->sk_reuseport_cb, reuse); + + if (old_reuse->num_socks + old_reuse->num_closed_socks == 0) + call_rcu(&old_reuse->rcu, reuseport_free_rcu); + + return 0; +} + void reuseport_detach_sock(struct sock *sk) { struct sock_reuseport *reuse; @@ -233,6 +346,10 @@ void reuseport_detach_sock(struct sock *sk) reuse = rcu_dereference_protected(sk->sk_reuseport_cb, lockdep_is_held(&reuseport_lock)); + /* reuseport_grow() has detached a closed sk */ + if (!reuse) + goto out; + /* Notify the bpf side. The sk may be added to a sockarray * map. If so, sockarray logic will remove it from the map. * @@ -244,15 +361,49 @@ void reuseport_detach_sock(struct sock *sk) bpf_sk_reuseport_detach(sk); rcu_assign_pointer(sk->sk_reuseport_cb, NULL); - __reuseport_detach_sock(sk, reuse); + + if (!__reuseport_detach_closed_sock(sk, reuse)) + __reuseport_detach_sock(sk, reuse); if (reuse->num_socks + reuse->num_closed_socks == 0) call_rcu(&reuse->rcu, reuseport_free_rcu); +out: spin_unlock_bh(&reuseport_lock); } EXPORT_SYMBOL(reuseport_detach_sock); +void reuseport_stop_listen_sock(struct sock *sk) +{ + if (sk->sk_protocol == IPPROTO_TCP) { + struct sock_reuseport *reuse; + + spin_lock_bh(&reuseport_lock); + + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + + if (sock_net(sk)->ipv4.sysctl_tcp_migrate_req) { + /* Migration capable, move sk from the listening section + * to the closed section. + */ + bpf_sk_reuseport_detach(sk); + + __reuseport_detach_sock(sk, reuse); + __reuseport_add_closed_sock(sk, reuse); + + spin_unlock_bh(&reuseport_lock); + return; + } + + spin_unlock_bh(&reuseport_lock); + } + + /* Not capable to do migration, detach immediately */ + reuseport_detach_sock(sk); +} +EXPORT_SYMBOL(reuseport_stop_listen_sock); + static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks, struct bpf_prog *prog, struct sk_buff *skb, int hdr_len) @@ -352,9 +503,13 @@ int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) struct sock_reuseport *reuse; struct bpf_prog *old_prog; - if (sk_unhashed(sk) && sk->sk_reuseport) { - int err = reuseport_alloc(sk, false); + if (sk_unhashed(sk)) { + int err; + if (!sk->sk_reuseport) + return -EINVAL; + + err = reuseport_alloc(sk, false); if (err) return err; } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) { @@ -380,13 +535,24 @@ int reuseport_detach_prog(struct sock *sk) struct sock_reuseport *reuse; struct bpf_prog *old_prog; - if (!rcu_access_pointer(sk->sk_reuseport_cb)) - return sk->sk_reuseport ? -ENOENT : -EINVAL; - old_prog = NULL; spin_lock_bh(&reuseport_lock); reuse = rcu_dereference_protected(sk->sk_reuseport_cb, lockdep_is_held(&reuseport_lock)); + + /* reuse must be checked after acquiring the reuseport_lock + * because reuseport_grow() can detach a closed sk. + */ + if (!reuse) { + spin_unlock_bh(&reuseport_lock); + return sk->sk_reuseport ? -ENOENT : -EINVAL; + } + + if (sk_unhashed(sk) && reuse->num_closed_socks) { + spin_unlock_bh(&reuseport_lock); + return -ENOENT; + } + old_prog = rcu_replace_pointer(reuse->prog, old_prog, lockdep_is_held(&reuseport_lock)); spin_unlock_bh(&reuseport_lock); diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index fd472eae4f5c..fa806e9167ec 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -135,10 +135,18 @@ static int inet_csk_bind_conflict(const struct sock *sk, bool relax, bool reuseport_ok) { struct sock *sk2; + bool reuseport_cb_ok; bool reuse = sk->sk_reuse; bool reuseport = !!sk->sk_reuseport; + struct sock_reuseport *reuseport_cb; kuid_t uid = sock_i_uid((struct sock *)sk); + rcu_read_lock(); + reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); + /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */ + reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); + rcu_read_unlock(); + /* * Unlike other sk lookup places we do not check * for sk_net here, since _all_ the socks listed @@ -156,14 +164,14 @@ static int inet_csk_bind_conflict(const struct sock *sk, if ((!relax || (!reuseport_ok && reuseport && sk2->sk_reuseport && - !rcu_access_pointer(sk->sk_reuseport_cb) && + reuseport_cb_ok && (sk2->sk_state == TCP_TIME_WAIT || uid_eq(uid, sock_i_uid(sk2))))) && inet_rcv_saddr_equal(sk, sk2, true)) break; } else if (!reuseport_ok || !reuseport || !sk2->sk_reuseport || - rcu_access_pointer(sk->sk_reuseport_cb) || + !reuseport_cb_ok || (sk2->sk_state != TCP_TIME_WAIT && !uid_eq(uid, sock_i_uid(sk2)))) { if (inet_rcv_saddr_equal(sk, sk2, true)) diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index c96866a53a66..80aeaf9e6e16 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -697,7 +697,7 @@ void inet_unhash(struct sock *sk) goto unlock; if (rcu_access_pointer(sk->sk_reuseport_cb)) - reuseport_detach_sock(sk); + reuseport_stop_listen_sock(sk); if (ilb) { inet_unhash2(hashinfo, sk); ilb->count--;