From: Jason Gunthorpe <jgg@nvidia.com>
Date: Sun, 12 Nov 2023 14:50:13 -0400
Subject: iommufd: Add iommufd_ctx to iommufd_put_object()
Git-commit: bd7a282650b8beb57bc9d19bfcb714b1ccae843a
Patch-mainline: v6.7-rc5
References: jsc#PED-7779 jsc#PED-7780

Will be used in the next patch.

Link: https://lore.kernel.org/r/1-v2-ca9e00171c5b+123-iommufd_syz4_jgg@nvidia.com/
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
Acked-by: Joerg Roedel <jroedel@suse.de>
---
 drivers/iommu/iommufd/device.c          | 14 +++++++-------
 drivers/iommu/iommufd/hw_pagetable.c    |  8 ++++----
 drivers/iommu/iommufd/ioas.c            | 14 +++++++-------
 drivers/iommu/iommufd/iommufd_private.h |  3 ++-
 drivers/iommu/iommufd/selftest.c        | 14 +++++++-------
 drivers/iommu/iommufd/vfio_compat.c     | 18 +++++++++---------
 6 files changed, 36 insertions(+), 35 deletions(-)

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 59d3a07300d9..873630c111c1 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -571,7 +571,7 @@ iommufd_device_auto_get_domain(struct iommufd_device *idev,
 			continue;
 		destroy_hwpt = (*do_attach)(idev, hwpt);
 		if (IS_ERR(destroy_hwpt)) {
-			iommufd_put_object(&hwpt->obj);
+			iommufd_put_object(idev->ictx, &hwpt->obj);
 			/*
 			 * -EINVAL means the domain is incompatible with the
 			 * device. Other error codes should propagate to
@@ -583,7 +583,7 @@ iommufd_device_auto_get_domain(struct iommufd_device *idev,
 			goto out_unlock;
 		}
 		*pt_id = hwpt->obj.id;
-		iommufd_put_object(&hwpt->obj);
+		iommufd_put_object(idev->ictx, &hwpt->obj);
 		goto out_unlock;
 	}
 
@@ -652,7 +652,7 @@ static int iommufd_device_change_pt(struct iommufd_device *idev, u32 *pt_id,
 		destroy_hwpt = ERR_PTR(-EINVAL);
 		goto out_put_pt_obj;
 	}
-	iommufd_put_object(pt_obj);
+	iommufd_put_object(idev->ictx, pt_obj);
 
 	/* This destruction has to be after we unlock everything */
 	if (destroy_hwpt)
@@ -660,7 +660,7 @@ static int iommufd_device_change_pt(struct iommufd_device *idev, u32 *pt_id,
 	return 0;
 
 out_put_pt_obj:
-	iommufd_put_object(pt_obj);
+	iommufd_put_object(idev->ictx, pt_obj);
 	return PTR_ERR(destroy_hwpt);
 }
 
@@ -792,7 +792,7 @@ static int iommufd_access_change_ioas_id(struct iommufd_access *access, u32 id)
 	if (IS_ERR(ioas))
 		return PTR_ERR(ioas);
 	rc = iommufd_access_change_ioas(access, ioas);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(access->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -941,7 +941,7 @@ void iommufd_access_notify_unmap(struct io_pagetable *iopt, unsigned long iova,
 
 		access->ops->unmap(access->data, iova, length);
 
-		iommufd_put_object(&access->obj);
+		iommufd_put_object(access->ictx, &access->obj);
 		xa_lock(&ioas->iopt.access_list);
 	}
 	xa_unlock(&ioas->iopt.access_list);
@@ -1243,6 +1243,6 @@ int iommufd_get_hw_info(struct iommufd_ucmd *ucmd)
 out_free:
 	kfree(data);
 out_put:
-	iommufd_put_object(&idev->obj);
+	iommufd_put_object(ucmd->ictx, &idev->obj);
 	return rc;
 }
diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c
index 2abbeafdbd22..cbb5df0a6c32 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -318,9 +318,9 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
 	if (ioas)
 		mutex_unlock(&ioas->mutex);
 out_put_pt:
-	iommufd_put_object(pt_obj);
+	iommufd_put_object(ucmd->ictx, pt_obj);
 out_put_idev:
-	iommufd_put_object(&idev->obj);
+	iommufd_put_object(ucmd->ictx, &idev->obj);
 	return rc;
 }
 
@@ -345,7 +345,7 @@ int iommufd_hwpt_set_dirty_tracking(struct iommufd_ucmd *ucmd)
 	rc = iopt_set_dirty_tracking(&ioas->iopt, hwpt_paging->common.domain,
 				     enable);
 
-	iommufd_put_object(&hwpt_paging->common.obj);
+	iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
 	return rc;
 }
 
@@ -368,6 +368,6 @@ int iommufd_hwpt_get_dirty_bitmap(struct iommufd_ucmd *ucmd)
 	rc = iopt_read_and_clear_dirty_data(
 		&ioas->iopt, hwpt_paging->common.domain, cmd->flags, cmd);
 
-	iommufd_put_object(&hwpt_paging->common.obj);
+	iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
 	return rc;
 }
diff --git a/drivers/iommu/iommufd/ioas.c b/drivers/iommu/iommufd/ioas.c
index d5624577f79f..742248276548 100644
--- a/drivers/iommu/iommufd/ioas.c
+++ b/drivers/iommu/iommufd/ioas.c
@@ -105,7 +105,7 @@ int iommufd_ioas_iova_ranges(struct iommufd_ucmd *ucmd)
 		rc = -EMSGSIZE;
 out_put:
 	up_read(&ioas->iopt.iova_rwsem);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -175,7 +175,7 @@ int iommufd_ioas_allow_iovas(struct iommufd_ucmd *ucmd)
 		interval_tree_remove(node, &allowed_iova);
 		kfree(container_of(node, struct iopt_allowed, node));
 	}
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -228,7 +228,7 @@ int iommufd_ioas_map(struct iommufd_ucmd *ucmd)
 	cmd->iova = iova;
 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
 out_put:
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -258,7 +258,7 @@ int iommufd_ioas_copy(struct iommufd_ucmd *ucmd)
 		return PTR_ERR(src_ioas);
 	rc = iopt_get_pages(&src_ioas->iopt, cmd->src_iova, cmd->length,
 			    &pages_list);
-	iommufd_put_object(&src_ioas->obj);
+	iommufd_put_object(ucmd->ictx, &src_ioas->obj);
 	if (rc)
 		return rc;
 
@@ -279,7 +279,7 @@ int iommufd_ioas_copy(struct iommufd_ucmd *ucmd)
 	cmd->dst_iova = iova;
 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
 out_put_dst:
-	iommufd_put_object(&dst_ioas->obj);
+	iommufd_put_object(ucmd->ictx, &dst_ioas->obj);
 out_pages:
 	iopt_free_pages_list(&pages_list);
 	return rc;
@@ -315,7 +315,7 @@ int iommufd_ioas_unmap(struct iommufd_ucmd *ucmd)
 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
 
 out_put:
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -393,6 +393,6 @@ int iommufd_ioas_option(struct iommufd_ucmd *ucmd)
 		rc = -EOPNOTSUPP;
 	}
 
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index a74cfefffbc6..f918a22f0d48 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -154,7 +154,8 @@ static inline bool iommufd_lock_obj(struct iommufd_object *obj)
 
 struct iommufd_object *iommufd_get_object(struct iommufd_ctx *ictx, u32 id,
 					  enum iommufd_object_type type);
-static inline void iommufd_put_object(struct iommufd_object *obj)
+static inline void iommufd_put_object(struct iommufd_ctx *ictx,
+				      struct iommufd_object *obj)
 {
 	refcount_dec(&obj->users);
 	up_read(&obj->destroy_rwsem);
diff --git a/drivers/iommu/iommufd/selftest.c b/drivers/iommu/iommufd/selftest.c
index 5d93434003d8..022ef8f55088 100644
--- a/drivers/iommu/iommufd/selftest.c
+++ b/drivers/iommu/iommufd/selftest.c
@@ -86,7 +86,7 @@ void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
 	if (IS_ERR(ioas))
 		return;
 	*iova = iommufd_test_syz_conv_iova(&ioas->iopt, iova);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 }
 
 struct mock_iommu_domain {
@@ -500,7 +500,7 @@ get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
 		return hwpt;
 	if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
 	    hwpt->domain->ops != mock_ops.default_domain_ops) {
-		iommufd_put_object(&hwpt->obj);
+		iommufd_put_object(ucmd->ictx, &hwpt->obj);
 		return ERR_PTR(-EINVAL);
 	}
 	*mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
@@ -518,7 +518,7 @@ get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
 		return hwpt;
 	if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
 	    hwpt->domain->ops != &domain_nested_ops) {
-		iommufd_put_object(&hwpt->obj);
+		iommufd_put_object(ucmd->ictx, &hwpt->obj);
 		return ERR_PTR(-EINVAL);
 	}
 	*mock_nested = container_of(hwpt->domain,
@@ -681,7 +681,7 @@ static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
 
 out_dev_obj:
-	iommufd_put_object(dev_obj);
+	iommufd_put_object(ucmd->ictx, dev_obj);
 	return rc;
 }
 
@@ -699,7 +699,7 @@ static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
 	down_write(&ioas->iopt.iova_rwsem);
 	rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
 	up_write(&ioas->iopt.iova_rwsem);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ucmd->ictx, &ioas->obj);
 	return rc;
 }
 
@@ -754,7 +754,7 @@ static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
 	rc = 0;
 
 out_put:
-	iommufd_put_object(&hwpt->obj);
+	iommufd_put_object(ucmd->ictx, &hwpt->obj);
 	return rc;
 }
 
@@ -1233,7 +1233,7 @@ static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
 out_free:
 	kvfree(tmp);
 out_put:
-	iommufd_put_object(&hwpt->obj);
+	iommufd_put_object(ucmd->ictx, &hwpt->obj);
 	return rc;
 }
 
diff --git a/drivers/iommu/iommufd/vfio_compat.c b/drivers/iommu/iommufd/vfio_compat.c
index 538fbf76354d..a3ad5f0b6c59 100644
--- a/drivers/iommu/iommufd/vfio_compat.c
+++ b/drivers/iommu/iommufd/vfio_compat.c
@@ -41,7 +41,7 @@ int iommufd_vfio_compat_ioas_get_id(struct iommufd_ctx *ictx, u32 *out_ioas_id)
 	if (IS_ERR(ioas))
 		return PTR_ERR(ioas);
 	*out_ioas_id = ioas->obj.id;
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return 0;
 }
 EXPORT_SYMBOL_NS_GPL(iommufd_vfio_compat_ioas_get_id, IOMMUFD_VFIO);
@@ -98,7 +98,7 @@ int iommufd_vfio_compat_ioas_create(struct iommufd_ctx *ictx)
 
 	if (ictx->vfio_ioas && iommufd_lock_obj(&ictx->vfio_ioas->obj)) {
 		ret = 0;
-		iommufd_put_object(&ictx->vfio_ioas->obj);
+		iommufd_put_object(ictx, &ictx->vfio_ioas->obj);
 		goto out_abort;
 	}
 	ictx->vfio_ioas = ioas;
@@ -133,7 +133,7 @@ int iommufd_vfio_ioas(struct iommufd_ucmd *ucmd)
 		if (IS_ERR(ioas))
 			return PTR_ERR(ioas);
 		cmd->ioas_id = ioas->obj.id;
-		iommufd_put_object(&ioas->obj);
+		iommufd_put_object(ucmd->ictx, &ioas->obj);
 		return iommufd_ucmd_respond(ucmd, sizeof(*cmd));
 
 	case IOMMU_VFIO_IOAS_SET:
@@ -143,7 +143,7 @@ int iommufd_vfio_ioas(struct iommufd_ucmd *ucmd)
 		xa_lock(&ucmd->ictx->objects);
 		ucmd->ictx->vfio_ioas = ioas;
 		xa_unlock(&ucmd->ictx->objects);
-		iommufd_put_object(&ioas->obj);
+		iommufd_put_object(ucmd->ictx, &ioas->obj);
 		return 0;
 
 	case IOMMU_VFIO_IOAS_CLEAR:
@@ -190,7 +190,7 @@ static int iommufd_vfio_map_dma(struct iommufd_ctx *ictx, unsigned int cmd,
 	iova = map.iova;
 	rc = iopt_map_user_pages(ictx, &ioas->iopt, &iova, u64_to_user_ptr(map.vaddr),
 				 map.size, iommu_prot, 0);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return rc;
 }
 
@@ -249,7 +249,7 @@ static int iommufd_vfio_unmap_dma(struct iommufd_ctx *ictx, unsigned int cmd,
 		rc = -EFAULT;
 
 err_put:
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return rc;
 }
 
@@ -272,7 +272,7 @@ static int iommufd_vfio_cc_iommu(struct iommufd_ctx *ictx)
 	}
 	mutex_unlock(&ioas->mutex);
 
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return rc;
 }
 
@@ -349,7 +349,7 @@ static int iommufd_vfio_set_iommu(struct iommufd_ctx *ictx, unsigned long type)
 	 */
 	if (type == VFIO_TYPE1_IOMMU)
 		rc = iopt_disable_large_pages(&ioas->iopt);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return rc;
 }
 
@@ -511,7 +511,7 @@ static int iommufd_vfio_iommu_get_info(struct iommufd_ctx *ictx,
 
 out_put:
 	up_read(&ioas->iopt.iova_rwsem);
-	iommufd_put_object(&ioas->obj);
+	iommufd_put_object(ictx, &ioas->obj);
 	return rc;
 }
 

