/*
 * livepatch_bsc1198133
 *
 * Fix for CVE-2022-1158, bsc#1198133
 *
 *  Upstream commit:
 *  2a8859f373b0 ("KVM: x86/mmu: do compare-and-exchange of gPTE via the user
 *                 address")
 *
 *  SLE12-SP3 commit:
 *  Not affected
 *
 *  SLE12-SP4, SLE12-SP5, SLE15 and SLE15-SP1 commit:
 *  Not affected
 *
 *  SLE15-SP2 and -SP3 commit:
 *  0581a6684913d1c0b6d90c21bfc1710eff2df07d
 *
 *  Copyright (c) 2022 SUSE
 *  Author: Nicolai Stange <nstange@suse.de>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

#if IS_ENABLED(CONFIG_X86_64)

#if !IS_MODULE(CONFIG_KVM)
#error "Live patch supports only CONFIG_KVM=m"
#endif

/* from include/linux/tracepoint.h */
#define KLPR___DECLARE_TRACE(name, proto, args, cond, data_proto, data_args) \
	static struct tracepoint (*klpe___tracepoint_##name);		\
	static inline void klpr_trace_##name(proto)			\
	{								\
		if (unlikely(static_key_enabled(&(*klpe___tracepoint_##name).key))) \
			__DO_TRACE(&(*klpe___tracepoint_##name),	\
				TP_PROTO(data_proto),			\
				TP_ARGS(data_args),			\
				TP_CONDITION(cond), 0);		\
		if (IS_ENABLED(CONFIG_LOCKDEP) && (cond)) {		\
			rcu_read_lock_sched_notrace();			\
			rcu_dereference_sched((*klpe___tracepoint_##name).funcs); \
			rcu_read_unlock_sched_notrace();		\
		}							\
	}								\

#define KLPR_DECLARE_TRACE(name, proto, args)				\
	KLPR___DECLARE_TRACE(name, PARAMS(proto), PARAMS(args),		\
			cpu_online(raw_smp_processor_id()),		\
			PARAMS(void *__data, proto),			\
			PARAMS(__data, args))

#define KLPR_DEFINE_EVENT(template, name, proto, args)          \
	KLPR_DECLARE_TRACE(name, PARAMS(proto), PARAMS(args))

#define KLPR_TRACE_EVENT(name, proto, args)	\
	KLPR_DECLARE_TRACE(name, PARAMS(proto), PARAMS(args))

/* klp-ccp: from arch/x86/kvm/irq.h */
#include <linux/mm_types.h>
#include <linux/hrtimer.h>
#include <linux/kvm_host.h>

/* klp-ccp: from arch/x86/include/asm/kvm_host.h */
static struct kvm_x86_ops *(*klpe_kvm_x86_ops);

/* klp-ccp: from include/linux/kvm_host.h */
static unsigned long (*klpe_kvm_vcpu_gfn_to_hva_prot)(struct kvm_vcpu *vcpu, gfn_t gfn, bool *writable);

static void (*klpe_kvm_vcpu_mark_page_dirty)(struct kvm_vcpu *vcpu, gfn_t gfn);

/* klp-ccp: from arch/x86/kvm/irq.h */
#include <linux/spinlock.h>

/* klp-ccp: from arch/x86/kvm/ioapic.h */
#include <linux/kvm_host.h>
#include <kvm/iodev.h>

#define ASSERT(x) do { } while (0)

/* klp-ccp: from arch/x86/kvm/lapic.h */
#include <kvm/iodev.h>
#include <linux/kvm_host.h>
/* klp-ccp: from arch/x86/kvm/mmu.h */
#include <linux/kvm_host.h>
/* klp-ccp: from arch/x86/kvm/kvm_cache_regs.h */
#include <linux/kvm_host.h>

#define KVM_POSSIBLE_CR4_GUEST_BITS				  \
	(X86_CR4_PVI | X86_CR4_DE | X86_CR4_PCE | X86_CR4_OSFXSR  \
	 | X86_CR4_OSXMMEXCPT | X86_CR4_LA57 | X86_CR4_PGE)

static inline ulong klpr_kvm_read_cr4_bits(struct kvm_vcpu *vcpu, ulong mask)
{
	ulong tmask = mask & KVM_POSSIBLE_CR4_GUEST_BITS;
	if (tmask & vcpu->arch.cr4_guest_owned_bits)
		(*klpe_kvm_x86_ops)->decache_cr4_guest_bits(vcpu);
	return vcpu->arch.cr4 & mask;
}

/* klp-ccp: from arch/x86/kvm/mmu.h */
#define PT_WRITABLE_SHIFT 1
#define PT_USER_SHIFT 2

#define PT_PRESENT_MASK (1ULL << 0)
#define PT_WRITABLE_MASK (1ULL << PT_WRITABLE_SHIFT)
#define PT_USER_MASK (1ULL << PT_USER_SHIFT)

#define PT_ACCESSED_SHIFT 5

#define PT_DIRTY_SHIFT 6

#define PT_PAGE_SIZE_SHIFT 7
#define PT_PAGE_SIZE_MASK (1ULL << PT_PAGE_SIZE_SHIFT)

#define PT64_NX_SHIFT 63

#define PT32_DIR_PSE36_SIZE 4
#define PT32_DIR_PSE36_SHIFT 13
#define PT32_DIR_PSE36_MASK \
	(((1ULL << PT32_DIR_PSE36_SIZE) - 1) << PT32_DIR_PSE36_SHIFT)

#define PT32E_ROOT_LEVEL 3

static inline u8 klpr_permission_fault(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
				  unsigned pte_access, unsigned pte_pkey,
				  unsigned pfec)
{
	int cpl = (*klpe_kvm_x86_ops)->get_cpl(vcpu);
	unsigned long rflags = (*klpe_kvm_x86_ops)->get_rflags(vcpu);

	/*
	 * If CPL < 3, SMAP prevention are disabled if EFLAGS.AC = 1.
	 *
	 * If CPL = 3, SMAP applies to all supervisor-mode data accesses
	 * (these are implicit supervisor accesses) regardless of the value
	 * of EFLAGS.AC.
	 *
	 * This computes (cpl < 3) && (rflags & X86_EFLAGS_AC), leaving
	 * the result in X86_EFLAGS_AC. We then insert it in place of
	 * the PFERR_RSVD_MASK bit; this bit will always be zero in pfec,
	 * but it will be one in index if SMAP checks are being overridden.
	 * It is important to keep this branchless.
	 */
	unsigned long smap = (cpl - 3) & (rflags & X86_EFLAGS_AC);
	int index = (pfec >> 1) +
		    (smap >> (X86_EFLAGS_AC_BIT - PFERR_RSVD_BIT + 1));
	bool fault = (mmu->permissions[index] >> pte_access) & 1;
	u32 errcode = PFERR_PRESENT_MASK;

	WARN_ON(pfec & (PFERR_PK_MASK | PFERR_RSVD_MASK));
	if (unlikely(mmu->pkru_mask)) {
		u32 pkru_bits, offset;

		/*
		* PKRU defines 32 bits, there are 16 domains and 2
		* attribute bits per domain in pkru.  pte_pkey is the
		* index of the protection domain, so pte_pkey * 2 is
		* is the index of the first bit for the domain.
		*/
		pkru_bits = (vcpu->arch.pkru >> (pte_pkey * 2)) & 3;

		/* clear present bit, replace PFEC.RSVD with ACC_USER_MASK. */
		offset = (pfec & ~1) +
			((pte_access & PT_USER_MASK) << (PFERR_RSVD_BIT - PT_USER_SHIFT));

		pkru_bits &= mmu->pkru_mask >> offset;
		errcode |= -pkru_bits & PFERR_PK_MASK;
		fault |= (pkru_bits != 0);
	}

	return -(u32)fault & errcode;
}

static int (*klpe_kvm_arch_write_log_dirty)(struct kvm_vcpu *vcpu);

/* klp-ccp: from arch/x86/kvm/x86.h */
#include <linux/kvm_host.h>
/* klp-ccp: from arch/x86/kvm/cpuid.h */
#include <asm/processor.h>
/* klp-ccp: from arch/x86/kvm/mmu.c */
#include <linux/kvm_host.h>
#include <linux/types.h>
#include <linux/string.h>
#include <linux/mm.h>
#include <linux/highmem.h>
#include <linux/moduleparam.h>
#include <linux/export.h>
#include <linux/compiler.h>
#include <linux/srcu.h>
#include <linux/slab.h>
#include <linux/sched/signal.h>
#include <linux/uaccess.h>
#include <linux/hash.h>
#include <linux/kern_levels.h>
#include <linux/kthread.h>
#include <asm/page.h>
#include <asm/pat.h>
#include <asm/cmpxchg.h>
#include <asm/io.h>
#include <asm/vmx.h>
#include <asm/kvm_page_track.h>

/* klp-ccp: from arch/x86/kvm/trace.h */
#include <linux/tracepoint.h>
#include <asm/vmx.h>
#include <asm/clocksource.h>
#include <asm/pvclock-abi.h>

/* klp-ccp: from arch/x86/kvm/mmu.c */
#define pgprintk(x...) do { } while (0)

#define PTE_PREFETCH_NUM		8

#define PT64_LEVEL_BITS 9

#define PT64_LEVEL_SHIFT(level) \
		(PAGE_SHIFT + (level - 1) * PT64_LEVEL_BITS)

#define PT64_INDEX(address, level)\
	(((address) >> PT64_LEVEL_SHIFT(level)) & ((1 << PT64_LEVEL_BITS) - 1))

#define PT32_LEVEL_BITS 10

#define PT32_LEVEL_SHIFT(level) \
		(PAGE_SHIFT + (level - 1) * PT32_LEVEL_BITS)

#define PT32_LVL_OFFSET_MASK(level) \
	(PT32_BASE_ADDR_MASK & ((1ULL << (PAGE_SHIFT + (((level) - 1) \
						* PT32_LEVEL_BITS))) - 1))

#define PT32_INDEX(address, level)\
	(((address) >> PT32_LEVEL_SHIFT(level)) & ((1 << PT32_LEVEL_BITS) - 1))

#define PT64_BASE_ADDR_MASK (physical_mask & ~(u64)(PAGE_SIZE-1))

#define PT64_LVL_ADDR_MASK(level) \
	(PT64_BASE_ADDR_MASK & ~((1ULL << (PAGE_SHIFT + (((level) - 1) \
						* PT64_LEVEL_BITS))) - 1))
#define PT64_LVL_OFFSET_MASK(level) \
	(PT64_BASE_ADDR_MASK & ((1ULL << (PAGE_SHIFT + (((level) - 1) \
						* PT64_LEVEL_BITS))) - 1))

#define PT32_BASE_ADDR_MASK PAGE_MASK

#define PT32_LVL_ADDR_MASK(level) \
	(PAGE_MASK & ~((1ULL << (PAGE_SHIFT + (((level) - 1) \
					    * PT32_LEVEL_BITS))) - 1))

#define ACC_EXEC_MASK    1
#define ACC_WRITE_MASK   PT_WRITABLE_MASK
#define ACC_USER_MASK    PT_USER_MASK

/* klp-ccp: from arch/x86/kvm/mmutrace.h */
KLPR_TRACE_EVENT(
	kvm_mmu_pagetable_walk,
	TP_PROTO(u64 addr, u32 pferr),
	TP_ARGS(addr, pferr)
);

KLPR_TRACE_EVENT(
	kvm_mmu_paging_element,
	TP_PROTO(u64 pte, int level),
	TP_ARGS(pte, level)
);

KLPR_DEFINE_EVENT(kvm_mmu_set_bit_class, kvm_mmu_set_accessed_bit,

	TP_PROTO(unsigned long table_gfn, unsigned index, unsigned size),

	TP_ARGS(table_gfn, index, size)
);

KLPR_DEFINE_EVENT(kvm_mmu_set_bit_class, kvm_mmu_set_dirty_bit,

	TP_PROTO(unsigned long table_gfn, unsigned index, unsigned size),

	TP_ARGS(table_gfn, index, size)
);

KLPR_TRACE_EVENT(
	kvm_mmu_walker_error,
	TP_PROTO(u32 pferr),
	TP_ARGS(pferr)
);

/* klp-ccp: from arch/x86/kvm/mmu.c */
static int is_cpuid_PSE36(void)
{
	return 1;
}

static gfn_t pse36_gfn_delta(u32 gpte)
{
	int shift = 32 - PT32_DIR_PSE36_SHIFT - PAGE_SHIFT;

	return (gpte & PT32_DIR_PSE36_MASK) << shift;
}

#ifdef CONFIG_KVM_MMU_AUDIT

/* klp-ccp: from arch/x86/kvm/mmu_audit.c */
#include <linux/ratelimit.h>

/* klp-ccp: from arch/x86/kvm/mmu.c */
#else
#error "klp-ccp: non-taken branch"
#endif

static bool
__is_rsvd_bits_set(struct rsvd_bits_validate *rsvd_check, u64 pte, int level)
{
	int bit7 = (pte >> 7) & 1, low6 = pte & 0x3f;

	return (pte & rsvd_check->rsvd_bits_mask[bit7][level-1]) |
		((rsvd_check->bad_mt_xwr & (1ull << low6)) != 0);
}

static bool is_rsvd_bits_set(struct kvm_mmu *mmu, u64 gpte, int level)
{
	return __is_rsvd_bits_set(&mmu->guest_rsvd_check, gpte, level);
}

static inline bool is_last_gpte(struct kvm_mmu *mmu,
				unsigned level, unsigned gpte)
{
	/*
	 * The RHS has bit 7 set iff level < mmu->last_nonleaf_level.
	 * If it is clear, there are no large pages at this level, so clear
	 * PT_PAGE_SIZE_MASK in gpte if that is the case.
	 */
	gpte &= level - mmu->last_nonleaf_level;

	/*
	 * PT_PAGE_TABLE_LEVEL always terminates.  The RHS has bit 7 set
	 * iff level <= PT_PAGE_TABLE_LEVEL, which for our purpose means
	 * level == PT_PAGE_TABLE_LEVEL; set PT_PAGE_SIZE_MASK in gpte then.
	 */
	gpte |= level - PT_PAGE_TABLE_LEVEL - 1;

	return gpte & PT_PAGE_SIZE_MASK;
}

#define PTTYPE_EPT 18 /* arbitrary */
#define PTTYPE PTTYPE_EPT

#include "bsc1198133_paging_tmpl.h"

#undef PTTYPE
#define PTTYPE 64

#include "bsc1198133_paging_tmpl.h"

#undef PTTYPE
#define PTTYPE 32

#include "bsc1198133_paging_tmpl.h"



#define LP_MODULE "kvm"

#include <linux/kernel.h>
#include <linux/module.h>
#include "livepatch_bsc1198133.h"
#include "../kallsyms_relocs.h"

static struct klp_kallsyms_reloc klp_funcs[] = {
	{ "__tracepoint_kvm_mmu_pagetable_walk",
	  (void *)&klpe___tracepoint_kvm_mmu_pagetable_walk, "kvm" },
	{ "__tracepoint_kvm_mmu_paging_element",
	  (void *)&klpe___tracepoint_kvm_mmu_paging_element, "kvm" },
	{ "__tracepoint_kvm_mmu_set_accessed_bit",
	  (void *)&klpe___tracepoint_kvm_mmu_set_accessed_bit, "kvm" },
	{ "__tracepoint_kvm_mmu_set_dirty_bit",
	  (void *)&klpe___tracepoint_kvm_mmu_set_dirty_bit, "kvm" },
	{ "__tracepoint_kvm_mmu_walker_error",
	  (void *)&klpe___tracepoint_kvm_mmu_walker_error, "kvm" },
	{ "kvm_arch_write_log_dirty", (void *)&klpe_kvm_arch_write_log_dirty,
	  "kvm" },
	{ "kvm_vcpu_gfn_to_hva_prot", (void *)&klpe_kvm_vcpu_gfn_to_hva_prot,
	  "kvm" },
	{ "kvm_vcpu_mark_page_dirty", (void *)&klpe_kvm_vcpu_mark_page_dirty,
	  "kvm" },
	{ "kvm_x86_ops", (void *)&klpe_kvm_x86_ops, "kvm" },
};

static int livepatch_bsc1198133_module_notify(struct notifier_block *nb,
					unsigned long action, void *data)
{
	struct module *mod = data;
	int ret;

	if (action != MODULE_STATE_COMING || strcmp(mod->name, LP_MODULE))
		return 0;

	mutex_lock(&module_mutex);
	ret = __klp_resolve_kallsyms_relocs(klp_funcs, ARRAY_SIZE(klp_funcs));
	mutex_unlock(&module_mutex);
	WARN(ret, "%s: delayed kallsyms lookup failed. System is broken and can crash.\n",
		__func__);

	return ret;
}

static struct notifier_block livepatch_bsc1198133_module_nb = {
	.notifier_call = livepatch_bsc1198133_module_notify,
	.priority = INT_MIN+1,
};

int livepatch_bsc1198133_init(void)
{
	int ret;

	mutex_lock(&module_mutex);
	if (find_module(LP_MODULE)) {
		ret = __klp_resolve_kallsyms_relocs(klp_funcs,
						    ARRAY_SIZE(klp_funcs));
		if (ret)
			goto out;
	}

	ret = register_module_notifier(&livepatch_bsc1198133_module_nb);
out:
	mutex_unlock(&module_mutex);
	return ret;
}

void livepatch_bsc1198133_cleanup(void)
{
	unregister_module_notifier(&livepatch_bsc1198133_module_nb);
}

#endif /* IS_ENABLED(CONFIG_X86_64) */
