#!/usr/bin/python3


# Copyright 2022 SUSE LLC
#
# This file is part of ec2imgutils
#
# ec2imgutils is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ec2imgutils is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ec2uploadimg. If not, see <http://www.gnu.org/licenses/>.

import argparse
import boto3
import os
import sys
import signal

import ec2imgutils.ec2utils as utils
import ec2imgutils.ec2uploadimg as ec2upimg
from ec2imgutils.ec2imgutilsExceptions import EC2AccountException
from ec2imgutils.ec2setup import EC2Setup

aborted = False
logger = None


def print_ex(msg=None):
    exc_type, value, tb = sys.exc_info()
    logger.debug(value, exc_info=True)

    if msg:
        logger.info(msg)
    else:
        logger.info("{}: {}".format(exc_type.__name__, value))


def signal_handler(signum, frame):
    global aborted
    if aborted:
        # if it got already aborted before,
        # we kill the upload without proper clean up
        sys.exit(1)
    aborted = True
    if uploader:
        uploader.abort()
    if not setup and not uploader:
        # if there is no setup and no uploader
        # there are no resources we need to cleanup
        # and we just exit
        sys.exit(1)


signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# Set up command line argument parsing
argparse = argparse.ArgumentParser(description='Upload image to Amazon EC2')
argparse.add_argument(
    '-a', '--account',
    dest='accountName',
    help='Account to use (Optional)',
    metavar='ACCOUNT_NAME'
)
argparse.add_argument(
    '--access-id',
    dest='accessKey',
    help='AWS access key (Optional)',
    metavar='AWS_ACCESS_KEY'
)
argparse.add_argument(
    '-B', '--backing-store',
    default='ssd',
    dest='backingStore',
    help='The backing store type, (mag|ssd), or one of the EBS storage IDs',
    metavar='EC2_BACKING_STORE'
)
argparse.add_argument(
    '--billing-codes',
    dest='billingCodes',
    help='The BillingProduct codes, only available to some accounts',
    metavar='BILLING_CODES'
)
argparse.add_argument(
    '--boot-kernel',
    dest='akiID',
    help='AWS kernel ID (aki) to boot the new image (Optional)',
    metavar='AWS_AKI_ID'
)
argparse.add_argument(
    '--boot-mode',
    dest='bootMode',
    help='Set the boot mode for the image, legacy-bios or uefi (Optional)',
    choices=['', 'legacy-bios', 'uefi'],
    metavar='BOOT_MODE'
)
argparse.add_argument(
    '-d', '--description',
    dest='descript',
    help='Image description, will also be used for the snapshot',
    metavar='IMAGE_DESCRIPTION',
    required=True
)
argparse.add_argument(
    '-e', '--ec2-ami',
    dest='amiID',
    help='AWS AMI id for the image to use to upload (Optional)',
    metavar='AWS_AMI_ID',
)
argparse.add_argument(
    '--ena-support',
    action='store_true',
    default=False,
    dest='ena',
    help='The image supports the ENA network interface (Optional)'
)
argparse.add_argument(
    '-f', '--file',
    default=os.sep.join(['~', '.ec2utils.conf']),
    dest='configFilePath',
    help='Path to configuration file, default ~/.ec2utils.conf (Optional)',
    metavar='CONFIG_FILE'
)
argparse.add_argument(
    '--grub2',
    action='store_true',
    default=False,
    dest='grub2',
    help='The image uses the GRUB2 bootloader (Optional)'
)
argparse.add_argument(
    '-i', '--instance-id',
    dest='runningID',
    help='ID of running instance to use to upload (Optional)',
    metavar='AWS_INSTANCE_ID',
)
argparse.add_argument(
    '-m', '--machine',
    dest='arch',
    help='Machine architecture arm64|i386|x86_64 for the uploaded image',
    metavar='ARCH',
    required=True
)
argparse.add_argument(
    '-n', '--name',
    dest='imgName',
    help='Image name',
    metavar='IMAGE_NAME',
    required=True
)
argparse.add_argument(
    '-p', '--private-key-file',
    dest='privateKey',
    help='Private SSH key file (Optional)',
    metavar='PRIVATE_KEY'
)
argparse.add_argument(
    '-r', '--regions',
    dest='regions',
    help='Comma separated list of regions for image upload',
    metavar='EC2_REGIONS',
    required=True
)
argparse.add_argument(
    '--root-volume-size',
    dest='rootVolSize',
    help='Size of root volume for new image in GB (Optional)',
    metavar='ROOT_VOLUME_SIZE',
)
argparse.add_argument(
    '--ssh-key-pair',
    dest='sshName',
    help='AWS SSH key pair name (Optional)',
    metavar='AWS_KEY_PAIR_NAME'
)
argparse.add_argument(
    '-s', '--secret-key',
    dest='secretKey',
    help='AWS secret access key (Optional)',
    metavar='AWS_SECRET_KEY'
)
help_msg = 'A comma separated listof security group ids to apply to the '
help_msg += 'helper instance. At least port 22 must be open (Optional)'
argparse.add_argument(
    '--security-group-ids',
    default='',
    dest='securityGroupIds',
    help=help_msg,
    metavar='AWS_SECURITY_GROUP_IDS'
)
argparse.add_argument(
    '--session-token',
    dest='sessionToken',
    help='AWS session token, for use of temporary credentials (Optional)',
    metavar='AWS_SESSION_TOKEN'
)
argparse.add_argument(
    '--snaponly',
    action='store_true',
    default=False,
    dest='snapOnly',
    help='Stop after snapshot creation (Optional)'
)
argparse.add_argument(
    '--sriov-support',
    action='store_true',
    default=False,
    dest='sriov',
    help='The image supports SRIOV network interface (Optional)'
)
argparse.add_argument(
    'source',
    help='The path to the image source file'
)
argparse.add_argument(
    '--ssh-timeout',
    default=300,
    dest='sshTimeout',
    help='Timeout value to wait for ssh connection, default 300 s (Optional)',
    metavar='SSH_TIME_OUT',
    type=int
)
argparse.add_argument(
    '--tpm-support',
    dest='tpm',
    help='The image supports NitroTPM, supported value 2.0/v2.0 (Optional)',
    choices=['', '2.0', 'v2.0'],
    metavar='TPM_SUPPORT'
)
argparse.add_argument(
    '-t', '--type',
    dest='instType',
    help='Instance type to use to upload image (Optional)',
    metavar='AWS_UPLOAD_INST_TYPE'
)
argparse.add_argument(
    '-u', '--user',
    dest='sshUser',
    help='The user for the ssh connection to the instance (Optional)',
    metavar='AWS_INSTANCE_USER'
)
argparse.add_argument(
    '--use-private-ip',
    action='store_true',
    default=False,
    dest='usePrivateIP',
    help='Use the instance private IP address to connect (Optional)'
)
argparse.add_argument(
    '--use-root-swap',
    action='store_true',
    default=False,
    dest='rootSwapMethod',
    help='Use the root swap method to create the new image (Optional)'
)
help_msg = 'The virtualization type, (para)virtual or (hvm), '
help_msg += 'default hvm (Optional)'
argparse.add_argument(
    '-V', '--virt-type',
    default='hvm',
    dest='virtType',
    help=help_msg,
    metavar='AWS_VIRT_TYPE'
)
help_msg = 'Enable verbose output (Optional). Multiple -v options increase '
help_msg += 'the verbosity. The maximum is 2.'
argparse.add_argument(
    '--verbose',
    '-v',
    action='count',
    default=0,
    dest='verbose',
    help=help_msg)
argparse.add_argument(
    '--version',
    action='version',
    version=utils.get_version(),
    help='Program version'
)
help_msg = 'The ID, starts with "subnet-" of the VPC subnet in which the '
help_msg += 'helper instance should run (Optional)'
argparse.add_argument(
    '--vpc-subnet-id',
    default='',
    dest='vpcSubnetId',
    help=help_msg,
    metavar='VPC_SUBNET_ID'
)
help_msg = 'Wait N-number of times for AWS operation timeout, default 1 '
help_msg += ' = 600 seconds (Optional)'
argparse.add_argument(
    '--wait-count',
    default=1,
    dest='waitCount',
    help=help_msg,
    metavar='AWS_WAIT_COUNT',
    type=int
)
argparse.add_argument(
    '--use-snapshot',
    action='store_true',
    default=False,
    dest='useSnap',
    help='Create the image using an existing snapshot that is specified \
          as the source variable (Optional)'
)

args = argparse.parse_args()
logger = utils.get_logger(args.verbose)

if '--version' in sys.argv:
    version_file_name = 'VERSION'
    base_path = os.path.dirname(utils.__file__)
    version = open(base_path + os.sep + version_file_name, 'r').read()
    logger.info(version)
    sys.exit(0)

if args.useSnap and (args.snapOnly or args.rootSwapMethod):
    logger.error('The options --use-snapshot cannot be specified \
                 with --snaponly or --use-root-swap')
    sys.exit(1)

config_file = os.path.expanduser(args.configFilePath)
config = None
if not os.path.isfile(config_file):
    logger.error('Configuration file "%s" not found.' % config_file)
    sys.exit(1)

try:
    config = utils.get_config(config_file)
except Exception:
    print_ex()
    sys.exit(1)

if not args.useSnap:
    if not os.path.isfile(args.source):
        logger.error('Could not find specified image file: %s' % args.source)
        sys.exit(1)

access_key = args.accessKey
if not access_key:
    try:
        access_key = utils.get_from_config(args.accountName,
                                           config,
                                           None,
                                           'access_key_id',
                                           '--access-id')
    except EC2AccountException as e:
        logger.error(e)
        sys.exit(1)

if not access_key:
    logger.error('Could not determine account access key')
    sys.exit(1)

secret_key = args.secretKey
if not secret_key:
    try:
        secret_key = utils.get_from_config(args.accountName,
                                           config,
                                           None,
                                           'secret_access_key',
                                           '--secret-key')
    except EC2AccountException as e:
        logger.error(e)
        sys.exit(1)

if not secret_key:
    logger.error('Could not determine account secret access key')
    sys.exit(1)

supported_arch = ['arm64', 'i386', 'x86_64']
if args.arch not in supported_arch:
    logger.error(
        'Specified architecture must be one of %s' % str(supported_arch)
    )
    sys.exit(1)

# All arm64 images must support ENA
# Choosing to assume the image has a sufficiently new kernel with the
# ENA driver rather than to error out
if args.arch == 'arm64':
    if not args.ena:
        logger.info('Enabling ENA support, required by arm64')
        args.ena = True

sriov_type = args.sriov
if sriov_type:
    sriov_type = 'simple'

virtualization_type = args.virtType
if virtualization_type == 'para':
    virtualization_type = 'paravirtual'

if sriov_type and virtualization_type != 'hvm':
    logger.error('SRIOV support is only possible with HVM images')
    sys.exit(1)

if args.ena and virtualization_type != 'hvm':
    logger.error('ENA is only possible with HVM images')
    sys.exit(1)

if args.amiID and args.runningID:
    logger.error(
        'Specify either AMI ID to start instance or running '
        ' ID to use already running instance'
    )
    sys.exit(1)

tpm_boot_options = ['uefi']
if args.tpm and args.bootMode not in tpm_boot_options:
    logger.error(
        'TPM can only be set with --boot-mode set to %s' % str(
            tpm_boot_options)
    )
    sys.exit(1)

root_volume_size = 10
if args.rootVolSize:
    root_volume_size = args.rootVolSize

regions = args.regions.split(',')
if (
    len(regions) > 1 and (args.amiID or args.akiID or args.runningID or
                          args.vpcSubnetId)
   ):
    logger.error(
        'Incompatible arguments: multiple regions specified'
    )
    logger.error(
        'Cannot process region unique argument for --ec2-ami,'
    )
    logger.error('--instance-id, or --boot-kernel')
    sys.exit(1)

try:
    ami_id = args.amiID
    running_id = args.runningID
    for region in regions:
        try:
            setup = EC2Setup(
                access_key,
                region,
                secret_key,
                args.sessionToken,
                log_callback=logger
            )
            if not ami_id and not running_id:
                try:
                    ami_id = utils.get_from_config(args.accountName,
                                                   config,
                                                   region,
                                                   'ami',
                                                   '--ec2-ami')
                except Exception as e:
                    logger.error('Could not determine helper AMI-ID')
                    logger.error(e)
                    sys.exit(1)
            bootkernel = args.akiID
            if args.virtType == 'hvm':
                bootkernel = None

            if not bootkernel and args.virtType != 'hvm':
                if args.grub2:
                    try:
                        bootkernel = utils.get_from_config(args.accountName,
                                                           config,
                                                           region,
                                                           'g2_aki_x86_64',
                                                           '--boot-kernel')
                    except Exception:
                        print_ex('Could not find bootkernel in config')
                        sys.exit(1)
                elif args.arch == 'x86_64':
                    try:
                        bootkernel = utils.get_from_config(args.accountName,
                                                           config,
                                                           region,
                                                           'aki_x86_64',
                                                           '--boot-kernel')
                    except Exception:
                        print_ex('Could not find bootkernel in config')
                        sys.exit(1)
                elif args.arch == 'i386':
                    try:
                        bootkernel = utils.get_from_config(args.accountName,
                                                           config,
                                                           region,
                                                           'aki_i386',
                                                           '--boot-kernel')
                    except Exception:
                        print_ex('Could not find bootkernel in config')
                        sys.exit(1)
                elif args.arch == 'arm64' and args.virtType != 'hvm':
                    logger.error(
                        'Images for arm64 must use hvm virtualization'
                    )
                    sys.exit(1)
                else:
                    logger.error(
                        'Could not reliably determine the bootkernel to use '
                        'must specify bootkernel, arch (x86_64|i386) or hvm'
                    )
                    sys.exit(1)

            inst_type = args.instType
            if not inst_type:
                inst_type = utils.get_from_config(args.accountName,
                                                  config,
                                                  region,
                                                  'instance_type',
                                                  '--type')

            key_pair_name = args.sshName
            ssh_private_key_file = args.privateKey
            if (
                    not key_pair_name
                    and not ssh_private_key_file
                    and not args.accountName
            ):
                key_pair_name, ssh_private_key_file = \
                    setup.create_upload_key_pair()
            if not key_pair_name:
                try:
                    key_pair_name = utils.get_from_config(args.accountName,
                                                          config,
                                                          region,
                                                          'ssh_key_name',
                                                          '--ssh-key-pair')
                except Exception as e:
                    logger.error(e)
                    sys.exit(1)
            if not key_pair_name:
                logger.error('Could not determine key pair name')
                sys.exit(1)

            if not ssh_private_key_file:
                try:
                    ssh_private_key_file = utils.get_from_config(
                        args.accountName,
                        config,
                        region,
                        'ssh_private_key',
                        '--private-key-file'
                    )
                except Exception as e:
                    logger.error(e)
                    sys.exit(1)

            if not ssh_private_key_file:
                logger.error(
                    'Could not determine the private ssh key file'
                )

            ssh_private_key_file = os.path.expanduser(ssh_private_key_file)

            if not os.path.exists(ssh_private_key_file):
                logger.error(
                    'SSH private key file "%s" does not exist'
                    % ssh_private_key_file
                )
                sys.exit(1)

            ssh_user = args.sshUser
            if not ssh_user:
                try:
                    ssh_user = utils.get_from_config(args.accountName,
                                                     config,
                                                     region,
                                                     'user',
                                                     '--user')
                except Exception as e:
                    logger.error(e)
                    sys.exit(1)

            if not ssh_user:
                logger.error('Could not determine ssh user to use')
                sys.exit(1)

            vpc_subnet_id = args.vpcSubnetId
            if not vpc_subnet_id and not args.accountName:
                vpc_subnet_id = setup.create_vpc_subnet()
            if not vpc_subnet_id and not (args.amiID or args.runningID):
                # Depending on instance type an instance may possibly only
                # launch inside a subnet. Look in the config for a subnet if
                # no subnet ID is given and the AMI to use was not
                # specified on the command line
                try:
                    vpc_subnet_id = utils.get_from_config(
                        args.accountName,
                        config,
                        region,
                        'subnet_id_%s' % region,
                        '--vpc-subnet-id'
                    )
                except Exception:
                    msg = 'Not using a subnet-id, none given on the '
                    msg += 'command line and none found in config for '
                    msg += '"subnet_id_%s" value' % region
                    logger.error(msg)
                    sys.exit(1)

                if vpc_subnet_id:
                    logger.debug('Using VPC subnet: %s' % vpc_subnet_id)

            security_group_ids = args.securityGroupIds
            if not security_group_ids and not args.runningID:
                if not args.accountName:
                    security_group_ids = setup.create_security_group()
                else:
                    try:
                        security_group_ids = utils.get_from_config(
                            args.accountName,
                            config,
                            region,
                            'security_group_ids_%s' % region,
                            '--security-group-ids'
                        )
                        if security_group_ids:
                            msg = 'Using Security Group IDs: %s'
                            logger.debug(msg % security_group_ids)
                    except Exception:
                        msg = 'No security group specified in the '
                        msg += 'configuration, "security_group_ids_%s" or '
                        msg += 'given on the command line.'
                        logger.error(msg)
                    if not security_group_ids and vpc_subnet_id:
                        ec2 = boto3.client(
                            aws_access_key_id=access_key,
                            aws_secret_access_key=secret_key,
                            region_name=region,
                            service_name='ec2'
                        )
                        subnet_data = ec2.describe_subnets(
                            SubnetIds=[vpc_subnet_id]
                        )
                        subnets = subnet_data.get('Subnets')
                        if not subnets:
                            msg = 'Unable to obtain VPC information for '
                            msg += 'provided subnet "%s". ' % vpc_subnet_id
                            msg += 'Subnet ID is invalid.'
                            print(msg, file=sys.stderr)
                            sys.exit(1)
                        vpc_id = subnets[0]['VpcId']
                        security_group_ids = setup.create_security_group(
                            vpc_id
                        )

            if not aborted:
                uploader = ec2upimg.EC2ImageUploader(
                    access_key=access_key,
                    backing_store=args.backingStore,
                    billing_codes=args.billingCodes,
                    bootkernel=bootkernel,
                    ena_support=args.ena,
                    image_arch=args.arch,
                    image_description=args.descript,
                    image_name=args.imgName,
                    image_virt_type=virtualization_type,
                    inst_user_name=ssh_user,
                    launch_ami=ami_id,
                    launch_inst_type=inst_type,
                    region=region,
                    root_volume_size=root_volume_size,
                    running_id=running_id,
                    secret_key=secret_key,
                    security_group_ids=security_group_ids,
                    session_token=args.sessionToken,
                    sriov_type=sriov_type,
                    ssh_key_pair_name=key_pair_name,
                    ssh_key_private_key_file=ssh_private_key_file,
                    ssh_timeout=args.sshTimeout,
                    use_grub2=args.grub2,
                    use_private_ip=args.usePrivateIP,
                    vpc_subnet_id=vpc_subnet_id,
                    wait_count=args.waitCount,
                    log_callback=logger,
                    boot_mode=args.bootMode,
                    tpm_support=args.tpm
                )

                if args.snapOnly:
                    snapshot = uploader.create_snapshot(args.source)
                    print('Created snapshot: ', snapshot['SnapshotId'])
                elif args.rootSwapMethod:
                    ami = uploader.create_image_use_root_swap(args.source)
                    print('Created image: ', ami)
                elif args.useSnap:
                    ami = uploader.create_image_from_snapshot(args.source)
                    print('Created image: ', ami)
                else:
                    ami = uploader.create_image(args.source)
                    print('Created image: ', ami)
        except Exception:
            print_ex()
        finally:
            setup.clean_up()
except Exception:
    print_ex()
    sys.exit(1)
