#!/usr/bin/env python3

"""
SPDX-License-Identifier: Apache-2.0
Copyright 2025 Keylime Project

A simple script to gather data for a one-shot attestation and send it to the verifier
"""

import argparse
import base64
import json
import logging
import os
import re
import ssl
import string
import sys
import tempfile

import requests
from packaging.version import Version

from keylime import cmd_exec, config, crypto
from keylime.requests_client import RequestsClient

API_VERSION = 2.5
DEBUG = False
EXIT_SUCESS = 0
MAX_RETRIES = 10
RETRY_INTERVAL = 3
IMA_ML = "/sys/kernel/security/ima/ascii_runtime_measurements"
MEASUREDBOOT_ML = "/sys/kernel/security/tpm0/binary_bios_measurements"

# various metadata about the TPM that we save
TPM_PUBLIC_KEY = None
EK_HANDLE = None
EK = None
AK_PW = None
AK = None
AK_CONTEXT_PATH = None

logging.basicConfig(level=logging.DEBUG, format="%(levelname)s - %(message)s")
logger = logging.getLogger("keylime_oneshot_attestation")


def random_password(length=20) -> str:
    rand = crypto.generate_random_key(length)
    chars = string.ascii_uppercase + string.digits + string.ascii_lowercase
    password = ""
    for i in range(length):
        password += chars[(rand[i]) % len(chars)]
    return password


# get the IMA log file data as a list
def get_ima_measurement_list():
    ima_fh = None
    measurement_list = []
    if os.path.exists(IMA_ML):
        ima_fh = open(IMA_ML, "r", encoding="utf-8")
        measurement_list = ima_fh.read()

    return measurement_list


# get the measured boot binary file data as base64 encoded data
def get_mb_log():
    mb_fh = None
    mb_log = None
    if os.path.exists(MEASUREDBOOT_ML):
        with open(MEASUREDBOOT_ML, "rb") as mb_fh:
            mb_log = base64.b64encode(mb_fh.read()).decode("utf-8")

    return mb_log


def get_cmd_env():
    env = os.environ.copy()
    # Don't clobber existing setting (if present)
    if "TPM2TOOLS_TCTI" not in env:
        env["TPM2TOOLS_TCTI"] = "device:/dev/tpmrm0"
    return env


def tpm_cmd(cmd=[], expectedcode=EXIT_SUCESS, raiseOnError=True, outputpaths=None):
    env = get_cmd_env()

    # Convert single outputpath to list
    if isinstance(outputpaths, str):
        outputpaths = [outputpaths]

    tpm_cmd_pattern = re.compile("^tpm2_")
    if not re.match("^tpm2_", cmd[0]):
        cmd[0] = "tpm2_" + cmd[0]

    numtries = 0
    while True:
        logger.info('Executing command "%s"', " ".join(cmd))
        cmd_return = cmd_exec.run(
            cmd=cmd, expectedcode=expectedcode, raiseOnError=False, outputpaths=outputpaths, env=env
        )
        rc = cmd_return["code"]
        retout = cmd_return["retout"]
        reterr = cmd_return["reterr"]

        # keep trying to get quote if a PCR race condition occurred in quote
        if cmd[0] == "tpm2_quote" and cmd_exec.list_contains_substring(
            reterr, "Error validating calculated PCR composite with quote"
        ):
            numtries += 1
            if numtries >= MAX_RETRIES:
                logger.error("Agent did not return proper quote due to PCR race condition.")
                break
            logger.info(
                "Failed to get quote %d/%d times, trying again in %f seconds...", numtries, MAX_RETRIES, RETRY_INTERVAL
            )
            time.sleep(RETRY_INTERVAL)
            continue

        break

    # Don't bother continuing if TPM call failed and we're raising on error
    if rc != expectedcode and raiseOnError:
        raise Exception(f"Command: {cmd} returned {rc}, expected {expectedcode}, output {retout}, stderr {reterr}")

    return cmd_return


def check_tpm2_tools_version():
    # check the version of tpm2_tools installed
    cmd_return = tpm_cmd(["startup", "--version"])
    rc = cmd_return["code"]
    output = "".join(config.convert(cmd_return["retout"]))
    errout = "".join(config.convert(cmd_return["reterr"]))
    if rc != EXIT_SUCESS:
        logger.error("Error establishing tpm2-tools version using TPM2_Startup: %s" + str(rc) + ": " + str(errout))
        sys.exit(1)

    # Extract the `version="x.x.x"` from tools
    version_str_ = re.search(r'version="([^"]+)"', output)
    if version_str_ is None:
        msg = f"Could not determine tpm2-tools version from TPM2_Startup output '{output}'"
        logger.error(msg)
        sys.exit(1)
    version_str = version_str_.group(1)
    # Extract the full semver release number.
    tools_version = version_str.split("-")

    logger.info("TPM2-TOOLS Version: %s", tools_version[0])
    if Version(tools_version[0]) < Version("5.5"):
        logger.error("TPM2-TOOLS Version %s is not supported by this utility.", tools_version[0])
        sys.exit(1)


def start_tpm() -> None:
    startup_return = tpm_cmd(["startup", "-c"])
    errout = config.convert(startup_return["reterr"])
    rc = startup_return["code"]
    if rc != EXIT_SUCESS:
        raise Exception("Error initializing emulated TPM with TPM2_Startup: %s" + str(rc) + ": " + str(errout))


def setup_ek_handle(tpm_endorsement_pw, tpm_owner_pw, ek_handle, encrypt_algo):
    global EK_HANDLE, EK

    # do we need to generate a new EK or use an existing one
    if not ek_handle or ek_handle == "generate":
        logger.info("Removing all saved sessions from TPM")
        cmd_return = tpm_cmd(["flushcontext", "-s"], raiseOnError=False)

        # create a new EK handle
        with tempfile.NamedTemporaryFile() as tmppath:
            cmd = [
                "tpm2_createek",
                "-c",
                "-",
                "-G",
                encrypt_algo,
                "-u",
                tmppath.name,
                "-w",
                str(tpm_owner_pw),
                "-P",
                str(tpm_endorsement_pw),
            ]
            cmd_return = tpm_cmd(cmd, raiseOnError=False, outputpaths=tmppath.name)
            output = cmd_return["retout"]
            reterr = cmd_return["reterr"]
            rc = cmd_return["code"]
            ek_binary = cmd_return["fileouts"][tmppath.name]
            EK = base64.b64encode(ek_binary).decode("ascii")

            if rc != EXIT_SUCESS:
                raise Exception("tpm2_createek failed with code " + str(rc) + ": " + str(reterr))

            retyaml = config.yaml_to_dict(output, logger=logger)
            if retyaml is None:
                raise Exception("Could not read YAML output of tpm2_createek.")
            if "persistent-handle" in retyaml:
                EK_HANDLE = retyaml["persistent-handle"]
            else:
                raise Exception("No persistent-handle in YAML output of tpm2_createek.")
            logger.info("Created EK with handle: %s", hex(EK_HANDLE))

            # Make sure that all transient objects are flushed
            tpm_cmd(["flushcontext", "-t"], raiseOnError=False)
    else:
        # use an existing EK handle
        EK_HANDLE = int(ek_handle, 16)
        logger.info("Using an already created EK with handle: %s", hex(EK_HANDLE))

        with tempfile.NamedTemporaryFile() as tmppath:
            cmd = ["readpublic", "-c", hex(EK_HANDLE), "-o", tmppath.name, "-f", "tss"]
            cmd_return = tpm_cmd(cmd, raiseOnError=False, outputpaths=tmppath.name)

            reterr = cmd_return["reterr"]
            rc = cmd_return["code"]
            ek_binary = cmd_return["fileouts"][tmppath.name]
            EK = base64.b64encode(ek_binary).decode("ascii")
            if rc != EXIT_SUCESS:
                raise Exception("tpm2_readpublic failed with code " + str(rc) + ": " + str(reterr))


def setup_ak_handle(tpm_endorsement_pw, tpm_encryption_algorithm, tpm_signing_algorithm, tpm_hash_algorithm):
    global AK_PW, AK, AK_CONTEXT_PATH
    logger.info("Creating a new AK identity")

    AK_PW = random_password(20)
    # make a temp file for the output
    with tempfile.NamedTemporaryFile() as akpubfile, tempfile.NamedTemporaryFile(delete=False) as akcontextfile:
        # ok lets write out the key now
        cmd = [
            "createak",
            "-C",
            hex(int(EK_HANDLE)),
            "-c",
            akcontextfile.name,
            "-G",
            tpm_encryption_algorithm,
            "-g",
            tpm_hash_algorithm,
            "-s",
            tpm_signing_algorithm,
            "-u",
            akpubfile.name,
            "-p",
            AK_PW,
            "-P",
            tpm_endorsement_pw,
        ]
        cmd_return = tpm_cmd(cmd, outputpaths=akpubfile.name)
        retout = cmd_return["retout"]
        reterr = cmd_return["reterr"]
        rc = cmd_return["code"]

        if rc != EXIT_SUCESS:
            raise Exception("tpm2_createak failed with code " + str(rc) + ": " + str(reterr))

        jsonout = config.yaml_to_dict(retout, logger=logger)
        if jsonout is None:
            raise Exception("Unable to parse YAML output of tpm2_createak. Is your tpm2-tools installation up to date?")
        ak_binary = cmd_return["fileouts"][akpubfile.name]
        if not ak_binary:
            raise Exception(
                "Unable to read public AK from create identity. Is your tpm2-tools installation up to date?"
            )
        AK = base64.b64encode(ak_binary).decode("ascii")

        if "loaded-key" not in jsonout:
            raise Exception("tpm2_createak failed to create AK: return " + str(reterr))

        AK_CONTEXT_PATH = akcontextfile.name

    # Make sure that all transient objects are flushed
    tpm_cmd(["flushcontext", "-t"], raiseOnError=False)


def create_quote(tpm_endorsement_pw, tpm_owner_pw, ek_handle, nonce, encrypt_algo, signing_algo, hash_algo):
    setup_ek_handle(tpm_endorsement_pw, tpm_owner_pw, ek_handle, encrypt_algo)
    setup_ak_handle(tpm_endorsement_pw, encrypt_algo, signing_algo, hash_algo)
    quote = ""

    with tempfile.NamedTemporaryFile() as quotepath, tempfile.NamedTemporaryFile() as sigpath, tempfile.NamedTemporaryFile() as pcrpath:
        # get the quote from the TPM
        cmd = [
            "quote",
            "-c",
            AK_CONTEXT_PATH,
            "-l",
            hash_algo + ":0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15",
            "-q",
            bytes(nonce, encoding="utf8").hex(),
            "-m",
            quotepath.name,
            "-s",
            sigpath.name,
            "-o",
            pcrpath.name,
            "-g",
            hash_algo,
            "-p",
            AK_PW,
        ]

        print(cmd)
        cmd_return = tpm_cmd(cmd, outputpaths=[quotepath.name, sigpath.name, pcrpath.name])
        # Make sure that all transient objects are flushed
        tpm_cmd(["tpm2_flushcontext", "-t"], raiseOnError=False)

        quoteraw = cmd_return["fileouts"][quotepath.name]
        sigraw = cmd_return["fileouts"][sigpath.name]
        pcrraw = cmd_return["fileouts"][pcrpath.name]
        quote_b64encode = base64.b64encode(quoteraw)
        sigraw_b64encode = base64.b64encode(sigraw)
        pcrraw_b64encode = base64.b64encode(pcrraw)
        quote = (
            quote_b64encode.decode("utf-8")
            + ":"
            + sigraw_b64encode.decode("utf-8")
            + ":"
            + pcrraw_b64encode.decode("utf-8")
        )

        # get rid of the temp files if they still exists
        if DEBUG:
            logger.info("Keeping temporary AK context file: %s", AK_CONTEXT_PATH)
        else:
            logger.info("Deleteing temporary AK file(s)")
            if AK_CONTEXT_PATH and os.path.isfile(AK_CONTEXT_PATH):
                os.unlink(AK_CONTEXT_PATH)

        return "r" + quote


def main():
    global DEBUG
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument(
        "--ek-handle",
        dest="ek_handle",
        required=False,
        help="The lable for the TPM's EK handle to reuse. Leave blank or set to 'generate' to generate a new EK handle.",
        default="generate",
    )
    arg_parser.add_argument(
        "--tpm-owner-pw",
        dest="tpm_owner_pw",
        required=False,
        help="The password for the TPM Owner hierarchy",
        default="",
    )
    arg_parser.add_argument(
        "--tpm-endorsement-pw",
        dest="tpm_endorsement_pw",
        required=False,
        help="The password for the TPM Endorsement hierarchy",
        default="",
    )
    arg_parser.add_argument(
        "--tpm-encryption-algorithm",
        dest="tpm_encryption_algorithm",
        required=False,
        help="The encryption algorithm to be used by the TPM. Defaults to 'rsa'",
        default="rsa",
    )
    arg_parser.add_argument(
        "--tpm-signing-algorithm",
        dest="tpm_signing_algorithm",
        required=False,
        help="The signing algorithm to be used by the TPM. Defaults to 'rsassa'",
        default="rsassa",
    )
    arg_parser.add_argument(
        "--tpm-hash-algorithm",
        dest="tpm_hash_algorithm",
        required=False,
        help="The hash algorithm to be used by the TPM. Defaults to 'sha256'",
        default="sha256",
    )
    arg_parser.add_argument(
        "--runtime-policy",
        dest="runtime_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime runtime policy.",
    )
    arg_parser.add_argument(
        "--tpm-policy",
        dest="tpm_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime TPM policy.",
    )
    arg_parser.add_argument(
        "--mb-policy",
        dest="mb_policy",
        type=argparse.FileType("r"),
        required=False,
        help="File for the keylime measured boot policy.",
    )
    arg_parser.add_argument("--debug", dest="debug", action="store_true", required=False, help="Show debug messages")

    arg_parser.add_argument(
        "--verifier-host",
        dest="verifier_host",
        required=True,
        help="The hostname of the Keylime verifier server",
    )
    arg_parser.add_argument(
        "--verifier-port",
        dest="verifier_port",
        required=False,
        default="8880",
        help="The port of the Keylime verifier server. Defaults to 8881",
    )
    arg_parser.add_argument(
        "--verifier-cacert",
        dest="verifier_cacert",
        required=False,
        default="/var/lib/keylime/cv_ca/cacert.crt",
        help="CA certificate for verifying the Keylime Verifier's TLS certificate. Defaults to /var/lib/keylime/cv_ca/cacert.crt",
    )

    args = arg_parser.parse_args()

    if args.debug:
        DEBUG = True
        logger.info("Set logging to DEBUG")
        logger.setLevel(logging.DEBUG)

    runtime_policy = ""
    tpm_policy = ""
    mb_policy = ""

    if args.runtime_policy:
        runtime_policy = args.runtime_policy.read()
    if args.tpm_policy:
        tpm_policy = args.tpm_policy.read()
    if args.mb_policy:
        mb_policy = args.mb_policy.read()

    if not runtime_policy and not tpm_policy and not mb_policy:
        logger.error("At least one policy (--runtime-policy, --tpm-policy, or --mb-policy) must be provided")
        sys.exit(1)

    check_tpm2_tools_version()
    start_tpm()

    nonce = random_password(20)

    tpm_owner_pw = args.tpm_owner_pw
    tpm_endorsement_pw = args.tpm_endorsement_pw
    hash_algo = args.tpm_hash_algorithm

    quote = create_quote(
        tpm_endorsement_pw,
        tpm_owner_pw,
        args.ek_handle,
        nonce,
        args.tpm_encryption_algorithm,
        args.tpm_signing_algorithm,
        hash_algo,
    )
    logger.info("TPM Quote: %s", quote)
    ima_measurement_list = get_ima_measurement_list()
    mb_log = get_mb_log()

    # send all of the data to the verifier
    oneshot_tpm_data = {
        "quote": quote,
        "nonce": nonce,
        "hash_alg": hash_algo,
        "tpm_ek": EK,
        "tpm_ak": AK,
        "tpm_policy": tpm_policy,
        "runtime_policy": runtime_policy,
        "mb_policy": mb_policy,
        "ima_measurement_list": ima_measurement_list,
        "mb_log": mb_log,
    }

    oneshot_data = {
        "type": "tpm",
        "data": oneshot_tpm_data,
    }

    verifier_url = f"https://{args.verifier_host}:{args.verifier_port}"
    oneshot_url = f"/v{API_VERSION}/verify/evidence"
    response = None

    tls_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
    tls_context.check_hostname = False
    if sys.version_info >= (3, 7):
        tls_context.minimum_version = ssl.TLSVersion.TLSv1_2  # pylint: disable=E1101
    else:
        tls_context.options &= ~ssl.OP_NO_TLSv1_2
    tls_context.load_verify_locations(cafile=args.verifier_cacert)
    tls_context.verify_mode = ssl.CERT_REQUIRED
    try:
        logger.debug("Sending data to the verifier at %s: %s", verifier_url + oneshot_url, oneshot_data)
        client = RequestsClient(base_url=verifier_url, tls_enabled=True, tls_context=tls_context, ignore_hostname=True)

        response = client.post(
            oneshot_url,
            data=json.dumps(oneshot_data),
            timeout=5,
            verify=args.verifier_cacert,
        )
    except Exception as e:
        logger.error("Failed to send oneshot data to verifier at %s: %s", verifier_url, e)
        sys.exit(1)

    logger.info("Verifier HTTP Response Code: %s", response.status_code)
    if response.status_code != 200:
        logger.error(
            "Failure from verifier (%s) for oneshot data to verifier: %s",
            verifier_url + oneshot_url,
            response.status_code,
        )
        sys.exit(1)

    try:
        response_body = response.json()
        logger.info("Verifier Response Body: %s", response_body)
    except Exception as e:
        logger.error("Failed to decode JSON payload from verifier: %s", e)
        sys.exit(1)


if __name__ == "__main__":
    main()
