diff --git a/include/net/udp.h b/include/net/udp.h index d4d064c59232..adf2ff8ac87c 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -515,6 +515,29 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk, return segs; } +static inline void udp_post_segment_fix_csum(struct sk_buff *skb) +{ + /* UDP-lite can't land here - no GRO */ + WARN_ON_ONCE(UDP_SKB_CB(skb)->partial_cov); + + /* UDP packets generated with UDP_SEGMENT and traversing: + * + * UDP tunnel(xmit) -> veth (segmentation) -> veth (gro) -> UDP tunnel (rx) + * + * can reach an UDP socket with CHECKSUM_NONE, because + * __iptunnel_pull_header() converts CHECKSUM_PARTIAL into NONE. + * SKB_GSO_UDP_L4 or SKB_GSO_FRAGLIST packets with no UDP tunnel will + * have a valid checksum, as the GRO engine validates the UDP csum + * before the aggregation and nobody strips such info in between. + * Instead of adding another check in the tunnel fastpath, we can force + * a valid csum after the segmentation. + * Additionally fixup the UDP CB. + */ + UDP_SKB_CB(skb)->cscov = skb->len; + if (skb->ip_summed == CHECKSUM_NONE && !skb->csum_valid) + skb->csum_valid = 1; +} + #ifdef CONFIG_BPF_SYSCALL struct sk_psock; struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 4a0478b17243..fe85dcf8c008 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2178,6 +2178,8 @@ static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) segs = udp_rcv_segment(sk, skb, true); skb_list_walk_safe(segs, skb, next) { __skb_pull(skb, skb_transport_offset(skb)); + + udp_post_segment_fix_csum(skb); ret = udp_queue_rcv_one_skb(sk, skb); if (ret > 0) ip_protocol_deliver_rcu(dev_net(skb->dev), skb, ret); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index d25e5a9252fd..fa2f54738392 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -749,6 +749,7 @@ static int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) skb_list_walk_safe(segs, skb, next) { __skb_pull(skb, skb_transport_offset(skb)); + udp_post_segment_fix_csum(skb); ret = udpv6_queue_rcv_one_skb(sk, skb); if (ret > 0) ip6_protocol_deliver_rcu(dev_net(skb->dev), skb, ret,