/*
 * Copyright (C) 2025, Stephan Mueller <smueller@chronox.de>
 *
 * License: see LICENSE file in root directory
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, ALL OF
 * WHICH ARE HEREBY DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
 * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
 * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 * USE OF THIS SOFTWARE, EVEN IF NOT ADVISED OF THE POSSIBILITY OF SUCH
 * DAMAGE.
 */
/*
 * This code is derived in parts from the code distribution provided with
 * https://pqc-hqc.org/
 *
 * The code is referenced as Public Domain
 */
/**
 * @file vector.c
 * @brief Implementation of vectors sampling and some utilities for the HQC scheme
 */

#include "../shake_prng.h"
#include "hqc_type.h"
#include "parsing_avx2.h"
#include "vector_avx2.h"

#if (LC_HQC_TYPE == 128)

#define SHBIT32 46

static uint64_t v_val[LC_HQC_PARAM_OMEGA_R] = {
	3982610457, 3982835871, 3983061310, 3983286775, 3983512265, 3983737781,
	3983963323, 3984188890, 3984414482, 3984640100, 3984865744, 3985091413,
	3985317108, 3985542828, 3985768574, 3985994345, 3986220142, 3986445965,
	3986671813, 3986897687, 3987123586, 3987349511, 3987575461, 3987801438,
	3988027439, 3988253467, 3988479520, 3988705599, 3988931703, 3989157833,
	3989383988, 3989610169, 3989836376, 3990062609, 3990288867, 3990515151,
	3990741460, 3990967795, 3991194156, 3991420543, 3991646955, 3991873393,
	3992099856, 3992326346, 3992552861, 3992779401, 3993005968, 3993232560,
	3993459178, 3993685821, 3993912490, 3994139185, 3994365906, 3994592653,
	3994819425, 3995046223, 3995273047, 3995499896, 3995726771, 3995953672,
	3996180599, 3996407552, 3996634530, 3996861534, 3997088564, 3997315620,
	3997542701, 3997769808, 3997996942, 3998224101, 3998451285, 3998678496,
	3998905732, 3999132994, 3999360282
};

#elif (LC_HQC_TYPE == 192)

#define SHBIT32 47
static uint64_t v_val[LC_HQC_PARAM_OMEGA_R] = {
	3925622391, 3925731892, 3925841400, 3925950913, 3926060433, 3926169959,
	3926279491, 3926389028, 3926498573, 3926608123, 3926717679, 3926827242,
	3926936810, 3927046385, 3927155966, 3927265552, 3927375145, 3927484745,
	3927594350, 3927703961, 3927813579, 3927923202, 3928032832, 3928142468,
	3928252110, 3928361758, 3928471412, 3928581072, 3928690739, 3928800411,
	3928910090, 3929019775, 3929129466, 3929239163, 3929348866, 3929458575,
	3929568291, 3929678012, 3929787740, 3929897474, 3930007214, 3930116960,
	3930226712, 3930336471, 3930446235, 3930556006, 3930665782, 3930775565,
	3930885354, 3930995149, 3931104951, 3931214758, 3931324572, 3931434391,
	3931544217, 3931654049, 3931763887, 3931873731, 3931983582, 3932093438,
	3932203301, 3932313170, 3932423044, 3932532925, 3932642813, 3932752706,
	3932862605, 3932972511, 3933082423, 3933192341, 3933302265, 3933412195,
	3933522131, 3933632074, 3933742022, 3933851977, 3933961938, 3934071905,
	3934181878, 3934291858, 3934401843, 3934511835, 3934621833, 3934731837,
	3934841847, 3934951863, 3935061886, 3935171914, 3935281949, 3935391990,
	3935502037, 3935612090, 3935722149, 3935832215, 3935942286, 3936052364,
	3936162448, 3936272538, 3936382635, 3936492737, 3936602846, 3936712960,
	3936823081, 3936933208, 3937043342, 3937153481, 3937263627, 3937373778,
	3937483936, 3937594100, 3937704271, 3937814447, 3937924630, 3938034818
};

#elif (LC_HQC_TYPE == 256)

#define SHBIT32 47

static uint64_t v_val[LC_HQC_PARAM_OMEGA_R] = {
	2441790661, 2441833027, 2441875394, 2441917763, 2441960133, 2442002504,
	2442044877, 2442087252, 2442129628, 2442172005, 2442214384, 2442256765,
	2442299147, 2442341530, 2442383915, 2442426301, 2442468689, 2442511078,
	2442553469, 2442595861, 2442638255, 2442680650, 2442723047, 2442765445,
	2442807844, 2442850245, 2442892648, 2442935052, 2442977457, 2443019864,
	2443062272, 2443104682, 2443147094, 2443189506, 2443231921, 2443274336,
	2443316754, 2443359172, 2443401593, 2443444014, 2443486437, 2443528862,
	2443571288, 2443613715, 2443656144, 2443698575, 2443741007, 2443783440,
	2443825875, 2443868312, 2443910749, 2443953189, 2443995630, 2444038072,
	2444080516, 2444122961, 2444165407, 2444207856, 2444250305, 2444292756,
	2444335209, 2444377663, 2444420119, 2444462576, 2444505034, 2444547494,
	2444589955, 2444632418, 2444674883, 2444717349, 2444759816, 2444802285,
	2444844755, 2444887227, 2444929700, 2444972175, 2445014651, 2445057129,
	2445099608, 2445142088, 2445184571, 2445227054, 2445269539, 2445312026,
	2445354514, 2445397003, 2445439494, 2445481987, 2445524480, 2445566976,
	2445609473, 2445651971, 2445694471, 2445736972, 2445779475, 2445821979,
	2445864485, 2445906992, 2445949501, 2445992011, 2446034523, 2446077036,
	2446119550, 2446162066, 2446204584, 2446247103, 2446289623, 2446332145,
	2446374669, 2446417194, 2446459720, 2446502248, 2446544778, 2446587308,
	2446629841, 2446672375, 2446714910, 2446757447, 2446799985, 2446842525,
	2446885066, 2446927608, 2446970153, 2447012698, 2447055245, 2447097794,
	2447140344, 2447182896, 2447225449, 2447268003, 2447310559, 2447353117,
	2447395676, 2447438236, 2447480798, 2447523361, 2447565926, 2447608493,
	2447651060, 2447693630, 2447736201, 2447778773, 2447821347, 2447863922,
	2447906499, 2447949077, 2447991657, 2448034238, 2448076820
};
#endif

/**
 * @brief Constant-time Barret reduction
 *
 * @param[in] a An integer to be reduced 
 * @param[in] i An array index
 * @return an integer equal to a % (PARAM_N - i)
 */
static inline uint16_t barrett_reduce(uint32_t a, uint16_t i)
{
	uint32_t t;

	t = (uint32_t)((v_val[i] * a + v_val[i]) >> SHBIT32);
	t *= (LC_HQC_PARAM_N - i);

	return (uint16_t)(a - t);
}

/**
 * @brief Constant-time comparison of two integers v1 and v2
 *
 * Returns 1 if v1 is equal to v2 and 0 otherwise
 * https://gist.github.com/sneves/10845247
 *
 * @param[in] v1 integer 1
 * @param[in] v2 integer 2
 */
static inline uint32_t compare_u32(uint32_t v1, uint32_t v2)
{
	return 1 ^ ((uint32_t)((v1 - v2) | (v2 - v1)) >> 31);
}

/**
 * @brief Generates a vector of a given Hamming weight
 *
 * Implementation of Algorithm 5 in https://eprint.iacr.org/2021/1631.pdf
 *
 * @param[in] ctx Pointer to the context of the seed expander
 * @param[in] v Pointer to an array
 * @param[in] weight Integer that is the Hamming weight
 */
void vect_set_random_fixed_weight_avx2(
	struct lc_hash_ctx *shake256, __m256i *v256, uint16_t weight,
	struct vect_set_random_fixed_weight_ws *ws)
{
	const __m256i posCmp256 = (__m256i){ 0UL, 1UL, 2UL, 3UL };

	seedexpander(shake256, (uint8_t *)ws->rand_u32, 4 * weight);

	for (uint16_t i = 0; i < weight; ++i)
		ws->tmp[i] = i + barrett_reduce(ws->rand_u32[i], i);

	for (int32_t i = (weight - 1); i-- > 0;) {
		uint32_t found = 0;

		for (size_t j = (size_t)i + 1; j < weight; ++j)
			found |= compare_u32(ws->tmp[j], ws->tmp[i]);

		uint32_t mask = -found;
		ws->tmp[i] = (mask & (uint32_t)i) ^ (~mask & ws->tmp[i]);
	}

	LC_FPU_ENABLE;

	for (uint16_t i = 0; i < weight; i++) {
		// we store the bloc number and bit position of each vb[i]
		uint64_t bloc = ws->tmp[i] >> 6;

		ws->bloc256[i] = _mm256_set1_epi64x((long long)(bloc >> 2));
		uint64_t pos = (bloc & 0x3UL);
		__m256i pos256 = _mm256_set1_epi64x((long long)pos);
		__m256i mask256 = _mm256_cmpeq_epi64(pos256, posCmp256);
		uint64_t bit64 = 1ULL << (ws->tmp[i] & 0x3f);
		__m256i bl256 = _mm256_set1_epi64x((long long)bit64);
		ws->bit256[i] = bl256 & mask256;
	}

#define LOOP_SIZE LC_HQC_CEIL_DIVIDE(LC_HQC_PARAM_N, 256)

	for (uint32_t i = 0; i < LOOP_SIZE; i++) {
		__m256i aux = _mm256_setzero_si256();
		__m256i i256 = _mm256_set1_epi64x(i);

		for (uint32_t j = 0; j < weight; j++) {
			__m256i mask256 =
				_mm256_cmpeq_epi64(ws->bloc256[j], i256);
			aux ^= ws->bit256[j] & mask256;
		}

		_mm256_storeu_si256(&v256[i], _mm256_xor_si256(v256[i], aux));
	}

#undef LOOP_SIZE

	LC_FPU_DISABLE;
}
