# systemd_ctypes
#
# Copyright (C) 2022 Martin Pitt <martin@piware.de>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import json
import tempfile
import unittest

import dbusmock  # type: ignore[import] # not typed
import pytest

import systemd_ctypes
from systemd_ctypes import bus, introspection

TEST_ADDR = ('org.freedesktop.Test', '/', 'org.freedesktop.Test.Main')


class Test_Greeter(bus.Object):
    @bus.Interface.Method('s', 's')
    def say_hello(self, name):
        return f'Hello {name}!'


class TestAPI(dbusmock.DBusTestCase):
    @classmethod
    def setUpClass(cls):
        cls.start_session_bus()
        cls.bus_user = systemd_ctypes.Bus.default_user()

    def setUp(self):
        self.mock_log = tempfile.NamedTemporaryFile()
        self.p_mock = self.spawn_server(*TEST_ADDR, stdout=self.mock_log)
        self.addCleanup(self.p_mock.wait)
        self.addCleanup(self.p_mock.terminate)

    def assertLog(self, regex):
        with open(self.mock_log.name, "rb") as f:
            self.assertRegex(f.read(), regex)

    def add_method(self, iface, name, in_sig, out_sig, code):
        result = self.bus_user.call_method('org.freedesktop.Test', '/', dbusmock.MOCK_IFACE, 'AddMethod', 'sssss',
                                           iface, name, in_sig, out_sig, code)
        self.assertEqual(result, ())

    def async_call(self, message):
        result = None

        async def _call():
            nonlocal result
            result = await self.bus_user.call_async(message)

        systemd_ctypes.event.run_async(_call())

        return result

    def test_append_arg(self):
        args = [
            ('i', 1234),
            ('s', 'Hello World'),
            ('v', {'t': 's', 'v': 'Hi!'}),
            ('v', {'t': 'i', 'v': 5678}),
            ('ai', [1, 2, 3, 5]),
            ('a{s(ii)}', {'start': (3, 4), 'end': (6, 7)}),
            ('a{sv}', {'one': {'t': 's', 'v': "Hello"}, 'two': {'t': 't', 'v': 1234567890}}),
        ]
        signature = ''.join(typestring for typestring, value in args)
        values = tuple(value for typestring, value in args)

        # Construct it one argument at a time
        message1 = self.bus_user.message_new_method_call(*TEST_ADDR, 'Do')
        for typestring, value in args:
            message1.append_arg(typestring, value)
        message1.seal(0, 0)
        self.assertEqual(message1.get_body(), values)

        # Construct it in one go
        message2 = self.bus_user.message_new_method_call(*TEST_ADDR, 'Do', signature, *values)
        message2.seal(0, 0)
        self.assertEqual(message2.get_body(), values)

    def test_noarg_noret_sync(self):
        self.add_method('', 'Do', '', '', '')
        result = self.bus_user.call_method(*TEST_ADDR, 'Do')
        self.assertEqual(result, ())
        self.assertLog(b'^[0-9.]+ Do$')

    def test_noarg_noret_async(self):
        self.add_method('', 'Do', '', '', '')
        message = self.bus_user.message_new_method_call(*TEST_ADDR, 'Do')
        self.assertEqual(self.async_call(message).get_body(), ())
        self.assertLog(b'^[0-9.]+ Do$')

    def test_strarg_strret_sync(self):
        self.add_method('', 'Reverse', 's', 's', 'ret = "".join(reversed(args[0]))')

        result = self.bus_user.call_method(*TEST_ADDR, 'Reverse', 's', 'ab c')
        self.assertEqual(result, ('c ba',))
        self.assertLog(b'^[0-9.]+ Reverse "ab c"\n$')

    def test_strarg_strret_async(self):
        self.add_method('', 'Reverse', 's', 's', 'ret = "".join(reversed(args[0]))')
        message = self.bus_user.message_new_method_call(*TEST_ADDR, 'Reverse', 's', 'ab c')
        self.assertEqual(self.async_call(message).get_body(), ('c ba',))
        self.assertLog(b'^[0-9.]+ Reverse "ab c"\n$')

    def test_bool(self):
        self.add_method('', 'Not', 'b', 'b', 'ret = not args[0]')

        for val in [False, True]:
            result = self.bus_user.call_method(*TEST_ADDR, 'Not', 'b', val)
            self.assertEqual(result, (not val,))

    def test_int_sync(self):
        self.add_method('', 'Inc', 'yiuxt', 'yiuxt',
                        'ret = (args[0] + 1, args[1] + 1, args[2] + 1, args[3] + 1, args[4] + 1)')

        result = self.bus_user.call_method(*TEST_ADDR, 'Inc', 'yiuxt',
                                           0x7E, 0x7FFFFFFE, 0xFFFFFFFE, 0x7FFFFFFFFFFFFFFE, 0xFFFFFFFFFFFFFFFE)
        self.assertEqual(result, (0x7F, 0x7FFFFFFF, 0xFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF))

    def test_int_async(self):
        self.add_method('', 'Dec', 'yiuxt', 'yiuxt',
                        'ret = (args[0] - 1, args[1] - 1, args[2] - 1, args[3] - 1, args[4] - 1)')

        message = self.bus_user.message_new_method_call(*TEST_ADDR, 'Dec', 'yiuxt',
                                                        1, -0x7FFFFFFF, 1, -0x7FFFFFFFFFFFFFFF, 1)
        self.assertEqual(self.async_call(message).get_body(), (0, -0x80000000, 0, -0x8000000000000000, 0))

    def test_int_error(self):
        # int overflow
        self.add_method('', 'Inc', 'i', 'i', 'ret = args[0] + 1')
        with pytest.raises(systemd_ctypes.BusError, match='OverflowError'):
            self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 0x7FFFFFFF)

        # uint underflow
        self.add_method('', 'Dec', 'u', 'u', 'ret = args[0] - 1')
        with pytest.raises(systemd_ctypes.BusError,
                           match="OverflowError: can't convert negative value to unsigned int"):
            self.bus_user.call_method(*TEST_ADDR, 'Dec', 'u', 0)

    def test_float(self):
        self.add_method('', 'Sq', 'd', 'd', 'ret = args[0] * args[0]')
        result = self.bus_user.call_method(*TEST_ADDR, 'Sq', 'd', -5.5)
        self.assertAlmostEqual(result[0], 30.25)

    def test_objpath(self):
        self.add_method('', 'Parent', 'o', 'o', "ret = '/'.join(args[0].split('/')[:-1])")
        result = self.bus_user.call_method(*TEST_ADDR, 'Parent', 'o', '/foo/bar/baz')
        self.assertEqual(result, ('/foo/bar',))

    def test_array_output(self):
        self.add_method('', 'Echo', 'u', 'as', 'ret = ["echo"] * args[0]')
        result = self.bus_user.call_method(*TEST_ADDR, 'Echo', 'u', 2)
        self.assertEqual(result, (['echo', 'echo'],))

    def test_array_input(self):
        self.add_method('', 'Count', 'as', 'u', 'ret = len(args[0])')
        result = self.bus_user.call_method(*TEST_ADDR, 'Count', 'as', ['first', 'second'])
        self.assertEqual(result, (2,))

    def test_dict_output(self):
        self.add_method('', 'GetStrs', '', 'a{ss}', "ret = {'a': 'x', 'b': 'y'}")
        result = self.bus_user.call_method(*TEST_ADDR, 'GetStrs')
        self.assertEqual(result, ({'a': 'x', 'b': 'y'},))

        self.add_method('', 'GetInts', '', 'a{ii}', "ret = {1: 42, 2: 99}")
        result = self.bus_user.call_method(*TEST_ADDR, 'GetInts')
        self.assertEqual(result, ({1: 42, 2: 99},))

        self.add_method('', 'GetVariants', '', 'a{sv}',
                        "ret = {'a': dbus.String('x', variant_level=1), 'b': dbus.Boolean(True, variant_level=1)}")
        result = self.bus_user.call_method(*TEST_ADDR, 'GetVariants')
        self.assertEqual(result, ({'a': {'t': 's', 'v': 'x'}, 'b': {'t': 'b', 'v': True}},))

    def test_dict_input(self):
        self.add_method('', 'CountStrs', 'a{ss}', 'u', 'ret = len(args[0])')
        result = self.bus_user.call_method(*TEST_ADDR, 'CountStrs', 'a{ss}', {'a': 'x', 'b': 'y'})
        self.assertEqual(result, (2,))

        # TODO: Add more data types once int and variants work

    def test_binary_encode(self):
        self.add_method('', 'DecodeUTF8', 'ay', 's', 'ret = bytes(args[0]).decode("utf-8")')
        result = self.bus_user.call_method(*TEST_ADDR, 'DecodeUTF8', 'ay', b'G\xc3\xa4nsef\xc3\xbc\xc3\x9fchen')
        self.assertEqual(result, ('Gänsefüßchen',))

    def test_base64_binary_encode(self):
        self.add_method('', 'DecodeUTF8', 'ay', 's', 'ret = bytes(args[0]).decode("utf-8")')
        result = self.bus_user.call_method(*TEST_ADDR, 'DecodeUTF8', 'ay', 'R8OkbnNlZsO8w59jaGVu')
        self.assertEqual(result, ('Gänsefüßchen',))

    def test_binary_decode(self):
        self.add_method('', 'EncodeUTF8', 's', 'ay', 'ret = args[0].encode("utf-8")')
        result = self.bus_user.call_method(*TEST_ADDR, 'EncodeUTF8', 's', 'Gänsefüßchen')
        self.assertEqual(result, (b'G\xc3\xa4nsef\xc3\xbc\xc3\x9fchen',))

    def test_base64_binary_decode(self):
        self.add_method('', 'EncodeUTF8', 's', 'ay', 'ret = args[0].encode("utf-8")')
        result = self.bus_user.call_method(*TEST_ADDR, 'EncodeUTF8', 's', 'Gänsefüßchen')
        result = json.loads(json.dumps(result, cls=systemd_ctypes.JSONEncoder))
        self.assertEqual(result, ['R8OkbnNlZsO8w59jaGVu'])

    def test_unknown_method_sync(self):
        with pytest.raises(systemd_ctypes.BusError, match='.*org.freedesktop.DBus.Error.UnknownMethod:.*'
                           'Do is not a valid method of interface org.freedesktop.Test.Main'):
            self.bus_user.call_method(*TEST_ADDR, 'Do')

    def test_unknown_method_async(self):
        message = self.bus_user.message_new_method_call(*TEST_ADDR, 'Do')
        with pytest.raises(systemd_ctypes.BusError, match='.*org.freedesktop.DBus.Error.UnknownMethod:.*'
                           'Do is not a valid method of interface org.freedesktop.Test.Main'):
            self.async_call(message).get_body()

    def test_call_signature_mismatch(self):
        self.add_method('', 'Inc', 'i', 'i', 'ret = args[0] + 1')
        # specified signature does not match server, but locally consistent args
        with pytest.raises(systemd_ctypes.BusError,
                           match='(InvalidArgs|TypeError).*Fewer items.*signature.*arguments'):
            self.bus_user.call_method(*TEST_ADDR, 'Inc', 'ii', 1, 2)
        with pytest.raises(systemd_ctypes.BusError, match='InvalidArgs|TypeError'):
            self.bus_user.call_method(*TEST_ADDR, 'Inc', 's', 'hello.*dbus.String.*integer')

        # specified signature does not match arguments
        with pytest.raises(AssertionError, match=r'call args \(1, 2\) have different length than signature.*'):
            self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 1, 2)
        with pytest.raises(TypeError, match=r'.*str.* as.* integer|int.*str'):
            self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 'hello')

    def test_custom_error(self):
        self.add_method('', 'Boom', '', '',
                        'raise dbus.exceptions.DBusException("no good", name="com.example.Error.NoGood")')
        with pytest.raises(systemd_ctypes.BusError, match='no good'):
            self.bus_user.call_method(*TEST_ADDR, 'Boom')

    def test_introspect(self):
        self.add_method('', 'Do', 'saiv', 'i', 'ret = 42')
        xml, = self.bus_user.call_method(TEST_ADDR[0], '/', 'org.freedesktop.DBus.Introspectable', 'Introspect')
        parsed = introspection.parse_xml(xml)
        expected = {
            'methods': {'Do': {'in': ['s', 'ai', 'v'], 'out': ['i']}},
            'properties': {},
            'signals': {}
        }
        self.assertEqual(parsed['org.freedesktop.Test.Main'], expected)

    def check_iface_sayhello(self, service_name):
        message = self.bus_user.message_new_method_call(
            service_name, '/', 'Test.Greeter',
            'SayHello', 's', 'world')
        self.assertEqual(self.async_call(message).get_body(), ('Hello world!',))

    def test_service(self):
        test_object = Test_Greeter()
        test_slot = self.bus_user.add_object('/', test_object)
        self.bus_user.request_name('com.example.Test', bus.Bus.NameFlags.DEFAULT)

        self.check_iface_sayhello('com.example.Test')

        self.bus_user.release_name('com.example.Test')
        del test_slot

    def test_service_replace(self):
        test_object = Test_Greeter()
        test_slot = self.bus_user.add_object('/', test_object)
        self.bus_user.request_name(TEST_ADDR[0], bus.Bus.NameFlags.REPLACE_EXISTING)

        self.check_iface_sayhello(TEST_ADDR[0])

        self.bus_user.release_name(TEST_ADDR[0])
        del test_slot

    def test_request_name_errors(self):
        # name already exists
        with pytest.raises(FileExistsError):
            self.bus_user.request_name(TEST_ADDR[0], bus.Bus.NameFlags.DEFAULT)

        # invalid name
        with pytest.raises(OSError, match='.*Invalid argument'):
            self.bus_user.request_name('', bus.Bus.NameFlags.DEFAULT)

        # invalid flag
        with pytest.raises(OSError, match='.*Invalid argument'):
            self.bus_user.request_name(TEST_ADDR[0], 0xFF)

        # name not taken
        with pytest.raises(ProcessLookupError):
            self.bus_user.release_name('com.example.NotThis')


if __name__ == '__main__':
    unittest.main()
