#!/usr/bin/python3

import argparse
import io
import os
import random
import sys
import urllib.error
import urllib.request

from azuremetadata import azuremetadatautils, azuremetadata


class PreserveArgumentOrder(argparse.Action):
    def __call__(self, parser, namespace, value, option_string=None):
        if 'ordered_args' not in namespace:
            setattr(namespace, 'ordered_args', [])
        namespace.ordered_args.append((self.dest, value))


api_version_parser = argparse.ArgumentParser(add_help=False)
api_version_parser.add_argument('-a', '--api', nargs='?', const=None)
api_version_parser.add_argument('--device', nargs='?', const=None)
api_args, _ = api_version_parser.parse_known_args()

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('-h', '--help', action="store_true", help="Display help")
parser.add_argument('-x', '--xml', action="store_true", help="Output as XML")
parser.add_argument('-j', '--json', action="store_true", help="Output as JSON")
parser.add_argument('-o', '--output', help="Output file path")
parser.add_argument('-a', '--api',
                    help="API version or 'latest' for newest API version")
# Only root can read the tag, thus hide the argument
if os.geteuid() == 0:
    help_msg = "Device to read disk tag from (default: root device)"
    parser.add_argument('--device',
                        help=help_msg,
                        nargs='?'
    )
parser.add_argument('--listapis', action="store_true",
                    help="List available API versions")

with io.StringIO() as string_io:
    parser.print_help(string_io)
    help_header = string_io.getvalue()

# IMDS is not intended to be used behind a proxy and IMDS does not support it
os.environ['no_proxy'] = '169.254.169.254'

try:
    metadata = azuremetadata.AzureMetadata(api_args.api)
    data = {}
    # Handle instances in ASM, aka Classic
    # Heuristic data: When requesting attestedData in ASM it triggers an
    # internal server error, that's the best thing we have to go on to
    # figure out whether or not we are in ASM.
    # ASM gets retired in 2023, rip this code out, it's ugly!
    instance_compute_url = 'http://169.254.169.254/metadata/instance/compute'
    req = urllib.request.Request(
        '%s?api-version=2019-11-01' % instance_compute_url,
        headers={'Metadata': 'true'}
    )
    try:
        response = urllib.request.urlopen(req, timeout=2)
    except urllib.error.HTTPError as e:
        if e.getcode() == 404:
            # Set the api version to the first implementation, it works in ASM
            metadata.set_api_version('2017-04-02')
            # Special code for SUSE, ugh becasue we know what we are
            # looking for there is unfortunately no better way.
            kvp_pool_file_path = '/var/lib/hyperv/.kvp_pool_3'
            if os.path.exists(kvp_pool_file_path):
                kvp_key_size = 512
                kvp_value_size = 2048
                with open(kvp_pool_file_path, 'rb') as kvp_pool_3:
                    while True:
                        key = kvp_pool_3.read(kvp_key_size)
                        value = kvp_pool_3.read(kvp_value_size)
                        if key and value:
                            if key.split(b"\x00")[0] == b'VirtualMachineId':
                                vm_id = value.split(b"\x00")[0]
                                data['subscriptionId'] = 'classic-{}'.format(
                                    vm_id.lower().decode('utf-8')
                                )
                                break
            else:
                data['subscriptionId'] = 'classic-{}'.format(
                    random.randint(0, 1e9)
            )
            # ensuring that --attestedData --signature and
            # --signature are available
            data['attestedData'] = {}
            data['attestedData']['signature'] = ''
            data['signature'] = ''
    # End code removal in 2023

    # Only root can read the tag only add the value if we are root
    if os.geteuid() == 0:
        data['billingTag'] = metadata.get_disk_tag(api_args.device)

    data.update(metadata.get_all())
    util = azuremetadatautils.AzureMetadataUtils(data)

    for key in util.available_params.keys():
        parser.add_argument(
            '--{}'.format(key), nargs='?', type=int, const=True,
            action=PreserveArgumentOrder
        )

    args = parser.parse_args()
    ordered_args = getattr(args, 'ordered_args', [])

    if args.help:
        print(help_header)
        print("\nquery arguments:")
        util.print_help()
        exit()

    if args.listapis:
        api_versions = metadata.list_api_versions()
        print('Available API versions:')
        for api in api_versions:
            print('    {}'.format(api))
        exit()

    fh = None
    if args.output:
        fh = open(args.output, 'w')

    try:
        if not len(ordered_args):
            util.print_pretty(
                print_xml=args.xml, print_json=args.json, file=fh
            )
        else:
            result = util.query(ordered_args)
            util.print_pretty(
                print_xml=args.xml, print_json=args.json, data=result, file=fh
            )
    finally:
        if fh:
            fh.close()

except azuremetadatautils.QueryException as e:
    print(e, file=sys.stderr)
    exit(1)
