forked from Minki/linux
skmsg: Extract __tcp_bpf_recvmsg() and tcp_bpf_wait_data()
Although these two functions are only used by TCP, they are not specific to TCP at all, both operate on skmsg and ingress_msg, so fit in net/core/skmsg.c very well. And we will need them for non-TCP, so rename and move them to skmsg.c and export them to modules. Signed-off-by: Cong Wang <cong.wang@bytedance.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/bpf/20210331023237.41094-13-xiyou.wangcong@gmail.com
This commit is contained in:
parent
d7f571188e
commit
2bc793e327
@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
|
||||
struct sk_msg *msg, u32 bytes);
|
||||
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
|
||||
struct sk_msg *msg, u32 bytes);
|
||||
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
|
||||
long timeo, int *err);
|
||||
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
|
||||
int len, int flags);
|
||||
|
||||
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
|
||||
{
|
||||
|
@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
|
||||
|
||||
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
|
||||
int flags);
|
||||
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
|
||||
struct msghdr *msg, int len, int flags);
|
||||
#endif /* CONFIG_NET_SOCK_MSG */
|
||||
|
||||
#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
|
||||
|
@ -399,6 +399,104 @@ out:
|
||||
}
|
||||
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
|
||||
|
||||
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
|
||||
long timeo, int *err)
|
||||
{
|
||||
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
||||
int ret = 0;
|
||||
|
||||
if (sk->sk_shutdown & RCV_SHUTDOWN)
|
||||
return 1;
|
||||
|
||||
if (!timeo)
|
||||
return ret;
|
||||
|
||||
add_wait_queue(sk_sleep(sk), &wait);
|
||||
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||
ret = sk_wait_event(sk, &timeo,
|
||||
!list_empty(&psock->ingress_msg) ||
|
||||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
|
||||
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||
remove_wait_queue(sk_sleep(sk), &wait);
|
||||
return ret;
|
||||
}
|
||||
EXPORT_SYMBOL_GPL(sk_msg_wait_data);
|
||||
|
||||
/* Receive sk_msg from psock->ingress_msg to @msg. */
|
||||
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
|
||||
int len, int flags)
|
||||
{
|
||||
struct iov_iter *iter = &msg->msg_iter;
|
||||
int peek = flags & MSG_PEEK;
|
||||
struct sk_msg *msg_rx;
|
||||
int i, copied = 0;
|
||||
|
||||
msg_rx = sk_psock_peek_msg(psock);
|
||||
while (copied != len) {
|
||||
struct scatterlist *sge;
|
||||
|
||||
if (unlikely(!msg_rx))
|
||||
break;
|
||||
|
||||
i = msg_rx->sg.start;
|
||||
do {
|
||||
struct page *page;
|
||||
int copy;
|
||||
|
||||
sge = sk_msg_elem(msg_rx, i);
|
||||
copy = sge->length;
|
||||
page = sg_page(sge);
|
||||
if (copied + copy > len)
|
||||
copy = len - copied;
|
||||
copy = copy_page_to_iter(page, sge->offset, copy, iter);
|
||||
if (!copy)
|
||||
return copied ? copied : -EFAULT;
|
||||
|
||||
copied += copy;
|
||||
if (likely(!peek)) {
|
||||
sge->offset += copy;
|
||||
sge->length -= copy;
|
||||
if (!msg_rx->skb)
|
||||
sk_mem_uncharge(sk, copy);
|
||||
msg_rx->sg.size -= copy;
|
||||
|
||||
if (!sge->length) {
|
||||
sk_msg_iter_var_next(i);
|
||||
if (!msg_rx->skb)
|
||||
put_page(page);
|
||||
}
|
||||
} else {
|
||||
/* Lets not optimize peek case if copy_page_to_iter
|
||||
* didn't copy the entire length lets just break.
|
||||
*/
|
||||
if (copy != sge->length)
|
||||
return copied;
|
||||
sk_msg_iter_var_next(i);
|
||||
}
|
||||
|
||||
if (copied == len)
|
||||
break;
|
||||
} while (i != msg_rx->sg.end);
|
||||
|
||||
if (unlikely(peek)) {
|
||||
msg_rx = sk_psock_next_msg(psock, msg_rx);
|
||||
if (!msg_rx)
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
msg_rx->sg.start = i;
|
||||
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
|
||||
msg_rx = sk_psock_dequeue_msg(psock);
|
||||
kfree_sk_msg(msg_rx);
|
||||
}
|
||||
msg_rx = sk_psock_peek_msg(psock);
|
||||
}
|
||||
|
||||
return copied;
|
||||
}
|
||||
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
|
||||
|
||||
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
|
||||
struct sk_buff *skb)
|
||||
{
|
||||
|
@ -10,80 +10,6 @@
|
||||
#include <net/inet_common.h>
|
||||
#include <net/tls.h>
|
||||
|
||||
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
|
||||
struct msghdr *msg, int len, int flags)
|
||||
{
|
||||
struct iov_iter *iter = &msg->msg_iter;
|
||||
int peek = flags & MSG_PEEK;
|
||||
struct sk_msg *msg_rx;
|
||||
int i, copied = 0;
|
||||
|
||||
msg_rx = sk_psock_peek_msg(psock);
|
||||
while (copied != len) {
|
||||
struct scatterlist *sge;
|
||||
|
||||
if (unlikely(!msg_rx))
|
||||
break;
|
||||
|
||||
i = msg_rx->sg.start;
|
||||
do {
|
||||
struct page *page;
|
||||
int copy;
|
||||
|
||||
sge = sk_msg_elem(msg_rx, i);
|
||||
copy = sge->length;
|
||||
page = sg_page(sge);
|
||||
if (copied + copy > len)
|
||||
copy = len - copied;
|
||||
copy = copy_page_to_iter(page, sge->offset, copy, iter);
|
||||
if (!copy)
|
||||
return copied ? copied : -EFAULT;
|
||||
|
||||
copied += copy;
|
||||
if (likely(!peek)) {
|
||||
sge->offset += copy;
|
||||
sge->length -= copy;
|
||||
if (!msg_rx->skb)
|
||||
sk_mem_uncharge(sk, copy);
|
||||
msg_rx->sg.size -= copy;
|
||||
|
||||
if (!sge->length) {
|
||||
sk_msg_iter_var_next(i);
|
||||
if (!msg_rx->skb)
|
||||
put_page(page);
|
||||
}
|
||||
} else {
|
||||
/* Lets not optimize peek case if copy_page_to_iter
|
||||
* didn't copy the entire length lets just break.
|
||||
*/
|
||||
if (copy != sge->length)
|
||||
return copied;
|
||||
sk_msg_iter_var_next(i);
|
||||
}
|
||||
|
||||
if (copied == len)
|
||||
break;
|
||||
} while (i != msg_rx->sg.end);
|
||||
|
||||
if (unlikely(peek)) {
|
||||
msg_rx = sk_psock_next_msg(psock, msg_rx);
|
||||
if (!msg_rx)
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
msg_rx->sg.start = i;
|
||||
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
|
||||
msg_rx = sk_psock_dequeue_msg(psock);
|
||||
kfree_sk_msg(msg_rx);
|
||||
}
|
||||
msg_rx = sk_psock_peek_msg(psock);
|
||||
}
|
||||
|
||||
return copied;
|
||||
}
|
||||
EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
|
||||
|
||||
static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
|
||||
struct sk_msg *msg, u32 apply_bytes, int flags)
|
||||
{
|
||||
@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
|
||||
return !empty;
|
||||
}
|
||||
|
||||
static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
|
||||
int flags, long timeo, int *err)
|
||||
{
|
||||
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
||||
int ret = 0;
|
||||
|
||||
if (sk->sk_shutdown & RCV_SHUTDOWN)
|
||||
return 1;
|
||||
|
||||
if (!timeo)
|
||||
return ret;
|
||||
|
||||
add_wait_queue(sk_sleep(sk), &wait);
|
||||
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||
ret = sk_wait_event(sk, &timeo,
|
||||
!list_empty(&psock->ingress_msg) ||
|
||||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
|
||||
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||
remove_wait_queue(sk_sleep(sk), &wait);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
|
||||
int nonblock, int flags, int *addr_len)
|
||||
{
|
||||
@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
|
||||
}
|
||||
lock_sock(sk);
|
||||
msg_bytes_ready:
|
||||
copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
|
||||
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
|
||||
if (!copied) {
|
||||
int data, err = 0;
|
||||
long timeo;
|
||||
|
||||
timeo = sock_rcvtimeo(sk, nonblock);
|
||||
data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
|
||||
data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
|
||||
if (data) {
|
||||
if (!sk_psock_queue_empty(psock))
|
||||
goto msg_bytes_ready;
|
||||
|
@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
skb = tls_wait_data(sk, psock, flags, timeo, &err);
|
||||
if (!skb) {
|
||||
if (psock) {
|
||||
int ret = __tcp_bpf_recvmsg(sk, psock,
|
||||
msg, len, flags);
|
||||
int ret = sk_msg_recvmsg(sk, psock, msg, len,
|
||||
flags);
|
||||
|
||||
if (ret > 0) {
|
||||
decrypted += ret;
|
||||
|
Loading…
Reference in New Issue
Block a user