/*
 * livepatch_bsc1235250
 *
 * Fix for CVE-2024-56664, bsc#1235250
 *
 *  Copyright (c) 2025 SUSE
 *  Author: Vincenzo Mezzela <vincenzo.mezzela@suse.com>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

/* klp-ccp: from net/core/sock_map.c */
#include <linux/bpf.h>
#include <linux/btf_ids.h>
#include <linux/filter.h>
#include <linux/errno.h>
#include <linux/file.h>
#include <linux/net.h>
#include <linux/workqueue.h>
#include <linux/skmsg.h>
#include <linux/list.h>
#include <linux/jhash.h>

struct bpf_stab {
	struct bpf_map map;
	struct sock **sks;
	struct sk_psock_progs progs;
	spinlock_t lock;
};

extern void sock_map_unref(struct sock *sk, void *link_raw);

static int klpp___sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
			     struct sock **psk)
{
	struct sock *sk = NULL;
	int err = 0;

	spin_lock_bh(&stab->lock);
	if (!sk_test || sk_test == *psk)
		sk = xchg(psk, NULL);

	if (likely(sk))
		sock_map_unref(sk, psk);
	else
		err = -EINVAL;

	spin_unlock_bh(&stab->lock);
	return err;
}

static void klpr_sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
				      void *link_raw)
{
	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);

	klpp___sock_map_delete(stab, sk, link_raw);
}

long klpp_sock_map_delete_elem(struct bpf_map *map, void *key)
{
	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
	u32 i = *(u32 *)key;
	struct sock **psk;

	if (unlikely(i >= map->max_entries))
		return -EINVAL;

	psk = &stab->sks[i];
	return klpp___sock_map_delete(stab, NULL, psk);
}

struct bpf_shtab_elem {
	struct rcu_head rcu;
	u32 hash;
	struct sock *sk;
	struct hlist_node node;
	u8 key[];
};

struct bpf_shtab_bucket {
	struct hlist_head head;
	spinlock_t lock;
};

struct bpf_shtab {
	struct bpf_map map;
	struct bpf_shtab_bucket *buckets;
	u32 buckets_num;
	u32 elem_size;
	struct sk_psock_progs progs;
	atomic_t count;
};

static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
							u32 hash)
{
	return &htab->buckets[hash & (htab->buckets_num - 1)];
}

extern struct bpf_shtab_elem *
sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
			  u32 key_size);

static void sock_hash_free_elem(struct bpf_shtab *htab,
				struct bpf_shtab_elem *elem)
{
	atomic_dec(&htab->count);
	kfree_rcu(elem, rcu);
}

static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
				       void *link_raw)
{
	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
	struct bpf_shtab_elem *elem_probe, *elem = link_raw;
	struct bpf_shtab_bucket *bucket;

	WARN_ON_ONCE(!rcu_read_lock_held());
	bucket = sock_hash_select_bucket(htab, elem->hash);

	/* elem may be deleted in parallel from the map, but access here
	 * is okay since it's going away only after RCU grace period.
	 * However, we need to check whether it's still present.
	 */
	spin_lock_bh(&bucket->lock);
	elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
					       elem->key, map->key_size);
	if (elem_probe && elem_probe == elem) {
		hlist_del_rcu(&elem->node);
		sock_map_unref(elem->sk, elem);
		sock_hash_free_elem(htab, elem);
	}
	spin_unlock_bh(&bucket->lock);
}

static void klpr_sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
{
	switch (link->map->map_type) {
	case BPF_MAP_TYPE_SOCKMAP:
		return klpr_sock_map_delete_from_link(link->map, sk,
						 link->link_raw);
	case BPF_MAP_TYPE_SOCKHASH:
		return sock_hash_delete_from_link(link->map, sk,
						  link->link_raw);
	default:
		break;
	}
}

void klpp_sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
{
	struct sk_psock_link *link;

	while ((link = sk_psock_link_pop(psock))) {
		klpr_sock_map_unlink(sk, link);
		sk_psock_free_link(link);
	}
}


#include "livepatch_bsc1235250.h"

#include <linux/livepatch.h>

extern typeof(sk_psock_link_pop) sk_psock_link_pop
	 KLP_RELOC_SYMBOL(vmlinux, vmlinux, sk_psock_link_pop);
extern typeof(sock_hash_lookup_elem_raw) sock_hash_lookup_elem_raw
	 KLP_RELOC_SYMBOL(vmlinux, vmlinux, sock_hash_lookup_elem_raw);
extern typeof(sock_map_unref) sock_map_unref
	 KLP_RELOC_SYMBOL(vmlinux, vmlinux, sock_map_unref);
