/*
 * livepatch_bsc1248615
 *
 * Fix for CVE-2024-58239, bsc#1248615
 *
 *  Copyright (c) 2025 SUSE
 *  Author: Fernando Gonzalez <fernando.gonzalez@suse.com>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */



/* klp-ccp: from net/tls/tls_sw.c */
#include <linux/bug.h>
#include <linux/sched/signal.h>
#include <linux/module.h>
#include <linux/splice.h>
#include <crypto/aead.h>
#include <net/strparser.h>
#include <net/tls.h>

/* klp-ccp: from include/net/tls.h */
static void (*klpe_tls_err_abort)(struct sock *sk, int err);

int klpp_tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
		   int nonblock, int flags, int *addr_len);

/* klp-ccp: from net/tls/tls_sw.c */
struct tls_decrypt_arg {
	bool zc;
	bool async;
	bool async_done;
};
static int (*klpe_tls_decrypt_async_wait)(struct tls_sw_context_rx *ctx);

static struct sk_buff *(*klpe_tls_wait_data)(struct sock *sk, struct sk_psock *psock,
				     bool nonblock, long timeo, int *err);

static int (*klpe_decrypt_skb_update)(struct sock *sk, struct sk_buff *skb,
			      struct iov_iter *dest,
			      struct tls_decrypt_arg *darg);

static bool (*klpe_tls_sw_advance_skb)(struct sock *sk, struct sk_buff *skb,
			       unsigned int len);

static int (*klpe_process_rx_list)(struct tls_sw_context_rx *ctx,
			   struct msghdr *msg,
			   u8 *control,
			   bool *cmsg,
			   size_t skip,
			   size_t len,
			   bool zc,
			   bool is_peek);

int klpp_tls_sw_recvmsg(struct sock *sk,
		   struct msghdr *msg,
		   size_t len,
		   int nonblock,
		   int flags,
		   int *addr_len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
	struct sk_psock *psock;
	unsigned char control = 0;
	ssize_t decrypted = 0;
	struct strp_msg *rxm;
	struct tls_msg *tlm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool async = false;
	bool cmsg = false;
	int target, err = 0;
	long timeo;
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
	bool is_peek = flags & MSG_PEEK;
	bool bpf_strp_enabled;

	flags |= nonblock;

	if (unlikely(flags & MSG_ERRQUEUE))
		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);

	psock = sk_psock_get(sk);
	lock_sock(sk);
	bpf_strp_enabled = sk_psock_strp_enabled(psock);

	/* Process pending decrypted records. It must be non-zero-copy */
	err = (*klpe_process_rx_list)(ctx, msg, &control, &cmsg, 0, len, false,
			      is_peek);
	if (err < 0) {
		(*klpe_tls_err_abort)(sk, err);
		goto end;
	}

	copied = err;
	if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA))
		goto recv_end;

	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
	len = len - copied;
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
		struct tls_decrypt_arg darg = {};
		bool retain_skb = false;
		int to_decrypt, chunk;

		skb = (*klpe_tls_wait_data)(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
		if (!skb) {
			if (psock) {
				int ret = sk_msg_recvmsg(sk, psock, msg, len,
							 flags);

				if (ret > 0) {
					decrypted += ret;
					len -= ret;
					continue;
				}
			}
			goto recv_end;
		} else {
			tlm = tls_msg(skb);
			if (prot->version == TLS_1_3_VERSION)
				tlm->control = 0;
			else
				tlm->control = ctx->control;
		}

		rxm = strp_msg(skb);

		to_decrypt = rxm->full_len - prot->overhead_size;

		if (to_decrypt <= len && !is_kvec && !is_peek &&
		    ctx->control == TLS_RECORD_TYPE_DATA &&
		    prot->version != TLS_1_3_VERSION &&
		    !bpf_strp_enabled)
			darg.zc = true;

		/* Do not use async mode if record is non-data */
		if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
			darg.async = ctx->async_capable;
		else
			darg.async = false;

		err = (*klpe_decrypt_skb_update)(sk, skb, &msg->msg_iter, &darg);
		if (err < 0) {
			(*klpe_tls_err_abort)(sk, -EBADMSG);
			goto recv_end;
		}

		if (prot->version == TLS_1_3_VERSION)
			tlm->control = ctx->control;
		async |= darg.async;

		/* If the type of records being processed is not known yet,
		 * set it to record type just dequeued. If it is already known,
		 * but does not match the record type just dequeued, go to end.
		 * We always get record type here since for tls1.2, record type
		 * is known just after record is dequeued from stream parser.
		 * For tls1.3, we disable async.
		 */

		if (!control)
			control = tlm->control;
		else if (control != tlm->control)
			goto recv_end;

		if (!cmsg) {
			int cerr;

			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
					sizeof(control), &control);
			cmsg = true;
			if (control != TLS_RECORD_TYPE_DATA) {
				if (cerr || msg->msg_flags & MSG_CTRUNC) {
					err = -EIO;
					goto recv_end;
				}
			}
		}

		if (async) {
			/* TLS 1.2-only, to_decrypt must be text length */
			chunk = min_t(int, to_decrypt, len);
			goto pick_next_record;
		}
		/* TLS 1.3 may have updated the length by more than overhead */
		chunk = rxm->full_len;

		if (!darg.zc) {
			if (bpf_strp_enabled) {
				err = sk_psock_tls_strp_read(psock, skb);
				if (err != __SK_PASS) {
					rxm->offset = rxm->offset + rxm->full_len;
					rxm->full_len = 0;
					if (err == __SK_DROP)
						consume_skb(skb);
					ctx->recv_pkt = NULL;
					__strp_unpause(&ctx->strp);
					continue;
				}
			}

			if (chunk > len) {
				retain_skb = true;
				chunk = len;
			}

			err = skb_copy_datagram_msg(skb, rxm->offset,
						    msg, chunk);
			if (err < 0)
				goto recv_end;

			if (!is_peek) {
				rxm->offset = rxm->offset + chunk;
				rxm->full_len = rxm->full_len - chunk;
			}
		}

pick_next_record:
		decrypted += chunk;
		len -= chunk;

		/* For async or peek case, queue the current skb */
		if (async || is_peek || retain_skb) {
			skb_queue_tail(&ctx->rx_list, skb);
			skb = NULL;
		}

		if ((*klpe_tls_sw_advance_skb)(sk, skb, chunk)) {
			/* Return full control message to
			 * userspace before trying to parse
			 * another message type
			 */
			msg->msg_flags |= MSG_EOR;
			if (control != TLS_RECORD_TYPE_DATA)
				goto recv_end;
		} else {
			break;
		}
	}

recv_end:
	if (async) {
		/* Wait for all previously submitted records to be decrypted */
		err = (*klpe_tls_decrypt_async_wait)(ctx);
		if (err) {
			/* one of async decrypt failed */
			(*klpe_tls_err_abort)(sk, err);
			copied = 0;
			decrypted = 0;
			goto end;
		}

		/* Drain records from the rx_list & copy if required */
		if (is_peek || is_kvec)
			err = (*klpe_process_rx_list)(ctx, msg, &control, &cmsg, copied,
					      decrypted, false, is_peek);
		else
			err = (*klpe_process_rx_list)(ctx, msg, &control, &cmsg, 0,
					      decrypted, true, is_peek);
		if (err < 0) {
			(*klpe_tls_err_abort)(sk, err);
			copied = 0;
			goto end;
		}
	}

	copied += decrypted;

end:
	release_sock(sk);
	sk_defer_free_flush(sk);
	if (psock)
		sk_psock_put(sk, psock);
	return copied ? : err;
}


#include "livepatch_bsc1248615.h"

#include <linux/kernel.h>
#include <linux/module.h>
#include "../kallsyms_relocs.h"

#define LP_MODULE "tls"

static struct klp_kallsyms_reloc klp_funcs[] = {
	{ "decrypt_skb_update", (void *)&klpe_decrypt_skb_update, "tls" },
	{ "process_rx_list", (void *)&klpe_process_rx_list, "tls" },
	{ "tls_decrypt_async_wait", (void *)&klpe_tls_decrypt_async_wait,
	  "tls" },
	{ "tls_err_abort", (void *)&klpe_tls_err_abort, "tls" },
	{ "tls_sw_advance_skb", (void *)&klpe_tls_sw_advance_skb, "tls" },
	{ "tls_wait_data", (void *)&klpe_tls_wait_data, "tls" },
};

static int module_notify(struct notifier_block *nb,
			unsigned long action, void *data)
{
	struct module *mod = data;
	int ret;

	if (action != MODULE_STATE_COMING || strcmp(mod->name, LP_MODULE))
		return 0;
	ret = klp_resolve_kallsyms_relocs(klp_funcs, ARRAY_SIZE(klp_funcs));

	WARN(ret, "%s: delayed kallsyms lookup failed. System is broken and can crash.\n",
		__func__);

	return ret;
}

static struct notifier_block module_nb = {
	.notifier_call = module_notify,
	.priority = INT_MIN+1,
};

int livepatch_bsc1248615_init(void)
{
	int ret;
	struct module *mod;

	ret = klp_kallsyms_relocs_init();
	if (ret)
		return ret;

	ret = register_module_notifier(&module_nb);
	if (ret)
		return ret;

	rcu_read_lock_sched();
	mod = (*klpe_find_module)(LP_MODULE);
	if (!try_module_get(mod))
		mod = NULL;
	rcu_read_unlock_sched();

	if (mod) {
		ret = klp_resolve_kallsyms_relocs(klp_funcs,
						ARRAY_SIZE(klp_funcs));
	}

	if (ret)
		unregister_module_notifier(&module_nb);
	module_put(mod);

	return ret;
}

void livepatch_bsc1248615_cleanup(void)
{
	unregister_module_notifier(&module_nb);
}
