/*
 * Copyright (C) 2025, Stephan Mueller <smueller@chronox.de>
 *
 * License: see LICENSE file in root directory
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, ALL OF
 * WHICH ARE HEREBY DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
 * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
 * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 * USE OF THIS SOFTWARE, EVEN IF NOT ADVISED OF THE POSSIBILITY OF SUCH
 * DAMAGE.
 */

#include "alignment.h"
#include "chacha20_c.h"
#include "compare.h"
#include "lc_chacha20_poly1305.h"
#include "math_helper.h"
#include "visibility.h"

static int lc_chacha20_poly1305_test(int argc)
{
	/* Test vector from RFC7539 */
	static const uint8_t aad[] = { 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1,
				       0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7 };
	static const uint8_t in[] = {
		0x4c, 0x61, 0x64, 0x69, 0x65, 0x73, 0x20, 0x61, 0x6e, 0x64,
		0x20, 0x47, 0x65, 0x6e, 0x74, 0x6c, 0x65, 0x6d, 0x65, 0x6e,
		0x20, 0x6f, 0x66, 0x20, 0x74, 0x68, 0x65, 0x20, 0x63, 0x6c,
		0x61, 0x73, 0x73, 0x20, 0x6f, 0x66, 0x20, 0x27, 0x39, 0x39,
		0x3a, 0x20, 0x49, 0x66, 0x20, 0x49, 0x20, 0x63, 0x6f, 0x75,
		0x6c, 0x64, 0x20, 0x6f, 0x66, 0x66, 0x65, 0x72, 0x20, 0x79,
		0x6f, 0x75, 0x20, 0x6f, 0x6e, 0x6c, 0x79, 0x20, 0x6f, 0x6e,
		0x65, 0x20, 0x74, 0x69, 0x70, 0x20, 0x66, 0x6f, 0x72, 0x20,
		0x74, 0x68, 0x65, 0x20, 0x66, 0x75, 0x74, 0x75, 0x72, 0x65,
		0x2c, 0x20, 0x73, 0x75, 0x6e, 0x73, 0x63, 0x72, 0x65, 0x65,
		0x6e, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, 0x62, 0x65,
		0x20, 0x69, 0x74, 0x2e
	};
	static const uint8_t key[] = { 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86,
				       0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d,
				       0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94,
				       0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
				       0x9c, 0x9d, 0x9e, 0x9f };
	static const uint8_t iv[] = { 0x40, 0x41, 0x42, 0x43,
				      0x44, 0x45, 0x46, 0x47 };
	static const uint8_t exp_ct[] = {
		0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86,
		0xaf, 0xbc, 0x53, 0xef, 0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51,
		0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7, 0x36, 0xee,
		0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12,
		0x82, 0xfa, 0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71,
		0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29, 0x05, 0xd6, 0xa5, 0xb6,
		0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77,
		0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58,
		0xfa, 0xb3, 0x24, 0xe4, 0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85,
		0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4, 0xde, 0xf0,
		0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce,
		0xc6, 0x4b, 0x61, 0x16
	};
	static const uint8_t exp_tag[] = { 0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09,
					   0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb,
					   0xd0, 0x60, 0x06, 0x91 };
	uint8_t act_ct[sizeof(exp_ct)] __align(sizeof(uint32_t));
	uint8_t act_tag[sizeof(exp_tag)] __align(sizeof(uint32_t));
	size_t len, i;
	const uint8_t *in_p;
	uint8_t *out_p;
	int ret = 0, rc;
	LC_CHACHA20_POLY1305_CTX_ON_STACK(cc20p1305);

	if (argc >= 2) {
		struct lc_chacha20_poly1305_cryptor *c = cc20p1305->aead_state;
		c->chacha20.sym = lc_chacha20_c;
	}

	lc_aead_setkey(cc20p1305, key, sizeof(key), iv, sizeof(iv));
	lc_aead_encrypt(cc20p1305, in, act_ct, sizeof(in), aad, sizeof(aad),
			act_tag, sizeof(act_tag));
	ret += lc_compare(act_ct, exp_ct, sizeof(exp_ct),
			  "ChaCha20 Poly 1305 encrypt ciphertext");
	ret += lc_compare(act_tag, exp_tag, sizeof(exp_tag),
			  "ChaCha20 Poly 1305 encrypt tag");
	lc_aead_zero(cc20p1305);

	lc_aead_setkey(cc20p1305, key, sizeof(key), iv, sizeof(iv));
	rc = lc_aead_decrypt(cc20p1305, act_ct, act_ct, sizeof(act_ct), aad,
			     sizeof(aad), act_tag, sizeof(act_tag));
	if (rc)
		ret += 1;
	ret += lc_compare(act_ct, in, sizeof(in),
			  "ChaCha20 Poly 1305 decrypt plaintext");
	lc_aead_zero(cc20p1305);

	/* Test the stream cipher API */
	lc_aead_setkey(cc20p1305, key, sizeof(key), iv, sizeof(iv));

	lc_aead_enc_init(cc20p1305, aad, sizeof(aad));

	len = sizeof(in);
	i = 1;
	in_p = in;
	out_p = act_ct;
	while (len) {
		size_t todo = min_size(len, i);

		lc_aead_enc_update(cc20p1305, in_p, out_p, todo);

		len -= todo;
		in_p += todo;
		out_p += todo;
		i++;
	}

	lc_aead_enc_final(cc20p1305, act_tag, sizeof(act_tag));

	ret += lc_compare(act_ct, exp_ct, sizeof(exp_ct),
			  "ChaCha20 Poly 1305 encrypt ciphertext");
	ret += lc_compare(act_tag, exp_tag, sizeof(exp_tag),
			  "ChaCha20 Poly 1305 encrypt tag");
	lc_aead_zero(cc20p1305);

	return ret;
}

LC_TEST_FUNC(int, main, int argc, char *argv[])
{
	int ret;

	(void)argc;
	(void)argv;

	ret = lc_chacha20_poly1305_test(argc);

	return ret;
}
