#!/usr/bin/python3
# This file is part of Cockpit.
#
# Copyright (C) 2018 Red Hat, Inc.
#
# Cockpit is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 2.1 of the License, or
# (at your option) any later version.
#
# Cockpit is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Cockpit; If not, see <http://www.gnu.org/licenses/>.
#
# This command is meant to be used as an authentication command launched
# by cockpit-ws. It asks for authentication from cockpit-ws and expects
# to receive a basic auth header in response. If COCKPIT_SSH_KEY_PATH is set
# we will try to decrypt the key with the given password. If successful
# we send the decrypted key to cockpit-ws for use with cockpit-ssh.
# Once finished we exec cockpit-ssh to actually establish the ssh connection.
# All communication with cockpit-ws happens on stdin and stdout using the
# cockpit protocol
# (https://github.com/cockpit-project/cockpit/blob/main/doc/protocol.md)

import base64
import json
import os
import subprocess
import sys
import time

COCKPIT_SSH_COMMAND = "/usr/libexec/cockpit-ssh"


def usage():
    print("usage", sys.argv[0], "[user@]host[:port]", file=sys.stderr)
    sys.exit(os.EX_USAGE)


def send_frame(content):
    data = json.dumps(content).encode()
    os.write(1, str(len(data) + 1).encode())
    os.write(1, b"\n\n")
    os.write(1, data)


def send_auth_command(challenge, response):
    cmd = {
        "command": "authorize",
    }

    if challenge:
        cmd["cookie"] = f"session{os.getpid()}{time.time()}"
        cmd["challenge"] = challenge
    if response:
        cmd["response"] = response

    send_frame(cmd)


def send_problem_init(problem, message, auth_methods):
    cmd = {
        "command": "init",
        "problem": problem
    }

    if message:
        cmd["message"] = message

    if auth_methods:
        cmd["auth-method-results"] = auth_methods

    send_frame(cmd)


def read_size(fd):
    sep = b'\n'
    size = 0
    seen = 0

    while True:
        t = os.read(fd, 1)

        if not t:
            return 0

        if t == sep:
            break

        size = (size * 10) + int(t)
        seen = seen + 1

        if seen > 7:
            raise ValueError("Invalid frame: size too long")

    return size


def read_frame(fd):
    size = read_size(fd)

    data = b""
    while size > 0:
        d = os.read(fd, size)
        size = size - len(d)
        data += d

    return data.decode()


def read_auth_reply():
    data = read_frame(1)
    cmd = json.loads(data)
    response = cmd.get("response")
    if cmd.get("command") != "authorize" or \
       not cmd.get("cookie") or not response:
        raise ValueError("Did not receive a valid authorize command")

    return response


def decode_basic_header(response):
    starts = "Basic "

    assert response
    assert response.startswith(starts), response

    val = base64.b64decode(response[len(starts):].encode()).decode()
    user, password = val.split(':', 1)
    return user, password


def load_key(fname, password):
    # ssh-add has the annoying behavior that it re-asks without any limit if the password is wrong
    # to mitigate, self-destruct to only allow one iteration
    with open("/run/askpass", "w") as fd:
        fd.write("""#!/bin/sh
rm -f $0
cat /run/password""")

        os.fchmod(fd.fileno(), 0o755)

    env = os.environ.copy()
    env["SSH_ASKPASS_REQUIRE"] = "force"
    env["SSH_ASKPASS"] = "/run/askpass"

    pass_fd = os.open("/run/password", os.O_CREAT | os.O_EXCL | os.O_WRONLY | os.O_CLOEXEC, mode=0o600)
    try:
        os.write(pass_fd, password.encode())
        os.close(pass_fd)

        p = subprocess.run(["ssh-add", "-t", "30", fname],
                           check=False, env=env,
                           capture_output=True, encoding="UTF-8")
    finally:
        os.unlink("/run/password")

    if p.returncode == 0:
        send_auth_command(None, "ssh-agent")
        return True
    else:
        print("Couldn't load private key:", p.stderr, file=sys.stderr)
        return False


def main(args):
    if len(args) != 2:
        usage()

    host = args[1]
    key_name = os.environ.get("COCKPIT_SSH_KEY_PATH")
    if key_name:
        send_auth_command("*", None)
        try:
            resp = read_auth_reply()
            user, password = decode_basic_header(resp)
        except (ValueError, TypeError, AssertionError) as e:
            send_problem_init("internal-error", str(e), {})
            raise

        if load_key(key_name, password):
            host = f"{user}@{host}"
        else:
            send_problem_init("authentication-failed", "Couldn't open private key",
                              {"password": "denied"})
            return

    os.execlpe(COCKPIT_SSH_COMMAND, COCKPIT_SSH_COMMAND, host, os.environ)


if __name__ == '__main__':
    main(sys.argv)
