diff --git a/arch/x86/include/asm/kvm-x86-ops.h b/arch/x86/include/asm/kvm-x86-ops.h
index 3c368b639c04..1a6d7e3f6c32 100644
--- a/arch/x86/include/asm/kvm-x86-ops.h
+++ b/arch/x86/include/asm/kvm-x86-ops.h
@@ -118,6 +118,7 @@ KVM_X86_OP_OPTIONAL(mem_enc_register_region)
 KVM_X86_OP_OPTIONAL(mem_enc_unregister_region)
 KVM_X86_OP_OPTIONAL(vm_copy_enc_context_from)
 KVM_X86_OP_OPTIONAL(vm_move_enc_context_from)
+KVM_X86_OP_OPTIONAL(guest_memory_reclaimed)
 KVM_X86_OP(get_msr_feature)
 KVM_X86_OP(can_emulate_instruction)
 KVM_X86_OP(apic_init_signal_blocked)
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index e0c0f0e1f754..4ff36610af6a 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1484,6 +1484,7 @@ struct kvm_x86_ops {
 	int (*mem_enc_unregister_region)(struct kvm *kvm, struct kvm_enc_region *argp);
 	int (*vm_copy_enc_context_from)(struct kvm *kvm, unsigned int source_fd);
 	int (*vm_move_enc_context_from)(struct kvm *kvm, unsigned int source_fd);
+	void (*guest_memory_reclaimed)(struct kvm *kvm);
 
 	int (*get_msr_feature)(struct kvm_msr_entry *entry);
 
diff --git a/arch/x86/kvm/svm/sev.c b/arch/x86/kvm/svm/sev.c
index 9a0375987029..0ad70c12c7c3 100644
--- a/arch/x86/kvm/svm/sev.c
+++ b/arch/x86/kvm/svm/sev.c
@@ -2262,6 +2262,14 @@ do_wbinvd:
 	wbinvd_on_all_cpus();
 }
 
+void sev_guest_memory_reclaimed(struct kvm *kvm)
+{
+	if (!sev_guest(kvm))
+		return;
+
+	wbinvd_on_all_cpus();
+}
+
 void sev_free_vcpu(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm;
diff --git a/arch/x86/kvm/svm/svm.c b/arch/x86/kvm/svm/svm.c
index bd4c64b362d2..7e45d03cd018 100644
--- a/arch/x86/kvm/svm/svm.c
+++ b/arch/x86/kvm/svm/svm.c
@@ -4620,6 +4620,7 @@ static struct kvm_x86_ops svm_x86_ops __initdata = {
 	.mem_enc_ioctl = sev_mem_enc_ioctl,
 	.mem_enc_register_region = sev_mem_enc_register_region,
 	.mem_enc_unregister_region = sev_mem_enc_unregister_region,
+	.guest_memory_reclaimed = sev_guest_memory_reclaimed,
 
 	.vm_copy_enc_context_from = sev_vm_copy_enc_context_from,
 	.vm_move_enc_context_from = sev_vm_move_enc_context_from,
diff --git a/arch/x86/kvm/svm/svm.h b/arch/x86/kvm/svm/svm.h
index f77a7d2d39dd..f76deff71002 100644
--- a/arch/x86/kvm/svm/svm.h
+++ b/arch/x86/kvm/svm/svm.h
@@ -609,6 +609,8 @@ int sev_mem_enc_unregister_region(struct kvm *kvm,
 				  struct kvm_enc_region *range);
 int sev_vm_copy_enc_context_from(struct kvm *kvm, unsigned int source_fd);
 int sev_vm_move_enc_context_from(struct kvm *kvm, unsigned int source_fd);
+void sev_guest_memory_reclaimed(struct kvm *kvm);
+
 void pre_sev_run(struct vcpu_svm *svm, int cpu);
 void __init sev_set_cpu_caps(void);
 void __init sev_hardware_setup(void);
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index c89dc09a764f..a6ab19afc638 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -9889,6 +9889,11 @@ void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
 		kvm_make_all_cpus_request(kvm, KVM_REQ_APIC_PAGE_RELOAD);
 }
 
+void kvm_arch_guest_memory_reclaimed(struct kvm *kvm)
+{
+	static_call_cond(kvm_x86_guest_memory_reclaimed)(kvm);
+}
+
 static void kvm_vcpu_reload_apic_access_page(struct kvm_vcpu *vcpu)
 {
 	if (!lapic_in_kernel(vcpu))
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 2dab4b696682..34eed5f85ed6 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -2219,6 +2219,8 @@ static inline long kvm_arch_vcpu_async_ioctl(struct file *filp,
 void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
 					    unsigned long start, unsigned long end);
 
+void kvm_arch_guest_memory_reclaimed(struct kvm *kvm);
+
 #ifdef CONFIG_HAVE_KVM_VCPU_RUN_PID_CHANGE
 int kvm_arch_vcpu_run_pid_change(struct kvm_vcpu *vcpu);
 #else
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 2a23f24d13cf..f30bb8c16f26 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -164,6 +164,10 @@ __weak void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
 {
 }
 
+__weak void kvm_arch_guest_memory_reclaimed(struct kvm *kvm)
+{
+}
+
 bool kvm_is_zone_device_pfn(kvm_pfn_t pfn)
 {
 	/*
@@ -357,6 +361,12 @@ void kvm_flush_remote_tlbs(struct kvm *kvm)
 EXPORT_SYMBOL_GPL(kvm_flush_remote_tlbs);
 #endif
 
+static void kvm_flush_shadow_all(struct kvm *kvm)
+{
+	kvm_arch_flush_shadow_all(kvm);
+	kvm_arch_guest_memory_reclaimed(kvm);
+}
+
 #ifdef KVM_ARCH_NR_OBJS_PER_MEMORY_CACHE
 static inline void *mmu_memory_cache_alloc_obj(struct kvm_mmu_memory_cache *mc,
 					       gfp_t gfp_flags)
@@ -485,12 +495,15 @@ typedef bool (*hva_handler_t)(struct kvm *kvm, struct kvm_gfn_range *range);
 typedef void (*on_lock_fn_t)(struct kvm *kvm, unsigned long start,
 			     unsigned long end);
 
+typedef void (*on_unlock_fn_t)(struct kvm *kvm);
+
 struct kvm_hva_range {
 	unsigned long start;
 	unsigned long end;
 	pte_t pte;
 	hva_handler_t handler;
 	on_lock_fn_t on_lock;
+	on_unlock_fn_t on_unlock;
 	bool flush_on_ret;
 	bool may_block;
 };
@@ -578,8 +591,11 @@ static __always_inline int __kvm_handle_hva_range(struct kvm *kvm,
 	if (range->flush_on_ret && ret)
 		kvm_flush_remote_tlbs(kvm);
 
-	if (locked)
+	if (locked) {
 		KVM_MMU_UNLOCK(kvm);
+		if (!IS_KVM_NULL_FN(range->on_unlock))
+			range->on_unlock(kvm);
+	}
 
 	srcu_read_unlock(&kvm->srcu, idx);
 
@@ -600,6 +616,7 @@ static __always_inline int kvm_handle_hva_range(struct mmu_notifier *mn,
 		.pte		= pte,
 		.handler	= handler,
 		.on_lock	= (void *)kvm_null_fn,
+		.on_unlock	= (void *)kvm_null_fn,
 		.flush_on_ret	= true,
 		.may_block	= false,
 	};
@@ -619,6 +636,7 @@ static __always_inline int kvm_handle_hva_range_no_flush(struct mmu_notifier *mn
 		.pte		= __pte(0),
 		.handler	= handler,
 		.on_lock	= (void *)kvm_null_fn,
+		.on_unlock	= (void *)kvm_null_fn,
 		.flush_on_ret	= false,
 		.may_block	= false,
 	};
@@ -687,6 +705,7 @@ static int kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
 		.pte		= __pte(0),
 		.handler	= kvm_unmap_gfn_range,
 		.on_lock	= kvm_inc_notifier_count,
+		.on_unlock	= kvm_arch_guest_memory_reclaimed,
 		.flush_on_ret	= true,
 		.may_block	= mmu_notifier_range_blockable(range),
 	};
@@ -741,6 +760,7 @@ static void kvm_mmu_notifier_invalidate_range_end(struct mmu_notifier *mn,
 		.pte		= __pte(0),
 		.handler	= (void *)kvm_null_fn,
 		.on_lock	= kvm_dec_notifier_count,
+		.on_unlock	= (void *)kvm_null_fn,
 		.flush_on_ret	= false,
 		.may_block	= mmu_notifier_range_blockable(range),
 	};
@@ -813,7 +833,7 @@ static void kvm_mmu_notifier_release(struct mmu_notifier *mn,
 	int idx;
 
 	idx = srcu_read_lock(&kvm->srcu);
-	kvm_arch_flush_shadow_all(kvm);
+	kvm_flush_shadow_all(kvm);
 	srcu_read_unlock(&kvm->srcu, idx);
 }
 
@@ -1225,7 +1245,7 @@ static void kvm_destroy_vm(struct kvm *kvm)
 	WARN_ON(rcuwait_active(&kvm->mn_memslots_update_rcuwait));
 	kvm->mn_active_invalidate_count = 0;
 #else
-	kvm_arch_flush_shadow_all(kvm);
+	kvm_flush_shadow_all(kvm);
 #endif
 	kvm_arch_destroy_vm(kvm);
 	kvm_destroy_devices(kvm);
@@ -1652,6 +1672,7 @@ static void kvm_invalidate_memslot(struct kvm *kvm,
 	 *	- kvm_is_visible_gfn (mmu_check_root)
 	 */
 	kvm_arch_flush_shadow_memslot(kvm, old);
+	kvm_arch_guest_memory_reclaimed(kvm);
 
 	/* Was released by kvm_swap_active_memslots, reacquire. */
 	mutex_lock(&kvm->slots_arch_lock);