/*
 * bsc1243650_net_sunrpc_xprtsock
 *
 * Fix for CVE-2024-53168, bsc#1243650
 *
 *  Copyright (c) 2025 SUSE
 *  Author: Vincenzo Mezzela <vincenzo.mezzela@suse.com>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

/* klp-ccp: from net/sunrpc/xprtsock.c */
#include <linux/types.h>
#include <linux/string.h>
#include <linux/slab.h>
#include <linux/module.h>
#include <linux/capability.h>
#include <linux/pagemap.h>
#include <linux/errno.h>
#include <linux/socket.h>
#include <linux/in.h>
#include <linux/net.h>
#include <linux/mm.h>

#include <linux/udp.h>
#include <linux/tcp.h>
#include <linux/sunrpc/clnt.h>

#include <linux/sunrpc/sched.h>

/* klp-ccp: from include/linux/sunrpc/svc.h */
#define SUNRPC_SVC_H

/* klp-ccp: from include/linux/sunrpc/cache.h */
struct cache_head;

/* klp-ccp: from include/linux/sunrpc/svcauth.h */
struct svc_rqst;

#define	SVC_OK		5

#define	SVC_COMPLETE	10

/* klp-ccp: from include/linux/sunrpc/svc.h */
struct svc_deferred_req;

/* klp-ccp: from include/linux/sunrpc/svc_xprt.h */
#define SUNRPC_SVC_XPRT_H

/* klp-ccp: from net/sunrpc/xprtsock.c */
#include <linux/sunrpc/xprtsock.h>
#include <linux/file.h>

#include <net/sock.h>
#include <net/checksum.h>

/* klp-ccp: from include/linux/static_call.h */
#define _LINUX_STATIC_CALL_H

/* klp-ccp: from net/sunrpc/xprtsock.c */
#include <linux/bvec.h>
#include <linux/highmem.h>
#include <linux/uio.h>
#include <linux/sched/mm.h>

#include <trace/events/sock.h>
#include <trace/events/sunrpc.h>

#include "../klp_trace.h"
KLPR_TRACE_EVENT(sunrpc, rpc_socket_connect,
	TP_PROTO( \
		struct rpc_xprt *xprt, \
		struct socket *socket, \
		int error \
	), \
	TP_ARGS(xprt, socket, error))

/* klp-ccp: from net/sunrpc/sunrpc.h */
#include <linux/net.h>

/* klp-ccp: from net/sunrpc/xprtsock.c */
extern void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
		struct socket *sock);

#define XS_TCP_INIT_REEST_TO	(3U * HZ)

# define RPCDBG_FACILITY	RPCDBG_TRANS

static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt)
{
	return (struct sockaddr *) &xprt->addr;
}

static void
xs_stream_start_connect(struct sock_xprt *transport)
{
	transport->xprt.stat.connect_count++;
	transport->xprt.stat.connect_start = jiffies;
}

static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk)
{
	transport->old_data_ready = sk->sk_data_ready;
	transport->old_state_change = sk->sk_state_change;
	transport->old_write_space = sk->sk_write_space;
	transport->old_error_report = sk->sk_error_report;
}

extern void xs_error_report(struct sock *sk);

extern void xs_reset_transport(struct sock_xprt *transport);

extern void xs_data_ready(struct sock *sk);

static void xs_tcp_force_close(struct rpc_xprt *xprt)
{
	xprt_force_disconnect(xprt);
}

extern void xs_tcp_state_change(struct sock *sk);

extern void xs_udp_write_space(struct sock *sk);

extern void xs_tcp_write_space(struct sock *sk);

static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt)
{
	struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
	struct sock *sk = transport->inet;

	if (transport->rcvsize) {
		sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
		sk->sk_rcvbuf = transport->rcvsize * xprt->max_reqs * 2;
	}
	if (transport->sndsize) {
		sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
		sk->sk_sndbuf = transport->sndsize * xprt->max_reqs * 2;
		sk->sk_write_space(sk);
	}
}

extern int xs_bind(struct sock_xprt *transport, struct socket *sock);

#ifdef CONFIG_DEBUG_LOCK_ALLOC
#error "klp-ccp: non-taken branch"
#else
static inline void xs_reclassify_socket(int family, struct socket *sock)
{
}
#endif

static struct socket *klpp_xs_create_sock(struct rpc_xprt *xprt,
		struct sock_xprt *transport, int family, int type,
		int protocol, bool reuseport)
{
	struct file *filp;
	struct socket *sock;
	int err;

	err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1);
	if (err < 0) {
		dprintk("RPC:       can't create %d transport socket (%d).\n",
				protocol, -err);
		goto out;
	}
	xs_reclassify_socket(family, sock);

	if (reuseport)
		sock_set_reuseport(sock->sk);

	err = xs_bind(transport, sock);
	if (err) {
		sock_release(sock);
		goto out;
	}

	if (protocol == IPPROTO_TCP) {
		__netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
		sock->sk->sk_net_refcnt = 1;
		get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
		sock_inuse_add(xprt->xprt_net, 1);
	}

	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
	if (IS_ERR(filp))
		return ERR_CAST(filp);
	transport->file = filp;

	return sock;
out:
	return ERR_PTR(err);
}

#if IS_ENABLED(CONFIG_SUNRPC_SWAP)

static void xs_set_memalloc(struct rpc_xprt *xprt)
{
	struct sock_xprt *transport = container_of(xprt, struct sock_xprt,
			xprt);

	/*
	 * If there's no sock, then we have nothing to set. The
	 * reconnecting process will get it for us.
	 */
	if (!transport->inet)
		return;
	if (atomic_read(&xprt->swapper))
		sk_set_memalloc(transport->inet);
}

#else
#error "klp-ccp: non-taken branch"
#endif

static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
{
	struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);

	if (!transport->inet) {
		struct sock *sk = sock->sk;

		lock_sock(sk);

		xs_save_old_callbacks(transport, sk);

		sk->sk_user_data = xprt;
		sk->sk_data_ready = xs_data_ready;
		sk->sk_write_space = xs_udp_write_space;
		sk->sk_use_task_frag = false;

		xprt_set_connected(xprt);

		/* Reset to new socket */
		transport->sock = sock;
		transport->inet = sk;

		xs_set_memalloc(xprt);

		release_sock(sk);
	}
	xs_udp_do_set_buffer_size(xprt);

	xprt->stat.connect_start = jiffies;
}

void klpp_xs_udp_setup_socket(struct work_struct *work)
{
	struct sock_xprt *transport =
		container_of(work, struct sock_xprt, connect_worker.work);
	struct rpc_xprt *xprt = &transport->xprt;
	struct socket *sock;
	int status = -EIO;
	unsigned int pflags = current->flags;

	if (atomic_read(&xprt->swapper))
		current->flags |= PF_MEMALLOC;
	sock = klpp_xs_create_sock(xprt, transport,
			xs_addr(xprt)->sa_family, SOCK_DGRAM,
			IPPROTO_UDP, false);
	if (IS_ERR(sock))
		goto out;

	dprintk("RPC:       worker connecting xprt %p via %s to "
				"%s (port %s)\n", xprt,
			xprt->address_strings[RPC_DISPLAY_PROTO],
			xprt->address_strings[RPC_DISPLAY_ADDR],
			xprt->address_strings[RPC_DISPLAY_PORT]);

	xs_udp_finish_connecting(xprt, sock);
	klpr_trace_rpc_socket_connect(xprt, sock, 0);
	status = 0;
out:
	xprt_clear_connecting(xprt);
	xprt_unlock_connect(xprt, transport);
	xprt_wake_pending_tasks(xprt, status);
	current_restore_flags(pflags, PF_MEMALLOC);
}

void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
		struct socket *sock);

static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
{
	struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);

	if (!transport->inet) {
		struct sock *sk = sock->sk;

		/* Avoid temporary address, they are bad for long-lived
		 * connections such as NFS mounts.
		 * RFC4941, section 3.6 suggests that:
		 *    Individual applications, which have specific
		 *    knowledge about the normal duration of connections,
		 *    MAY override this as appropriate.
		 */
		if (xs_addr(xprt)->sa_family == PF_INET6) {
			ip6_sock_set_addr_preferences(sk,
				IPV6_PREFER_SRC_PUBLIC);
		}

		xs_tcp_set_socket_timeouts(xprt, sock);
		tcp_sock_set_nodelay(sk);

		lock_sock(sk);

		xs_save_old_callbacks(transport, sk);

		sk->sk_user_data = xprt;
		sk->sk_data_ready = xs_data_ready;
		sk->sk_state_change = xs_tcp_state_change;
		sk->sk_write_space = xs_tcp_write_space;
		sk->sk_error_report = xs_error_report;
		sk->sk_use_task_frag = false;

		/* socket options */
		sock_reset_flag(sk, SOCK_LINGER);

		xprt_clear_connected(xprt);

		/* Reset to new socket */
		transport->sock = sock;
		transport->inet = sk;

		release_sock(sk);
	}

	if (!xprt_bound(xprt))
		return -ENOTCONN;

	xs_set_memalloc(xprt);

	xs_stream_start_connect(transport);

	/* Tell the socket layer to start connecting... */
	set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state);
	return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
}

void klpp_xs_tcp_setup_socket(struct work_struct *work)
{
	struct sock_xprt *transport =
		container_of(work, struct sock_xprt, connect_worker.work);
	struct socket *sock = transport->sock;
	struct rpc_xprt *xprt = &transport->xprt;
	int status;
	unsigned int pflags = current->flags;

	if (atomic_read(&xprt->swapper))
		current->flags |= PF_MEMALLOC;

	if (xprt_connected(xprt))
		goto out;
	if (test_and_clear_bit(XPRT_SOCK_CONNECT_SENT,
			       &transport->sock_state) ||
	    !sock) {
		xs_reset_transport(transport);
		sock = klpp_xs_create_sock(xprt, transport, xs_addr(xprt)->sa_family,
				      SOCK_STREAM, IPPROTO_TCP, true);
		if (IS_ERR(sock)) {
			xprt_wake_pending_tasks(xprt, PTR_ERR(sock));
			goto out;
		}
	}

	dprintk("RPC:       worker connecting xprt %p via %s to "
				"%s (port %s)\n", xprt,
			xprt->address_strings[RPC_DISPLAY_PROTO],
			xprt->address_strings[RPC_DISPLAY_ADDR],
			xprt->address_strings[RPC_DISPLAY_PORT]);

	status = xs_tcp_finish_connecting(xprt, sock);
	klpr_trace_rpc_socket_connect(xprt, sock, status);
	dprintk("RPC:       %p connect status %d connected %d sock state %d\n",
			xprt, -status, xprt_connected(xprt),
			sock->sk->sk_state);
	switch (status) {
	case 0:
	case -EINPROGRESS:
		/* SYN_SENT! */
		set_bit(XPRT_SOCK_CONNECT_SENT, &transport->sock_state);
		if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO)
			xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
		fallthrough;
	case -EALREADY:
		goto out_unlock;
	case -EADDRNOTAVAIL:
		/* Source port number is unavailable. Try a new one! */
		transport->srcport = 0;
		status = -EAGAIN;
		break;
	case -EPERM:
		/* Happens, for instance, if a BPF program is preventing
		 * the connect. Remap the error so upper layers can better
		 * deal with it.
		 */
		status = -ECONNREFUSED;
		fallthrough;
	case -EINVAL:
		/* Happens, for instance, if the user specified a link
		 * local IPv6 address without a scope-id.
		 */
	case -ECONNREFUSED:
	case -ECONNRESET:
	case -ENETDOWN:
	case -ENETUNREACH:
	case -EHOSTUNREACH:
	case -EADDRINUSE:
	case -ENOBUFS:
		break;
	default:
		printk("%s: connect returned unhandled error %d\n",
			__func__, status);
		status = -EAGAIN;
	}

	/* xs_tcp_force_close() wakes tasks with a fixed error code.
	 * We need to wake them first to ensure the correct error code.
	 */
	xprt_wake_pending_tasks(xprt, status);
	xs_tcp_force_close(xprt);
out:
	xprt_clear_connecting(xprt);
out_unlock:
	xprt_unlock_connect(xprt, transport);
	current_restore_flags(pflags, PF_MEMALLOC);
}


#include "livepatch_bsc1243650.h"

#include <linux/livepatch.h>

extern typeof(rpc_debug) rpc_debug KLP_RELOC_SYMBOL(sunrpc, sunrpc, rpc_debug);
extern typeof(xprt_force_disconnect) xprt_force_disconnect
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xprt_force_disconnect);
extern typeof(xprt_unlock_connect) xprt_unlock_connect
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xprt_unlock_connect);
extern typeof(xprt_wake_pending_tasks) xprt_wake_pending_tasks
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xprt_wake_pending_tasks);
extern typeof(xs_bind) xs_bind KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_bind);
extern typeof(xs_data_ready) xs_data_ready
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_data_ready);
extern typeof(xs_error_report) xs_error_report
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_error_report);
extern typeof(xs_reset_transport) xs_reset_transport
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_reset_transport);
extern typeof(xs_tcp_set_socket_timeouts) xs_tcp_set_socket_timeouts
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_tcp_set_socket_timeouts);
extern typeof(xs_tcp_state_change) xs_tcp_state_change
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_tcp_state_change);
extern typeof(xs_tcp_write_space) xs_tcp_write_space
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_tcp_write_space);
extern typeof(xs_udp_write_space) xs_udp_write_space
	 KLP_RELOC_SYMBOL(sunrpc, sunrpc, xs_udp_write_space);
