From: Jason Gunthorpe <jgg@nvidia.com>
Date: Thu, 11 May 2023 01:42:12 -0300
Subject: iommu: Consolidate the default_domain setup to one function
Git-commit: d99be00f42eac9fc35a164f3f6c8c7a56b295aa9
Patch-mainline: v6.5-rc1
References: jsc#PED-7779 jsc#PED-7780

Make iommu_change_dev_def_domain() general enough to setup the initial
default_domain or replace it with a new default_domain. Call the new
function iommu_setup_default_domain() and make it the only place in the
code that stores to group->default_domain.

Consolidate the three copies of the default_domain setup sequence. The flow
flow requires:

 - Determining the domain type to use
 - Checking if the current default domain is the same type
 - Allocating a domain
 - Doing iommu_create_device_direct_mappings()
 - Attaching it to devices
 - Store group->default_domain

This adjusts the domain allocation from the prior patch to be able to
detect if each of the allocation steps is already the domain we already
have, which is a more robust version of what change default domain was
already doing.

Reviewed-by: Lu Baolu <baolu.lu@linux.intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Tested-by: Heiko Stuebner <heiko@sntech.de>
Tested-by: Niklas Schnelle <schnelle@linux.ibm.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
Link: https://lore.kernel.org/r/14-v5-1b99ae392328+44574-iommu_err_unwind_jgg@nvidia.com
Signed-off-by: Joerg Roedel <jroedel@suse.de>
---
 drivers/iommu/iommu.c | 202 ++++++++++++++++++++++----------------------------
 1 file changed, 89 insertions(+), 113 deletions(-)

diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 7120f57c8028..34f721434b28 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -93,9 +93,6 @@ static const char * const iommu_group_resv_type_string[] = {
 static int iommu_bus_notifier(struct notifier_block *nb,
 			      unsigned long action, void *data);
 static void iommu_release_device(struct device *dev);
-static struct iommu_domain *
-iommu_group_alloc_default_domain(struct iommu_group *group, int req_type);
-static int iommu_get_def_domain_type(struct device *dev);
 static struct iommu_domain *__iommu_domain_alloc(const struct bus_type *bus,
 						 unsigned type);
 static int __iommu_attach_device(struct iommu_domain *domain,
@@ -126,7 +123,9 @@ static void __iommu_group_set_domain_nofail(struct iommu_group *group,
 		group, new_domain, IOMMU_SET_DOMAIN_MUST_SUCCEED));
 }
 
-static int iommu_create_device_direct_mappings(struct iommu_group *group,
+static int iommu_setup_default_domain(struct iommu_group *group,
+				      int target_type);
+static int iommu_create_device_direct_mappings(struct iommu_domain *domain,
 					       struct device *dev);
 static struct iommu_group *iommu_group_get_for_dev(struct device *dev);
 static ssize_t iommu_group_store_type(struct iommu_group *group,
@@ -424,33 +423,18 @@ int iommu_probe_device(struct device *dev)
 
 	mutex_lock(&group->mutex);
 
-	iommu_create_device_direct_mappings(group, dev);
+	if (group->default_domain)
+		iommu_create_device_direct_mappings(group->default_domain, dev);
 
 	if (group->domain) {
 		ret = __iommu_device_set_domain(group, dev, group->domain, 0);
+		if (ret)
+			goto err_unlock;
 	} else if (!group->default_domain) {
-		/*
-		 * Try to allocate a default domain - needs support from the
-		 * IOMMU driver. There are still some drivers which don't
-		 * support default domains, so the return value is not yet
-		 * checked.
-		 */
-		group->default_domain = iommu_group_alloc_default_domain(
-			group, iommu_get_def_domain_type(dev));
-		if (group->default_domain) {
-			iommu_create_device_direct_mappings(group, dev);
-			ret = __iommu_group_set_domain(group,
-						       group->default_domain);
-		}
-
-		/*
-		 * We assume that the iommu driver starts up the device in
-		 * 'set_platform_dma_ops' mode if it does not support default
-		 * domains.
-		 */
+		ret = iommu_setup_default_domain(group, 0);
+		if (ret)
+			goto err_unlock;
 	}
-	if (ret)
-		goto err_unlock;
 
 	mutex_unlock(&group->mutex);
 	iommu_group_put(group);
@@ -967,16 +951,15 @@ int iommu_group_set_name(struct iommu_group *group, const char *name)
 }
 EXPORT_SYMBOL_GPL(iommu_group_set_name);
 
-static int iommu_create_device_direct_mappings(struct iommu_group *group,
+static int iommu_create_device_direct_mappings(struct iommu_domain *domain,
 					       struct device *dev)
 {
-	struct iommu_domain *domain = group->default_domain;
 	struct iommu_resv_region *entry;
 	struct list_head mappings;
 	unsigned long pg_size;
 	int ret = 0;
 
-	if (!domain || !iommu_is_dma_domain(domain))
+	if (!iommu_is_dma_domain(domain))
 		return 0;
 
 	BUG_ON(!domain->pgsize_bitmap);
@@ -1647,6 +1630,15 @@ static int iommu_get_def_domain_type(struct device *dev)
 	return 0;
 }
 
+static struct iommu_domain *
+__iommu_group_alloc_default_domain(const struct bus_type *bus,
+				   struct iommu_group *group, int req_type)
+{
+	if (group->default_domain && group->default_domain->type == req_type)
+		return group->default_domain;
+	return __iommu_domain_alloc(bus, req_type);
+}
+
 /*
  * req_type of 0 means "auto" which means to select a domain based on
  * iommu_def_domain_type or what the driver actually supports.
@@ -1662,17 +1654,17 @@ iommu_group_alloc_default_domain(struct iommu_group *group, int req_type)
 	lockdep_assert_held(&group->mutex);
 
 	if (req_type)
-		return __iommu_domain_alloc(bus, req_type);
+		return __iommu_group_alloc_default_domain(bus, group, req_type);
 
 	/* The driver gave no guidance on what type to use, try the default */
-	dom = __iommu_domain_alloc(bus, iommu_def_domain_type);
+	dom = __iommu_group_alloc_default_domain(bus, group, iommu_def_domain_type);
 	if (dom)
 		return dom;
 
 	/* Otherwise IDENTITY and DMA_FQ defaults will try DMA */
 	if (iommu_def_domain_type == IOMMU_DOMAIN_DMA)
 		return NULL;
-	dom = __iommu_domain_alloc(bus, IOMMU_DOMAIN_DMA);
+	dom = __iommu_group_alloc_default_domain(bus, group, IOMMU_DOMAIN_DMA);
 	if (!dom)
 		return NULL;
 
@@ -1815,21 +1807,6 @@ static void __iommu_group_dma_finalize(struct iommu_group *group)
 				   iommu_group_do_probe_finalize);
 }
 
-static int iommu_do_create_direct_mappings(struct device *dev, void *data)
-{
-	struct iommu_group *group = data;
-
-	iommu_create_device_direct_mappings(group, dev);
-
-	return 0;
-}
-
-static int iommu_group_create_direct_mappings(struct iommu_group *group)
-{
-	return __iommu_group_for_each_dev(group, group,
-					  iommu_do_create_direct_mappings);
-}
-
 int bus_iommu_probe(const struct bus_type *bus)
 {
 	struct iommu_group *group, *next;
@@ -1851,27 +1828,16 @@ int bus_iommu_probe(const struct bus_type *bus)
 		/* Remove item from the list */
 		list_del_init(&group->entry);
 
-		/* Try to allocate default domain */
-		group->default_domain = iommu_group_alloc_default_domain(
-			group, iommu_get_default_domain_type(group, 0));
-		if (!group->default_domain) {
+		ret = iommu_setup_default_domain(group, 0);
+		if (ret) {
 			mutex_unlock(&group->mutex);
-			continue;
+			return ret;
 		}
-
-		iommu_group_create_direct_mappings(group);
-
-		ret = __iommu_group_set_domain(group, group->default_domain);
-
 		mutex_unlock(&group->mutex);
-
-		if (ret)
-			break;
-
 		__iommu_group_dma_finalize(group);
 	}
 
-	return ret;
+	return 0;
 }
 
 bool iommu_present(const struct bus_type *bus)
@@ -2860,68 +2826,83 @@ int iommu_dev_disable_feature(struct device *dev, enum iommu_dev_features feat)
 }
 EXPORT_SYMBOL_GPL(iommu_dev_disable_feature);
 
-/*
- * Changes the default domain of an iommu group
- *
- * @group: The group for which the default domain should be changed
- * @dev: The first device in the group
- * @type: The type of the new default domain that gets associated with the group
- *
- * Returns 0 on success and error code on failure
+/**
+ * iommu_setup_default_domain - Set the default_domain for the group
+ * @group: Group to change
+ * @target_type: Domain type to set as the default_domain
  *
- * Note:
- * 1. Presently, this function is called only when user requests to change the
- *    group's default domain type through /sys/kernel/iommu_groups/<grp_id>/type
- *    Please take a closer look if intended to use for other purposes.
+ * Allocate a default domain and set it as the current domain on the group. If
+ * the group already has a default domain it will be changed to the target_type.
+ * When target_type is 0 the default domain is selected based on driver and
+ * system preferences.
  */
-static int iommu_change_dev_def_domain(struct iommu_group *group,
-				       struct device *dev, int type)
+static int iommu_setup_default_domain(struct iommu_group *group,
+				      int target_type)
 {
-	struct iommu_domain *prev_dom;
+	struct iommu_domain *old_dom = group->default_domain;
+	struct group_device *gdev;
+	struct iommu_domain *dom;
+	int req_type;
 	int ret;
 
 	lockdep_assert_held(&group->mutex);
 
-	prev_dom = group->default_domain;
-	type = iommu_get_default_domain_type(group, type);
-	if (type < 0)
+	req_type = iommu_get_default_domain_type(group, target_type);
+	if (req_type < 0)
 		return -EINVAL;
 
 	/*
-	 * Switch to a new domain only if the requested domain type is different
-	 * from the existing default domain type
+	 * There are still some drivers which don't support default domains, so
+	 * we ignore the failure and leave group->default_domain NULL.
+	 *
+	 * We assume that the iommu driver starts up the device in
+	 * 'set_platform_dma_ops' mode if it does not support default domains.
 	 */
-	if (prev_dom->type == type)
+	dom = iommu_group_alloc_default_domain(group, req_type);
+	if (!dom) {
+		/* Once in default_domain mode we never leave */
+		if (group->default_domain)
+			return -ENODEV;
+		group->default_domain = NULL;
 		return 0;
-
-	group->default_domain = NULL;
-	group->domain = NULL;
-
-	/* Sets group->default_domain to the newly allocated domain */
-	group->default_domain = iommu_group_alloc_default_domain(group, type);
-	if (!group->default_domain) {
-		ret = -EINVAL;
-		goto restore_old_domain;
 	}
 
-	group->domain = prev_dom;
-	ret = iommu_create_device_direct_mappings(group, dev);
-	if (ret)
-		goto free_new_domain;
-
-	ret = __iommu_group_set_domain(group, group->default_domain);
-	if (ret)
-		goto free_new_domain;
-
-	iommu_domain_free(prev_dom);
+	if (group->default_domain == dom)
+		return 0;
 
-	return 0;
+	/*
+	 * IOMMU_RESV_DIRECT and IOMMU_RESV_DIRECT_RELAXABLE regions must be
+	 * mapped before their device is attached, in order to guarantee
+	 * continuity with any FW activity
+	 */
+	for_each_group_device(group, gdev)
+		iommu_create_device_direct_mappings(dom, gdev->dev);
 
-free_new_domain:
-	iommu_domain_free(group->default_domain);
-restore_old_domain:
-	group->default_domain = prev_dom;
+	/* We must set default_domain early for __iommu_device_set_domain */
+	group->default_domain = dom;
+	if (!group->domain) {
+		/*
+		 * Drivers are not allowed to fail the first domain attach.
+		 * The only way to recover from this is to fail attaching the
+		 * iommu driver and call ops->release_device. Put the domain
+		 * in group->default_domain so it is freed after.
+		 */
+		ret = __iommu_group_set_domain_internal(
+			group, dom, IOMMU_SET_DOMAIN_MUST_SUCCEED);
+		if (WARN_ON(ret))
+			goto out_free;
+	} else {
+		ret = __iommu_group_set_domain(group, dom);
+		if (ret) {
+			iommu_domain_free(dom);
+			group->default_domain = old_dom;
+			return ret;
+		}
+	}
 
+out_free:
+	if (old_dom)
+		iommu_domain_free(old_dom);
 	return ret;
 }
 
@@ -2937,8 +2918,6 @@ static int iommu_change_dev_def_domain(struct iommu_group *group,
 static ssize_t iommu_group_store_type(struct iommu_group *group,
 				      const char *buf, size_t count)
 {
-	struct group_device *grp_dev;
-	struct device *dev;
 	int ret, req_type;
 
 	if (!capable(CAP_SYS_ADMIN) || !capable(CAP_SYS_RAWIO))
@@ -2976,10 +2955,7 @@ static ssize_t iommu_group_store_type(struct iommu_group *group,
 		return -EPERM;
 	}
 
-	grp_dev = list_first_entry(&group->devices, struct group_device, list);
-	dev = grp_dev->dev;
-
-	ret = iommu_change_dev_def_domain(group, dev, req_type);
+	ret = iommu_setup_default_domain(group, req_type);
 
 	/*
 	 * Release the mutex here because ops->probe_finalize() call-back of

