# Copyright: (c) 2025, Luca Bilke <luca@bil.ke>
# MIT License (see LICENSE)

from __future__ import annotations

import copy
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Callable

from ansible_collections.snailed.ez_docker.plugins.module_utils.common import recursive_update

if TYPE_CHECKING:
    from ansible_collections.snailed.ez_docker.plugins.module_utils.models import State

BASE_DOCUMENTATION = """
name:
    description:
        - Name of the service.
    type: str
overwrite:
    description:
        - Definition to force.
    type: dict
"""

BASE_ARGS: dict[str, Any] = {
    "name": {"type": "str"},
    "overwrite": {"type": "dict"},
}


def get_base_definition(state: State, service_name: str) -> dict[str, Any]:
    project_name: str = state.params["name"]
    return {
        "container_name": f"{project_name}_{service_name}",
        "hostname": f"{project_name}_{service_name}",
        "restart": "unless-stopped",
        "environment": {},
        "labels": {},
        "volumes": [],
    }


def get_default_definition(state: State, service_name: str) -> dict[str, Any]:
    settings: dict[str, Any] = state.params.get("settings", {})
    default_definition: dict[str, Any] = settings.get("default_definition", {})
    service_default_definitions: dict[str, Any] = settings.get("service_default_definitions", {})
    service_default_definition: dict[str, Any] = service_default_definitions.get(service_name, {})

    return default_definition | service_default_definition


def get_default_args(state: State, helper_name: str) -> dict[str, Any]:
    settings: dict[str, Any] = state.params.get("settings", {})
    service_default_args: dict[str, Any] = settings.get("service_default_args", {})
    default_args: dict[str, Any] = service_default_args.get(helper_name, {})
    return default_args


def apply_update(state: State, service_name: str, update: dict[str, Any]) -> State:
    project = copy.deepcopy(state.after)
    service = project["services"].get(service_name, {})
    service = recursive_update(service, update)

    volumes: list[dict[str, Any]] = service.get("volumes", [])
    unique_volumes = list({vol["source"]: vol for vol in volumes if "target" in vol}.values())
    service["volumes"] = unique_volumes

    project["services"][service_name] = service

    return replace(state, after=project)


def run_helper(
    state: State,
    params: dict[str, Any],
    helper: Callable[[State, dict[str, Any]], dict[str, Any]] = lambda _a, _b: {},
) -> State:
    if not params.get("name"):
        params["name"] = str.split(helper.__module__, ".")[-1]

    base_definition = get_base_definition(state, params["name"])
    state = apply_update(state, params["name"], base_definition)

    default_definition = get_default_definition(state, params["name"])
    state = apply_update(state, params["name"], default_definition)

    helper_update = helper(state, params)
    state = apply_update(state, params["name"], helper_update)

    if not (overwrite := params.get("overwrite")):
        overwrite = params.get("definition", {})

    return apply_update(state, params["name"], overwrite)