#!/usr/bin/python3

# Copyright (c) 2022, SUSE LLC, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public
# License along with this library.

"""Update the cached SMT server objects to set the configured protocol"""

import glob
import os

import cloudregister.registerutils as utils

from cloudregister import smt

if os.path.exists(utils.OLD_REGISTRATION_DATA_DIR):
    # The case where both new and old location exist cannot happen. This
    # code runs on package update and as such moves the directory.
    #
    # The hard coded path ends in '/' thus producing an empty entry in the
    # split(). We want to make sure we do not create nested dirs.
    os.system(
        'mv %s %s' % (utils.OLD_REGISTRATION_DATA_DIR,
                os.sep.join(utils.REGISTRATION_DATA_DIR.split(os.sep)[:-2])
        )
    )

# Update the server cache with region information introduced in
# region server 8.1.3
cached_server_files = glob.glob('%s/*SMT*' % utils.REGISTRATION_DATA_DIR)
proxies = None
if utils.set_proxy():
    proxies = {
        'http_proxy': os.environ.get('http_proxy'),
        'https_proxy': os.environ.get('https_proxy')
    }
cfg = utils.get_config()
region_rmt_data = utils.fetch_smt_data(cfg, proxies, True)
region_rmt_servers = []
for child in region_rmt_data:
    smt_server = smt.SMT(child, utils.https_only(cfg))
    region_rmt_servers.append(smt_server)

for cached_server_file in cached_server_files:
    rmt = utils.get_smt_from_store(cached_server_file)
    cached_ipv4 = rmt.get_ipv4()
    for region_rmt in region_rmt_servers:
        if region_rmt.get_ipv4() == cached_ipv4:
            utils.store_smt_data(cached_server_file, region_rmt)
            break
