io_uring: ensure recv and recvmsg handle MSG_WAITALL correctly

We currently don't attempt to get the full asked for length even if
MSG_WAITALL is set, if we get a partial receive. If we do see a partial
receive, then just note how many bytes we did and return -EAGAIN to
get it retried.

The iov is advanced appropriately for the vector based case, and we
manually bump the buffer and remainder for the non-vector case.

Cc: stable@vger.kernel.org
Reported-by: Constantine Gavrilov <constantine.gavrilov@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
This commit is contained in:
Jens Axboe 2022-03-23 09:32:35 -06:00
parent 4d55f238f8
commit 7ba89d2af1

View File

@ -612,6 +612,7 @@ struct io_sr_msg {
int msg_flags; int msg_flags;
int bgid; int bgid;
size_t len; size_t len;
size_t done_io;
}; };
struct io_open { struct io_open {
@ -5417,12 +5418,21 @@ static int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
if (req->ctx->compat) if (req->ctx->compat)
sr->msg_flags |= MSG_CMSG_COMPAT; sr->msg_flags |= MSG_CMSG_COMPAT;
#endif #endif
sr->done_io = 0;
return 0; return 0;
} }
static bool io_net_retry(struct socket *sock, int flags)
{
if (!(flags & MSG_WAITALL))
return false;
return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
}
static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags) static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
{ {
struct io_async_msghdr iomsg, *kmsg; struct io_async_msghdr iomsg, *kmsg;
struct io_sr_msg *sr = &req->sr_msg;
struct socket *sock; struct socket *sock;
struct io_buffer *kbuf; struct io_buffer *kbuf;
unsigned flags; unsigned flags;
@ -5465,6 +5475,10 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
return io_setup_async_msg(req, kmsg); return io_setup_async_msg(req, kmsg);
if (ret == -ERESTARTSYS) if (ret == -ERESTARTSYS)
ret = -EINTR; ret = -EINTR;
if (ret > 0 && io_net_retry(sock, flags)) {
sr->done_io += ret;
return io_setup_async_msg(req, kmsg);
}
req_set_fail(req); req_set_fail(req);
} else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) { } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
req_set_fail(req); req_set_fail(req);
@ -5474,6 +5488,10 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
if (kmsg->free_iov) if (kmsg->free_iov)
kfree(kmsg->free_iov); kfree(kmsg->free_iov);
req->flags &= ~REQ_F_NEED_CLEANUP; req->flags &= ~REQ_F_NEED_CLEANUP;
if (ret >= 0)
ret += sr->done_io;
else if (sr->done_io)
ret = sr->done_io;
__io_req_complete(req, issue_flags, ret, io_put_kbuf(req, issue_flags)); __io_req_complete(req, issue_flags, ret, io_put_kbuf(req, issue_flags));
return 0; return 0;
} }
@ -5524,12 +5542,22 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
return -EAGAIN; return -EAGAIN;
if (ret == -ERESTARTSYS) if (ret == -ERESTARTSYS)
ret = -EINTR; ret = -EINTR;
if (ret > 0 && io_net_retry(sock, flags)) {
sr->len -= ret;
sr->buf += ret;
sr->done_io += ret;
return -EAGAIN;
}
req_set_fail(req); req_set_fail(req);
} else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) { } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
out_free: out_free:
req_set_fail(req); req_set_fail(req);
} }
if (ret >= 0)
ret += sr->done_io;
else if (sr->done_io)
ret = sr->done_io;
__io_req_complete(req, issue_flags, ret, io_put_kbuf(req, issue_flags)); __io_req_complete(req, issue_flags, ret, io_put_kbuf(req, issue_flags));
return 0; return 0;
} }