diff --git a/arch/s390/mm/pgtable.c b/arch/s390/mm/pgtable.c
index 4e54492f463a..84bddda8d412 100644
--- a/arch/s390/mm/pgtable.c
+++ b/arch/s390/mm/pgtable.c
@@ -585,7 +585,7 @@ int gmap_fault(struct gmap *gmap, unsigned long gaddr,
 		rc = vmaddr;
 		goto out_up;
 	}
-	if (fixup_user_fault(current, gmap->mm, vmaddr, fault_flags)) {
+	if (fixup_user_fault(current, gmap->mm, vmaddr, fault_flags, NULL)) {
 		rc = -EFAULT;
 		goto out_up;
 	}
@@ -727,7 +727,8 @@ int gmap_ipte_notify(struct gmap *gmap, unsigned long gaddr, unsigned long len)
 			break;
 		}
 		/* Get the page mapped */
-		if (fixup_user_fault(current, gmap->mm, addr, FAULT_FLAG_WRITE)) {
+		if (fixup_user_fault(current, gmap->mm, addr, FAULT_FLAG_WRITE,
+				     NULL)) {
 			rc = -EFAULT;
 			break;
 		}
@@ -802,7 +803,8 @@ retry:
 	if (!(pte_val(*ptep) & _PAGE_INVALID) &&
 	     (pte_val(*ptep) & _PAGE_PROTECT)) {
 		pte_unmap_unlock(ptep, ptl);
-		if (fixup_user_fault(current, mm, addr, FAULT_FLAG_WRITE)) {
+		if (fixup_user_fault(current, mm, addr, FAULT_FLAG_WRITE,
+				     NULL)) {
 			up_read(&mm->mmap_sem);
 			return -EFAULT;
 		}
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 792f2469c142..1d6ec55d8b25 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1194,7 +1194,8 @@ int invalidate_inode_page(struct page *page);
 extern int handle_mm_fault(struct mm_struct *mm, struct vm_area_struct *vma,
 			unsigned long address, unsigned int flags);
 extern int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
-			    unsigned long address, unsigned int fault_flags);
+			    unsigned long address, unsigned int fault_flags,
+			    bool *unlocked);
 #else
 static inline int handle_mm_fault(struct mm_struct *mm,
 			struct vm_area_struct *vma, unsigned long address,
@@ -1206,7 +1207,7 @@ static inline int handle_mm_fault(struct mm_struct *mm,
 }
 static inline int fixup_user_fault(struct task_struct *tsk,
 		struct mm_struct *mm, unsigned long address,
-		unsigned int fault_flags)
+		unsigned int fault_flags, bool *unlocked)
 {
 	/* should never happen if there's no MMU */
 	BUG();
diff --git a/kernel/futex.c b/kernel/futex.c
index eed92a8a4c49..c6f514573b28 100644
--- a/kernel/futex.c
+++ b/kernel/futex.c
@@ -604,7 +604,7 @@ static int fault_in_user_writeable(u32 __user *uaddr)
 
 	down_read(&mm->mmap_sem);
 	ret = fixup_user_fault(current, mm, (unsigned long)uaddr,
-			       FAULT_FLAG_WRITE);
+			       FAULT_FLAG_WRITE, NULL);
 	up_read(&mm->mmap_sem);
 
 	return ret < 0 ? ret : 0;
diff --git a/mm/gup.c b/mm/gup.c
index aa21c4b865a5..b64a36175884 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -618,6 +618,8 @@ EXPORT_SYMBOL(__get_user_pages);
  * @mm:		mm_struct of target mm
  * @address:	user address
  * @fault_flags:flags to pass down to handle_mm_fault()
+ * @unlocked:	did we unlock the mmap_sem while retrying, maybe NULL if caller
+ *		does not allow retry
  *
  * This is meant to be called in the specific scenario where for locking reasons
  * we try to access user memory in atomic context (within a pagefault_disable()
@@ -629,22 +631,28 @@ EXPORT_SYMBOL(__get_user_pages);
  * The main difference with get_user_pages() is that this function will
  * unconditionally call handle_mm_fault() which will in turn perform all the
  * necessary SW fixup of the dirty and young bits in the PTE, while
- * handle_mm_fault() only guarantees to update these in the struct page.
+ * get_user_pages() only guarantees to update these in the struct page.
  *
  * This is important for some architectures where those bits also gate the
  * access permission to the page because they are maintained in software.  On
  * such architectures, gup() will not be enough to make a subsequent access
  * succeed.
  *
- * This has the same semantics wrt the @mm->mmap_sem as does filemap_fault().
+ * This function will not return with an unlocked mmap_sem. So it has not the
+ * same semantics wrt the @mm->mmap_sem as does filemap_fault().
  */
 int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
-		     unsigned long address, unsigned int fault_flags)
+		     unsigned long address, unsigned int fault_flags,
+		     bool *unlocked)
 {
 	struct vm_area_struct *vma;
 	vm_flags_t vm_flags;
-	int ret;
+	int ret, major = 0;
 
+	if (unlocked)
+		fault_flags |= FAULT_FLAG_ALLOW_RETRY;
+
+retry:
 	vma = find_extend_vma(mm, address);
 	if (!vma || address < vma->vm_start)
 		return -EFAULT;
@@ -654,6 +662,7 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
 		return -EFAULT;
 
 	ret = handle_mm_fault(mm, vma, address, fault_flags);
+	major |= ret & VM_FAULT_MAJOR;
 	if (ret & VM_FAULT_ERROR) {
 		if (ret & VM_FAULT_OOM)
 			return -ENOMEM;
@@ -663,8 +672,19 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
 			return -EFAULT;
 		BUG();
 	}
+
+	if (ret & VM_FAULT_RETRY) {
+		down_read(&mm->mmap_sem);
+		if (!(fault_flags & FAULT_FLAG_TRIED)) {
+			*unlocked = true;
+			fault_flags &= ~FAULT_FLAG_ALLOW_RETRY;
+			fault_flags |= FAULT_FLAG_TRIED;
+			goto retry;
+		}
+	}
+
 	if (tsk) {
-		if (ret & VM_FAULT_MAJOR)
+		if (major)
 			tsk->maj_flt++;
 		else
 			tsk->min_flt++;