# Copyright: (c) 2024, Luca Bilke <luca@bil.ke>
# MIT License (see LICENSE)

from __future__ import annotations

import copy
from dataclasses import asdict, dataclass, field, replace
from typing import Any, Callable

import yaml
from ansible.module_utils.basic import AnsibleModule

PROJECTS_DIR = "/var/lib/ez_compose"
BASE_SERVICE_ARGS = {
    "project_name": {
        "type": "str",
        "required": True,
    },
    "name": {
        "type": "str",
        "required": True,
    },
    "image": {
        "type": "str",
    },
    "state": {
        "type": "str",
        "default": "present",
        "choices": ["present", "absent"],
    },
    "defaults": {
        "type": "dict",
    },
}
BASE_LABEL_ARGS = {
    "project_name": {
        "type": "str",
        "required": True,
    },
    "name": {
        "type": "str",
        "required": True,
    },
    "state": {
        "type": "str",
        "default": "present",
        "choices": ["present", "absent"],
    },
}


@dataclass(frozen=True)
class Result:
    changed: bool = False
    diff: dict[str, Any] = field(default_factory=dict)


@dataclass(frozen=True)
class Settings:
    projects_dir: str = "/usr/local/share/ez_compose/"


@dataclass(frozen=True)
class State:
    module: Any  # Replace Any with the actual type of AnsibleModule if available
    result: Result
    compose_filepath: str
    before: dict[str, Any]
    after: dict[str, Any]


def _recursive_update(default: dict[Any, Any], update: dict[Any, Any]) -> dict[Any, Any]:
    for key in update:
        if isinstance(update[key], dict) and isinstance(default.get(key), dict):
            default[key] = _recursive_update(default[key], update[key])

        elif isinstance(update[key], list) and isinstance(default.get(key), list):
            default_set = set(default[key])
            custom_set = set(update[key])
            default[key] = list(default_set.union(custom_set))

        else:
            default[key] = update[key]

    return default


def get_state(module: AnsibleModule) -> State:
    compose_filepath = f"{PROJECTS_DIR}/{module.params['project_name']}/docker-compose.yml"

    try:
        with open(compose_filepath, "r") as fp:
            before = yaml.safe_load(fp)

    except FileNotFoundError:
        before: dict[str, Any] = {}

    return State(
        module=module,
        result=Result(),
        compose_filepath=compose_filepath,
        before=before,
        after=before,
    )


def apply_service_base(state: State) -> State:
    service_name = state.module.params["name"]
    project_name = state.module.params["project_name"]
    image = state.module.params["image"]

    update: dict[str, Any] = {
        "service_name": f"{project_name}_{service_name}",
        "hostname": f"{project_name}_{service_name}",
        "image": image,
        "restart": "unless-stopped",
        "environment": {},
        "labels": {},
        "volumes": [],
        "networks": {
            f"{project_name}_internal": None,
        },
    }

    return update_service(state, update)


def set_service_defaults(state: State) -> State:
    container_name = state.module.params["name"]
    defaults = state.module.params["defaults"]
    project = copy.deepcopy(state.after)
    services = project["services"]
    service = services[container_name]

    _ = _recursive_update(service, defaults)

    services.update({container_name: service})

    return replace(state, after=project)


def update_service(state: State, update: dict[str, Any]) -> State:
    project = copy.deepcopy(state.after)
    service_name = state.module.params["name"]

    _ = _recursive_update(project["services"][service_name], update)

    return replace(state, after=project)


def remove_service(state: State) -> State:
    project = copy.deepcopy(state.after)
    service_name = state.module.params["name"]

    del project["services"][service_name]

    return replace(state, after=project)


def remove_labels(state: State, label_names: list[str]) -> State:
    project = copy.deepcopy(state.after)
    service_name = state.module.params["name"]
    service = project["services"].get(service_name, {})

    labels = service.get("labels", {})

    if labels:
        for label in labels:
            if label in label_names:
                try:
                    del service["labels"][label]
                except KeyError:
                    pass

        service["labels"] = labels

    else:
        try:
            del service["labels"]
        except KeyError:
            pass

    project["services"][service_name] = service

    return replace(state, after=project)


def update_project(state: State) -> State:
    project = copy.deepcopy(state.after)

    project_services = project.get("services", {})
    project_networks = project.get("networks", {})
    project_volumes = project.get("volumes", {})

    for service in project_services:
        if service_volumes := service.get("volumes"):
            service_volume_names = [x["source"] for x in service_volumes]
            project_volumes.update(
                {
                    service_volume_name: None
                    for service_volume_name in service_volume_names
                    if service_volume_name not in project_volumes
                }
            )

        if service_network_names := service.get("networks").keys():
            project_networks.update(
                {
                    service_network_name: None
                    for service_network_name in service_network_names
                    if service_network_name not in project_networks
                }
            )

    return replace(state, after=project)


def write_compose(state: State) -> State:
    file = state.compose_filepath

    with open(file, mode="w") as stream:
        yaml.dump(state.after, stream)

    return state


def run_service(
    extra_args: dict[str, Any] = {},
    helper: Callable[[State], State] = lambda _: _,
) -> None:
    module = AnsibleModule(
        argument_spec={**BASE_SERVICE_ARGS, **extra_args},
        supports_check_mode=True,
    )

    state = get_state(module)

    if module.params["state"] == "absent":
        state = remove_service(state)

    else:
        for f in [apply_service_base, set_service_defaults, helper, update_project]:
            state = f(state)

    exit(state)


def run_label(
    extra_args: dict[str, Any], helper: Callable[[State], State], label_names: list[str]
) -> None:
    module = AnsibleModule(
        argument_spec={**BASE_LABEL_ARGS, **extra_args},
        supports_check_mode=True,
    )

    state = get_state(module)

    if module.params["state"] == "absent":
        state = remove_labels(state, label_names)

    else:
        state = helper(state)

    exit(state)


def exit(state: State) -> None:
    # TODO: Check diff and set changed variable

    if state.module.check_mode:
        state.module.exit_json(**asdict(state.result))  # type: ignore[reportUnkownMemberType]

    _ = write_compose(state)

    state.module.exit_json(**asdict(state.result))  # type: ignore[reportUnkownMemberType]