#!/usr/bin/python
#
# Module that removes channels from an installed satellite
#
#
# Copyright (c) 2008--2012 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
#
# Red Hat trademarks are not licensed under GPLv2. No permission is
# granted to use or replicate Red Hat trademarks that are incorporated
# in this software or its documentation.
#

import sys
import string
import os
import shutil
from optparse import Option, OptionParser

# quick check to see if you are a super-user.
if os.getuid() != 0:
    sys.stderr.write('ERROR: must be root to execute\n')
    sys.exit(8)

try:
    from rhn import rhnLockfile    #new place for rhnLockfile
except:
    from spacewalk.common import rhnLockfile #old place for rhnLockfile

from spacewalk.satellite_tools.progress_bar import ProgressBar
from spacewalk.common.rhnLog import initLOG, log_debug, log_error
from spacewalk.common.rhnConfig import CFG, initCFG
from spacewalk.server import rhnSQL
from spacewalk.server.rhnPackage import unlink_package_file

options_table = [
    Option("-v", "--verbose",       action="count",
           help="Increase verbosity"),
    Option("-l", "--list",          action="store_true",
           help="List defined channels and exit"),
    Option("-c", "--channel",       action="append", default=[],
           help="Delete this channel (can be present multiple times)"),
    Option("-a", "--channel-with-children",       action="append", default=[],
           help="Delete this channel and its child channels (can be present multiple times)"),
    Option("-u", "--unsubscribe",   action="store_true",
           help="Unsubscribe systems registered to the specified channels"),
    Option(      "--justdb",        action="store_true",
                 help="Delete only from the database, do not remove files from disk"),
    Option(      "--force",         action="store_true",
                 help="Remove the channel packages from any other channels too (Not Recommended)"),
    Option("-p", "--skip-packages", action="store_true",
           help="Do not remove package metadata or packages from the filesystem (Not Recommended)"),
    Option(      "--skip-kickstart-trees", action="store_true",
           help="Do not remove kickstart trees from the filesystem (Not Recommended)."),
    Option(      "--just-kickstart-trees", action="store_true",
           help="Remove only the kickstart trees for the channels specified."),
    Option(      "--skip-channels", action="store_true",
           help="Remove only packages from channel not the channel itself."),
]

LOCK = []
LOCK_DIR = '/var/run'

def main():

    global LOCK
    global options_table
    parser = OptionParser(option_list=options_table)

    (options, args) = parser.parse_args()

    if args:
        for arg in args:
            sys.stderr.write("Not a valid option ('%s'), try --help\n" % arg)
        sys.exit(-1)

    if not (options.channel or options.list or options.channel_with_children):
        sys.stderr.write("Nothing to do\n")
        sys.exit(0)

    if  not options.list:
        for command in ['spacewalk-remove-channel', 'satellite-sync',
                    'spacewalk-repo-sync']:
            try:
                LOCK.append(rhnLockfile.Lockfile(
                                os.path.join(LOCK_DIR, "%s.pid" % command)))
            except rhnLockfile.LockfileLockedException:
                print("ERROR: An instance of %s is running, exiting." % command)
                sys.exit(-1)

    initCFG('server')
    initLOG("stdout", options.verbose or 0)

    rhnSQL.initDB()
    rhnSQL.clear_log_id()
    rhnSQL.set_log_auth_login('SETUP')

    dict_label, dict_parents = __listChannels()
    if options.list:
        keys = dict_parents.keys()
        keys.sort()
        for c in keys:
            print c
            for sub in dict_parents[c]:
                print "\t" + sub
        sys.exit(0)


    # Verify if the channel is valid
    channels = {}
    for channel in options.channel:
        channels[channel] = None

    for parent in options.channel_with_children:
        if dict_parents.has_key(parent):
            channels[parent]= None
            for ch in dict_parents[parent]:
                channels[ch] = None
        else:
            print "Unknown parent channel %s" % parent
            sys.exit(-1)

    child_test_fail = False
    for channel in channels.keys():
        if not dict_label.has_key(channel):
            print "Unknown channel %s" % channel
            sys.exit(-1)
        if not options.skip_channels and not options.just_kickstart_trees:
            # Sanity check: verify subchannels are deleted as well if base
            # channels are selected
            if not dict_parents.has_key(channel):
                continue
            # this channel is a parent channel?
            children = []
            for subch in dict_parents[channel]:
                if not channels.has_key(subch):
                    child_test_fail = True
                    children.append(subch)
            if children:
                print "Error: cannot remove channel %s: subchannel(s) exist: " %(
                    channel)
                for child in children:
                    print "\t\t\t" + child

    if child_test_fail:
        sys.exit(-1)

    if not options.skip_channels and not options.just_kickstart_trees:
       if __serverCheck(channels.keys(), options.unsubscribe):
           sys.exit(-1)

       if __kickstartCheck(channels.keys()):
           sys.exit(-1)


    try:
        delete_channels(channels.keys(), force=options.force,
                        justdb=options.justdb, skip_packages=options.skip_packages,
                        skip_channels=options.skip_channels,
                        skip_kickstart_trees=options.skip_kickstart_trees,
                        just_kickstart_trees=options.just_kickstart_trees)
    except:
        rhnSQL.rollback()
        raise
    rhnSQL.commit()
    releaseLOCK()
    return 0

def __serverCheck(labels, unsubscribe):
    sql = """
        select distinct S.org_id, S.id, S.name
        from rhnChannel c inner join
             rhnServerChannel sc on c.id = sc.channel_id inner join
             rhnServer s on s.id = sc.server_id
        where c.label in (%s)
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    h.execute(**params)
    list = h.fetchall_dict()
    if not list:
        return 0


    if unsubscribe:
        return  __unsubscribeServers(labels)

    print("\nCurrently there are systems subscribed to one or more of the specified channels.")
    print("If you would like to automatically unsubscribe these systems, simply use the --unsubscribe flag.\n")
    print("The following systems were found to be subscribed:")

    print 'org_id'.ljust(8),
    print 'id'.ljust(14),
    print('name')
    print ("-"*32)
    for map in list:
        print str(map['org_id']).ljust(8),
        print str(map['id']).ljust(14),
        print(map['name'])

    return len(list)


def __unsubscribeServers(labels):
    sql = """
        select distinct sc.server_id as server_id, C.id as channel_id, c.parent_channel, c.label
        from rhnChannel c inner join
             rhnServerChannel sc on c.id = sc.channel_id
        where c.label in (%s) order by C.parent_channel
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    h.execute(**params)
    list = h.fetchall_dict()

    channel_counts = {}
    for i in list:
        if channel_counts.has_key(i['label']):
            channel_counts[i['label']] = channel_counts[i['label']] + 1
        else:
            channel_counts[i['label']] = 1
    print("\nThe following channels will have their systems unsubscribed:")
    channel_list = channel_counts.keys()
    channel_list.sort()
    for i in channel_list:
        print(i.ljust(40)),
        print(str(channel_counts[i]).ljust(8))

    pb = ProgressBar(prompt='Unsubscribing:    ', endTag=' - complete',
                     finalSize=len(list), finalBarLength=40, stream=sys.stdout)
    pb.printAll(1)

    unsubscribe_server_proc = rhnSQL.Procedure("rhn_channel.unsubscribe_server")
    for i in list:
        unsubscribe_server_proc(i['server_id'], i['channel_id'])
        pb.addTo(1)
        pb.printIncrement()
    pb.printComplete()


def __kickstartCheck(labels):
    sql = """
        select K.org_id, K.label
        from rhnKSData K inner join
             rhnKickstartDefaults KD on KD.kickstart_id = K.id inner join
             rhnKickstartableTree KT on KT.id = KD.kstree_id inner join
             rhnChannel c on c.id = KT.channel_id
        where c.label in (%s)
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    h.execute(**params)
    list = h.fetchall_dict()

    if not list:
        return 0

    print("The following kickstarts are associated with one of the specified channels. " +
          "Please remove these or change their associated base channel.\n")
    print('org_id'.ljust(8)),
    print('label')
    print("-"*20)
    for map in list:
        print str(map['org_id']).ljust(8),
        print(map['label'])

    return len(list)


def __listChannels():
    sql = """
        select c1.label, c2.label parent_channel
        from rhnChannel c1 left outer join rhnChannel c2 on c1.parent_channel = c2.id
        order by c2.label desc, c1.label asc
    """
    h = rhnSQL.prepare(sql)
    h.execute()
    labels = {}
    parents = {}
    while 1:
        row = h.fetchone_dict()
        if not row:
            break
        parent_channel = row['parent_channel']
        labels[row['label']] = parent_channel
        if not parent_channel:
            parents[row['label']] = []

        if parent_channel:
            parents[parent_channel].append(row['label'])

    return labels, parents


def delete_channels(channelLabels, force=0, justdb=0, skip_packages=0, skip_channels=0,
                    skip_kickstart_trees=0, just_kickstart_trees=0):
    # Get the package ids
    if not channelLabels:
        return

    rpms_ids = list_packages(channelLabels, force=force, sources=0)
    rpms_paths = _get_package_paths(rpms_ids, sources=0)
    srpms_ids = list_packages(channelLabels, force=force, sources=1)
    srpms_paths = _get_package_paths(srpms_ids, sources=1)

    if not skip_packages and not just_kickstart_trees:
        _delete_srpms(srpms_ids)
        _delete_rpms(rpms_ids)

    if not skip_kickstart_trees and not justdb:
        _delete_ks_files(channelLabels)

    if not justdb and not skip_packages and not just_kickstart_trees:
        _delete_files(rpms_paths + srpms_paths)


    # Get the channel ids
    h = rhnSQL.prepare("""
        select id, parent_channel
        from rhnChannel
        where label = :label
        order by parent_channel""")
    channel_ids = []
    for label in channelLabels:
        h.execute(label=label)
        row = h.fetchone_dict()
        if not row:
            break
        channel_id = row['id']
        if row['parent_channel']:
            # Subchannel, we have to remove it first
            channel_ids.insert(0, channel_id)
        else:
            channel_ids.append(channel_id)

    if not channel_ids:
        return

    indirect_tables = [
        ['rhnKickstartableTree', 'channel_id', 'rhnKSTreeFile', 'kstree_id'],
    ]
    query = """
        delete from %(table_2)s where %(link_field)s in (
            select id
              from %(table_1)s
             where %(channel_field)s = :channel_id
        )
    """
    for e in indirect_tables:
        args = {
            'table_1'       : e[0],
            'channel_field' : e[1],
            'table_2'       : e[2],
            'link_field'    : e[3],
        }
        h = rhnSQL.prepare(query % args)
        h.executemany(channel_id=channel_ids)

    tables = [
        ['rhnErrataFileChannel', 'channel_id'],
        ['rhnErrataNotificationQueue', 'channel_id'],
        ['rhnChannelErrata', 'channel_id'],
        ['rhnChannelPackage', 'channel_id'],
        ['rhnRegTokenChannels', 'channel_id'],
        ['rhnServerProfile', 'base_channel'],
        ['rhnKickstartableTree', 'channel_id'],
    ]

    if not skip_channels:
        tables.extend([
            ['suseProductChannel', 'channel_id'],
            ['rhnChannelFamilyMembers', 'channel_id'],
            ['rhnDistChannelMap', 'channel_id'],
            ['rhnReleaseChannelMap', 'channel_id'],
            ['rhnChannel', 'id'],
        ])

    if just_kickstart_trees:
        tables = [['rhnKickstartableTree', 'channel_id']]

    query = "delete from %s where %s = :channel_id"
    for table, field in tables:
        log_debug(3, "Processing table %s" % table)
        h = rhnSQL.prepare(query % (table, field))
        h.executemany(channel_id=channel_ids)

    if not justdb and not just_kickstart_trees:
        __deleteRepoData(channelLabels)


def __deleteRepoData(labels):
    dir = '/var/cache/' + CFG.repomd_path_prefix
    for label in labels:
        if os.path.isdir(dir + '/' + label):
            shutil.rmtree(dir + '/' + label)


def list_packages(channelLabels, sources=0, force=0):
    "List the source ids for the channels"
    if sources:
        packages = "srpms"
    else:
        packages = "rpms"
    log_debug(3, "Listing %s" % packages)
    if not channelLabels:
        return []

    params, bind_params = _bind_many(channelLabels)
    bind_params = string.join(bind_params, ', ')

    if sources:
        templ = _templ_srpms()
    else:
        templ = _templ_rpms()

    if force:
        query = templ % ("", bind_params)
    else:
        if CFG.DB_BACKEND == 'oracle':
            minus_op = 'MINUS' # ORA syntax
        else:
            minus_op = 'EXCEPT' # ANSI syntax
        query = """
            %s
            %s
            %s
        """ % (
                templ % ("", bind_params),
                minus_op,
                templ % ("not", bind_params),
            )
    h = rhnSQL.prepare(query)
    h.execute(**params)
    return map(lambda x: x['id'], h.fetchall_dict() or [])

def _templ_rpms():
    "Returns a template for querying rpms"
    log_debug(4, "Generating template for querying rpms")
    return """\
        select cp.package_id id
        from rhnChannel c, rhnChannelPackage cp
        where c.label %s in (%s)
        and cp.channel_id = c.id"""

def _templ_srpms():
    "Returns a template for querying srpms"
    log_debug(4, "Generating template for querying srpms")
    return """\
        select  ps.id id
        from    rhnPackage p,
                rhnPackageSource ps,
                rhnChannelPackage cp,
                rhnChannel c
        where   c.label %s in (%s)
            and c.id = cp.channel_id
            and cp.package_id = p.id
            and p.source_rpm_id = ps.source_rpm_id
            and ((p.org_id is null and ps.org_id is null) or
                p.org_id = ps.org_id)"""

def _delete_srpms(srcPackageIds):
    """Blow away rhnPackageSource and rhnFile entries.
    """
    if not srcPackageIds:
        return
    # nuke the rhnPackageSource entry
    h = rhnSQL.prepare("""
        delete
        from rhnPackageSource
        where id = :id
    """)
    count = h.executemany(id=srcPackageIds)
    if not count:
        count = 0
    log_debug(2, "Successfully deleted %s/%s source package ids" % (
        count, len(srcPackageIds)))

def _delete_rpms(packageIds):
    if not packageIds:
        return
    group = 300
    toDel = packageIds[:]
    print "Deleting package metadata (" + str(len(toDel)) + "):"
    pb = ProgressBar(prompt='Removing:         ', endTag=' - complete',
                     finalSize=len(packageIds), finalBarLength=40, stream=sys.stdout)
    pb.printAll(1)

    while len(toDel) > 0:
        _delete_rpm_group(toDel[:group])
        del toDel[:group]
        pb.addTo(group)
        pb.printIncrement()
    pb.printComplete()

def _delete_rpm_group(packageIds):

    references = [
        'rhnChannelPackage',
        'rhnErrataPackage',
        'rhnErrataPackageTMP',
        'rhnPackageChangelogRec',
        'rhnPackageConflicts',
        'rhnPackageFile',
        'rhnPackageObsoletes',
        'rhnPackageProvides',
        'rhnPackageRequires',
        'rhnPackageRecommends',
        'rhnPackageSuggests',
        'rhnPackageSupplements',
        'rhnPackageEnhances',
        'rhnServerNeededCache',
        'susePackageProductFile',
    ]
    deleteStatement = "delete from %s where package_id = :package_id"
    for table in references:
        h = rhnSQL.prepare(deleteStatement % table)
        count = h.executemany(package_id=packageIds)
        log_debug(3, "Deleted from %s: %d rows" % (table, count))
    deleteStatement = "delete from rhnPackage where id = :package_id"
    h = rhnSQL.prepare(deleteStatement)
    count = h.executemany(package_id=packageIds)
    if count:
        log_debug(2, "DELETED package id %s" % str(packageIds))
    else:
        log_error("No such package id %s" % str(packageIds))
    rhnSQL.commit()

def _delete_files(relpaths):
    for relpath in relpaths:
        path = os.path.join(CFG.MOUNT_POINT, relpath)
        if not os.path.exists(path):
            log_debug(1, "Not removing %s: no such file" % path)
            continue
        unlink_package_file(path)

def _bind_many(l):
    h = {}
    lr = []
    for i in range(len(l)):
        key = 'p_%s' % i
        h[key] = l[i]
        lr.append(':' + key)
    return h, lr

def _get_package_paths(package_ids, sources=0):
    if sources:
        table = "rhnPackageSource"
    else:
        table = "rhnPackage"
    h = rhnSQL.prepare("select path from %s where id = :package_id" % table)
    pdict = {}
    for package_id in package_ids:
        h.execute(package_id=package_id)
        row = h.fetchone_dict()
        if not row:
            continue
        if not row['path']:
            continue
        pdict[row['path']] = None

    return pdict.keys()

def _delete_ks_files(channel_labels):
    sql = """
        select kt.base_path
          from rhnChannel c
          join rhnKickstartableTree kt on c.id = kt.channel_id
         where c.label in (%s) and not exists (
                select 1
                  from rhnKickstartableTree ktx
                  join rhnChannel cx on cx.id = ktx.channel_id
                 where replace(ktx.base_path, :mnt_point, '') =
                       replace(kt.base_path, :mnt_point, '')
                   and cx.label not in (%s))
    """

    params, bind_params = _bind_many(channel_labels)
    params['mnt_point'] = CFG.MOUNT_POINT + '/'
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params, bind_params))
    h.execute(**params)
    list = h.fetchall_dict() or []

    for map in list:
        path = os.path.join(CFG.MOUNT_POINT, str(map['base_path']))
        if not os.path.exists(path):
            log_debug(1, "Not removing %s: no such file" % path)
            continue
        shutil.rmtree(path)


def releaseLOCK():
    global LOCK
    for lock in LOCK:
        lock.release()



if __name__ == '__main__':
    try:
        sys.exit(main() or 0)
    except KeyboardInterrupt:
        sys.stderr.write("\nUser interrupted process.\n")
        releaseLOCK()
        sys.exit(0)
    except SystemExit:
        # Normal exit
        raise
    except Exception, e:
        releaseLOCK()
        sys.stderr.write("\nERROR: unhandled exception occurred: (%s).\n" % e)
        import traceback
        traceback.print_exc()
        sys.exit(-1)

