# Copyright: (c) 2024, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING, Any import yaml from ansible_collections.snailed.ez_compose.plugins.module_utils.models import Result, State if TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule # pyright: ignore[reportMissingTypeStubs] def recursive_update( default: dict[Any, Any], update: dict[Any, Any], ) -> dict[Any, Any]: default = copy.deepcopy(default) for key in update: # noqa: PLC0206 if isinstance(update[key], dict) and ( isinstance(default.get(key), dict) or default.get(key) is None ): default[key] = recursive_update(default.get(key, {}), update[key]) elif isinstance(update[key], list) and ( isinstance(default.get(key), list) or default.get(key) is None ): # default_set = set(default.get(key, [])) # custom_set = set(update[key]) # default[key] = list(default_set.union(custom_set)) default[key] = default.get(key, []).extend(update[key]) else: default[key] = update[key] return default def get_state(module: AnsibleModule) -> State: """Create a new state object, loading the compose file into "before" if it exists.""" compose_filepath = f"{module.params['project_dir']}/{module.params['name']}/docker-compose.yml" try: with Path(compose_filepath).open("r") as fp: before = yaml.safe_load(fp) except FileNotFoundError: before: dict[str, Any] = {} return State( module=module, result=Result(), compose_filepath=compose_filepath, before=before, after={ "name": module.params["name"], "services": {}, "networks": {}, "volumes": {}, }, ) def update_project(state: State) -> State: """Ensure that networks/volumes that exist in services also exist in the project.""" project = copy.deepcopy(state.after) project_services: dict[str, Any] = project.get("services", {}) project_networks: dict[str, Any] = project.get("networks", {}) project_volumes: dict[str, Any] = project.get("volumes", {}) for project_service in [x for x in project_services.values() if x]: if service_volumes := project_service.get("volumes"): service_volume_names = [x["source"] for x in service_volumes] project_volumes.update( { service_volume_name: None for service_volume_name in service_volume_names if service_volume_name not in project_volumes }, ) if service_network_names := project_service.get("networks", {}).keys(): project_networks.update( { service_network_name: None for service_network_name in service_network_names if service_network_name not in project_networks }, ) return replace(state, after=project) def write_compose(state: State) -> State: file = state.compose_filepath with Path(file).open(mode="w") as stream: yaml.dump(state.after, stream) return state