SUNRPC: Recognize control messages in server-side TCP socket code

To support kTLS, the server-side TCP socket receive path needs to
watch for CMSGs.

Acked-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
This commit is contained in:
Chuck Lever 2023-04-17 09:42:14 -04:00
parent 6a0cdf56bf
commit 5e052dda12
2 changed files with 48 additions and 2 deletions

View File

@ -69,6 +69,8 @@ extern const struct tls_cipher_size_desc tls_cipher_size_desc[];
#define TLS_CRYPTO_INFO_READY(info) ((info)->cipher_type)
#define TLS_RECORD_TYPE_ALERT 0x15
#define TLS_RECORD_TYPE_HANDSHAKE 0x16
#define TLS_RECORD_TYPE_DATA 0x17
#define TLS_AAD_SPACE_SIZE 13

View File

@ -43,6 +43,7 @@
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tcp_states.h>
#include <net/tls.h>
#include <linux/uaccess.h>
#include <linux/highmem.h>
#include <asm/ioctls.h>
@ -216,6 +217,49 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining)
return len;
}
static int
svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg,
struct cmsghdr *cmsg, int ret)
{
if (cmsg->cmsg_level == SOL_TLS &&
cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
u8 content_type = *((u8 *)CMSG_DATA(cmsg));
switch (content_type) {
case TLS_RECORD_TYPE_DATA:
/* TLS sets EOR at the end of each application data
* record, even though there might be more frames
* waiting to be decrypted.
*/
msg->msg_flags &= ~MSG_EOR;
break;
case TLS_RECORD_TYPE_ALERT:
ret = -ENOTCONN;
break;
default:
ret = -EAGAIN;
}
}
return ret;
}
static int
svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg)
{
union {
struct cmsghdr cmsg;
u8 buf[CMSG_SPACE(sizeof(u8))];
} u;
int ret;
msg->msg_control = &u;
msg->msg_controllen = sizeof(u);
ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT);
if (unlikely(msg->msg_controllen != sizeof(u)))
ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret);
return ret;
}
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek)
{
@ -263,7 +307,7 @@ static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen,
iov_iter_advance(&msg.msg_iter, seek);
buflen -= seek;
}
len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
len = svc_tcp_sock_recv_cmsg(svsk, &msg);
if (len > 0)
svc_flush_bvec(bvec, len, seek);
@ -877,7 +921,7 @@ static ssize_t svc_tcp_read_marker(struct svc_sock *svsk,
iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen;
iov.iov_len = want;
iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want);
len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
len = svc_tcp_sock_recv_cmsg(svsk, &msg);
if (len < 0)
return len;
svsk->sk_tcplen += len;