diff --git a/include/linux/jump_label.h b/include/linux/jump_label.h
index b6a29c126cc4..2168cc6b8b30 100644
--- a/include/linux/jump_label.h
+++ b/include/linux/jump_label.h
@@ -151,6 +151,7 @@ extern struct jump_entry __start___jump_table[];
 extern struct jump_entry __stop___jump_table[];
 
 extern void jump_label_init(void);
+extern void jump_label_invalidate_init(void);
 extern void jump_label_lock(void);
 extern void jump_label_unlock(void);
 extern void arch_jump_label_transform(struct jump_entry *entry,
@@ -198,6 +199,8 @@ static __always_inline void jump_label_init(void)
 	static_key_initialized = true;
 }
 
+static inline void jump_label_invalidate_init(void) {}
+
 static __always_inline bool static_key_false(struct static_key *key)
 {
 	if (unlikely(static_key_count(key) > 0))
diff --git a/init/main.c b/init/main.c
index a8100b954839..969eaf140ef0 100644
--- a/init/main.c
+++ b/init/main.c
@@ -89,6 +89,7 @@
 #include <linux/io.h>
 #include <linux/cache.h>
 #include <linux/rodata_test.h>
+#include <linux/jump_label.h>
 
 #include <asm/io.h>
 #include <asm/bugs.h>
@@ -1000,6 +1001,7 @@ static int __ref kernel_init(void *unused)
 	/* need to finish all async __init code before freeing the memory */
 	async_synchronize_full();
 	ftrace_free_init_mem();
+	jump_label_invalidate_init();
 	free_initmem();
 	mark_readonly();
 	system_state = SYSTEM_RUNNING;
diff --git a/kernel/jump_label.c b/kernel/jump_label.c
index b4517095db6a..b71776576a66 100644
--- a/kernel/jump_label.c
+++ b/kernel/jump_label.c
@@ -16,6 +16,7 @@
 #include <linux/jump_label_ratelimit.h>
 #include <linux/bug.h>
 #include <linux/cpu.h>
+#include <asm/sections.h>
 
 #ifdef HAVE_JUMP_LABEL
 
@@ -417,6 +418,20 @@ void __init jump_label_init(void)
 	cpus_read_unlock();
 }
 
+/* Disable any jump label entries in __init code */
+void __init jump_label_invalidate_init(void)
+{
+	struct jump_entry *iter_start = __start___jump_table;
+	struct jump_entry *iter_stop = __stop___jump_table;
+	struct jump_entry *iter;
+
+	for (iter = iter_start; iter < iter_stop; iter++) {
+		if (iter->code >= (unsigned long)_sinittext &&
+		    iter->code < (unsigned long)_einittext)
+			iter->code = 0;
+	}
+}
+
 #ifdef CONFIG_MODULES
 
 static enum jump_label_type jump_label_init_type(struct jump_entry *entry)
@@ -633,6 +648,7 @@ static void jump_label_del_module(struct module *mod)
 	}
 }
 
+/* Disable any jump label entries in module init code */
 static void jump_label_invalidate_module_init(struct module *mod)
 {
 	struct jump_entry *iter_start = mod->jump_entries;