/*
 * livepatch_bsc1208838
 *
 * Fix for CVE-2023-1078, bsc#1208838
 *
 *  Upstream commits:
 *  f753a68980cf ("rds: rds_rm_zerocopy_callback() use list_first_entry()")
 *  68762148d1b0 ("rds: rds_rm_zerocopy_callback() correct order for
 *                 list_add_tail()")
 *
 *  SLE12-SP4, SLE12-SP5, SLE15 and SLE15-SP1 commit:
 *  none yet
 *
 *  SLE15-SP2 and -SP3 commit:
 *  none yet
 *
 *  SLE15-SP4 commit:
 *  none yet
 *
 *
 *  Copyright (c) 2023 SUSE
 *  Author: Lukas Hruska <lukas.hruska@suse.cz>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

#if IS_ENABLED(CONFIG_RDS)

#if !IS_MODULE(CONFIG_RDS)
#error "Live patch supports only CONFIG=m"
#endif

/* klp-ccp: from net/rds/message.c */
#include <linux/kernel.h>

#include <linux/slab.h>
#include <linux/gfp.h>
#include <linux/mmzone.h>

/* klp-ccp: from net/rds/message.c */
#include <linux/export.h>
#include <linux/skbuff.h>
#include <linux/list.h>
#include <linux/errqueue.h>
/* klp-ccp: from net/rds/rds.h */
#include <net/sock.h>
#include <linux/scatterlist.h>
#include <linux/highmem.h>
#include <linux/mutex.h>
#include <linux/rds.h>
#include <linux/rhashtable.h>
#include <linux/refcount.h>
#include <linux/in6.h>

#ifdef RDS_DEBUG
#error "klp-ccp: non-taken branch"
#else

static inline __printf(1, 2)
void rdsdebug(char *fmt, ...)
{
}
#endif

#define RDS_HEADER_EXT_SPACE	16

struct rds_header {
	__be64	h_sequence;
	__be64	h_ack;
	__be32	h_len;
	__be16	h_sport;
	__be16	h_dport;
	u8	h_flags;
	u8	h_credit;
	u8	h_padding[4];
	__sum16	h_csum;

	u8	h_exthdr[RDS_HEADER_EXT_SPACE];
};

#define RDS_RX_MAX_TRACES	(RDS_MSG_RX_DGRAM_TRACE_MAX + 1)

struct rds_incoming {
	refcount_t		i_refcount;
	struct list_head	i_item;
	struct rds_connection	*i_conn;
	struct rds_conn_path	*i_conn_path;
	struct rds_header	i_hdr;
	unsigned long		i_rx_jiffies;
	struct in6_addr		i_saddr;

	rds_rdma_cookie_t	i_rdma_cookie;
	ktime_t			i_rx_tstamp;
	u64			i_rx_lat_trace[RDS_RX_MAX_TRACES];
};

struct rds_mr {
	struct rb_node		r_rb_node;
	refcount_t		r_refcount;
	u32			r_key;

	/* A copy of the creation flags */
	unsigned int		r_use_once:1;
	unsigned int		r_invalidate:1;
	unsigned int		r_write:1;

	/* This is for RDS_MR_DEAD.
	 * It would be nice & consistent to make this part of the above
	 * bit field here, but we need to use test_and_set_bit.
	 */
	unsigned long		r_state;
	struct rds_sock		*r_sock; /* back pointer to the socket that owns us */
	struct rds_transport	*r_trans;
	void			*r_trans_private;
};

#define RDS_MSG_PAGEVEC		7

struct rds_znotifier {
	struct mmpin		z_mmp;
	u32			z_cookie;
};

struct rds_msg_zcopy_info {
	struct list_head rs_zcookie_next;
	union {
		struct rds_znotifier znotif;
		struct rds_zcopy_cookies zcookies;
	};
};

struct rds_msg_zcopy_queue {
	struct list_head zcookie_head;
	spinlock_t lock; /* protects zcookie_head queue */
};

struct rds_message {
	refcount_t		m_refcount;
	struct list_head	m_sock_item;
	struct list_head	m_conn_item;
	struct rds_incoming	m_inc;
	u64			m_ack_seq;
	struct in6_addr		m_daddr;
	unsigned long		m_flags;

	/* Never access m_rs without holding m_rs_lock.
	 * Lock nesting is
	 *  rm->m_rs_lock
	 *   -> rs->rs_lock
	 */
	spinlock_t		m_rs_lock;
	wait_queue_head_t	m_flush_wait;

	struct rds_sock		*m_rs;

	/* cookie to send to remote, in rds header */
	rds_rdma_cookie_t	m_rdma_cookie;

	unsigned int		m_used_sgs;
	unsigned int		m_total_sgs;

	void			*m_final_op;

	struct {
		struct rm_atomic_op {
			int			op_type;
			union {
				struct {
					uint64_t	compare;
					uint64_t	swap;
					uint64_t	compare_mask;
					uint64_t	swap_mask;
				} op_m_cswp;
				struct {
					uint64_t	add;
					uint64_t	nocarry_mask;
				} op_m_fadd;
			};

			u32			op_rkey;
			u64			op_remote_addr;
			unsigned int		op_notify:1;
			unsigned int		op_recverr:1;
			unsigned int		op_mapped:1;
			unsigned int		op_silent:1;
			unsigned int		op_active:1;
			struct scatterlist	*op_sg;
			struct rds_notifier	*op_notifier;

			struct rds_mr		*op_rdma_mr;
		} atomic;
		struct rm_rdma_op {
			u32			op_rkey;
			u64			op_remote_addr;
			unsigned int		op_write:1;
			unsigned int		op_fence:1;
			unsigned int		op_notify:1;
			unsigned int		op_recverr:1;
			unsigned int		op_mapped:1;
			unsigned int		op_silent:1;
			unsigned int		op_active:1;
			unsigned int		op_bytes;
			unsigned int		op_nents;
			unsigned int		op_count;
			struct scatterlist	*op_sg;
			struct rds_notifier	*op_notifier;

			struct rds_mr		*op_rdma_mr;
		} rdma;
		struct rm_data_op {
			unsigned int		op_active:1;
			unsigned int		op_nents;
			unsigned int		op_count;
			unsigned int		op_dmasg;
			unsigned int		op_dmaoff;
			struct rds_znotifier	*op_mmp_znotifier;
			struct scatterlist	*op_sg;
		} data;
	};

	struct rds_conn_path *m_conn_path;
};

#define RDS_BOUND_KEY_LEN \
	(sizeof(struct in6_addr) + sizeof(__u32) + sizeof(__be16))

struct rds_sock {
	struct sock		rs_sk;

	u64			rs_user_addr;
	u64			rs_user_bytes;

	/*
	 * bound_addr used for both incoming and outgoing, no INADDR_ANY
	 * support.
	 */
	struct rhash_head	rs_bound_node;
	u8			rs_bound_key[RDS_BOUND_KEY_LEN];
	struct sockaddr_in6	rs_bound_sin6;
	struct in6_addr		rs_conn_addr;
	__be16			rs_conn_port;
	struct rds_transport    *rs_transport;

	/*
	 * rds_sendmsg caches the conn it used the last time around.
	 * This helps avoid costly lookups.
	 */
	struct rds_connection	*rs_conn;

	/* flag indicating we were congested or not */
	int			rs_congested;
	/* seen congestion (ENOBUFS) when sending? */
	int			rs_seen_congestion;

	/* rs_lock protects all these adjacent members before the newline */
	spinlock_t		rs_lock;
	struct list_head	rs_send_queue;
	u32			rs_snd_bytes;
	int			rs_rcv_bytes;
	struct list_head	rs_notify_queue;	/* currently used for failed RDMAs */

	/* Congestion wake_up. If rs_cong_monitor is set, we use cong_mask
	 * to decide whether the application should be woken up.
	 * If not set, we use rs_cong_track to find out whether a cong map
	 * update arrived.
	 */
	uint64_t		rs_cong_mask;
	uint64_t		rs_cong_notify;
	struct list_head	rs_cong_list;
	unsigned long		rs_cong_track;

	/*
	 * rs_recv_lock protects the receive queue, and is
	 * used to serialize with rds_release.
	 */
	rwlock_t		rs_recv_lock;
	struct list_head	rs_recv_queue;

	/* just for stats reporting */
	struct list_head	rs_item;

	/* these have their own lock */
	spinlock_t		rs_rdma_lock;
	struct rb_root		rs_rdma_keys;

	/* Socket options - in case there will be more */
	unsigned char		rs_recverr,
				rs_cong_monitor;
	u32			rs_hash_initval;

	/* Socket receive path trace points*/
	u8			rs_rx_traces;
	u8			rs_rx_trace[RDS_MSG_RX_DGRAM_TRACE_MAX];
	struct rds_msg_zcopy_queue rs_zcookie_queue;
	u8			rs_tos;
};

static inline struct sock *rds_rs_to_sk(struct rds_sock *rs)
{
	return &rs->rs_sk;
}

static void (*klpe_rds_wake_sk_sleep)(struct rds_sock *rs);

void klpp_rds_message_put(struct rds_message *rm);

static void (*klpe_rds_rdma_free_op)(struct rm_rdma_op *ro);
static void (*klpe_rds_atomic_free_op)(struct rm_atomic_op *ao);

static void (*klpe___rds_put_mr_final)(struct rds_mr *mr);
static inline void klpr_rds_mr_put(struct rds_mr *mr)
{
	if (refcount_dec_and_test(&mr->r_refcount))
		(*klpe___rds_put_mr_final)(mr);
}

/* klp-ccp: from net/rds/message.c */
static inline bool rds_zcookie_add(struct rds_msg_zcopy_info *info, u32 cookie)
{
	struct rds_zcopy_cookies *ck = &info->zcookies;
	int ncookies = ck->num;

	if (ncookies == RDS_MAX_ZCOOKIES)
		return false;
	ck->cookies[ncookies] = cookie;
	ck->num =  ++ncookies;
	return true;
}

static struct rds_msg_zcopy_info *rds_info_from_znotifier(struct rds_znotifier *znotif)
{
	return container_of(znotif, struct rds_msg_zcopy_info, znotif);
}

static void klpp_rds_rm_zerocopy_callback(struct rds_sock *rs,
				     struct rds_znotifier *znotif)
{
	struct rds_msg_zcopy_info *info;
	struct rds_msg_zcopy_queue *q;
	u32 cookie = znotif->z_cookie;
	struct rds_zcopy_cookies *ck;
	struct list_head *head;
	unsigned long flags;

	mm_unaccount_pinned_pages(&znotif->z_mmp);
	q = &rs->rs_zcookie_queue;
	spin_lock_irqsave(&q->lock, flags);
	head = &q->zcookie_head;
	if (!list_empty(head)) {
		info = list_first_entry(head, struct rds_msg_zcopy_info,
					rs_zcookie_next);
		if (rds_zcookie_add(info, cookie)) {
			spin_unlock_irqrestore(&q->lock, flags);
			kfree(rds_info_from_znotifier(znotif));
			/* caller invokes rds_wake_sk_sleep() */
			return;
		}
	}

	info = rds_info_from_znotifier(znotif);
	ck = &info->zcookies;
	memset(ck, 0, sizeof(*ck));
	WARN_ON(!rds_zcookie_add(info, cookie));
	list_add_tail(&info->rs_zcookie_next, &q->zcookie_head);

	spin_unlock_irqrestore(&q->lock, flags);
	/* caller invokes rds_wake_sk_sleep() */
}

static void klpr_rds_message_purge(struct rds_message *rm)
{
	unsigned long i, flags;
	bool zcopy = false;

	if (unlikely(test_bit(RDS_MSG_PAGEVEC, &rm->m_flags)))
		return;

	spin_lock_irqsave(&rm->m_rs_lock, flags);
	if (rm->m_rs) {
		struct rds_sock *rs = rm->m_rs;

		if (rm->data.op_mmp_znotifier) {
			zcopy = true;
			klpp_rds_rm_zerocopy_callback(rs, rm->data.op_mmp_znotifier);
			(*klpe_rds_wake_sk_sleep)(rs);
			rm->data.op_mmp_znotifier = NULL;
		}
		sock_put(rds_rs_to_sk(rs));
		rm->m_rs = NULL;
	}
	spin_unlock_irqrestore(&rm->m_rs_lock, flags);

	for (i = 0; i < rm->data.op_nents; i++) {
		/* XXX will have to put_page for page refs */
		if (!zcopy)
			__free_page(sg_page(&rm->data.op_sg[i]));
		else
			put_page(sg_page(&rm->data.op_sg[i]));
	}
	rm->data.op_nents = 0;

	if (rm->rdma.op_active)
		(*klpe_rds_rdma_free_op)(&rm->rdma);
	if (rm->rdma.op_rdma_mr)
		klpr_rds_mr_put(rm->rdma.op_rdma_mr);

	if (rm->atomic.op_active)
		(*klpe_rds_atomic_free_op)(&rm->atomic);
	if (rm->atomic.op_rdma_mr)
		klpr_rds_mr_put(rm->atomic.op_rdma_mr);
}

void klpp_rds_message_put(struct rds_message *rm)
{
	rdsdebug("put rm %p ref %d\n", rm, refcount_read(&rm->m_refcount));
	WARN(!refcount_read(&rm->m_refcount), "danger refcount zero on %p\n", rm);
	if (refcount_dec_and_test(&rm->m_refcount)) {
		BUG_ON(!list_empty(&rm->m_sock_item));
		BUG_ON(!list_empty(&rm->m_conn_item));
		klpr_rds_message_purge(rm);

		kfree(rm);
	}
}



#define LP_MODULE "rds"

#include <linux/kernel.h>
#include <linux/module.h>
#include "livepatch_bsc1208838.h"
#include "../kallsyms_relocs.h"

static struct klp_kallsyms_reloc klp_funcs[] = {
	{ "__rds_put_mr_final", (void *)&klpe___rds_put_mr_final, "rds" },
	{ "rds_atomic_free_op", (void *)&klpe_rds_atomic_free_op, "rds" },
	{ "rds_rdma_free_op", (void *)&klpe_rds_rdma_free_op, "rds" },
	{ "rds_wake_sk_sleep", (void *)&klpe_rds_wake_sk_sleep, "rds" },
};

static int module_notify(struct notifier_block *nb,
			unsigned long action, void *data)
{
	struct module *mod = data;
	int ret;

	if (action != MODULE_STATE_COMING || strcmp(mod->name, LP_MODULE))
		return 0;
	mutex_lock(&module_mutex);
	ret = __klp_resolve_kallsyms_relocs(klp_funcs, ARRAY_SIZE(klp_funcs));
	mutex_unlock(&module_mutex);

	WARN(ret, "%s: delayed kallsyms lookup failed. System is broken and can crash.\n",
		__func__);

	return ret;
}

static struct notifier_block module_nb = {
	.notifier_call = module_notify,
	.priority = INT_MIN+1,
};

int livepatch_bsc1208838_init(void)
{
	int ret;

	mutex_lock(&module_mutex);
	if (find_module(LP_MODULE)) {
		ret = __klp_resolve_kallsyms_relocs(klp_funcs,
						    ARRAY_SIZE(klp_funcs));
		if (ret)
			goto out;
	}

	ret = register_module_notifier(&module_nb);
out:
	mutex_unlock(&module_mutex);
	return ret;
}

void livepatch_bsc1208838_cleanup(void)
{
	unregister_module_notifier(&module_nb);
}

#endif /* IS_ENABLED(CONFIG_RDS) */
