diff --git a/drivers/vfio/pci/pds/dirty.c b/drivers/vfio/pci/pds/dirty.c index 09fed7c1771a..4824cdfe01ed 100644 --- a/drivers/vfio/pci/pds/dirty.c +++ b/drivers/vfio/pci/pds/dirty.c @@ -85,18 +85,20 @@ static int pds_vfio_dirty_alloc_bitmaps(struct pds_vfio_dirty *dirty, return -ENOMEM; } - dirty->region.host_seq.bmp = host_seq_bmp; - dirty->region.host_ack.bmp = host_ack_bmp; + dirty->region.host_seq = host_seq_bmp; + dirty->region.host_ack = host_ack_bmp; + dirty->region.bmp_bytes = bytes; return 0; } static void pds_vfio_dirty_free_bitmaps(struct pds_vfio_dirty *dirty) { - vfree(dirty->region.host_seq.bmp); - vfree(dirty->region.host_ack.bmp); - dirty->region.host_seq.bmp = NULL; - dirty->region.host_ack.bmp = NULL; + vfree(dirty->region.host_seq); + vfree(dirty->region.host_ack); + dirty->region.host_seq = NULL; + dirty->region.host_ack = NULL; + dirty->region.bmp_bytes = 0; } static void __pds_vfio_dirty_free_sgl(struct pds_vfio_pci_device *pds_vfio, @@ -301,8 +303,8 @@ void pds_vfio_dirty_disable(struct pds_vfio_pci_device *pds_vfio, bool send_cmd) static int pds_vfio_dirty_seq_ack(struct pds_vfio_pci_device *pds_vfio, struct pds_vfio_region *region, - struct pds_vfio_bmp_info *bmp_info, - u32 offset, u32 bmp_bytes, bool read_seq) + unsigned long *seq_ack_bmp, u32 offset, + u32 bmp_bytes, bool read_seq) { const char *bmp_type_str = read_seq ? "read_seq" : "write_ack"; u8 dma_dir = read_seq ? DMA_FROM_DEVICE : DMA_TO_DEVICE; @@ -319,7 +321,7 @@ static int pds_vfio_dirty_seq_ack(struct pds_vfio_pci_device *pds_vfio, int err; int i; - bmp = (void *)((u64)bmp_info->bmp + offset); + bmp = (void *)((u64)seq_ack_bmp + offset); page_offset = offset_in_page(bmp); bmp -= page_offset; @@ -387,7 +389,7 @@ static int pds_vfio_dirty_write_ack(struct pds_vfio_pci_device *pds_vfio, u32 offset, u32 len) { - return pds_vfio_dirty_seq_ack(pds_vfio, region, ®ion->host_ack, + return pds_vfio_dirty_seq_ack(pds_vfio, region, region->host_ack, offset, len, WRITE_ACK); } @@ -395,7 +397,7 @@ static int pds_vfio_dirty_read_seq(struct pds_vfio_pci_device *pds_vfio, struct pds_vfio_region *region, u32 offset, u32 len) { - return pds_vfio_dirty_seq_ack(pds_vfio, region, ®ion->host_seq, + return pds_vfio_dirty_seq_ack(pds_vfio, region, region->host_seq, offset, len, READ_SEQ); } @@ -411,8 +413,8 @@ static int pds_vfio_dirty_process_bitmaps(struct pds_vfio_pci_device *pds_vfio, int dword_count; dword_count = len_bytes / sizeof(u64); - seq = (__le64 *)((u64)region->host_seq.bmp + bmp_offset); - ack = (__le64 *)((u64)region->host_ack.bmp + bmp_offset); + seq = (__le64 *)((u64)region->host_seq + bmp_offset); + ack = (__le64 *)((u64)region->host_ack + bmp_offset); bmp_offset_bit = bmp_offset * 8; for (int i = 0; i < dword_count; i++) { @@ -479,6 +481,13 @@ static int pds_vfio_dirty_sync(struct pds_vfio_pci_device *pds_vfio, return -EINVAL; } + if (bmp_bytes > region->bmp_bytes) { + dev_err(dev, + "Calculated bitmap bytes %llu larger than region's cached bmp_bytes %llu\n", + bmp_bytes, region->bmp_bytes); + return -EINVAL; + } + bmp_offset = DIV_ROUND_UP((iova - region->start) / region->page_size, sizeof(u64)); diff --git a/drivers/vfio/pci/pds/dirty.h b/drivers/vfio/pci/pds/dirty.h index 07662d369e7c..a1f6d894f913 100644 --- a/drivers/vfio/pci/pds/dirty.h +++ b/drivers/vfio/pci/pds/dirty.h @@ -4,14 +4,10 @@ #ifndef _DIRTY_H_ #define _DIRTY_H_ -struct pds_vfio_bmp_info { - unsigned long *bmp; - u32 bmp_bytes; -}; - struct pds_vfio_region { - struct pds_vfio_bmp_info host_seq; - struct pds_vfio_bmp_info host_ack; + unsigned long *host_seq; + unsigned long *host_ack; + u64 bmp_bytes; u64 size; u64 start; u64 page_size;