diff --git a/drivers/staging/rdma/hfi1/mmu_rb.c b/drivers/staging/rdma/hfi1/mmu_rb.c index c7ad0164ea9a..eac4d041d351 100644 --- a/drivers/staging/rdma/hfi1/mmu_rb.c +++ b/drivers/staging/rdma/hfi1/mmu_rb.c @@ -71,6 +71,7 @@ static inline void mmu_notifier_range_start(struct mmu_notifier *, struct mm_struct *, unsigned long, unsigned long); static void mmu_notifier_mem_invalidate(struct mmu_notifier *, + struct mm_struct *, unsigned long, unsigned long); static struct mmu_rb_node *__mmu_rb_search(struct mmu_rb_handler *, unsigned long, unsigned long); @@ -137,7 +138,7 @@ void hfi1_mmu_rb_unregister(struct rb_root *root) rbnode = rb_entry(node, struct mmu_rb_node, node); rb_erase(node, root); if (handler->ops->remove) - handler->ops->remove(root, rbnode, false); + handler->ops->remove(root, rbnode, NULL); } } @@ -201,14 +202,14 @@ static struct mmu_rb_node *__mmu_rb_search(struct mmu_rb_handler *handler, } static void __mmu_rb_remove(struct mmu_rb_handler *handler, - struct mmu_rb_node *node, bool arg) + struct mmu_rb_node *node, struct mm_struct *mm) { /* Validity of handler and node pointers has been checked by caller. */ hfi1_cdbg(MMU, "Removing node addr 0x%llx, len %u", node->addr, node->len); __mmu_int_rb_remove(node, handler->root); if (handler->ops->remove) - handler->ops->remove(handler->root, node, arg); + handler->ops->remove(handler->root, node, mm); } struct mmu_rb_node *hfi1_mmu_rb_search(struct rb_root *root, unsigned long addr, @@ -237,7 +238,7 @@ void hfi1_mmu_rb_remove(struct rb_root *root, struct mmu_rb_node *node) return; spin_lock_irqsave(&handler->lock, flags); - __mmu_rb_remove(handler, node, false); + __mmu_rb_remove(handler, node, NULL); spin_unlock_irqrestore(&handler->lock, flags); } @@ -260,7 +261,7 @@ unlock: static inline void mmu_notifier_page(struct mmu_notifier *mn, struct mm_struct *mm, unsigned long addr) { - mmu_notifier_mem_invalidate(mn, addr, addr + PAGE_SIZE); + mmu_notifier_mem_invalidate(mn, mm, addr, addr + PAGE_SIZE); } static inline void mmu_notifier_range_start(struct mmu_notifier *mn, @@ -268,25 +269,28 @@ static inline void mmu_notifier_range_start(struct mmu_notifier *mn, unsigned long start, unsigned long end) { - mmu_notifier_mem_invalidate(mn, start, end); + mmu_notifier_mem_invalidate(mn, mm, start, end); } static void mmu_notifier_mem_invalidate(struct mmu_notifier *mn, + struct mm_struct *mm, unsigned long start, unsigned long end) { struct mmu_rb_handler *handler = container_of(mn, struct mmu_rb_handler, mn); struct rb_root *root = handler->root; - struct mmu_rb_node *node; + struct mmu_rb_node *node, *ptr = NULL; unsigned long flags; spin_lock_irqsave(&handler->lock, flags); - for (node = __mmu_int_rb_iter_first(root, start, end - 1); node; - node = __mmu_int_rb_iter_next(node, start, end - 1)) { + for (node = __mmu_int_rb_iter_first(root, start, end - 1); + node; node = ptr) { + /* Guard against node removal. */ + ptr = __mmu_int_rb_iter_next(node, start, end - 1); hfi1_cdbg(MMU, "Invalidating node addr 0x%llx, len %u", node->addr, node->len); if (handler->ops->invalidate(root, node)) - __mmu_rb_remove(handler, node, true); + __mmu_rb_remove(handler, node, mm); } spin_unlock_irqrestore(&handler->lock, flags); } diff --git a/drivers/staging/rdma/hfi1/mmu_rb.h b/drivers/staging/rdma/hfi1/mmu_rb.h index f8523fdb8a18..19a306e83c7d 100644 --- a/drivers/staging/rdma/hfi1/mmu_rb.h +++ b/drivers/staging/rdma/hfi1/mmu_rb.h @@ -59,7 +59,8 @@ struct mmu_rb_node { struct mmu_rb_ops { bool (*filter)(struct mmu_rb_node *, unsigned long, unsigned long); int (*insert)(struct rb_root *, struct mmu_rb_node *); - void (*remove)(struct rb_root *, struct mmu_rb_node *, bool); + void (*remove)(struct rb_root *, struct mmu_rb_node *, + struct mm_struct *); int (*invalidate)(struct rb_root *, struct mmu_rb_node *); }; diff --git a/drivers/staging/rdma/hfi1/user_exp_rcv.c b/drivers/staging/rdma/hfi1/user_exp_rcv.c index 0861e095df8d..5b72849bbd71 100644 --- a/drivers/staging/rdma/hfi1/user_exp_rcv.c +++ b/drivers/staging/rdma/hfi1/user_exp_rcv.c @@ -87,7 +87,8 @@ static u32 find_phys_blocks(struct page **, unsigned, struct tid_pageset *); static int set_rcvarray_entry(struct file *, unsigned long, u32, struct tid_group *, struct page **, unsigned); static int mmu_rb_insert(struct rb_root *, struct mmu_rb_node *); -static void mmu_rb_remove(struct rb_root *, struct mmu_rb_node *, bool); +static void mmu_rb_remove(struct rb_root *, struct mmu_rb_node *, + struct mm_struct *); static int mmu_rb_invalidate(struct rb_root *, struct mmu_rb_node *); static int program_rcvarray(struct file *, unsigned long, struct tid_group *, struct tid_pageset *, unsigned, u16, struct page **, @@ -899,7 +900,7 @@ static int unprogram_rcvarray(struct file *fp, u32 tidinfo, if (!node || node->rcventry != (uctxt->expected_base + rcventry)) return -EBADF; if (HFI1_CAP_IS_USET(TID_UNMAP)) - mmu_rb_remove(&fd->tid_rb_root, &node->mmu, false); + mmu_rb_remove(&fd->tid_rb_root, &node->mmu, NULL); else hfi1_mmu_rb_remove(&fd->tid_rb_root, &node->mmu); @@ -965,7 +966,7 @@ static void unlock_exp_tids(struct hfi1_ctxtdata *uctxt, continue; if (HFI1_CAP_IS_USET(TID_UNMAP)) mmu_rb_remove(&fd->tid_rb_root, - &node->mmu, false); + &node->mmu, NULL); else hfi1_mmu_rb_remove(&fd->tid_rb_root, &node->mmu); @@ -1032,7 +1033,7 @@ static int mmu_rb_insert(struct rb_root *root, struct mmu_rb_node *node) } static void mmu_rb_remove(struct rb_root *root, struct mmu_rb_node *node, - bool notifier) + struct mm_struct *mm) { struct hfi1_filedata *fdata = container_of(root, struct hfi1_filedata, tid_rb_root); diff --git a/drivers/staging/rdma/hfi1/user_sdma.c b/drivers/staging/rdma/hfi1/user_sdma.c index ab6b6a42000f..e08c74fe4c6b 100644 --- a/drivers/staging/rdma/hfi1/user_sdma.c +++ b/drivers/staging/rdma/hfi1/user_sdma.c @@ -299,7 +299,8 @@ static int defer_packet_queue( static void activate_packet_queue(struct iowait *, int); static bool sdma_rb_filter(struct mmu_rb_node *, unsigned long, unsigned long); static int sdma_rb_insert(struct rb_root *, struct mmu_rb_node *); -static void sdma_rb_remove(struct rb_root *, struct mmu_rb_node *, bool); +static void sdma_rb_remove(struct rb_root *, struct mmu_rb_node *, + struct mm_struct *); static int sdma_rb_invalidate(struct rb_root *, struct mmu_rb_node *); static struct mmu_rb_ops sdma_rb_ops = { @@ -1063,8 +1064,10 @@ static int pin_vector_pages(struct user_sdma_request *req, rb_node = hfi1_mmu_rb_search(&pq->sdma_rb_root, (unsigned long)iovec->iov.iov_base, iovec->iov.iov_len); - if (rb_node) + if (rb_node && !IS_ERR(rb_node)) node = container_of(rb_node, struct sdma_mmu_node, rb); + else + rb_node = NULL; if (!node) { node = kzalloc(sizeof(*node), GFP_KERNEL); @@ -1502,7 +1505,7 @@ static void user_sdma_free_request(struct user_sdma_request *req, bool unpin) &req->pq->sdma_rb_root, (unsigned long)req->iovs[i].iov.iov_base, req->iovs[i].iov.iov_len); - if (!mnode) + if (!mnode || IS_ERR(mnode)) continue; node = container_of(mnode, struct sdma_mmu_node, rb); @@ -1547,7 +1550,7 @@ static int sdma_rb_insert(struct rb_root *root, struct mmu_rb_node *mnode) } static void sdma_rb_remove(struct rb_root *root, struct mmu_rb_node *mnode, - bool notifier) + struct mm_struct *mm) { struct sdma_mmu_node *node = container_of(mnode, struct sdma_mmu_node, rb); @@ -1557,14 +1560,19 @@ static void sdma_rb_remove(struct rb_root *root, struct mmu_rb_node *mnode, node->pq->n_locked -= node->npages; spin_unlock(&node->pq->evict_lock); - unpin_vector_pages(notifier ? NULL : current->mm, node->pages, - node->npages); + /* + * If mm is set, we are being called by the MMU notifier and we + * should not pass a mm_struct to unpin_vector_page(). This is to + * prevent a deadlock when hfi1_release_user_pages() attempts to + * take the mmap_sem, which the MMU notifier has already taken. + */ + unpin_vector_pages(mm ? NULL : current->mm, node->pages, node->npages); /* * If called by the MMU notifier, we have to adjust the pinned * page count ourselves. */ - if (notifier) - current->mm->pinned_vm -= node->npages; + if (mm) + mm->pinned_vm -= node->npages; kfree(node); }