diff --git a/drivers/infiniband/hw/hfi1/mmu_rb.c b/drivers/infiniband/hw/hfi1/mmu_rb.c index ccbf52c8ff6f..d41fd87a39f2 100644 --- a/drivers/infiniband/hw/hfi1/mmu_rb.c +++ b/drivers/infiniband/hw/hfi1/mmu_rb.c @@ -217,21 +217,27 @@ static struct mmu_rb_node *__mmu_rb_search(struct mmu_rb_handler *handler, return node; } -struct mmu_rb_node *hfi1_mmu_rb_extract(struct mmu_rb_handler *handler, - unsigned long addr, unsigned long len) +bool hfi1_mmu_rb_remove_unless_exact(struct mmu_rb_handler *handler, + unsigned long addr, unsigned long len, + struct mmu_rb_node **rb_node) { struct mmu_rb_node *node; unsigned long flags; + bool ret = false; spin_lock_irqsave(&handler->lock, flags); node = __mmu_rb_search(handler, addr, len); if (node) { + if (node->addr == addr && node->len == len) + goto unlock; __mmu_int_rb_remove(node, &handler->root); list_del(&node->list); /* remove from LRU list */ + ret = true; } +unlock: spin_unlock_irqrestore(&handler->lock, flags); - - return node; + *rb_node = node; + return ret; } void hfi1_mmu_rb_evict(struct mmu_rb_handler *handler, void *evict_arg) diff --git a/drivers/infiniband/hw/hfi1/mmu_rb.h b/drivers/infiniband/hw/hfi1/mmu_rb.h index 754f6ebf13fb..f04cec1e99d1 100644 --- a/drivers/infiniband/hw/hfi1/mmu_rb.h +++ b/drivers/infiniband/hw/hfi1/mmu_rb.h @@ -81,7 +81,8 @@ int hfi1_mmu_rb_insert(struct mmu_rb_handler *handler, void hfi1_mmu_rb_evict(struct mmu_rb_handler *handler, void *evict_arg); void hfi1_mmu_rb_remove(struct mmu_rb_handler *handler, struct mmu_rb_node *mnode); -struct mmu_rb_node *hfi1_mmu_rb_extract(struct mmu_rb_handler *handler, - unsigned long addr, unsigned long len); +bool hfi1_mmu_rb_remove_unless_exact(struct mmu_rb_handler *handler, + unsigned long addr, unsigned long len, + struct mmu_rb_node **rb_node); #endif /* _HFI1_MMU_RB_H */ diff --git a/drivers/infiniband/hw/hfi1/user_sdma.c b/drivers/infiniband/hw/hfi1/user_sdma.c index 16fd519216dc..79450cf2a3d5 100644 --- a/drivers/infiniband/hw/hfi1/user_sdma.c +++ b/drivers/infiniband/hw/hfi1/user_sdma.c @@ -1165,14 +1165,23 @@ static int pin_vector_pages(struct user_sdma_request *req, struct hfi1_user_sdma_pkt_q *pq = req->pq; struct sdma_mmu_node *node = NULL; struct mmu_rb_node *rb_node; + bool extracted; - rb_node = hfi1_mmu_rb_extract(pq->handler, - (unsigned long)iovec->iov.iov_base, - iovec->iov.iov_len); - if (rb_node) + extracted = + hfi1_mmu_rb_remove_unless_exact(pq->handler, + (unsigned long) + iovec->iov.iov_base, + iovec->iov.iov_len, &rb_node); + if (rb_node) { node = container_of(rb_node, struct sdma_mmu_node, rb); - else - rb_node = NULL; + if (!extracted) { + atomic_inc(&node->refcount); + iovec->pages = node->pages; + iovec->npages = node->npages; + iovec->node = node; + return 0; + } + } if (!node) { node = kzalloc(sizeof(*node), GFP_KERNEL);