#! /bin/bash -x
# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.

set -eu

RUN_DIR=/run/nvidia
PID_FILE=${RUN_DIR}/${0##*/}.pid
DRIVER_VERSION=${DRIVER_VERSION:?"Missing DRIVER_VERSION env"}
DRIVER_BRANCH=${DRIVER_BRANCH:?"Missing DRIVER_BRANCH env"}
DRIVER_TYPE=${DRIVER_TYPE:-""}
NUM_VGPU_DEVICES=0
NVIDIA_MODULE_PARAMS=()
NVIDIA_UVM_MODULE_PARAMS=()
NVIDIA_MODESET_MODULE_PARAMS=()
NVIDIA_PEERMEM_MODULE_PARAMS=()
KERNEL_MODULE_TYPE=${KERNEL_MODULE_TYPE:-auto}

TARGETARCH=${TARGETARCH:-$(uname -m)}
DRIVER_ARCH=${TARGETARCH/amd64/x86_64}
DRIVER_ARCH=${DRIVER_ARCH/arm64/aarch64}
echo "DRIVER_ARCH is $DRIVER_ARCH"

_update_ca_certificates() {
    if [ -d /etc/pki/trust/anchors ] && [ -n "$(ls -A /etc/pki/trust/anchors)" ]; then
        update-ca-certificates
    fi
}

_assert_nvswitch_system() {
    [ -d /proc/driver/nvidia-nvswitch ] || return 1
    entries=$(ls -1 /proc/driver/nvidia-nvswitch/devices/* 2>/dev/null)
    if [ -z "${entries}" ]; then
        return 1
    fi
    return 0
}

# For each kernel module configuration file mounted into the container,
# parse the file contents and extract the custom module parameters that
# are to be passed as input to 'modprobe'.
#
# Assumptions:
# - Configuration files are named <module-name>.conf (i.e. nvidia.conf, nvidia-uvm.conf).
# - Configuration files are mounted inside the container at /drivers.
# - Each line in the file contains at least one parameter, where parameters on the same line
#   are space delimited. It is up to the user to properly format the file to ensure
#   the correct set of parameters are passed to 'modprobe'.
_get_module_params() {
    local base_path="/drivers"

    # Starting from R580, we need to enable the CDMM (Coherent Driver Memory Management) module parameter.
    # This prevents the GPU memory for coherent systems (GH200, GB200 etc) from being exposed as a NUMA node
    # and thereby preventing over-reporting of a Kubernetes node's memory. This is needed for Kubernetes use-cases
    if [[ "${DRIVER_BRANCH}" -ge 580 ]]; then
        NVIDIA_MODULE_PARAMS+=("NVreg_CoherentGPUMemoryMode=driver")
    fi

    # nvidia
    if [ -f "${base_path}/nvidia.conf" ]; then
       while read -r -a params || [ -n "${params[*]}" ]; do
           NVIDIA_MODULE_PARAMS+=("${params[@]}")
       done <"${base_path}/nvidia.conf"
       echo "Module parameters provided for nvidia: ${NVIDIA_MODULE_PARAMS[@]}"
    fi
    # nvidia-uvm
    if [ -f "${base_path}/nvidia-uvm.conf" ]; then
       while read -r -a params || [ -n "${params[*]}" ]; do
           NVIDIA_UVM_MODULE_PARAMS+=("${params[@]}")
       done <"${base_path}/nvidia-uvm.conf"
       echo "Module parameters provided for nvidia-uvm: ${NVIDIA_UVM_MODULE_PARAMS[@]}"
    fi
    # nvidia-modeset
    if [ -f "${base_path}/nvidia-modeset.conf" ]; then
       while read -r -a params || [ -n "${params[*]}" ]; do
           NVIDIA_MODESET_MODULE_PARAMS+=("${params[@]}")
       done <"${base_path}/nvidia-modeset.conf"
       echo "Module parameters provided for nvidia-modeset: ${NVIDIA_MODESET_MODULE_PARAMS[@]}"
    fi
    # nvidia-peermem
    if [ -f "${base_path}/nvidia-peermem.conf" ]; then
       while read -r -a params || [ -n "${params[*]}" ]; do
           NVIDIA_PEERMEM_MODULE_PARAMS+=("${params[@]}")
       done <"${base_path}/nvidia-peermem.conf"
       echo "Module parameters provided for nvidia-peermem: ${NVIDIA_PEERMEM_MODULE_PARAMS[@]}"
    fi
}

# Load the kernel modules and start persistenced.
_load_driver() {
    echo "Parsing kernel module parameters..."
    _get_module_params

    local nv_fw_search_path="$RUN_DIR/driver/lib/firmware"
    local set_fw_path="true"
    local fw_path_config_file="/sys/module/firmware_class/parameters/path"
    for param in "${NVIDIA_MODULE_PARAMS[@]}"; do
        if [[ "$param" == "NVreg_EnableGpuFirmware=0" ]]; then
          set_fw_path="false"
        fi
    done

    if [[ "$set_fw_path" == "true" ]]; then
       echo "Configuring the following firmware search path in '$fw_path_config_file': $nv_fw_search_path"
        if [[ ! -z $(grep '[^[:space:]]' "$fw_path_config_file") ]]; then
            echo "WARNING: A search path is already configured in $fw_path_config_file"
            echo "         Retaining the current configuration"
        else
            echo -n "$nv_fw_search_path" > "$fw_path_config_file" || echo "WARNING: Failed to configure the firmware search path"
        fi
        if [ -d "/opt/lib/firmware/nvidia/${DRIVER_VERSION}" ]; then
            rm -rf "${nv_fw_search_path:?}/nvidia/${DRIVER_VERSION:?}"
            mkdir -p "${nv_fw_search_path:?}/nvidia/${DRIVER_VERSION:?}"
            cp /opt/lib/firmware/nvidia/${DRIVER_VERSION}/gsp_*.bin "${nv_fw_search_path:?}/nvidia/${DRIVER_VERSION:?}"
        fi
    fi

    local current_kernel=$(uname -r)
    cp -rfx "/run/host/lib/modules/${current_kernel}" /lib/modules
    mkdir -p "/lib/modules/${current_kernel}/updates"
    cp -fx "/opt/${KERNEL_TYPE}/"* "/lib/modules/${current_kernel}/updates"
    # do not ship drm driver
    rm -f "/lib/modules/${current_kernel}/updates/nvidia-drm."*
    depmod -a

    echo "Loading ipmi and i2c_core kernel modules..."
    modprobe -a i2c_core ipmi_msghandler ipmi_devintf


    echo "Loading NVIDIA driver kernel modules..."
    set -o xtrace +o nounset
    modprobe nvidia "${NVIDIA_MODULE_PARAMS[@]}"
    modprobe nvidia-uvm "${NVIDIA_UVM_MODULE_PARAMS[@]}"
    modprobe nvidia-modeset "${NVIDIA_MODESET_MODULE_PARAMS[@]}"
    set +o xtrace -o nounset

    echo "Starting NVIDIA persistence daemon..."
    nvidia-persistenced --persistence-mode

    if [ "${DRIVER_TYPE}" = "vgpu" ]; then
        echo "Copying gridd.conf..."
        cp /drivers/gridd.conf /etc/nvidia/gridd.conf
        if [ "${VGPU_LICENSE_SERVER_TYPE}" = "NLS" ]; then
            echo "Copying ClientConfigToken..."
            mkdir -p  /etc/nvidia/ClientConfigToken/
            cp /drivers/ClientConfigToken/* /etc/nvidia/ClientConfigToken/
        fi

        echo "Starting nvidia-gridd.."
        LD_LIBRARY_PATH=/usr/lib64/nvidia/gridd nvidia-gridd

        # Start virtual topology daemon
        _start_vgpu_topology_daemon
    fi

    if _assert_nvswitch_system; then
        echo "Starting NVIDIA fabric manager daemon..."
        nv-fabricmanager -c /usr/share/nvidia/nvswitch/fabricmanager.cfg
        echo "Starting NVLink Subnet Manager daemon..."
        nvlsm -c /usr/share/nvidia/nvswitch/fabricmanager.cfg
    fi
    return 0
}

# Stop a daemon by sending SIGTERM and waiting for it to exit.
_stop_daemon() {
    local name="$1"
    local pid_file="$2"

    if [ -f "${pid_file}" ]; then
        echo "Stopping ${name}..."
        local pid=$(< "${pid_file}")

        kill -SIGTERM "${pid}" 2>/dev/null || true
        for i in $(seq 1 50); do
            kill -0 "${pid}" 2>/dev/null || break
            sleep 0.1
        done
        if kill -0 "${pid}" 2>/dev/null; then
            echo "Could not stop ${name}" >&2
            return 1
        fi
    fi
}

# Stop persistenced and unload the kernel modules if they are currently loaded.
_unload_driver() {
    local rmmod_args=()
    local nvidia_deps=0
    local nvidia_refs=0
    local nvidia_uvm_refs=0
    local nvidia_modeset_refs=0
    local nvidia_peermem_refs=0

    _stop_daemon "NVIDIA persistence daemon" /var/run/nvidia-persistenced/nvidia-persistenced.pid || return 1
    _stop_daemon "NVIDIA grid daemon" /var/run/nvidia-gridd/nvidia-gridd.pid || return 1
    _stop_daemon "NVIDIA fabric manager daemon" /var/run/nvidia-fabricmanager/nv-fabricmanager.pid || return 1
    _stop_daemon "NVLink Subnet Manager daemon" /var/run/nvidia-fabricmanager/nvlsm.pid || return 1

    echo "Unloading NVIDIA driver kernel modules..."
    if [ -f /sys/module/nvidia_peermem/refcnt ]; then
        nvidia_peermem_refs=$(< /sys/module/nvidia_peermem/refcnt)
        rmmod_args+=("nvidia-peermem")
        ((++nvidia_deps))
    fi
    if [ -f /sys/module/nvidia_modeset/refcnt ]; then
        nvidia_modeset_refs=$(< /sys/module/nvidia_modeset/refcnt)
        rmmod_args+=("nvidia-modeset")
        ((++nvidia_deps))
    fi
    if [ -f /sys/module/nvidia_uvm/refcnt ]; then
        nvidia_uvm_refs=$(< /sys/module/nvidia_uvm/refcnt)
        rmmod_args+=("nvidia-uvm")
        ((++nvidia_deps))
    fi
    if [ -f /sys/module/nvidia/refcnt ]; then
        nvidia_refs=$(< /sys/module/nvidia/refcnt)
        rmmod_args+=("nvidia")
    fi
    if [ ${nvidia_refs} -gt ${nvidia_deps} ] || [ ${nvidia_uvm_refs} -gt 0 ] || [ ${nvidia_modeset_refs} -gt 0 ] || [ ${nvidia_peermem_refs} -gt 0 ]; then
        echo "Could not unload NVIDIA driver kernel modules, driver is in use" >&2
        return 1
    fi

    if [ ${#rmmod_args[@]} -gt 0 ]; then
        rmmod "${rmmod_args[@]}"
    fi
    return 0
}

# Mount the driver rootfs into the run directory with the exception of sysfs.
_mount_rootfs() {
    echo "Mounting NVIDIA driver rootfs..."
    mount --make-runbindable /sys
    mount --make-private /sys
    mkdir -p "${RUN_DIR}/driver"
    mount --rbind / "${RUN_DIR}/driver"

    echo "Check SELinux status"
    if [ -e /sys/fs/selinux ]; then
        echo "SELinux is enabled"
        echo "Change device files security context for selinux compatibility"
        chcon -R -t container_file_t "${RUN_DIR}/driver/dev"
    else
        echo "SELinux is disabled, skipping..."
    fi
}

# Unmount the driver rootfs from the run directory.
_unmount_rootfs() {
    echo "Unmounting NVIDIA driver rootfs..."
    if findmnt -r -o TARGET | grep "${RUN_DIR}/driver" > /dev/null; then
        umount -l -R "${RUN_DIR}/driver"
    fi
}

_shutdown() {
    if _unload_driver; then
        _unmount_rootfs
        rm -f "${PID_FILE}"
        return 0
    fi
    return 1
}

# _resolve_kernel_type determines which kernel module type, open or proprietary, to install.
# This function assumes that the nvidia-installer binary is in the PATH, so this function
# should only be invoked after the userspace driver components have been installed.
#
# KERNEL_MODULE_TYPE is the frontend interface that users can use to configure which module
# to install. Valid values for KERNEL_MODULE_TYPE are 'auto' (default), 'open', and 'proprietary'.
# When 'auto' is configured, we use the nvidia-installer to recommend the module type to install.
_resolve_kernel_type() {
  if [ "${KERNEL_MODULE_TYPE}" == "proprietary" ]; then
    KERNEL_TYPE=proprietary
  elif [ "${KERNEL_MODULE_TYPE}" == "open" ]; then
    KERNEL_TYPE=open
  elif [ "${KERNEL_MODULE_TYPE}" == "auto" ]; then
    kernel_module_type=$(/usr/local/bin/nvidia-driver-selector.sh)
    if [ -z "${kernel_module_type}" ]; then
      echo "failed to retrieve the recommended kernel module type from nvidia-installer, falling back to using the driver branch"
      _resolve_kernel_type_from_driver_branch
      return 0
    fi
    [[ "${kernel_module_type}" == "open" ]] && KERNEL_TYPE=open || KERNEL_TYPE=proprietary
  else
    echo "invalid value for the KERNEL_MODULE_TYPE variable: ${KERNEL_MODULE_TYPE}"
    return 1
  fi
}

_resolve_kernel_type_from_driver_branch() {
  [[ "${DRIVER_BRANCH}" -lt 560 ]] && KERNEL_TYPE=proprietary || KERNEL_TYPE=open
}

_find_vgpu_driver_version() {
    local count=""
    local version=""
    local drivers_path="/drivers"

    if [ "${DISABLE_VGPU_VERSION_CHECK}" = "true" ]; then
        echo "vgpu version compatibility check is disabled"
        return 0
    fi
    # check if vgpu devices are present
    count=$(vgpu-util count)
    if [ $? -ne 0 ]; then
         echo "cannot find vgpu devices on host, please check /var/log/vgpu-util.log for more details..."
         return 0
    fi
    NUM_VGPU_DEVICES=$(echo "$count" | awk -F= '{print $2}')
    if [ $NUM_VGPU_DEVICES -eq 0 ]; then
        # no vgpu devices found, treat as passthrough
        return 0
    fi
    echo "found $NUM_VGPU_DEVICES vgpu devices on host"

    # find compatible guest driver using driver catalog
    if [ -d "/mnt/shared-nvidia-driver-toolkit/drivers" ]; then
        drivers_path="/mnt/shared-nvidia-driver-toolkit/drivers"
    fi
    version=$(vgpu-util match -i "${drivers_path}" -c "${drivers_path}/vgpuDriverCatalog.yaml")
    if [ $? -ne 0 ]; then
        echo "cannot find match for compatible vgpu driver from available list, please check /var/log/vgpu-util.log for more details..."
        return 1
    fi
    DRIVER_VERSION=$(echo "$version" | awk -F= '{print $2}')
    echo "vgpu driver version selected: ${DRIVER_VERSION}"
    return 0
}

_start_vgpu_topology_daemon() {
    type nvidia-topologyd > /dev/null 2>&1 || return 0
    echo "Starting nvidia-topologyd.."
    nvidia-topologyd
}

_prepare() {
    if [ "${DRIVER_TYPE}" = "vgpu" ]; then
        _find_vgpu_driver_version || exit 1
    fi
    # Determine the kernel module type
    _resolve_kernel_type || exit 1
}

_prepare_exclusive() {
    _prepare

    exec 3> "${PID_FILE}"
    if ! flock -n 3; then
        echo "An instance of the NVIDIA driver is already running, aborting"
        exit 1
    fi
    echo $$ >&3

    trap "echo 'Caught signal'; exit 1" HUP INT QUIT PIPE TERM
    trap "_shutdown" EXIT

    _unload_driver || exit 1
    _unmount_rootfs
}



_load() {
    _mount_rootfs
    _load_driver

    echo "Done, now waiting for signal"
    sleep infinity &
    trap "echo 'Caught signal'; _shutdown && { kill $!; exit 0; }" HUP INT QUIT PIPE TERM
    trap - EXIT
    while true; do wait $! || continue; done
    exit 0
}

init() {
    _prepare_exclusive
    _update_ca_certificates
    _load
}

load() {
    _prepare
    _load
}

# Wait for MOFED drivers to be loaded and load nvidia-peermem whenever it gets unloaded during MOFED driver updates
reload_nvidia_peermem() {
    if [ "$USE_HOST_MOFED" = "true" ]; then
        until  lsmod | grep mlx5_core > /dev/null 2>&1 && [ -f /run/nvidia/validations/.driver-ctr-ready ];
        do
            echo "waiting for mellanox ofed and nvidia drivers to be installed"
            sleep 10
        done
    else
        # use driver readiness flag created by MOFED container
        until  [ -f /run/mellanox/drivers/.driver-ready ] && [ -f /run/nvidia/validations/.driver-ctr-ready ];
        do
            echo "waiting for mellanox ofed and nvidia drivers to be installed"
            sleep 10
        done
    fi
    # get any parameters provided for nvidia-peermem
    _get_module_params && set +o nounset
    if chroot /run/nvidia/driver modprobe nvidia-peermem "${NVIDIA_PEERMEM_MODULE_PARAMS[@]}"; then
        if [ -f /sys/module/nvidia_peermem/refcnt ]; then
            echo "successfully loaded nvidia-peermem module, now waiting for signal"
            sleep infinity &
            trap "echo 'Caught signal'; exit 1" HUP INT QUIT PIPE TERM
            while true; do wait $! || continue; done
            exit 0
        fi
    fi
    echo "failed to load nvidia-peermem module"
    exit 1
}

# probe by gpu-operator for liveness/startup checks for nvidia-peermem module to be loaded when MOFED drivers are ready
probe_nvidia_peermem() {
    if lsmod | grep mlx5_core > /dev/null 2>&1; then
        if [ ! -f /sys/module/nvidia_peermem/refcnt ]; then
            echo "nvidia-peermem module is not loaded"
            return 1
        fi
    else
        echo "MOFED drivers are not ready, skipping probe to avoid container restarts..."
    fi
    return 0
}

usage() {
    cat >&2 <<EOF
Usage: $0 COMMAND [ARG...]

Commands:
  init   [-a | --accept-license]
  load
EOF
    exit 1
}

if [ $# -eq 0 ]; then
    usage
fi
command=$1; shift
case "${command}" in
    init) options=$(getopt -l accept-license -o a -- "$@") || usage ;;
    load) options="" ;;
    reload_nvidia_peermem) options="" ;;
    probe_nvidia_peermem) options="" ;;
    *) usage ;;
esac
set -- ${options}

while [ $# -gt 0 ]; do
    case "$1" in
    -a | --accept-license) shift ;; # accepted for backward compatibility
    --) shift; break ;;
    *) break ;;
    esac
done
if [ $# -ne 0 ]; then
    usage
fi

$command
