From: Suravee Suthikulpanit <suravee.suthikulpanit@amd.com>
Date: Mon, 5 Feb 2024 11:56:01 +0000
Subject: iommu/amd: Introduce get_amd_iommu_from_dev()
Git-commit: 6f35fe5d8a0ad0125c52fb20f5d67000e369eb3a
Patch-mainline: v6.9-rc1
References: jsc#PED-10968

Introduce get_amd_iommu_from_dev() and get_amd_iommu_from_dev_data().
And replace rlookup_amd_iommu() with the new helper function where
applicable to avoid unnecessary loop to look up struct amd_iommu from
struct device.

Suggested-by: Jason Gunthorpe <jgg@ziepe.ca>
Signed-off-by: Suravee Suthikulpanit <suravee.suthikulpanit@amd.com>
Signed-off-by: Vasant Hegde <vasant.hegde@amd.com>
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Link: https://lore.kernel.org/r/20240205115615.6053-4-vasant.hegde@amd.com
Signed-off-by: Joerg Roedel <jroedel@suse.de>
---
 drivers/iommu/amd/amd_iommu.h | 15 ++++++++++++
 drivers/iommu/amd/iommu.c     | 54 +++++++++++--------------------------------
 include/linux/iommu.h         | 16 +++++++++++++
 3 files changed, 44 insertions(+), 41 deletions(-)

diff --git a/drivers/iommu/amd/amd_iommu.h b/drivers/iommu/amd/amd_iommu.h
index 7464222ad8af..4e1f9a444a1a 100644
--- a/drivers/iommu/amd/amd_iommu.h
+++ b/drivers/iommu/amd/amd_iommu.h
@@ -138,6 +138,21 @@ static inline void *alloc_pgtable_page(int nid, gfp_t gfp)
 	return page ? page_address(page) : NULL;
 }
 
+/*
+ * This must be called after device probe completes. During probe
+ * use rlookup_amd_iommu() get the iommu.
+ */
+static inline struct amd_iommu *get_amd_iommu_from_dev(struct device *dev)
+{
+	return iommu_get_iommu_dev(dev, struct amd_iommu, iommu);
+}
+
+/* This must be called after device probe completes. */
+static inline struct amd_iommu *get_amd_iommu_from_dev_data(struct iommu_dev_data *dev_data)
+{
+	return iommu_get_iommu_dev(dev_data->dev, struct amd_iommu, iommu);
+}
+
 bool translation_pre_enabled(struct amd_iommu *iommu);
 bool amd_iommu_is_attach_deferred(struct device *dev);
 int __init add_special_device(u8 type, u8 id, u32 *devid, bool cmd_line);
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index 743465cb677f..2a5ae7a841cc 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -1384,14 +1384,9 @@ void amd_iommu_flush_all_caches(struct amd_iommu *iommu)
 static int device_flush_iotlb(struct iommu_dev_data *dev_data, u64 address,
 			      size_t size, ioasid_t pasid, bool gn)
 {
-	struct amd_iommu *iommu;
+	struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
 	struct iommu_cmd cmd;
-	int qdep;
-
-	qdep     = dev_data->ats_qdep;
-	iommu    = rlookup_amd_iommu(dev_data->dev);
-	if (!iommu)
-		return -EINVAL;
+	int qdep = dev_data->ats_qdep;
 
 	build_inv_iotlb_pages(&cmd, dev_data->devid, qdep, address,
 			      size, pasid, gn);
@@ -1411,16 +1406,12 @@ static int device_flush_dte_alias(struct pci_dev *pdev, u16 alias, void *data)
  */
 static int device_flush_dte(struct iommu_dev_data *dev_data)
 {
-	struct amd_iommu *iommu;
+	struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
 	struct pci_dev *pdev = NULL;
 	struct amd_iommu_pci_seg *pci_seg;
 	u16 alias;
 	int ret;
 
-	iommu = rlookup_amd_iommu(dev_data->dev);
-	if (!iommu)
-		return -EINVAL;
-
 	if (dev_is_pci(dev_data->dev))
 		pdev = to_pci_dev(dev_data->dev);
 
@@ -1805,11 +1796,7 @@ static void clear_dte_entry(struct amd_iommu *iommu, u16 devid)
 static void do_attach(struct iommu_dev_data *dev_data,
 		      struct protection_domain *domain)
 {
-	struct amd_iommu *iommu;
-
-	iommu = rlookup_amd_iommu(dev_data->dev);
-	if (!iommu)
-		return;
+	struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
 
 	/* Update data structures */
 	dev_data->domain = domain;
@@ -1833,11 +1820,7 @@ static void do_attach(struct iommu_dev_data *dev_data,
 static void do_detach(struct iommu_dev_data *dev_data)
 {
 	struct protection_domain *domain = dev_data->domain;
-	struct amd_iommu *iommu;
-
-	iommu = rlookup_amd_iommu(dev_data->dev);
-	if (!iommu)
-		return;
+	struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
 
 	/* Update data structures */
 	dev_data->domain = NULL;
@@ -2003,10 +1986,8 @@ static void update_device_table(struct protection_domain *domain)
 	struct iommu_dev_data *dev_data;
 
 	list_for_each_entry(dev_data, &domain->dev_list, list) {
-		struct amd_iommu *iommu = rlookup_amd_iommu(dev_data->dev);
+		struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
 
-		if (!iommu)
-			continue;
 		set_dte_entry(iommu, dev_data);
 		clone_aliases(iommu, dev_data->dev);
 	}
@@ -2187,11 +2168,8 @@ static struct iommu_domain *do_iommu_domain_alloc(unsigned int type,
 	struct protection_domain *domain;
 	struct amd_iommu *iommu = NULL;
 
-	if (dev) {
-		iommu = rlookup_amd_iommu(dev);
-		if (!iommu)
-			return ERR_PTR(-ENODEV);
-	}
+	if (dev)
+		iommu = get_amd_iommu_from_dev(dev);
 
 	/*
 	 * Since DTE[Mode]=0 is prohibited on SNP-enabled system,
@@ -2272,7 +2250,7 @@ static int amd_iommu_attach_device(struct iommu_domain *dom,
 {
 	struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
 	struct protection_domain *domain = to_pdomain(dom);
-	struct amd_iommu *iommu = rlookup_amd_iommu(dev);
+	struct amd_iommu *iommu = get_amd_iommu_from_dev(dev);
 	int ret;
 
 	/*
@@ -2411,7 +2389,7 @@ static bool amd_iommu_capable(struct device *dev, enum iommu_cap cap)
 	case IOMMU_CAP_DEFERRED_FLUSH:
 		return true;
 	case IOMMU_CAP_DIRTY_TRACKING: {
-		struct amd_iommu *iommu = rlookup_amd_iommu(dev);
+		struct amd_iommu *iommu = get_amd_iommu_from_dev(dev);
 
 		return amd_iommu_hd_support(iommu);
 	}
@@ -2440,9 +2418,7 @@ static int amd_iommu_set_dirty_tracking(struct iommu_domain *domain,
 	}
 
 	list_for_each_entry(dev_data, &pdomain->dev_list, list) {
-		iommu = rlookup_amd_iommu(dev_data->dev);
-		if (!iommu)
-			continue;
+		iommu = get_amd_iommu_from_dev_data(dev_data);
 
 		dev_table = get_dev_table(iommu);
 		pte_root = dev_table[dev_data->devid].data[0];
@@ -2502,9 +2478,7 @@ static void amd_iommu_get_resv_regions(struct device *dev,
 		return;
 
 	devid = PCI_SBDF_TO_DEVID(sbdf);
-	iommu = rlookup_amd_iommu(dev);
-	if (!iommu)
-		return;
+	iommu = get_amd_iommu_from_dev(dev);
 	pci_seg = iommu->pci_seg;
 
 	list_for_each_entry(entry, &pci_seg->unity_map, list) {
@@ -2838,9 +2812,7 @@ int amd_iommu_complete_ppr(struct pci_dev *pdev, u32 pasid,
 	struct iommu_cmd cmd;
 
 	dev_data = dev_iommu_priv_get(&pdev->dev);
-	iommu    = rlookup_amd_iommu(&pdev->dev);
-	if (!iommu)
-		return -ENODEV;
+	iommu    = get_amd_iommu_from_dev(&pdev->dev);
 
 	build_complete_ppr(&cmd, dev_data->devid, pasid, status,
 			   tag, dev_data->pri_tlp);
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 1ea2a820e1eb..32bb121e8032 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -654,6 +654,22 @@ static inline struct iommu_device *dev_to_iommu_device(struct device *dev)
 	return (struct iommu_device *)dev_get_drvdata(dev);
 }
 
+/**
+ * iommu_get_iommu_dev - Get iommu_device for a device
+ * @dev: an end-point device
+ *
+ * Note that this function must be called from the iommu_ops
+ * to retrieve the iommu_device for a device, which the core code
+ * guarentees it will not invoke the op without an attached iommu.
+ */
+static inline struct iommu_device *__iommu_get_iommu_dev(struct device *dev)
+{
+	return dev->iommu->iommu_dev;
+}
+
+#define iommu_get_iommu_dev(dev, type, member) \
+	container_of(__iommu_get_iommu_dev(dev), type, member)
+
 static inline void iommu_iotlb_gather_init(struct iommu_iotlb_gather *gather)
 {
 	*gather = (struct iommu_iotlb_gather) {

