ez_docker/plugins/module_utils/service/common.py
2025-01-17 18:31:17 +01:00

75 lines
2.5 KiB
Python

# Copyright: (c) 2025, Luca Bilke <luca@bil.ke>
# MIT License (see LICENSE)
from __future__ import annotations
import copy
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Callable
from ansible_collections.snailed.ez_compose.plugins.module_utils.common import recursive_update
if TYPE_CHECKING:
from ansible_collections.snailed.ez_compose.plugins.module_utils.models import State
BASE_ARGS: dict[str, Any] = {
"name": {"type": "str"},
"overwrite": {"type": "dict"},
}
def get_base_definition(state: State, service_name: str) -> dict[str, Any]:
project_name: str = state.module.params["name"]
return {
"container_name": f"{project_name}_{service_name}",
"hostname": f"{project_name}_{service_name}",
"restart": "unless-stopped",
"environment": {},
"labels": {},
"volumes": [],
"networks": {
"internal": None,
},
}
def get_default_definition(state: State, service_name: str) -> dict[str, Any]:
return (
state.module.params.get("settings", {})
.get("default_definition", {})
.get(service_name, {})
)
def apply_update(state: State, service_name: str, update: dict[str, Any]) -> State:
project = copy.deepcopy(state.after)
service = project["services"].get(service_name, {})
project["services"][service_name] = recursive_update(service, update)
volumes: list[dict[str, Any]] = project["services"][service_name].get("volumes") or []
# FIX: this silently throws out misconfigured volumes
unique_volumes = list({vol["source"]: vol for vol in volumes if "target" in vol}.values())
project["services"][service_name]["volumes"] = unique_volumes
return replace(state, after=project)
def run_helper(
state: State,
params: dict[str, Any],
helper: Callable[[State, dict[str, Any]], dict[str, Any]] = lambda _a, _b: {},
) -> State:
if not params.get("name"):
params["name"] = str.split(helper.__module__, ".")[-1]
if not (overwrite := params.get("overwrite")):
overwrite = params.get("definition", {})
base_definition = get_base_definition(state, params["name"])
default_definition = get_default_definition(state, params["name"])
helper_update = helper(state, params)
state = apply_update(state, params["name"], base_definition)
state = apply_update(state, params["name"], default_definition)
state = apply_update(state, params["name"], helper_update)
return apply_update(state, params["name"], overwrite)