diff --git a/include/net/tls.h b/include/net/tls.h index 3a33924db2bc..61fef2880114 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -390,8 +390,12 @@ tls_offload_ctx_tx(const struct tls_context *tls_ctx) static inline bool tls_sw_has_ctx_tx(const struct sock *sk) { - struct tls_context *ctx = tls_get_ctx(sk); + struct tls_context *ctx; + if (!sk_is_inet(sk) || !inet_test_bit(IS_ICSK, sk)) + return false; + + ctx = tls_get_ctx(sk); if (!ctx) return false; return !!tls_sw_ctx_tx(ctx); @@ -399,8 +403,12 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk) static inline bool tls_sw_has_ctx_rx(const struct sock *sk) { - struct tls_context *ctx = tls_get_ctx(sk); + struct tls_context *ctx; + if (!sk_is_inet(sk) || !inet_test_bit(IS_ICSK, sk)) + return false; + + ctx = tls_get_ctx(sk); if (!ctx) return false; return !!tls_sw_ctx_rx(ctx);