diff --git a/include/linux/vmacache.h b/include/linux/vmacache.h
index a5b3aa8d281f..3e9a963edd6a 100644
--- a/include/linux/vmacache.h
+++ b/include/linux/vmacache.h
@@ -5,12 +5,6 @@
 #include <linux/sched.h>
 #include <linux/mm.h>
 
-/*
- * Hash based on the page number. Provides a good hit rate for
- * workloads with good locality and those with random accesses as well.
- */
-#define VMACACHE_HASH(addr) ((addr >> PAGE_SHIFT) & VMACACHE_MASK)
-
 static inline void vmacache_flush(struct task_struct *tsk)
 {
 	memset(tsk->vmacache.vmas, 0, sizeof(tsk->vmacache.vmas));
diff --git a/mm/vmacache.c b/mm/vmacache.c
index db7596eb6132..ea517bef7dc5 100644
--- a/mm/vmacache.c
+++ b/mm/vmacache.c
@@ -6,6 +6,18 @@
 #include <linux/sched/task.h>
 #include <linux/mm.h>
 #include <linux/vmacache.h>
+#include <asm/pgtable.h>
+
+/*
+ * Hash based on the pmd of addr if configured with MMU, which provides a good
+ * hit rate for workloads with spatial locality.  Otherwise, use pages.
+ */
+#ifdef CONFIG_MMU
+#define VMACACHE_SHIFT	PMD_SHIFT
+#else
+#define VMACACHE_SHIFT	PAGE_SHIFT
+#endif
+#define VMACACHE_HASH(addr) ((addr >> VMACACHE_SHIFT) & VMACACHE_MASK)
 
 /*
  * Flush vma caches for threads that share a given mm.
@@ -87,6 +99,7 @@ static bool vmacache_valid(struct mm_struct *mm)
 
 struct vm_area_struct *vmacache_find(struct mm_struct *mm, unsigned long addr)
 {
+	int idx = VMACACHE_HASH(addr);
 	int i;
 
 	count_vm_vmacache_event(VMACACHE_FIND_CALLS);
@@ -95,16 +108,20 @@ struct vm_area_struct *vmacache_find(struct mm_struct *mm, unsigned long addr)
 		return NULL;
 
 	for (i = 0; i < VMACACHE_SIZE; i++) {
-		struct vm_area_struct *vma = current->vmacache.vmas[i];
+		struct vm_area_struct *vma = current->vmacache.vmas[idx];
 
-		if (!vma)
-			continue;
-		if (WARN_ON_ONCE(vma->vm_mm != mm))
-			break;
-		if (vma->vm_start <= addr && vma->vm_end > addr) {
-			count_vm_vmacache_event(VMACACHE_FIND_HITS);
-			return vma;
+		if (vma) {
+#ifdef CONFIG_DEBUG_VM_VMACACHE
+			if (WARN_ON_ONCE(vma->vm_mm != mm))
+				break;
+#endif
+			if (vma->vm_start <= addr && vma->vm_end > addr) {
+				count_vm_vmacache_event(VMACACHE_FIND_HITS);
+				return vma;
+			}
 		}
+		if (++idx == VMACACHE_SIZE)
+			idx = 0;
 	}
 
 	return NULL;
@@ -115,6 +132,7 @@ struct vm_area_struct *vmacache_find_exact(struct mm_struct *mm,
 					   unsigned long start,
 					   unsigned long end)
 {
+	int idx = VMACACHE_HASH(start);
 	int i;
 
 	count_vm_vmacache_event(VMACACHE_FIND_CALLS);
@@ -123,12 +141,14 @@ struct vm_area_struct *vmacache_find_exact(struct mm_struct *mm,
 		return NULL;
 
 	for (i = 0; i < VMACACHE_SIZE; i++) {
-		struct vm_area_struct *vma = current->vmacache.vmas[i];
+		struct vm_area_struct *vma = current->vmacache.vmas[idx];
 
 		if (vma && vma->vm_start == start && vma->vm_end == end) {
 			count_vm_vmacache_event(VMACACHE_FIND_HITS);
 			return vma;
 		}
+		if (++idx == VMACACHE_SIZE)
+			idx = 0;
 	}
 
 	return NULL;