From: Vasant Hegde <vasant.hegde@amd.com>
Date: Wed, 22 Nov 2023 09:02:15 +0000
Subject: iommu/amd/pgtbl_v2: Invalidate updated page ranges only
Git-commit: c7fc12354be0ba47566d55f4ebdc6a47bd69d5ed
Patch-mainline: v6.8-rc1
References: jsc#PED-10968

Enhance __domain_flush_pages() to detect domain page table mode and use
that info to build invalidation commands. So that we can use
amd_iommu_domain_flush_pages() to invalidate v2 page table.

Also pass PASID, gn variable to device_flush_iotlb() so that it can build
IOTLB invalidation command for both v1 and v2 page table.

Signed-off-by: Vasant Hegde <vasant.hegde@amd.com>
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Link: https://lore.kernel.org/r/20231122090215.6191-10-vasant.hegde@amd.com
Signed-off-by: Joerg Roedel <jroedel@suse.de>
---
 drivers/iommu/amd/io_pgtable_v2.c | 10 ++--------
 drivers/iommu/amd/iommu.c         | 28 ++++++++++++++++++++--------
 2 files changed, 22 insertions(+), 16 deletions(-)

diff --git a/drivers/iommu/amd/io_pgtable_v2.c b/drivers/iommu/amd/io_pgtable_v2.c
index f818a7e254d4..6d69ba60744f 100644
--- a/drivers/iommu/amd/io_pgtable_v2.c
+++ b/drivers/iommu/amd/io_pgtable_v2.c
@@ -244,7 +244,6 @@ static int iommu_v2_map_pages(struct io_pgtable_ops *ops, unsigned long iova,
 	unsigned long mapped_size = 0;
 	unsigned long o_iova = iova;
 	size_t size = pgcount << __ffs(pgsize);
-	int count = 0;
 	int ret = 0;
 	bool updated = false;
 
@@ -265,19 +264,14 @@ static int iommu_v2_map_pages(struct io_pgtable_ops *ops, unsigned long iova,
 
 		*pte = set_pte_attr(paddr, map_size, prot);
 
-		count++;
 		iova += map_size;
 		paddr += map_size;
 		mapped_size += map_size;
 	}
 
 out:
-	if (updated) {
-		if (count > 1)
-			amd_iommu_flush_tlb(&pdom->domain, 0);
-		else
-			amd_iommu_flush_page(&pdom->domain, 0, o_iova);
-	}
+	if (updated)
+		amd_iommu_domain_flush_pages(pdom, o_iova, size);
 
 	if (mapped)
 		*mapped += mapped_size;
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index 77cf1e3de053..255ea754c0cd 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -85,6 +85,11 @@ static void detach_device(struct device *dev);
  *
  ****************************************************************************/
 
+static inline bool pdom_is_v2_pgtbl_mode(struct protection_domain *pdom)
+{
+	return (pdom && (pdom->flags & PD_IOMMUV2_MASK));
+}
+
 static inline int get_acpihid_device_id(struct device *dev,
 					struct acpihid_map_entry **entry)
 {
@@ -1382,8 +1387,8 @@ void amd_iommu_flush_all_caches(struct amd_iommu *iommu)
 /*
  * Command send function for flushing on-device TLB
  */
-static int device_flush_iotlb(struct iommu_dev_data *dev_data,
-			      u64 address, size_t size)
+static int device_flush_iotlb(struct iommu_dev_data *dev_data, u64 address,
+			      size_t size, ioasid_t pasid, bool gn)
 {
 	struct amd_iommu *iommu;
 	struct iommu_cmd cmd;
@@ -1395,7 +1400,7 @@ static int device_flush_iotlb(struct iommu_dev_data *dev_data,
 		return -EINVAL;
 
 	build_inv_iotlb_pages(&cmd, dev_data->devid, qdep, address,
-			      size, IOMMU_NO_PASID, false);
+			      size, pasid, gn);
 
 	return iommu_queue_command(iommu, &cmd);
 }
@@ -1441,8 +1446,11 @@ static int device_flush_dte(struct iommu_dev_data *dev_data)
 			return ret;
 	}
 
-	if (dev_data->ats_enabled)
-		ret = device_flush_iotlb(dev_data, 0, ~0UL);
+	if (dev_data->ats_enabled) {
+		/* Invalidate the entire contents of an IOTLB */
+		ret = device_flush_iotlb(dev_data, 0, ~0UL,
+					 IOMMU_NO_PASID, false);
+	}
 
 	return ret;
 }
@@ -1458,9 +1466,13 @@ static void __domain_flush_pages(struct protection_domain *domain,
 	struct iommu_dev_data *dev_data;
 	struct iommu_cmd cmd;
 	int ret = 0, i;
+	ioasid_t pasid = IOMMU_NO_PASID;
+	bool gn = false;
+
+	if (pdom_is_v2_pgtbl_mode(domain))
+		gn = true;
 
-	build_inv_iommu_pages(&cmd, address, size, domain->id,
-			      IOMMU_NO_PASID, false);
+	build_inv_iommu_pages(&cmd, address, size, domain->id, pasid, gn);
 
 	for (i = 0; i < amd_iommu_get_num_iommus(); ++i) {
 		if (!domain->dev_iommu[i])
@@ -1478,7 +1490,7 @@ static void __domain_flush_pages(struct protection_domain *domain,
 		if (!dev_data->ats_enabled)
 			continue;
 
-		ret |= device_flush_iotlb(dev_data, address, size);
+		ret |= device_flush_iotlb(dev_data, address, size, pasid, gn);
 	}
 
 	WARN_ON(ret);

