diff --git a/include/net/tls.h b/include/net/tls.h
index a528a082da73..a5a938583295 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -519,6 +519,9 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk)
 	return !!tls_sw_ctx_tx(ctx);
 }
 
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
+
 static inline struct tls_offload_context_rx *
 tls_offload_ctx_rx(const struct tls_context *tls_ctx)
 {
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 3e5e8e021a87..4a1da837a733 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -546,6 +546,23 @@ static int tls_device_push_pending_record(struct sock *sk, int flags)
 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 }
 
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	int rc = 0;
+
+	if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
+		gfp_t sk_allocation = sk->sk_allocation;
+
+		sk->sk_allocation = GFP_ATOMIC;
+		rc = tls_push_partial_record(sk, ctx,
+					     MSG_DONTWAIT | MSG_NOSIGNAL);
+		sk->sk_allocation = sk_allocation;
+	}
+
+	if (!rc)
+		ctx->sk_write_space(sk);
+}
+
 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 7e05af75536d..17e8667917aa 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -212,7 +212,6 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 static void tls_write_space(struct sock *sk)
 {
 	struct tls_context *ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
 
 	/* If in_tcp_sendpages call lower protocol write space handler
 	 * to ensure we wake up any waiting operations there. For example
@@ -223,14 +222,12 @@ static void tls_write_space(struct sock *sk)
 		return;
 	}
 
-	/* Schedule the transmission if tx list is ready */
-	if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
-		/* Schedule the transmission */
-		if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
-			schedule_delayed_work(&tx_ctx->tx_work.work, 0);
-	}
-
-	ctx->sk_write_space(sk);
+#ifdef CONFIG_TLS_DEVICE
+	if (ctx->tx_conf == TLS_HW)
+		tls_device_write_space(sk, ctx);
+	else
+#endif
+		tls_sw_write_space(sk, ctx);
 }
 
 static void tls_ctx_free(struct tls_context *ctx)
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 1cc830582fa8..917caacd4d31 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -2126,6 +2126,19 @@ static void tx_work_handler(struct work_struct *work)
 	release_sock(sk);
 }
 
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
+
+	/* Schedule the transmission if tx list is ready */
+	if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
+		/* Schedule the transmission */
+		if (!test_and_set_bit(BIT_TX_SCHEDULED,
+				      &tx_ctx->tx_bitmask))
+			schedule_delayed_work(&tx_ctx->tx_work.work, 0);
+	}
+}
+
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);