diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index a1561957dccb..5bfe61a5e8e3 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -151,6 +151,34 @@ void switch_mm(struct mm_struct *prev, struct mm_struct *next,
 	local_irq_restore(flags);
 }
 
+static void sync_current_stack_to_mm(struct mm_struct *mm)
+{
+	unsigned long sp = current_stack_pointer;
+	pgd_t *pgd = pgd_offset(mm, sp);
+
+	if (CONFIG_PGTABLE_LEVELS > 4) {
+		if (unlikely(pgd_none(*pgd))) {
+			pgd_t *pgd_ref = pgd_offset_k(sp);
+
+			set_pgd(pgd, *pgd_ref);
+		}
+	} else {
+		/*
+		 * "pgd" is faked.  The top level entries are "p4d"s, so sync
+		 * the p4d.  This compiles to approximately the same code as
+		 * the 5-level case.
+		 */
+		p4d_t *p4d = p4d_offset(pgd, sp);
+
+		if (unlikely(p4d_none(*p4d))) {
+			pgd_t *pgd_ref = pgd_offset_k(sp);
+			p4d_t *p4d_ref = p4d_offset(pgd_ref, sp);
+
+			set_p4d(p4d, *p4d_ref);
+		}
+	}
+}
+
 void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 			struct task_struct *tsk)
 {
@@ -226,11 +254,7 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 			 * mapped in the new pgd, we'll double-fault.  Forcibly
 			 * map it.
 			 */
-			unsigned int index = pgd_index(current_stack_pointer);
-			pgd_t *pgd = next->pgd + index;
-
-			if (unlikely(pgd_none(*pgd)))
-				set_pgd(pgd, init_mm.pgd[index]);
+			sync_current_stack_to_mm(next);
 		}
 
 		/* Stop remote flushes for the previous mm */