#!/usr/bin/python
#
# Licensed under the GNU General Public License Version 3
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# Copyright 2012 Aron Parsons <aronparsons@gmail.com>
#

import os.path
import getpass
import logging
import os
import re
import sys
import time
import xmlrpclib
from optparse import Option, OptionParser
import ConfigParser


class Config(object):

    """
    Configuration parser with defaults handling.
    """
    SECTION_GENERAL = "general"
    OPT_PHASES = "phases"
    OPT_EXCLUDE_CHNL = "exclude channels"

    def __init__(self, path):
        self.path = path
        self.default_section = Config.SECTION_GENERAL
        self._cfg = ConfigParser.RawConfigParser()

    def get(self, section, option, default=None):
        if not self._cfg.has_section(section):
            if not self._cfg.has_section(self.default_section):
                return default
            else:
                if not self._cfg.has_option(self.default_section, option):
                    return default
                else:
                    return self._cfg.get(self.default_section, option)
        else:
            if not self._cfg.has_option(section, option):
                return default

        return self._cfg.get(section, option)

    def set(self, section, option, value):
        if not self._cfg.has_section(section):
            self._cfg.add_section(section)
        self._cfg.set(section, option, value)

    def write(self):
        with open(self.path, "wb") as handle:
            self._cfg.write(handle)

    def read(self):
        self._cfg.read(self.path)

    def sections(self):
        return filter(lambda i: i != self.default_section, self._cfg.sections())


CONF_DIR = os.path.expanduser("~/.spacewalk-manage-channel-lifecycle")
USER_CONF_FILE = os.path.join(CONF_DIR, "settings.conf")
SESSION_CACHE = os.path.join(CONF_DIR, "session")

##############################################################################


def setup_config(config):
    """
    Setup configuration.
    """
    if os.path.isfile(CONF_DIR):
        os.unlink(CONF_DIR)

    if not os.path.exists(CONF_DIR):
        logging.debug("Creating directory: %s" % CONF_DIR)
        try:
            os.mkdir(CONF_DIR, 0700)
        except IOError:
            logging.error("Unable to create %s" % CONF_DIR)
            sys.exit(1)

    if not os.path.exists(USER_CONF_FILE):
        logging.debug("Creating configuration file: %s" % USER_CONF_FILE)
        config.set(Config.SECTION_GENERAL, Config.OPT_PHASES, "dev, test, prod")
        config.set(Config.SECTION_GENERAL, Config.OPT_EXCLUDE_CHNL, "")
        config.write()
    else:
        config.read()


def ask(msg, password=False):
    """
    Ask input from the console. Hide the echo, in case of password or sensitive information.
    """
    msg += ": "
    return password and getpass.getpass(msg) or raw_input(msg)


def parse_enumerated(data):
    """
    Parse comma-separated elements.
    """
    items = []
    if data:
        items = filter(None, map(lambda i: i.strip(), data.split(",")))

    return items


# determine the name of the next phase
def get_next_phase(current_name):
    if current_name not in phases:
        logging.error('Invalid phase name: %s' % current_name)
        sys.exit(1)
    else:
        current_num = phases.index(current_name)

    # return the next phase name
    if current_num + 1 < len(phases):
        return phases[current_num + 1]
    else:
        logging.error("Maximum phase exceeded!  You can't move past '%s'." % phases[-1])
        sys.exit(1)


def print_channel_tree():
    tree = {}

    # determine parent channels so we can make a pretty tree for the user
    for channel in all_channels:
        if channel.get('parent_label'):
            if not tree.has_key(channel.get('parent_label')):
                tree[channel.get('parent_label')] = []

            tree[channel.get('parent_label')].append(channel.get('label'))
        else:
            if not tree.has_key(channel.get('label')):
                tree[channel.get('label')] = []

    # print the channels in a tree format
    print "Channel tree:"
    idx = 1
    step = len(str(len(tree.keys())))
    for parent in sorted(tree.keys()):
        if len(tree[parent]):
            print
        print " %s. %s" % (str(idx).rjust(step), parent)
        idx += 1
        if len(tree[parent]):
            for child in sorted(tree[parent]):
                print (' ' * step) + '     \\__ %s' % child
            print
    print


def channel_exists(channel, quiet=False):
    try:
        client.channel.software.getDetails(session, channel)
        return True
    except xmlrpclib.Fault, e:
        if options.debug:
            logging.exception(e)
        if not quiet:
            logging.error('Channel %s does not exist' % channel)
        return False


def merge_channels(source_label, dest_label):
    if dest_label.startswith('rhel') and not options.tolerant:
        logging.error("Destination lablel starts with 'rhel'.  Aborting!")
        logging.error("Pass --tolerant to override this")
        sys.exit(1)

    if options.exclude_channel:
        for pattern in options.exclude_channel:
            if source_label == pattern:
                logging.info('Skipping %s due to an exclude filter'
                             % source_label)
                return

    # remove all the packages from the channel if requested
    if options.clear_channel or options.rollback:
        clear_channel(dest_label)

    logging.info('Merging packages from %s into %s' %
                 (source_label, dest_label))

    if not options.dry_run:
        # merge the packages from one channel into another
        try:
            packages = client.channel.software.mergePackages(session,
                                                             source_label,
                                                             dest_label)

            logging.info('Added %i packages' % len(packages))
        except xmlrpclib.Fault, e:
            if options.debug:
                logging.exception(e)
            logging.error('Failed to merge packages')

            if options.tolerant:
                return False
            else:
                sys.exit(1)

    if not options.no_errata:
        logging.info('Merging errata into %s' % dest_label)

        if not options.dry_run:
            # merge the errata from one channel into another
            try:
                errata = client.channel.software.mergeErrata(session,
                                                             source_label,
                                                             dest_label)

                logging.info('Added %i errata' % len(errata))
            except xmlrpclib.Fault, e:
                if options.debug:
                    logging.exception(e)
                logging.error('Failed to merge errata')

                if options.tolerant:
                    return False
                else:
                    sys.exit(1)

    print


def clone_channel(source_label, details):
    if options.exclude_channel:
        for pattern in options.exclude_channel:
            if source_label == pattern:
                logging.info('Skipping %s due to an exclude filter'
                             % source_label)
                return

    # channel doesn't exist, clone it from the original
    logging.info('Cloning %s from %s' % (details['label'], source_label))

    if not options.dry_run:
        try:
            client.channel.software.clone(session, source_label, details, False)
        except xmlrpclib.Fault, e:
            if options.debug:
                logging.exception(e)
            logging.error('Failed to clone channel')

            if options.tolerant:
                return False
            else:
                sys.exit(1)


def clear_channel(label):
    logging.info('Clearing all errata from %s' % label)

    if not options.dry_run:
        # attempt to clear the errata from the channel
        try:
            all_errata = client.channel.software.listErrata(session, label)
            errata_names = [e.get('advisory_name') for e in all_errata]
            client.channel.software.removeErrata(session,
                                                 label,
                                                 errata_names,
                                                 False)
        except xmlrpclib.Fault, e:
            if options.debug:
                logging.exception(e)
            logging.warning('Failed to clear errata from %s' % label)

    if not options.dry_run:
        logging.info('Clearing all packages from %s' % label)
        all_packages = client.channel.software.listAllPackages(session, label)
        package_ids = [p.get('id') for p in all_packages]
        client.channel.software.removePackages(session, label, package_ids)


def get_current_phase(source):
    """
    Get current phase from the source channel label.
    """
    for phase in phases:
        if source.startswith(phase):
            return phase


def build_channel_labels(source):
    if options.archive:
        # prefix the existing channel with 'archive-YYYYMMDD-'
        date_string = time.strftime('%Y%m%d', time.gmtime())
        destination = 'archive-%s-%s' % (date_string, source)
    elif options.init:
        destination = '%s-%s' % (phases[0], source)

        if channel_exists(destination, quiet=True):
            logging.error('%s already exists.  Use --promote instead.'
                          % destination)
            sys.exit(1)
    elif options.promote:
        # get the phase label from the channel label
        current_phase = get_current_phase(source)
        if current_phase:
            next_phase = get_next_phase(current_phase)

            # replace the name of the phase in the destination label
            destination = re.sub('^%s' % current_phase,
                                 '%s' % next_phase,
                                 source)
        else:
            destination = '%s-%s' % (phases[0], source)
    elif options.rollback:
        # strip off the archive prefix when rolling back
        destination = re.sub('archive-\d{8}-', '', source)

    return (source, destination)

##############################################################################

usage = '''usage: %prog [options]

Create a 'dev' channel based on the latest packages:
spacewalk-manage-channel-lifecycle -c sles11-sp3-pool-x86_64 --init

Promote the packages from 'dev' to 'test':
spacewalk-manage-channel-lifecycle -c dev-sles11-sp3-pool-x86_64 --promote

Promote the packages from 'test' to 'prod':
spacewalk-manage-channel-lifecycle -c test-sles11-sp3-pool-x86_64 --promote

Archive a production channel:
spacewalk-manage-channel-lifecycle -c prod-sles11-sp3-pool-x86_64 --archive

Rollback the production channel to an archived version:
spacewalk-manage-channel-lifecycle \\
    -c archive-20110520-prod-sles11-sp3-pool-x86_64 --rollback'''

option_list = [
    Option('-l', '--list-channels', help='list existing channels',
           action='store_true'),
    Option('', '--init', help='initialize a development channel',
           action='store_true'),
    Option('-w', '--workflow', help='use configured workflow', default=""),
    Option('-f', '--list-workflows', help='list configured workflows', default=False,
           action='store_true'),
    Option('', '--promote', help='promote a channel to the next phase',
           action='store_true'),
    Option('', '--archive', help='archive a channel', action='store_true'),
    Option('', '--rollback', help='rollback', action='store_true'),
    Option('-c', '--channel', help='channel to init/promote/archive/rollback'),
    Option('-C', '--clear-channel',
           help='clear all packages/errata from the channel before merging',
           action='store_true'),
    Option('-x', '--exclude-channel', help='skip these channels',
           action='append'),
    Option('', '--no-errata', help="don't merge errata data with --promote",
           action='store_true'),
    Option('', '--no-children', help="don't operate on child channels",
           action='store_true'),
    Option('-P', '--phases', default='',
           help='comma-separated list of phases [default: dev,test,prod'),
    Option('-u', '--username', help='Spacewalk username'),
    Option('-p', '--password', help='Spacewalk password'),
    Option('-s', '--server',
           help='Spacewalk server [default: %default]', default='localhost'),
    Option('-n', '--dry-run', help="don't perform any operations",
           action='store_true'),
    Option('-t', '--tolerant', help='be tolerant of errors',
           action='store_true'),
    Option('-d', '--debug', help='enable debug logging', action='count')
]

parser = OptionParser(option_list=option_list, usage=usage)
(options, args) = parser.parse_args()

if options.debug:
    level = logging.DEBUG
else:
    level = logging.INFO

logging.basicConfig(level=level, format='%(levelname)s: %(message)s')

options.workflow = options.workflow or Config.SECTION_GENERAL
config = Config(USER_CONF_FILE)
try:
    setup_config(config)
except ConfigParser.ParsingError, ex:
    logging.error("Unable to process configuration:\n" + str(ex))
    sys.exit(1)

if options.list_workflows:
    workflows = config.sections()
    if not workflows:
        print "There are no additinal configured workflows except default."
        sys.exit(0)

    print "Configured additional workflows:"
    idx = 1
    for workflow in workflows:
        print "  %s. %s" % (idx, workflow)
        idx += 1
    print
    sys.exit(0)

# sanity check
if not (options.init or options.promote or options.archive or options.rollback
        or options.list_channels):
    logging.error("You must provide an action " +
                  "(--init/--promote/--archive/--rollback/--list-channels)")
    sys.exit(1)
elif not options.list_channels and not options.channel:
    logging.error('--channel is required')
    sys.exit(1)

# parse the list of phases
phases = parse_enumerated(options.phases or config.get(options.workflow,
                                                       Config.OPT_PHASES,
                                                       "dev,test,prod"))

# update exclude channel option
options.exclude_channel = options.exclude_channel or \
    parse_enumerated(config.get(options.workflow,
                                Config.OPT_EXCLUDE_CHNL, ""))

if len(phases) < 2:
    logging.error('You must define at least 2 phases')
    sys.exit(1)

# determine if you want to enable XMLRPC debugging
if options.debug > 1:
    xmlrpc_debug = True
else:
    xmlrpc_debug = False

# connect to the server
client = xmlrpclib.Server('https://%s/rpc/api' % options.server,
                          verbose=xmlrpc_debug)

session = None

# check for an existing session
if os.path.exists(SESSION_CACHE):
    try:
        fh = open(SESSION_CACHE, 'r')
        session = fh.readline()
        fh.close()
    except IOError, e:
        if options.debug:
            logging.exception(e)
        logging.debug('Failed to read session cache')

    # validate the session
    try:
        client.channel.listMyChannels(session)
    except xmlrpclib.Fault, e:
        if options.debug:
            logging.exception(e)
        logging.warning('Existing session is invalid')
        session = None

if not session:
    # prompt for the username
    if not options.username:
        while not options.username:
            options.username = ask('Spacewalk Username')

    # prompt for the password
    if not options.password:
        options.password = ask('Spacewalk Password', password=True)
    if not options.password:
        logging.warning("Empty password is not a good practice!")

    # login to the server
    try:
        session = client.auth.login(options.username, options.password)
    except xmlrpclib.Fault, e:
        if options.debug:
            logging.exception(e)
        logging.error('Failed to log into %s' % options.server)
        sys.exit(1)

    # save the session for subsuquent runs
    try:
        fh = open(SESSION_CACHE, 'w')
        fh.write(session)
        fh.close()
    except IOError, e:
        if options.debug:
            logging.exception(e)
        logging.warning('Failed to write session cache')

# list all of the channels once
try:
    all_channels = client.channel.listSoftwareChannels(session)
    all_channel_labels = [c.get('label') for c in all_channels]
except xmlrpclib.Fault, e:
    if options.debug:
        logging.exception(e)
    logging.error('Could not retrieve the list of software channels')
    sys.exit(1)

# list the available custom channels and exit
if options.list_channels:
    print_channel_tree()
    sys.exit(0)

# ensure the source channel exists
if not channel_exists(options.channel):
    sys.exit(1)

# determine the channel labels for the parent channel
(parent_source, parent_dest) = build_channel_labels(options.channel)

# inform the user that no changes will take place
if options.dry_run:
    logging.info('DRY RUN - No changes are being made to the channels')
    time.sleep(2)
    print

# merge packages/errata if the destination channel label already exists,
# otherwise clone it from the source channel
if parent_dest in all_channel_labels:
    merge_channels(parent_source, parent_dest)
else:
    # channel doesn't exist, clone it from the original
    details = {'label': parent_dest,
               'name': parent_dest,
               'summary': parent_dest}

    clone_channel(parent_source, details)

if not options.no_children:
    children = []

    # get the child channels for the source parent channel
    for channel in all_channels:
        if channel.get('parent_label') == parent_source:
            children.append(channel.get('label'))

    for label in children:
        # determine the child channels for the source parent channel
        (child_source, child_dest) = build_channel_labels(label)

        # merge packages/errata if the destination channel label already exists,
        # otherwise clone it from the source channel
        if child_dest in all_channel_labels:
            merge_channels(child_source, child_dest)
        else:
            details = {'label': child_dest,
                       'name': child_dest,
                       'summary': child_dest,
                       'parent_label': parent_dest}

            clone_channel(child_source, details)
