/*
 * livepatch_bsc1232900
 *
 * Fix for CVE-2024-49855, bsc#1232900
 *
 *  Upstream commit:
 *  c9ea57c91f03 ("nbd: fix race between timeout and normal completion")
 *
 *  SLE12-SP5 commit:
 *  Not affected
 *
 *  SLE15-SP3 commit:
 *  Not affected
 *
 *  SLE15-SP4 and -SP5 commit:
 *  Not affected
 *
 *  SLE15-SP6 commit:
 *  57c54c81084d6fb5c9e9df0725264ca23ee310d6
 *
 *  SLE MICRO-6-0 commit:
 *  57c54c81084d6fb5c9e9df0725264ca23ee310d6
 *
 *  Copyright (c) 2025 SUSE
 *  Author: Ali Abdallah <ali.abdallah@suse.com>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

/* klp-ccp: from drivers/block/nbd.c */
#define pr_fmt(fmt) "nbd: " fmt

#include <linux/major.h>

#include <linux/blkdev.h>
#include <linux/module.h>
#include <linux/init.h>
#include <linux/sched.h>
#include <linux/sched/mm.h>
#include <linux/fs.h>
#include <linux/bio.h>
#include <linux/stat.h>
#include <linux/errno.h>

#include <linux/ioctl.h>
#include <linux/mutex.h>
#include <linux/compiler.h>
#include <linux/completion.h>
#include <linux/err.h>
#include <linux/kernel.h>
#include <linux/slab.h>

/* klp-ccp: from include/linux/prefetch.h */
#define _LINUX_PREFETCH_H

#define prefetch(x) __builtin_prefetch(x)

/* klp-ccp: from include/linux/scatterlist.h */
#define _LINUX_SCATTERLIST_H

/* klp-ccp: from include/linux/kthread.h */
#define _LINUX_KTHREAD_H

/* klp-ccp: from include/linux/net.h */
#define _LINUX_NET_H

/* klp-ccp: from drivers/block/nbd.c */
#include <linux/net.h>
#include <linux/kthread.h>
#include <linux/types.h>

#include <linux/blk-mq.h>

#include <linux/uaccess.h>
#include <asm/types.h>

#include <linux/nbd.h>

#define CREATE_TRACE_POINTS

/* klp-ccp: from include/trace/events/nbd.h */
#define _TRACE_NBD_H

/* klp-ccp: from include/linux/tracepoint.h */
#define _LINUX_TRACEPOINT_H

/* klp-ccp: from include/trace/define_trace.h */
#ifdef CREATE_TRACE_POINTS

#undef CREATE_TRACE_POINTS

#include <linux/stringify.h>

#define TRACE_EVENT(name, proto, args, tstruct, assign, print)	\
	DEFINE_TRACE(name, PARAMS(proto), PARAMS(args))

#define DECLARE_TRACE(name, proto, args)	\
	DEFINE_TRACE(name, PARAMS(proto), PARAMS(args))

#define TRACE_HEADER_MULTI_READ

/* klp-ccp: from include/trace/events/nbd.h */
#if !defined(_TRACE_NBD_H) || defined(TRACE_HEADER_MULTI_READ)

#include <linux/tracepoint.h>

#else
#error "klp-ccp: a preceeding branch should have been taken"
#endif

#include <trace/define_trace.h>

#else
#error "klp-ccp: a preceeding branch should have been taken"
/* klp-ccp: from include/trace/define_trace.h */
#endif /* CREATE_TRACE_POINTS */

/* klp-ccp: from drivers/block/nbd.c */
struct nbd_sock {
	struct socket *sock;
	struct mutex tx_lock;
	struct request *pending;
	int sent;
	bool dead;
	int fallback_index;
	int cookie;
};

#define NBD_RT_TIMEDOUT			0

struct nbd_config {
	u32 flags;
	unsigned long runtime_flags;
	u64 dead_conn_timeout;

	struct nbd_sock **socks;
	int num_connections;
	atomic_t live_connections;
	wait_queue_head_t conn_wait;

	atomic_t recv_threads;
	wait_queue_head_t recv_wq;
	unsigned int blksize_bits;
	loff_t bytesize;
#if IS_ENABLED(CONFIG_DEBUG_FS)
	struct dentry *dbg_dir;
#else
#error "klp-ccp: a preceeding branch should have been taken"
#endif
};

struct nbd_device {
	struct blk_mq_tag_set tag_set;

	int index;
	refcount_t config_refs;
	refcount_t refs;
	struct nbd_config *config;
	struct mutex config_lock;
	struct gendisk *disk;
	struct workqueue_struct *recv_workq;
	struct work_struct remove_work;

	struct list_head list;
	struct task_struct *task_setup;

	unsigned long flags;
	pid_t pid; /* pid of nbd-client, if attached */

	char *backend;
};

#define NBD_CMD_REQUEUED	1

#define NBD_CMD_INFLIGHT	2

struct nbd_cmd {
	struct nbd_device *nbd;
	struct mutex lock;
	int index;
	int cookie;
	int retries;
	blk_status_t status;
	unsigned long flags;
	u32 cmd_cookie;
};

extern void nbd_config_put(struct nbd_device *nbd);

static inline struct device *nbd_to_dev(struct nbd_device *nbd)
{
	return disk_to_dev(nbd->disk);
}

void klpp_nbd_requeue_cmd(struct nbd_cmd *cmd)
{
	struct request *req = blk_mq_rq_from_pdu(cmd);

	lockdep_assert_held(&cmd->lock);

	/*
	 * Clear INFLIGHT flag so that this cmd won't be completed in
	 * normal completion path
	 *
	 * INFLIGHT flag will be set when the cmd is queued to nbd next
	 * time.
	 */
	__clear_bit(NBD_CMD_INFLIGHT, &cmd->flags);

	if (!test_and_set_bit(NBD_CMD_REQUEUED, &cmd->flags))
		blk_mq_requeue_request(req, true);
}

static const char *nbdcmd_to_ascii(int cmd)
{
	switch (cmd) {
	case  NBD_CMD_READ: return "read";
	case NBD_CMD_WRITE: return "write";
	case  NBD_CMD_DISC: return "disconnect";
	case NBD_CMD_FLUSH: return "flush";
	case  NBD_CMD_TRIM: return "trim/discard";
	}
	return "invalid";
}

extern void nbd_mark_nsock_dead(struct nbd_device *nbd, struct nbd_sock *nsock,
				int notify);

extern void sock_shutdown(struct nbd_device *nbd);

static u32 req_to_nbd_cmd_type(struct request *req)
{
	switch (req_op(req)) {
	case REQ_OP_DISCARD:
		return NBD_CMD_TRIM;
	case REQ_OP_FLUSH:
		return NBD_CMD_FLUSH;
	case REQ_OP_WRITE:
		return NBD_CMD_WRITE;
	case REQ_OP_READ:
		return NBD_CMD_READ;
	default:
		return U32_MAX;
	}
}

static struct nbd_config *nbd_get_config_unlocked(struct nbd_device *nbd)
{
	if (refcount_inc_not_zero(&nbd->config_refs)) {
		/*
		 * Add smp_mb__after_atomic to ensure that reading nbd->config_refs
		 * and reading nbd->config is ordered. The pair is the barrier in
		 * nbd_alloc_and_init_config(), avoid nbd->config_refs is set
		 * before nbd->config.
		 */
		smp_mb__after_atomic();
		return nbd->config;
	}

	return NULL;
}

enum blk_eh_timer_return klpp_nbd_xmit_timeout(struct request *req)
{
	struct nbd_cmd *cmd = blk_mq_rq_to_pdu(req);
	struct nbd_device *nbd = cmd->nbd;
	struct nbd_config *config;

	if (!mutex_trylock(&cmd->lock))
		return BLK_EH_RESET_TIMER;

	if (!test_bit(NBD_CMD_INFLIGHT, &cmd->flags)) {
		mutex_unlock(&cmd->lock);
		return BLK_EH_DONE;
	}

	config = nbd_get_config_unlocked(nbd);
	if (!config) {
		cmd->status = BLK_STS_TIMEOUT;
		__clear_bit(NBD_CMD_INFLIGHT, &cmd->flags);
		mutex_unlock(&cmd->lock);
		goto done;
	}

	if (config->num_connections > 1 ||
	    (config->num_connections == 1 && nbd->tag_set.timeout)) {
		dev_err_ratelimited(nbd_to_dev(nbd),
				    "Connection timed out, retrying (%d/%d alive)\n",
				    atomic_read(&config->live_connections),
				    config->num_connections);
		/*
		 * Hooray we have more connections, requeue this IO, the submit
		 * path will put it on a real connection. Or if only one
		 * connection is configured, the submit path will wait util
		 * a new connection is reconfigured or util dead timeout.
		 */
		if (config->socks) {
			if (cmd->index < config->num_connections) {
				struct nbd_sock *nsock =
					config->socks[cmd->index];
				mutex_lock(&nsock->tx_lock);
				/* We can have multiple outstanding requests, so
				 * we don't want to mark the nsock dead if we've
				 * already reconnected with a new socket, so
				 * only mark it dead if its the same socket we
				 * were sent out on.
				 */
				if (cmd->cookie == nsock->cookie)
					nbd_mark_nsock_dead(nbd, nsock, 1);
				mutex_unlock(&nsock->tx_lock);
			}
			klpp_nbd_requeue_cmd(cmd);
			mutex_unlock(&cmd->lock);
			nbd_config_put(nbd);
			return BLK_EH_DONE;
		}
	}

	if (!nbd->tag_set.timeout) {
		/*
		 * Userspace sets timeout=0 to disable socket disconnection,
		 * so just warn and reset the timer.
		 */
		struct nbd_sock *nsock = config->socks[cmd->index];
		cmd->retries++;
		dev_info(nbd_to_dev(nbd), "Possible stuck request %p: control (%s@%llu,%uB). Runtime %u seconds\n",
			req, nbdcmd_to_ascii(req_to_nbd_cmd_type(req)),
			(unsigned long long)blk_rq_pos(req) << 9,
			blk_rq_bytes(req), (req->timeout / HZ) * cmd->retries);

		mutex_lock(&nsock->tx_lock);
		if (cmd->cookie != nsock->cookie) {
			klpp_nbd_requeue_cmd(cmd);
			mutex_unlock(&nsock->tx_lock);
			mutex_unlock(&cmd->lock);
			nbd_config_put(nbd);
			return BLK_EH_DONE;
		}
		mutex_unlock(&nsock->tx_lock);
		mutex_unlock(&cmd->lock);
		nbd_config_put(nbd);
		return BLK_EH_RESET_TIMER;
	}

	dev_err_ratelimited(nbd_to_dev(nbd), "Connection timed out\n");
	set_bit(NBD_RT_TIMEDOUT, &config->runtime_flags);
	cmd->status = BLK_STS_IOERR;
	__clear_bit(NBD_CMD_INFLIGHT, &cmd->flags);
	mutex_unlock(&cmd->lock);
	sock_shutdown(nbd);
	nbd_config_put(nbd);
done:
	blk_mq_complete_request(req);
	return BLK_EH_DONE;
}

void nbd_config_put(struct nbd_device *nbd);


#include "livepatch_bsc1232900.h"

#include <linux/livepatch.h>

extern typeof(nbd_config_put) nbd_config_put
	 KLP_RELOC_SYMBOL(nbd, nbd, nbd_config_put);
extern typeof(nbd_mark_nsock_dead) nbd_mark_nsock_dead
	 KLP_RELOC_SYMBOL(nbd, nbd, nbd_mark_nsock_dead);
extern typeof(sock_shutdown) sock_shutdown
	 KLP_RELOC_SYMBOL(nbd, nbd, sock_shutdown);
