Use READ/WRITE_ONCE() for IP local_port_range.

Commit 227b60f510 added a seqlock to ensure that the low and high
port numbers were always updated together.
This is overkill because the two 16bit port numbers can be held in
a u32 and read/written in a single instruction.

More recently 91d0b78c51 added support for finer per-socket limits.
The user-supplied value is 'high << 16 | low' but they are held
separately and the socket options protected by the socket lock.

Use a u32 containing 'high << 16 | low' for both the 'net' and 'sk'
fields and use READ_ONCE()/WRITE_ONCE() to ensure both values are
always updated together.

Change (the now trival) inet_get_local_port_range() to a static inline
to optimise the calling code.
(In particular avoiding returning integers by reference.)

Signed-off-by: David Laight <david.laight@aculab.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Acked-by: Mat Martineau <martineau@kernel.org>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Link: https://lore.kernel.org/r/4e505d4198e946a8be03fb1b4c3072b0@AcuMS.aculab.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
David Laight 2023-12-06 13:44:20 +00:00 committed by Jakub Kicinski
parent 36b0bdb6d3
commit d9f28735af
7 changed files with 43 additions and 57 deletions

View File

@ -234,10 +234,7 @@ struct inet_sock {
int uc_index; int uc_index;
int mc_index; int mc_index;
__be32 mc_addr; __be32 mc_addr;
struct { u32 local_port_range; /* high << 16 | low */
__u16 lo;
__u16 hi;
} local_port_range;
struct ip_mc_socklist __rcu *mc_list; struct ip_mc_socklist __rcu *mc_list;
struct inet_cork_full cork; struct inet_cork_full cork;

View File

@ -349,7 +349,13 @@ static inline u64 snmp_fold_field64(void __percpu *mib, int offt, size_t syncp_o
} \ } \
} }
void inet_get_local_port_range(const struct net *net, int *low, int *high); static inline void inet_get_local_port_range(const struct net *net, int *low, int *high)
{
u32 range = READ_ONCE(net->ipv4.ip_local_ports.range);
*low = range & 0xffff;
*high = range >> 16;
}
void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high); void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high);
#ifdef CONFIG_SYSCTL #ifdef CONFIG_SYSCTL

View File

@ -19,8 +19,7 @@ struct hlist_head;
struct fib_table; struct fib_table;
struct sock; struct sock;
struct local_ports { struct local_ports {
seqlock_t lock; u32 range; /* high << 16 | low */
int range[2];
bool warned; bool warned;
}; };

View File

@ -1847,9 +1847,7 @@ static __net_init int inet_init_net(struct net *net)
/* /*
* Set defaults for local port range * Set defaults for local port range
*/ */
seqlock_init(&net->ipv4.ip_local_ports.lock); net->ipv4.ip_local_ports.range = 60999u << 16 | 32768u;
net->ipv4.ip_local_ports.range[0] = 32768;
net->ipv4.ip_local_ports.range[1] = 60999;
seqlock_init(&net->ipv4.ping_group_range.lock); seqlock_init(&net->ipv4.ping_group_range.lock);
/* /*

View File

@ -117,34 +117,25 @@ bool inet_rcv_saddr_any(const struct sock *sk)
return !sk->sk_rcv_saddr; return !sk->sk_rcv_saddr;
} }
void inet_get_local_port_range(const struct net *net, int *low, int *high)
{
unsigned int seq;
do {
seq = read_seqbegin(&net->ipv4.ip_local_ports.lock);
*low = net->ipv4.ip_local_ports.range[0];
*high = net->ipv4.ip_local_ports.range[1];
} while (read_seqretry(&net->ipv4.ip_local_ports.lock, seq));
}
EXPORT_SYMBOL(inet_get_local_port_range);
void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high) void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high)
{ {
const struct inet_sock *inet = inet_sk(sk); const struct inet_sock *inet = inet_sk(sk);
const struct net *net = sock_net(sk); const struct net *net = sock_net(sk);
int lo, hi, sk_lo, sk_hi; int lo, hi, sk_lo, sk_hi;
u32 sk_range;
inet_get_local_port_range(net, &lo, &hi); inet_get_local_port_range(net, &lo, &hi);
sk_lo = inet->local_port_range.lo; sk_range = READ_ONCE(inet->local_port_range);
sk_hi = inet->local_port_range.hi; if (unlikely(sk_range)) {
sk_lo = sk_range & 0xffff;
sk_hi = sk_range >> 16;
if (unlikely(lo <= sk_lo && sk_lo <= hi)) if (lo <= sk_lo && sk_lo <= hi)
lo = sk_lo; lo = sk_lo;
if (unlikely(lo <= sk_hi && sk_hi <= hi)) if (lo <= sk_hi && sk_hi <= hi)
hi = sk_hi; hi = sk_hi;
}
*low = lo; *low = lo;
*high = hi; *high = hi;

View File

@ -1055,6 +1055,19 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
case IP_TOS: /* This sets both TOS and Precedence */ case IP_TOS: /* This sets both TOS and Precedence */
ip_sock_set_tos(sk, val); ip_sock_set_tos(sk, val);
return 0; return 0;
case IP_LOCAL_PORT_RANGE:
{
u16 lo = val;
u16 hi = val >> 16;
if (optlen != sizeof(u32))
return -EINVAL;
if (lo != 0 && hi != 0 && lo > hi)
return -EINVAL;
WRITE_ONCE(inet->local_port_range, val);
return 0;
}
} }
err = 0; err = 0;
@ -1332,20 +1345,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
err = xfrm_user_policy(sk, optname, optval, optlen); err = xfrm_user_policy(sk, optname, optval, optlen);
break; break;
case IP_LOCAL_PORT_RANGE:
{
const __u16 lo = val;
const __u16 hi = val >> 16;
if (optlen != sizeof(__u32))
goto e_inval;
if (lo != 0 && hi != 0 && lo > hi)
goto e_inval;
inet->local_port_range.lo = lo;
inet->local_port_range.hi = hi;
break;
}
default: default:
err = -ENOPROTOOPT; err = -ENOPROTOOPT;
break; break;
@ -1692,6 +1691,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
case IP_LOCAL_PORT_RANGE:
val = READ_ONCE(inet->local_port_range);
goto copyval;
} }
if (needs_rtnl) if (needs_rtnl)
@ -1721,9 +1723,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
else else
err = ip_get_mcast_msfilter(sk, optval, optlen, len); err = ip_get_mcast_msfilter(sk, optval, optlen, len);
goto out; goto out;
case IP_LOCAL_PORT_RANGE:
val = inet->local_port_range.hi << 16 | inet->local_port_range.lo;
break;
case IP_PROTOCOL: case IP_PROTOCOL:
val = inet_sk(sk)->inet_num; val = inet_sk(sk)->inet_num;
break; break;

View File

@ -50,26 +50,22 @@ static int tcp_plb_max_cong_thresh = 256;
static int sysctl_tcp_low_latency __read_mostly; static int sysctl_tcp_low_latency __read_mostly;
/* Update system visible IP port range */ /* Update system visible IP port range */
static void set_local_port_range(struct net *net, int range[2]) static void set_local_port_range(struct net *net, unsigned int low, unsigned int high)
{ {
bool same_parity = !((range[0] ^ range[1]) & 1); bool same_parity = !((low ^ high) & 1);
write_seqlock_bh(&net->ipv4.ip_local_ports.lock);
if (same_parity && !net->ipv4.ip_local_ports.warned) { if (same_parity && !net->ipv4.ip_local_ports.warned) {
net->ipv4.ip_local_ports.warned = true; net->ipv4.ip_local_ports.warned = true;
pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n"); pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n");
} }
net->ipv4.ip_local_ports.range[0] = range[0]; WRITE_ONCE(net->ipv4.ip_local_ports.range, high << 16 | low);
net->ipv4.ip_local_ports.range[1] = range[1];
write_sequnlock_bh(&net->ipv4.ip_local_ports.lock);
} }
/* Validate changes from /proc interface. */ /* Validate changes from /proc interface. */
static int ipv4_local_port_range(struct ctl_table *table, int write, static int ipv4_local_port_range(struct ctl_table *table, int write,
void *buffer, size_t *lenp, loff_t *ppos) void *buffer, size_t *lenp, loff_t *ppos)
{ {
struct net *net = struct net *net = table->data;
container_of(table->data, struct net, ipv4.ip_local_ports.range);
int ret; int ret;
int range[2]; int range[2];
struct ctl_table tmp = { struct ctl_table tmp = {
@ -93,7 +89,7 @@ static int ipv4_local_port_range(struct ctl_table *table, int write,
(range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock))) (range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock)))
ret = -EINVAL; ret = -EINVAL;
else else
set_local_port_range(net, range); set_local_port_range(net, range[0], range[1]);
} }
return ret; return ret;
@ -733,8 +729,8 @@ static struct ctl_table ipv4_net_table[] = {
}, },
{ {
.procname = "ip_local_port_range", .procname = "ip_local_port_range",
.maxlen = sizeof(init_net.ipv4.ip_local_ports.range), .maxlen = 0,
.data = &init_net.ipv4.ip_local_ports.range, .data = &init_net,
.mode = 0644, .mode = 0644,
.proc_handler = ipv4_local_port_range, .proc_handler = ipv4_local_port_range,
}, },