from __future__ import annotations

import gzip
from io import BytesIO

from starlette.datastructures import MutableHeaders

from starlette_compress._utils import is_start_message_satisfied

TYPE_CHECKING = False
if TYPE_CHECKING:
    from starlette.types import ASGIApp, Message, Receive, Scope, Send


class GZipResponder:
    __slots__ = (
        'app',
        'level',
        'minimum_size',
    )

    def __init__(self, app: ASGIApp, minimum_size: int, level: int) -> None:
        self.app = app
        self.minimum_size = minimum_size
        self.level = level

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        start_message: Message | None = None
        compressor: gzip.GzipFile | None = None
        buffer: BytesIO | None = None

        async def wrapper(message: Message) -> None:
            nonlocal start_message, compressor, buffer

            message_type: str = message['type']

            # handle start message
            if message_type == 'http.response.start':
                if start_message is not None:
                    raise AssertionError(
                        'Unexpected repeated http.response.start message'
                    )

                if is_start_message_satisfied(message):
                    # capture start message and wait for response body
                    start_message = message
                    return
                else:
                    await send(message)
                    return

            # skip if start message is not satisfied or unknown message type
            if start_message is None or message_type != 'http.response.body':
                await send(message)
                return

            body: bytes = message.get('body', b'')
            more_body: bool = message.get('more_body', False)

            if compressor is None:
                # skip compression for small responses
                if not more_body and len(body) < self.minimum_size:
                    await send(start_message)
                    await send(message)
                    return

                headers = MutableHeaders(raw=start_message['headers'])
                headers['Content-Encoding'] = 'gzip'
                headers.add_vary_header('Accept-Encoding')

                if not more_body:
                    # one-shot
                    compressed_body = gzip.compress(body, compresslevel=self.level)
                    headers['Content-Length'] = str(len(compressed_body))
                    message['body'] = compressed_body
                    await send(start_message)
                    await send(message)
                    return

                # begin streaming
                del headers['Content-Length']
                await send(start_message)
                buffer = BytesIO()
                compressor = gzip.GzipFile(
                    mode='wb', compresslevel=self.level, fileobj=buffer
                )

            if buffer is None:
                raise AssertionError('Compressor is set but buffer is not')

            # streaming
            compressor.write(body)
            if not more_body:
                compressor.close()
            compressed_body = buffer.getvalue()
            if more_body:
                if compressed_body:
                    buffer.seek(0)
                    buffer.truncate()
                else:
                    return
            await send(
                {
                    'type': 'http.response.body',
                    'body': compressed_body,
                    'more_body': more_body,
                }
            )

        await self.app(scope, receive, wrapper)
