diff --git a/drivers/misc/habanalabs/common/memory.c b/drivers/misc/habanalabs/common/memory.c index 53707e9c5c3e..207fda36cac9 100644 --- a/drivers/misc/habanalabs/common/memory.c +++ b/drivers/misc/habanalabs/common/memory.c @@ -314,16 +314,17 @@ static void free_phys_pg_pack(struct hl_device *hdev, /** * free_device_memory() - free device memory. * @ctx: pointer to the context structure. - * @handle: handle of the memory chunk to free. + * @args: host parameters containing the requested size. * * This function does the following: * - Free the device memory related to the given handle. */ -static int free_device_memory(struct hl_ctx *ctx, u32 handle) +static int free_device_memory(struct hl_ctx *ctx, struct hl_mem_in *args) { struct hl_device *hdev = ctx->hdev; struct hl_vm *vm = &hdev->vm; struct hl_vm_phys_pg_pack *phys_pg_pack; + u32 handle = args->free.handle; spin_lock(&vm->idr_lock); phys_pg_pack = idr_find(&vm->phys_pg_pack_handles, handle); @@ -1117,20 +1118,22 @@ init_page_pack_err: /** * unmap_device_va() - unmap the given device virtual address. * @ctx: pointer to the context structure. - * @vaddr: device virtual address to unmap. + * @args: host parameters with device virtual address to unmap. * @ctx_free: true if in context free flow, false otherwise. * * This function does the following: * - unmap the physical pages related to the given virtual address. * - return the device virtual block to the virtual block list. */ -static int unmap_device_va(struct hl_ctx *ctx, u64 vaddr, bool ctx_free) +static int unmap_device_va(struct hl_ctx *ctx, struct hl_mem_in *args, + bool ctx_free) { struct hl_device *hdev = ctx->hdev; struct hl_vm_phys_pg_pack *phys_pg_pack = NULL; struct hl_vm_hash_node *hnode = NULL; struct hl_userptr *userptr = NULL; struct hl_va_range *va_range; + u64 vaddr = args->unmap.device_virt_addr; enum vm_type_t *vm_type; bool is_userptr; int rc = 0; @@ -1280,7 +1283,7 @@ static int mem_ioctl_no_mmu(struct hl_fpriv *hpriv, union hl_mem_args *args) break; case HL_MEM_OP_FREE: - rc = free_device_memory(ctx, args->in.free.handle); + rc = free_device_memory(ctx, &args->in); break; case HL_MEM_OP_MAP: @@ -1388,7 +1391,7 @@ int hl_mem_ioctl(struct hl_fpriv *hpriv, void *data) goto out; } - rc = free_device_memory(ctx, args->in.free.handle); + rc = free_device_memory(ctx, &args->in); break; case HL_MEM_OP_MAP: @@ -1399,8 +1402,7 @@ int hl_mem_ioctl(struct hl_fpriv *hpriv, void *data) break; case HL_MEM_OP_UNMAP: - rc = unmap_device_va(ctx, args->in.unmap.device_virt_addr, - false); + rc = unmap_device_va(ctx, &args->in, false); break; default: @@ -1858,6 +1860,7 @@ void hl_vm_ctx_fini(struct hl_ctx *ctx) struct hl_vm_phys_pg_pack *phys_pg_list; struct hl_vm_hash_node *hnode; struct hlist_node *tmp_node; + struct hl_mem_in args; int i; if (!hdev->mmu_enable) @@ -1877,7 +1880,8 @@ void hl_vm_ctx_fini(struct hl_ctx *ctx) dev_dbg(hdev->dev, "hl_mem_hash_node of vaddr 0x%llx of asid %d is still alive\n", hnode->vaddr, ctx->asid); - unmap_device_va(ctx, hnode->vaddr, true); + args.unmap.device_virt_addr = hnode->vaddr; + unmap_device_va(ctx, &args, true); } mutex_lock(&ctx->mmu_lock);