diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c index 77adf405e23c..7300d0a10d1e 100644 --- a/drivers/infiniband/core/umem_odp.c +++ b/drivers/infiniband/core/umem_odp.c @@ -176,18 +176,15 @@ static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp) struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; down_write(&per_mm->umem_rwsem); - if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp))) { - /* - * Note that the representation of the intervals in the - * interval tree considers the ending point as contained in - * the interval, while the function ib_umem_end returns the - * first address which is not contained in the umem. - */ - umem_odp->interval_tree.start = ib_umem_start(umem_odp); - umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1; - interval_tree_insert(&umem_odp->interval_tree, - &per_mm->umem_tree); - } + /* + * Note that the representation of the intervals in the interval tree + * considers the ending point as contained in the interval, while the + * function ib_umem_end returns the first address which is not + * contained in the umem. + */ + umem_odp->interval_tree.start = ib_umem_start(umem_odp); + umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1; + interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree); up_write(&per_mm->umem_rwsem); } @@ -196,11 +193,8 @@ static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp) struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; down_write(&per_mm->umem_rwsem); - if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp))) - interval_tree_remove(&umem_odp->interval_tree, - &per_mm->umem_tree); + interval_tree_remove(&umem_odp->interval_tree, &per_mm->umem_tree); complete_all(&umem_odp->notifier_completion); - up_write(&per_mm->umem_rwsem); } @@ -320,6 +314,9 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root, int pages = size >> PAGE_SHIFT; int ret; + if (!size) + return ERR_PTR(-EINVAL); + odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL); if (!odp_data) return ERR_PTR(-ENOMEM); @@ -381,6 +378,9 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) struct mm_struct *mm = umem->owning_mm; int ret_val; + if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0) + umem_odp->is_implicit_odp = 1; + umem_odp->page_shift = PAGE_SHIFT; if (access & IB_ACCESS_HUGETLB) { struct vm_area_struct *vma; @@ -401,7 +401,10 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) init_completion(&umem_odp->notifier_completion); - if (ib_umem_odp_num_pages(umem_odp)) { + if (!umem_odp->is_implicit_odp) { + if (!ib_umem_odp_num_pages(umem_odp)) + return -EINVAL; + umem_odp->page_list = vzalloc(array_size(sizeof(*umem_odp->page_list), ib_umem_odp_num_pages(umem_odp))); @@ -420,7 +423,9 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ret_val = get_per_mm(umem_odp); if (ret_val) goto out_dma_list; - add_umem_to_per_mm(umem_odp); + + if (!umem_odp->is_implicit_odp) + add_umem_to_per_mm(umem_odp); return 0; @@ -439,13 +444,14 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) * It is the driver's responsibility to ensure, before calling us, * that the hardware will not attempt to access the MR any more. */ - ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp), - ib_umem_end(umem_odp)); - - remove_umem_from_per_mm(umem_odp); + if (!umem_odp->is_implicit_odp) { + ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp), + ib_umem_end(umem_odp)); + remove_umem_from_per_mm(umem_odp); + vfree(umem_odp->dma_list); + vfree(umem_odp->page_list); + } put_per_mm(umem_odp); - vfree(umem_odp->dma_list); - vfree(umem_odp->page_list); } /* diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c index b74fad08412f..ba2ec495b6e3 100644 --- a/drivers/infiniband/hw/mlx5/mr.c +++ b/drivers/infiniband/hw/mlx5/mr.c @@ -1600,7 +1600,7 @@ static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr) /* Wait for all running page-fault handlers to finish. */ synchronize_srcu(&dev->mr_srcu); /* Destroy all page mappings */ - if (umem_odp->page_list) + if (!umem_odp->is_implicit_odp) mlx5_ib_invalidate_range(umem_odp, ib_umem_start(umem_odp), ib_umem_end(umem_odp)); diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c index 82b716a28ec1..80c07d85b966 100644 --- a/drivers/infiniband/hw/mlx5/odp.c +++ b/drivers/infiniband/hw/mlx5/odp.c @@ -584,7 +584,7 @@ static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr, struct ib_umem_odp *odp; size_t size; - if (!odp_mr->page_list) { + if (odp_mr->is_implicit_odp) { odp = implicit_mr_get_data(mr, io_virt, bcnt); if (IS_ERR(odp)) diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h index 030d5cbad02c..14b38b4459c5 100644 --- a/include/rdma/ib_umem_odp.h +++ b/include/rdma/ib_umem_odp.h @@ -69,6 +69,14 @@ struct ib_umem_odp { /* Tree tracking */ struct interval_tree_node interval_tree; + /* + * An implicit odp umem cannot be DMA mapped, has 0 length, and serves + * only as an anchor for the driver to hold onto the per_mm. FIXME: + * This should be removed and drivers should work with the per_mm + * directly. + */ + bool is_implicit_odp; + struct completion notifier_completion; int dying; unsigned int page_shift;