From: Yi Liu <yi.l.liu@intel.com>
Date: Thu, 28 Mar 2024 05:29:58 -0700
Subject: iommu: Pass domain to remove_dev_pasid() op
Git-commit: d2f85a263883b679f87ed8f911746105658e9c47
Patch-mainline: v6.10-rc1
References: jsc#PED-10968

Existing remove_dev_pasid() callbacks of the underlying iommu drivers
get the attached domain from the group->pasid_array. However, the domain
stored in group->pasid_array is not always correct in all scenarios.
A wrong domain may result in failure in remove_dev_pasid() callback.
To avoid such problems, it is more reliable to pass the domain to the
remove_dev_pasid() op.

Suggested-by: Jason Gunthorpe <jgg@nvidia.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Link: https://lore.kernel.org/r/20240328122958.83332-3-yi.l.liu@intel.com
Signed-off-by: Joerg Roedel <jroedel@suse.de>
---
 drivers/iommu/amd/amd_iommu.h               |    3 ++-
 drivers/iommu/amd/pasid.c                   |    8 ++------
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c |    9 ++-------
 drivers/iommu/intel/iommu.c                 |   11 +++--------
 drivers/iommu/iommu.c                       |    9 +++++----
 include/linux/iommu.h                       |    3 ++-
 6 files changed, 16 insertions(+), 27 deletions(-)

--- a/drivers/iommu/amd/amd_iommu.h
+++ b/drivers/iommu/amd/amd_iommu.h
@@ -52,7 +52,8 @@ struct iommu_domain *amd_iommu_domain_al
 void amd_iommu_domain_free(struct iommu_domain *dom);
 int iommu_sva_set_dev_pasid(struct iommu_domain *domain,
 			    struct device *dev, ioasid_t pasid);
-void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid);
+void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+				struct iommu_domain *domain);
 
 /* SVA/PASID */
 bool amd_iommu_pasid_supported(void);
--- a/drivers/iommu/amd/pasid.c
+++ b/drivers/iommu/amd/pasid.c
@@ -141,19 +141,15 @@ out_unlock:
 	return ret;
 }
 
-void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+				struct iommu_domain *domain)
 {
 	struct protection_domain *sva_pdom;
-	struct iommu_domain *domain;
 	unsigned long flags;
 
 	if (!is_pasid_valid(dev_iommu_priv_get(dev), pasid))
 		return;
 
-	/* Get protection domain */
-	domain = iommu_get_domain_for_dev_pasid(dev, pasid, IOMMU_DOMAIN_SVA);
-	if (!domain)
-		return;
 	sva_pdom = to_pdomain(domain);
 
 	spin_lock_irqsave(&sva_pdom->lock, flags);
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -3087,14 +3087,9 @@ static int arm_smmu_def_domain_type(stru
 	return 0;
 }
 
-static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+				      struct iommu_domain *domain)
 {
-	struct iommu_domain *domain;
-
-	domain = iommu_get_domain_for_dev_pasid(dev, pasid, IOMMU_DOMAIN_SVA);
-	if (WARN_ON(IS_ERR(domain)) || !domain)
-		return;
-
 	arm_smmu_sva_remove_dev_pasid(domain, dev, pasid);
 }
 
--- a/drivers/iommu/intel/iommu.c
+++ b/drivers/iommu/intel/iommu.c
@@ -4576,19 +4576,15 @@ static int intel_iommu_iotlb_sync_map(st
 	return 0;
 }
 
-static void intel_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+static void intel_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+					 struct iommu_domain *domain)
 {
 	struct device_domain_info *info = dev_iommu_priv_get(dev);
+	struct dmar_domain *dmar_domain = to_dmar_domain(domain);
 	struct dev_pasid_info *curr, *dev_pasid = NULL;
 	struct intel_iommu *iommu = info->iommu;
-	struct dmar_domain *dmar_domain;
-	struct iommu_domain *domain;
 	unsigned long flags;
 
-	domain = iommu_get_domain_for_dev_pasid(dev, pasid, 0);
-	if (WARN_ON_ONCE(!domain))
-		goto out_tear_down;
-
 	/*
 	 * The SVA implementation needs to handle its own stuffs like the mm
 	 * notification. Before consolidating that code into iommu core, let
@@ -4599,7 +4595,6 @@ static void intel_iommu_remove_dev_pasid
 		goto out_tear_down;
 	}
 
-	dmar_domain = to_dmar_domain(domain);
 	spin_lock_irqsave(&dmar_domain->lock, flags);
 	list_for_each_entry(curr, &dmar_domain->dev_pasids, link_domain) {
 		if (curr->dev == dev && curr->pasid == pasid) {
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3335,20 +3335,21 @@ err_revert:
 
 		if (device == last_gdev)
 			break;
-		ops->remove_dev_pasid(device->dev, pasid);
+		ops->remove_dev_pasid(device->dev, pasid, domain);
 	}
 	return ret;
 }
 
 static void __iommu_remove_group_pasid(struct iommu_group *group,
-				       ioasid_t pasid)
+				       ioasid_t pasid,
+				       struct iommu_domain *domain)
 {
 	struct group_device *device;
 	const struct iommu_ops *ops;
 
 	for_each_group_device(group, device) {
 		ops = dev_iommu_ops(device->dev);
-		ops->remove_dev_pasid(device->dev, pasid);
+		ops->remove_dev_pasid(device->dev, pasid, domain);
 	}
 }
 
@@ -3418,7 +3419,7 @@ void iommu_detach_device_pasid(struct io
 	struct iommu_group *group = dev->iommu_group;
 
 	mutex_lock(&group->mutex);
-	__iommu_remove_group_pasid(group, pasid);
+	__iommu_remove_group_pasid(group, pasid, domain);
 	WARN_ON(xa_erase(&group->pasid_array, pasid) != domain);
 	mutex_unlock(&group->mutex);
 }
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -578,7 +578,8 @@ struct iommu_ops {
 			      struct iommu_page_response *msg);
 
 	int (*def_domain_type)(struct device *dev);
-	void (*remove_dev_pasid)(struct device *dev, ioasid_t pasid);
+	void (*remove_dev_pasid)(struct device *dev, ioasid_t pasid,
+				 struct iommu_domain *domain);
 
 	const struct iommu_domain_ops *default_domain_ops;
 	unsigned long pgsize_bitmap;
