This commit is contained in:
Luca Bilke 2025-01-18 19:16:11 +01:00
parent c5c0e3f6f3
commit 729cd27ed9
8 changed files with 84 additions and 36 deletions

View file

@ -15,15 +15,22 @@ if TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule # pyright: ignore[reportMissingTypeStubs]
def clean_none(obj: dict[str, Any]) -> dict[str, Any]:
obj = copy.deepcopy(obj)
for k in copy.deepcopy(obj):
if obj.get(k, "sentinel") is None:
del obj[k]
return obj
def recursive_update(
default: dict[Any, Any],
update: dict[Any, Any],
) -> dict[Any, Any]:
default: dict[str, Any],
update: dict[str, Any],
) -> dict[str, Any]:
default = copy.deepcopy(default)
for k, v in update.items():
if isinstance(v, dict):
v = cast(dict[Any, Any], v)
v = cast(dict[str, Any], v)
default[k] = recursive_update(default.get(k) or {}, v)
elif isinstance(v, list):

View file

@ -5,7 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from ansible_collections.snailed.ez_compose.plugins.module_utils.service import common
from ansible_collections.snailed.ez_compose.plugins.module_utils.common import (
clean_none,
recursive_update,
)
if TYPE_CHECKING:
from ansible_collections.snailed.ez_compose.plugins.module_utils.models import State
@ -14,11 +17,26 @@ if TYPE_CHECKING:
BASE_ARGS: dict[str, Any] = {}
def apply_update(
service: dict[str, Any],
update: dict[str, Any],
) -> dict[str, Any]:
return recursive_update(service, update)
def get_default_args(state: State, helper_name: str) -> dict[str, Any]:
settings: dict[str, Any] = state.module.params.get("settings") or {}
label_default_args: dict[str, Any] = settings.get("label_default_args") or {}
default_args: dict[str, Any] = label_default_args.get(helper_name) or {}
return clean_none(default_args)
def run_helper(
state: State,
service_name: str,
service: dict[str, Any],
params: dict[str, Any],
helper: Callable[[State, str, dict[str, Any]], dict[str, Any]] = lambda _a, _b, _c: {},
) -> State:
) -> dict[str, Any]:
update = helper(state, service_name, params)
return common.apply_update(state, service_name, update)
return apply_update(service, update)

View file

@ -20,7 +20,7 @@ def helper(state: State, service_name: str, params: dict[str, Any]) -> dict[str,
project_name: str = state.module.params["name"]
middleware: str = params["middleware"]
settings: dict[str, str] = params["settings"]
proxy_type: str = state.module.params["proxy_type"]
proxy_type: str = params["proxy_type"]
name: str = (
params.get("name") or f"{project_name}_{service_name}_{proxy_type}_{middleware.lower()}"
)

View file

@ -7,7 +7,10 @@ import copy
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Callable
from ansible_collections.snailed.ez_compose.plugins.module_utils.common import recursive_update
from ansible_collections.snailed.ez_compose.plugins.module_utils.common import (
clean_none,
recursive_update,
)
if TYPE_CHECKING:
from ansible_collections.snailed.ez_compose.plugins.module_utils.models import State
@ -27,29 +30,39 @@ def get_base_definition(state: State, service_name: str) -> dict[str, Any]:
"environment": {},
"labels": {},
"volumes": [],
"networks": {
"internal": None,
},
# TODO: this should be set per service helper
# "networks": {
# "internal": None,
# },
}
def get_default_definition(state: State, service_name: str) -> dict[str, Any]:
return (
state.module.params.get("settings", {})
.get("default_definition", {})
.get(service_name, {})
)
settings: dict[str, Any] = state.module.params.get("settings") or {}
default_definition: dict[str, Any] = settings.get("default_definition") or {}
service_default_definitions: dict[str, Any] = settings.get("service_default_definitions") or {}
service_default_definition: dict[str, Any] = service_default_definitions.get(service_name) or {}
return default_definition | service_default_definition
def get_default_args(state: State, helper_name: str) -> dict[str, Any]:
settings: dict[str, Any] = state.module.params.get("settings") or {}
service_default_args: dict[str, Any] = settings.get("service_default_args") or {}
default_args: dict[str, Any] = service_default_args.get(helper_name) or {}
return clean_none(default_args)
def apply_update(state: State, service_name: str, update: dict[str, Any]) -> State:
project = copy.deepcopy(state.after)
service = project["services"].get(service_name, {})
project["services"][service_name] = recursive_update(service, update)
service = recursive_update(service, update)
volumes: list[dict[str, Any]] = project["services"][service_name].get("volumes") or []
# FIX: this silently throws out misconfigured volumes
volumes: list[dict[str, Any]] = service.get("volumes") or []
unique_volumes = list({vol["source"]: vol for vol in volumes if "target" in vol}.values())
project["services"][service_name]["volumes"] = unique_volumes
service["volumes"] = unique_volumes
project["services"][service_name] = service
return replace(state, after=project)
@ -62,14 +75,16 @@ def run_helper(
if not params.get("name"):
params["name"] = str.split(helper.__module__, ".")[-1]
base_definition = get_base_definition(state, params["name"])
state = apply_update(state, params["name"], base_definition)
default_definition = get_default_definition(state, params["name"])
state = apply_update(state, params["name"], default_definition)
helper_update = helper(state, params)
state = apply_update(state, params["name"], helper_update)
if not (overwrite := params.get("overwrite")):
overwrite = params.get("definition", {})
base_definition = get_base_definition(state, params["name"])
default_definition = get_default_definition(state, params["name"])
helper_update = helper(state, params)
state = apply_update(state, params["name"], base_definition)
state = apply_update(state, params["name"], default_definition)
state = apply_update(state, params["name"], helper_update)
return apply_update(state, params["name"], overwrite)

View file

@ -6,6 +6,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any
from ansible_collections.snailed.ez_compose.plugins.module_utils import label, spec
from ansible_collections.snailed.ez_compose.plugins.module_utils.common import clean_none
if TYPE_CHECKING:
from ansible_collections.snailed.ez_compose.plugins.module_utils.models import (
@ -25,15 +26,14 @@ def helper(state: State, params: dict[str, Any]) -> dict[str, Any]:
update: dict[str, Any] = {}
networks = update.get("networks", {})
if internal_network:
networks = update.get("networks", {})
networks["internal"] = None
update["networks"] = networks
update["networks"] = networks
for name, args in [(x, y) for x, y in params.get("label_helpers", {}).items() if y]:
label_params = label.common.get_default_args(state, name) | clean_none(args)
helper = getattr(label, name).helper
state = label.common.run_helper(state, params["name"], args, helper)
update = label.common.run_helper(state, params["name"], update, label_params, helper)
return update

View file

@ -92,6 +92,7 @@ def settings_spec() -> dict[str, Any]:
for arg in service_args.values():
arg.pop("required", None)
arg.pop("default", None)
settings["options"]["service_default_args"]["options"][module_name] = {
"type": "dict",
@ -103,6 +104,7 @@ def settings_spec() -> dict[str, Any]:
for arg in label_args.values():
arg.pop("required", None)
arg.pop("default", None)
settings["options"]["label_default_args"]["options"][module_name] = {
"type": "dict",

View file

@ -491,9 +491,12 @@ def main() -> None:
for name, services_params in [(x, y) for x, y in module.params["services"].items() if y]:
for index, service_params in enumerate(services_params):
service_params["_index"] = index
params = service.common.get_default_args(state, name) | common.clean_none(
service_params
)
params["_index"] = index
helper = getattr(service, name).helper
state = service.common.run_helper(state, service_params, helper)
state = service.common.run_helper(state, params, helper)
state = common.update_project(state)
state = common.set_result(state)

View file

@ -11,13 +11,16 @@ ignore = [
"TD002",
"TD003",
"INP001",
"COM812",
"D100",
"D101",
"D103",
"D104",
"D203",
"D213",
]
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/units/**" = [
"S101",
"PT009",