# Copyright: (c) 2025, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable from ansible_collections.snailed.ez_compose.plugins.module_utils.common import ( clean_none, recursive_update, ) if TYPE_CHECKING: from ansible_collections.snailed.ez_compose.plugins.module_utils.models import State BASE_ARGS: dict[str, Any] = { "name": {"type": "str"}, "overwrite": {"type": "dict"}, } def get_base_definition(state: State, service_name: str) -> dict[str, Any]: project_name: str = state.module.params["name"] return { "container_name": f"{project_name}_{service_name}", "hostname": f"{project_name}_{service_name}", "restart": "unless-stopped", "environment": {}, "labels": {}, "volumes": [], # TODO: this should be set per service helper # "networks": { # "internal": None, # }, } def get_default_definition(state: State, service_name: str) -> dict[str, Any]: settings: dict[str, Any] = state.module.params.get("settings") or {} default_definition: dict[str, Any] = settings.get("default_definition") or {} service_default_definitions: dict[str, Any] = settings.get("service_default_definitions") or {} service_default_definition: dict[str, Any] = service_default_definitions.get(service_name) or {} return default_definition | service_default_definition def get_default_args(state: State, helper_name: str) -> dict[str, Any]: settings: dict[str, Any] = state.module.params.get("settings") or {} service_default_args: dict[str, Any] = settings.get("service_default_args") or {} default_args: dict[str, Any] = service_default_args.get(helper_name) or {} return clean_none(default_args) def apply_update(state: State, service_name: str, update: dict[str, Any]) -> State: project = copy.deepcopy(state.after) service = project["services"].get(service_name, {}) service = recursive_update(service, update) volumes: list[dict[str, Any]] = service.get("volumes") or [] unique_volumes = list({vol["source"]: vol for vol in volumes if "target" in vol}.values()) service["volumes"] = unique_volumes project["services"][service_name] = service return replace(state, after=project) def run_helper( state: State, params: dict[str, Any], helper: Callable[[State, dict[str, Any]], dict[str, Any]] = lambda _a, _b: {}, ) -> State: if not params.get("name"): params["name"] = str.split(helper.__module__, ".")[-1] base_definition = get_base_definition(state, params["name"]) state = apply_update(state, params["name"], base_definition) default_definition = get_default_definition(state, params["name"]) state = apply_update(state, params["name"], default_definition) helper_update = helper(state, params) state = apply_update(state, params["name"], helper_update) if not (overwrite := params.get("overwrite")): overwrite = params.get("definition", {}) return apply_update(state, params["name"], overwrite)