diff --git a/arch/arm64/kvm/guest.c b/arch/arm64/kvm/guest.c
index 4ddb20017b2f..60815ae477cf 100644
--- a/arch/arm64/kvm/guest.c
+++ b/arch/arm64/kvm/guest.c
@@ -1053,6 +1053,14 @@ long kvm_vm_ioctl_mte_copy_tags(struct kvm *kvm,
 		} else {
 			num_tags = mte_copy_tags_from_user(maddr, tags,
 							MTE_GRANULES_PER_PAGE);
+
+			/*
+			 * Set the flag after checking the write
+			 * completed fully
+			 */
+			if (num_tags == MTE_GRANULES_PER_PAGE)
+				set_bit(PG_mte_tagged, &page->flags);
+
 			kvm_release_pfn_dirty(pfn);
 		}
 
@@ -1061,10 +1069,6 @@ long kvm_vm_ioctl_mte_copy_tags(struct kvm *kvm,
 			goto out;
 		}
 
-		/* Set the flag after checking the write completed fully */
-		if (write)
-			set_bit(PG_mte_tagged, &page->flags);
-
 		gfn++;
 		tags += num_tags;
 		length -= PAGE_SIZE;