#!/usr/bin/env python3

"""
SPDX-License-Identifier: Apache-2.0
Copyright 2025 Keylime Project

A simple script to gather data for a one-shot attestation and send it to the verifier
"""

import argparse
import base64
import json
import logging
import os
import re
import string
import sys
import tempfile

import requests
from packaging.version import Version

from keylime import cmd_exec, config, crypto, keylime_logging, secure_mount, web_util
from keylime.requests_client import RequestsClient

API_VERSION = 2.4
DEBUG = False
EXIT_SUCESS = 0
MAX_RETRIES = 10
RETRY_INTERVAL = 3
IMA_ML = "/sys/kernel/security/ima/ascii_runtime_measurements"
MEASUREDBOOT_ML = "/sys/kernel/security/tpm0/binary_bios_measurements"

# various metadata about the TPM that we save
TPM_PUBLIC_KEY = None
EK_HANDLE = None
EK = None
AK_PW = None
AK = None
AK_CONTEXT_PATH = None

logging.basicConfig(level=logging.DEBUG, format="%(levelname)s - %(message)s")
logger = logging.getLogger("keylime_oneshot_attestation")


def random_password(length=20) -> str:
    rand = crypto.generate_random_key(length)
    chars = string.ascii_uppercase + string.digits + string.ascii_lowercase
    password = ""
    for i in range(length):
        password += chars[(rand[i]) % len(chars)]
    return password


# get the IMA log file data as a list
def get_ima_measurement_list():
    ima_fh = None
    measurement_list = []
    if os.path.exists(IMA_ML):
        ima_fh = open(IMA_ML, "r", encoding="utf-8")
        measurement_list = ima_fh.read()

    return measurement_list


# get the measured boot binary file data as base64 encoded data
def get_mb_log():
    mb_fh = None
    mb_log = None
    if os.path.exists(MEASUREDBOOT_ML):
        with open(MEASUREDBOOT_ML, "rb") as mb_fh:
            mb_log = base64.b64encode(mb_fh.read())

    return mb_log


def get_cmd_env():
    env = os.environ.copy()
    # Don't clobber existing setting (if present)
    if "TPM2TOOLS_TCTI" not in env:
        env["TPM2TOOLS_TCTI"] = "device:/dev/tpmrm0"
    return env


def tpm_cmd(cmd=[], expectedcode=EXIT_SUCESS, raiseOnError=True, outputpaths=None):
    env = get_cmd_env()

    # Convert single outputpath to list
    if isinstance(outputpaths, str):
        outputpaths = [outputpaths]

    tpm_cmd_pattern = re.compile("^tpm2_")
    if not re.match("^tpm2_", cmd[0]):
        cmd[0] = "tpm2_" + cmd[0]

    numtries = 0
    while True:
        logger.info('Executing command "%s"', " ".join(cmd))
        cmd_return = cmd_exec.run(
            cmd=cmd, expectedcode=expectedcode, raiseOnError=False, outputpaths=outputpaths, env=env
        )
        rc = cmd_return["code"]
        retout = cmd_return["retout"]
        reterr = cmd_return["reterr"]

        # keep trying to get quote if a PCR race condition occurred in quote
        if cmd[0] == "tpm2_quote" and cmd_exec.list_contains_substring(
            reterr, "Error validating calculated PCR composite with quote"
        ):
            numtries += 1
            if numtries >= MAX_RETRIES:
                logger.error("Agent did not return proper quote due to PCR race condition.")
                break
            logger.info(
                "Failed to get quote %d/%d times, trying again in %f seconds...", numtries, MAX_RETRIES, RETRY_INTERVAL
            )
            time.sleep(RETRY_INTERVAL)
            continue

        break

    # Don't bother continuing if TPM call failed and we're raising on error
    if rc != expectedcode and raiseOnError:
        raise Exception(f"Command: {cmd} returned {rc}, expected {expectedcode}, output {retout}, stderr {reterr}")

    return cmd_return


def check_tpm2_tools_version():
    # check the version of tpm2_tools installed
    cmd_return = tpm_cmd(["startup", "--version"])
    rc = cmd_return["code"]
    output = "".join(config.convert(cmd_return["retout"]))
    errout = "".join(config.convert(cmd_return["reterr"]))
    if rc != EXIT_SUCESS:
        logger.error("Error establishing tpm2-tools version using TPM2_Startup: %s" + str(rc) + ": " + str(errout))
        sys.exit()

    # Extract the `version="x.x.x"` from tools
    version_str_ = re.search(r'version="([^"]+)"', output)
    if version_str_ is None:
        msg = f"Could not determine tpm2-tools version from TPM2_Startup output '{output}'"
        logger.error(msg)
        sys.exit()
    version_str = version_str_.group(1)
    # Extract the full semver release number.
    tools_version = version_str.split("-")

    logger.info("TPM2-TOOLS Version: %s", tools_version[0])
    if Version(tools_version[0]) < Version("5.5"):
        logger.error("TPM2-TOOLS Version %s is not supported by this utility.", tools_version[0])
        sys.exit()


def start_tpm() -> None:
    startup_return = tpm_cmd(["startup", "-c"])
    errout = config.convert(startup_return["reterr"])
    rc = startup_return["code"]
    if rc != EXIT_SUCESS:
        raise Exception("Error initializing emulated TPM with TPM2_Startup: %s" + str(rc) + ": " + str(errout))


def setup_ek_handle(tpm_pw, ek_handle, encrypt_algo):
    global EK_HANDLE, EK

    # do we need to generatea a new EK or use an existing one
    if not ek_handle or ek_handle == "generate":
        logger.info("Removing all saved sessions from TPM")
        cmd_return = tpm_cmd(["flushcontext", "-s"], raiseOnError=False)
        cmd_return = tpm_cmd(["changeauth", "-c", "o", tpm_pw], raiseOnError=False)
        cmd_return = tpm_cmd(["changeauth", "-c", "e", tpm_pw], raiseOnError=False)
        rc = cmd_return["code"]

        # if we fail, see if already owned with this pw
        if rc != EXIT_SUCESS:
            cmd_return = tpm_cmd(["changeauth", "-c", "o", "-p", tpm_pw, tpm_pw], raiseOnError=False)
            cmd_return = tpm_cmd(["changeauth", "-c", "e", "-p", tpm_pw, tpm_pw], raiseOnError=False)

            reterr = cmd_return["reterr"]
            rc = cmd_return["code"]
            if rc != EXIT_SUCESS:
                # ut-oh, already owned but not with provided pw!
                raise Exception("Owner password unknown, TPM reset required. Code %s" + str(rc) + ": " + str(reterr))

        logger.info("TPM Owner password confirmed: %s", tpm_pw)

        # create a new EK handle
        with tempfile.NamedTemporaryFile() as tmppath:
            cmd = [
                "tpm2_createek",
                "-c",
                "-",
                "-G",
                encrypt_algo,
                "-u",
                tmppath.name,
                "-w",
                str(tpm_pw),
                "-P",
                str(tpm_pw),
            ]
            cmd_return = tpm_cmd(cmd, raiseOnError=False, outputpaths=tmppath.name)
            output = cmd_return["retout"]
            reterr = cmd_return["reterr"]
            rc = cmd_return["code"]
            ek_binary = cmd_return["fileouts"][tmppath.name]
            EK = base64.b64encode(ek_binary).decode("ascii")

            if rc != EXIT_SUCESS:
                raise Exception("tpm2_createek failed with code " + str(rc) + ": " + str(reterr))

            retyaml = config.yaml_to_dict(output, logger=logger)
            if retyaml is None:
                raise Exception("Could not read YAML output of tpm2_createek.")
            if "persistent-handle" in retyaml:
                EK_HANDLE = retyaml["persistent-handle"]
            else:
                raise Exception("No persistent-handle in YAML output of tpm2_createek.")
            logger.info("Created EK with handle: %s", hex(EK_HANDLE))

            # Make sure that all transient objects are flushed
            tpm_cmd(["flushcontext", "-t"], raiseOnError=False)
    else:
        # use an existing EK handle
        EK_HANDLE = int(ek_handle, 16)
        logger.info("Using an already created EK with handle: %s", hex(EK_HANDLE))

        with tempfile.NamedTemporaryFile() as tmppath:
            cmd = ["readpublic", "-c", hex(EK_HANDLE), "-o", tmppath.name, "-f", "tss"]
            cmd_return = tpm_cmd(cmd, raiseOnError=False, outputpaths=tmppath.name)

            reterr = cmd_return["reterr"]
            rc = cmd_return["code"]
            ek_binary = cmd_return["fileouts"][tmppath.name]
            EK = base64.b64encode(ek_binary).decode("ascii")
            if rc != EXIT_SUCESS:
                raise Exception("tpm2_readpublic failed with code " + str(rc) + ": " + str(reterr))


def setup_ak_handle(tpm_pw, tpm_encryption_algorithm, tpm_signing_algorithm, tpm_hash_algorithm):
    global AK_PW, AK, AK_CONTEXT_PATH
    logger.info("Creating a new AK identity")

    AK_PW = random_password(20)
    # make a temp file for the output
    with tempfile.NamedTemporaryFile() as akpubfile, tempfile.NamedTemporaryFile(delete=False) as akcontextfile:
        # ok lets write out the key now
        cmd = [
            "createak",
            "-C",
            hex(int(EK_HANDLE)),
            "-c",
            akcontextfile.name,
            "-G",
            tpm_encryption_algorithm,
            "-g",
            tpm_hash_algorithm,
            "-s",
            tpm_signing_algorithm,
            "-u",
            akpubfile.name,
            "-p",
            AK_PW,
            "-P",
            tpm_pw,
        ]
        cmd_return = tpm_cmd(cmd, outputpaths=akpubfile.name)
        retout = cmd_return["retout"]
        reterr = cmd_return["reterr"]
        rc = cmd_return["code"]

        if rc != EXIT_SUCESS:
            raise Exception("tpm2_createak failed with code " + str(rc) + ": " + str(reterr))

        jsonout = config.yaml_to_dict(retout, logger=logger)
        if jsonout is None:
            raise Exception("Unable to parse YAML output of tpm2_createak. Is your tpm2-tools installation up to date?")
        ak_binary = cmd_return["fileouts"][akpubfile.name]
        if not ak_binary:
            raise Exception(
                "Unable to read public AK from create identity. Is your tpm2-tools installation up to date?"
            )
        AK = base64.b64encode(ak_binary).decode("ascii")

        if "loaded-key" not in jsonout:
            raise Exception("tpm2_createak failed to create AK: return " + str(reterr))

        AK_CONTEXT_PATH = akcontextfile.name

    # Make sure that all transient objects are flushed
    tpm_cmd(["flushcontext", "-t"], raiseOnError=False)


def create_quote(tpm_pw, ek_handle, nonce, encrypt_algo, signing_algo, hash_algo):
    setup_ek_handle(tpm_pw, ek_handle, encrypt_algo)
    setup_ak_handle(tpm_pw, encrypt_algo, signing_algo, hash_algo)
    quote = ""

    with tempfile.NamedTemporaryFile() as quotepath, tempfile.NamedTemporaryFile() as sigpath, tempfile.NamedTemporaryFile() as pcrpath:
        # get the quote from the TPM
        cmd = [
            "quote",
            "-c",
            AK_CONTEXT_PATH,
            "-l",
            hash_algo + ":0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15",
            "-q",
            bytes(nonce, encoding="utf8").hex(),
            "-m",
            quotepath.name,
            "-s",
            sigpath.name,
            "-o",
            pcrpath.name,
            "-g",
            hash_algo,
            "-p",
            AK_PW,
        ]

        print(cmd)
        cmd_return = tpm_cmd(cmd, outputpaths=[quotepath.name, sigpath.name, pcrpath.name])
        # Make sure that all transient objects are flushed
        tpm_cmd(["tpm2_flushcontext", "-t"], raiseOnError=False)

        quoteraw = cmd_return["fileouts"][quotepath.name]
        sigraw = cmd_return["fileouts"][sigpath.name]
        pcrraw = cmd_return["fileouts"][pcrpath.name]
        quote_b64encode = base64.b64encode(quoteraw)
        sigraw_b64encode = base64.b64encode(sigraw)
        pcrraw_b64encode = base64.b64encode(pcrraw)
        quote = (
            quote_b64encode.decode("utf-8")
            + ":"
            + sigraw_b64encode.decode("utf-8")
            + ":"
            + pcrraw_b64encode.decode("utf-8")
        )

        # get rid of the temp files if they still exists
        if DEBUG:
            logger.info("Keeping temporary AK context file: %s", AK_CONTEXT_PATH)
        else:
            logger.info("Deleteing temporary AK file(s)")
            if AK_CONTEXT_PATH and os.path.isfile(AK_CONTEXT_PATH):
                os.unlink(AK_CONTEXT_PATH)

        return "r" + quote


def main():
    global DEBUG
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument(
        "--ek-handle",
        dest="ek_handle",
        required=False,
        help="The lable for the TPM's EK handle to reuse. Leave blank or set to 'generate' to generate a new EK handle.",
        default="generate",
    )
    arg_parser.add_argument(
        "--tpm-pw",
        dest="tpm_pw",
        required=True,
        help="The password for the TPM. If set to 'generate' then a new random password will be created (if possible).",
    )
    arg_parser.add_argument(
        "--tpm-encryption-algorithm",
        dest="tpm_encryption_algorithm",
        required=False,
        help="The encryption algorithm to be used by the TPM. Defaults to 'rsa'",
        default="rsa",
    )
    arg_parser.add_argument(
        "--tpm-signing-algorithm",
        dest="tpm_signing_algorithm",
        required=False,
        help="The signing algorithm to be used by the TPM. Defaults to 'rsassa'",
        default="rsassa",
    )
    arg_parser.add_argument(
        "--tpm-hash-algorithm",
        dest="tpm_hash_algorithm",
        required=False,
        help="The hash algorithm to be used by the TPM. Defaults to 'sha256'",
        default="sha256",
    )
    arg_parser.add_argument(
        "--runtime-policy",
        dest="runtime_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime runtime policy.",
    )
    arg_parser.add_argument(
        "--tpm-policy",
        dest="tpm_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime TPM policy.",
    )
    arg_parser.add_argument(
        "--mb-policy",
        dest="mb_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime measured boot policy.",
    )
    arg_parser.add_argument("--debug", dest="debug", action="store_true", required=False, help="Show debug messages")

    arg_parser.add_argument(
        "--verifier-host",
        dest="verifier_host",
        required=True,
        help="The hostname of the Keylime verifier server",
    )
    arg_parser.add_argument(
        "--verifier-port",
        dest="verifier_port",
        required=False,
        default="8880",
        help="The port of the Keylime verifier server. Defaults to 8881",
    )
    arg_parser.add_argument(
        "--verifier-cert",
        dest="verifier_cert",
        required=False,
        default="/var/lib/keylime/cv_ca/server-cert.crt",
        help="mTLS cert for the Keylime Verifier. Defaults to /var/lib/keylime/cv_ca/server-cert.crt",
    )
    arg_parser.add_argument(
        "--verifier-cacert",
        dest="verifier_cacert",
        required=False,
        default="/var/lib/keylime/cv_ca/cacert.crt",
        help="mTLS CA cert for the Keylime Verifier. Defaults to /var/lib/keylime/cv_ca/cacert.crt",
    )
    arg_parser.add_argument(
        "--verifier-key",
        dest="verifier_key",
        required=False,
        default="/var/lib/keylime/cv_ca/server-private.pem",
        help="mTLS key for the Keylime Verifier. Defaults to /var/lib/keylime/cv_ca/server-private.pem",
    )
    arg_parser.add_argument(
        "--verifier-key-pw",
        dest="verifier_key_pw",
        required=False,
        default="",
        help="mTLS key password for the Keylime Verifier if necessary",
    )

    args = arg_parser.parse_args()

    if args.debug:
        DEBUG = True
        logger.info("Set logging to DEBUG")
        logger.setLevel(logging.DEBUG)

    runtime_policy = ""
    tpm_policy = ""
    mb_policy = ""

    if args.runtime_policy:
        runtime_policy = args.runtime_policy.read()
    if args.tpm_policy:
        tpm_policy = args.tpm_policy.read()
    if args.mb_policy:
        mb_policy = args.mb_policy.read()

    check_tpm2_tools_version()
    start_tpm()

    nonce = random_password(20)

    # if the pw is set to "generate" create a new one
    tpm_pw = args.tpm_pw
    hash_algo = args.tpm_hash_algorithm
    if tpm_pw == "generate":
        logger.info("Generating random TPM owner password")
        tpm_pw = random_password(20)

    quote = create_quote(
        tpm_pw, args.ek_handle, nonce, args.tpm_encryption_algorithm, args.tpm_signing_algorithm, hash_algo
    )
    logger.info("TPM Quote: %s", quote)
    ima_measurement_list = get_ima_measurement_list()
    mb_log = get_mb_log()

    # send all of the data to the verifier
    oneshot_tpm_data = {
        "quote": quote,
        "nonce": nonce,
        "hash_alg": hash_algo,
        "tpm_ek": EK,
        "tpm_ak": AK,
        "tpm_policy": tpm_policy,
        "runtime_policy": runtime_policy,
        "mb_policy": mb_policy,
        "ima_measurement_list": ima_measurement_list,
        "mb_log": mb_log,
    }

    oneshot_data = {
        "type": "tpm",
        "data": oneshot_tpm_data,
    }

    verifier_url = f"https://{args.verifier_host}:{args.verifier_port}"
    oneshot_url = f"/v{API_VERSION}/verify/evidence"
    response = None
    tls_context = web_util.generate_tls_context(
        args.verifier_cert,
        args.verifier_key,
        [args.verifier_cacert],
        args.verifier_key_pw,
        is_client=True,
        logger=logger,
    )
    try:
        logger.debug("Sending data to the verifier at %s: %s", verifier_url + oneshot_url, oneshot_data)
        client = RequestsClient(base_url=verifier_url, tls_enabled=True, tls_context=tls_context, ignore_hostname=True)

        response = client.post(
            oneshot_url,
            data=json.dumps(oneshot_data),
            timeout=5,
            cert=(args.verifier_cert, args.verifier_key),
            verify=args.verifier_cacert,
        )
    except Exception as e:
        logger.error("Failed to send oneshot data to verifier at %s: %s", verifier_url, e)

    logger.info("Verifier HTTP Response Code: %s", response.status_code)
    if response.status_code != 200:
        logger.error(
            "Failure from verifier (%s) for oneshot data to verifier: %s",
            verifier_url + oneshot_url,
            response.status_code,
        )

    try:
        response_body = response.json()
    except Exception as e:
        logger.error("Failed to decode JSON payload from verifier: %s", e)

    logger.info("Verifier Response Body: %s", response_body)


if __name__ == "__main__":
    main()
