diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index 73902acf2b71..04f52e719571 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -485,6 +485,7 @@ void sock_zerocopy_put_abort(struct ubuf_info *uarg);
 
 void sock_zerocopy_callback(struct ubuf_info *uarg, bool success);
 
+int skb_zerocopy_iter_dgram(struct sk_buff *skb, struct msghdr *msg, int len);
 int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb,
 			     struct msghdr *msg, int len,
 			     struct ubuf_info *uarg);
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index 3c814565ed7c..1350901c5cb8 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -1105,6 +1105,12 @@ EXPORT_SYMBOL_GPL(sock_zerocopy_put_abort);
 extern int __zerocopy_sg_from_iter(struct sock *sk, struct sk_buff *skb,
 				   struct iov_iter *from, size_t length);
 
+int skb_zerocopy_iter_dgram(struct sk_buff *skb, struct msghdr *msg, int len)
+{
+	return __zerocopy_sg_from_iter(skb->sk, skb, &msg->msg_iter, len);
+}
+EXPORT_SYMBOL_GPL(skb_zerocopy_iter_dgram);
+
 int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb,
 			     struct msghdr *msg, int len,
 			     struct ubuf_info *uarg)
diff --git a/net/core/sock.c b/net/core/sock.c
index 6d7e189e3cd9..f5bb89785e47 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1018,7 +1018,10 @@ set_rcvbuf:
 
 	case SO_ZEROCOPY:
 		if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6) {
-			if (sk->sk_protocol != IPPROTO_TCP)
+			if (!((sk->sk_type == SOCK_STREAM &&
+			       sk->sk_protocol == IPPROTO_TCP) ||
+			      (sk->sk_type == SOCK_DGRAM &&
+			       sk->sk_protocol == IPPROTO_UDP)))
 				ret = -ENOTSUPP;
 		} else if (sk->sk_family != PF_RDS) {
 			ret = -ENOTSUPP;
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 5dbec21856f4..6f843aff628c 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -867,6 +867,7 @@ static int __ip_append_data(struct sock *sk,
 			    unsigned int flags)
 {
 	struct inet_sock *inet = inet_sk(sk);
+	struct ubuf_info *uarg = NULL;
 	struct sk_buff *skb;
 
 	struct ip_options *opt = cork->opt;
@@ -916,6 +917,19 @@ static int __ip_append_data(struct sock *sk,
 	    (!exthdrlen || (rt->dst.dev->features & NETIF_F_HW_ESP_TX_CSUM)))
 		csummode = CHECKSUM_PARTIAL;
 
+	if (flags & MSG_ZEROCOPY && length && sock_flag(sk, SOCK_ZEROCOPY)) {
+		uarg = sock_zerocopy_realloc(sk, length, skb_zcopy(skb));
+		if (!uarg)
+			return -ENOBUFS;
+		if (rt->dst.dev->features & NETIF_F_SG &&
+		    csummode == CHECKSUM_PARTIAL) {
+			paged = true;
+		} else {
+			uarg->zerocopy = 0;
+			skb_zcopy_set(skb, uarg);
+		}
+	}
+
 	cork->length += length;
 
 	/* So, what's going on in the loop below?
@@ -1006,6 +1020,7 @@ alloc_new_skb:
 			cork->tx_flags = 0;
 			skb_shinfo(skb)->tskey = tskey;
 			tskey = 0;
+			skb_zcopy_set(skb, uarg);
 
 			/*
 			 *	Find where to start putting bytes.
@@ -1068,7 +1083,7 @@ alloc_new_skb:
 				err = -EFAULT;
 				goto error;
 			}
-		} else {
+		} else if (!uarg || !uarg->zerocopy) {
 			int i = skb_shinfo(skb)->nr_frags;
 
 			err = -ENOMEM;
@@ -1098,6 +1113,10 @@ alloc_new_skb:
 			skb->data_len += copy;
 			skb->truesize += copy;
 			wmem_alloc_delta += copy;
+		} else {
+			err = skb_zerocopy_iter_dgram(skb, from, copy);
+			if (err < 0)
+				goto error;
 		}
 		offset += copy;
 		length -= copy;
@@ -1105,11 +1124,13 @@ alloc_new_skb:
 
 	if (wmem_alloc_delta)
 		refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
+	sock_zerocopy_put(uarg);
 	return 0;
 
 error_efault:
 	err = -EFAULT;
 error:
+	sock_zerocopy_put_abort(uarg);
 	cork->length -= length;
 	IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS);
 	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 827a3f5ff3bb..7df04d20a91f 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -1245,6 +1245,7 @@ static int __ip6_append_data(struct sock *sk,
 {
 	struct sk_buff *skb, *skb_prev = NULL;
 	unsigned int maxfraglen, fragheaderlen, mtu, orig_mtu, pmtu;
+	struct ubuf_info *uarg = NULL;
 	int exthdrlen = 0;
 	int dst_exthdrlen = 0;
 	int hh_len;
@@ -1322,6 +1323,19 @@ emsgsize:
 	    rt->dst.dev->features & (NETIF_F_IPV6_CSUM | NETIF_F_HW_CSUM))
 		csummode = CHECKSUM_PARTIAL;
 
+	if (flags & MSG_ZEROCOPY && length && sock_flag(sk, SOCK_ZEROCOPY)) {
+		uarg = sock_zerocopy_realloc(sk, length, skb_zcopy(skb));
+		if (!uarg)
+			return -ENOBUFS;
+		if (rt->dst.dev->features & NETIF_F_SG &&
+		    csummode == CHECKSUM_PARTIAL) {
+			paged = true;
+		} else {
+			uarg->zerocopy = 0;
+			skb_zcopy_set(skb, uarg);
+		}
+	}
+
 	/*
 	 * Let's try using as much space as possible.
 	 * Use MTU if total length of the message fits into the MTU.
@@ -1445,6 +1459,7 @@ alloc_new_skb:
 			cork->tx_flags = 0;
 			skb_shinfo(skb)->tskey = tskey;
 			tskey = 0;
+			skb_zcopy_set(skb, uarg);
 
 			/*
 			 *	Find where to start putting bytes
@@ -1506,7 +1521,7 @@ alloc_new_skb:
 				err = -EFAULT;
 				goto error;
 			}
-		} else {
+		} else if (!uarg || !uarg->zerocopy) {
 			int i = skb_shinfo(skb)->nr_frags;
 
 			err = -ENOMEM;
@@ -1536,6 +1551,10 @@ alloc_new_skb:
 			skb->data_len += copy;
 			skb->truesize += copy;
 			wmem_alloc_delta += copy;
+		} else {
+			err = skb_zerocopy_iter_dgram(skb, from, copy);
+			if (err < 0)
+				goto error;
 		}
 		offset += copy;
 		length -= copy;
@@ -1543,11 +1562,13 @@ alloc_new_skb:
 
 	if (wmem_alloc_delta)
 		refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
+	sock_zerocopy_put(uarg);
 	return 0;
 
 error_efault:
 	err = -EFAULT;
 error:
+	sock_zerocopy_put_abort(uarg);
 	cork->length -= length;
 	IP6_INC_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS);
 	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);