# Copyright: (c) 2024, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable from ansible_collections.snailed.ez_compose.plugins.module_utils.common import ( recursive_update, update_project, ) if TYPE_CHECKING: from ansible_collections.snailed.ez_compose.plugins.module_utils.models import ( State, ) BASE_ARGS: dict[str, Any] = { "name": {"type": "str"}, "overwrite": {"type": "dict"}, } def apply_base(state: State, params: dict[str, Any]) -> State: project_name: str = state.module.params["name"] new: dict[str, Any] = { "container_name": f"{project_name}_{params['name']}", "hostname": f"{project_name}_{params['name']}", "restart": "unless-stopped", "environment": {}, "labels": {}, "volumes": [], "networks": { "internal": None, }, } | ( state.module.params.get("settings", {}) .get("default_definition", {}) .get(params["name"], {}) ) return update(state, params, new) def apply_definition(state: State, params: dict[str, Any], definition: dict[str, Any]) -> State: service_name: str = params["name"] project = copy.deepcopy(state.after) services: dict[str, Any] = project["services"] service: dict[str, Any] = services[service_name] _ = recursive_update(service, definition) services.update({service_name: service}) return replace(state, after=project) def apply_settings(state: State, params: dict[str, Any]) -> State: settings = state.module.params.get("settings", {}) params = settings.get("service_default_args", {}).get(params["name"], {}) | params return update( state, params, settings.get("service_default_definitions", {}).get(params["name"], {}), ) def update(state: State, params: dict[str, Any], update: dict[str, Any]) -> State: service_name: str = params["name"] project = copy.deepcopy(state.after) project["services"][service_name] = project["services"].get(service_name, {}) _ = recursive_update(project["services"][service_name], update) # FIX: this silently throws out misconfigured volumes unique_volumes = dict( { vol["target"]: vol for vol in project["services"][service_name].get("volumes", []) if "target" in vol }.values(), ) project["services"][service_name]["volumes"] = unique_volumes return replace(state, after=project) def run_helper( state: State, params: dict[str, Any], helper: Callable[[State, dict[str, Any]], State] = lambda x, _: x, ) -> State: if not params.get("name"): params["name"] = str.split(helper.__module__, ".")[-1] if not params.get("overwrite"): params["overwrite"] = params.get("definition", {}) state = apply_base(state, params) state = apply_settings(state, params) state = helper(state, params) state = apply_definition(state, params, params["overwrite"]) return update_project(state)