# Copyright: (c) 2025, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING, Any, cast import yaml from ansible_collections.snailed.ez_compose.plugins.module_utils.models import Result, State if TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule # pyright: ignore[reportMissingTypeStubs] def recursive_update( default: dict[Any, Any], update: dict[Any, Any], ) -> dict[Any, Any]: default = copy.deepcopy(default) for k, v in update.items(): if isinstance(v, dict): v = cast(dict[Any, Any], v) default[k] = recursive_update(default.get(k) or {}, v) elif isinstance(v, list): v = cast(list[Any], v) new = cast(list[Any], (default.get(k) or [])) new.extend(v) default[k] = new else: default[k] = v return default def get_state(module: AnsibleModule) -> State: """Create a new state object, loading the compose file into "before" if it exists.""" compose_filepath = f"{module.params['project_dir']}/{module.params['name']}/docker-compose.yml" try: with Path(compose_filepath).open("r") as fp: before = yaml.safe_load(fp) except FileNotFoundError: before: dict[str, Any] = {} return State( module=module, result=Result(), compose_filepath=compose_filepath, before=before, after={ "name": module.params["name"], "services": {}, "networks": {}, "volumes": {}, }, ) def update_project(state: State) -> State: """Ensure that networks/volumes that exist in services also exist in the project.""" project = copy.deepcopy(state.after) project_services: dict[str, Any] = project.get("services", {}) project_networks: dict[str, Any] = project.get("networks", {}) project_volumes: dict[str, Any] = project.get("volumes", {}) for project_service in [x for x in project_services.values() if x]: if service_volumes := project_service.get("volumes"): service_volume_names = [x["source"] for x in service_volumes] project_volumes.update( { service_volume_name: None for service_volume_name in service_volume_names if service_volume_name not in project_volumes }, ) if service_network_names := project_service.get("networks", {}).keys(): project_networks.update( { service_network_name: None for service_network_name in service_network_names if service_network_name not in project_networks }, ) return replace(state, after=project) def set_result(state: State) -> State: # noqa: C901 def _changed(before: Any, after: Any) -> bool: # noqa: ANN401, C901, PLR0911 if type(before) is not type(after): return True if isinstance(before, dict): before = cast(dict[str, Any], before) if len(before) != len(after): return True for key in before: if key not in after: return True if _changed(before[key], after[key]): return True return False if isinstance(before, list): before = sorted(cast(list[Any], before)) after = sorted(after) if len(before) != len(after): return True for index in before.enumerate(): if _changed(before[index], after[index]): return True return before != after result = Result( _changed(state.before, state.after), { "before": state.before, "after": state.after, }, ) return replace(state, result=result) def write_compose(state: State) -> State: file = state.compose_filepath with Path(file).open(mode="w") as stream: yaml.dump(state.after, stream) return state