# Copyright: (c) 2024, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import dataclass, field, replace from pathlib import Path from typing import TYPE_CHECKING, Any import yaml if TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule # type: ignore[reportMissingStubFile] @dataclass(frozen=True) class Result: changed: bool = False diff: dict[str, Any] = field(default_factory=dict) @dataclass(frozen=True) class State: module: AnsibleModule result: Result compose_filepath: str before: dict[str, Any] after: dict[str, Any] def recursive_update( default: dict[Any, Any], update: dict[Any, Any], ) -> dict[Any, Any]: for key in update: # noqa: PLC0206 if isinstance(update[key], dict) and isinstance(default.get(key), dict): default[key] = recursive_update(default[key], update[key]) elif isinstance(update[key], list) and isinstance(default.get(key), list): default_set = set(default[key]) custom_set = set(update[key]) default[key] = list(default_set.union(custom_set)) else: default[key] = update[key] return default def get_state(module: AnsibleModule) -> State: """Create a new state object, loading the compose file into "before" if it exists.""" compose_filepath = f"{module.params['project_dir']}/{module.params['name']}/docker-compose.yml" try: with Path(compose_filepath).open("r") as fp: before = yaml.safe_load(fp) except FileNotFoundError: before: dict[str, Any] = {} return State( module=module, result=Result(), compose_filepath=compose_filepath, before=before, after={ "name": module.params["name"], }, ) def update_project(state: State) -> State: """Ensure that networks/volumes that exist in services also exist in the project.""" project = copy.deepcopy(state.after) project_services = project.get("services", {}) project_networks = project.get("networks", {}) project_volumes = project.get("volumes", {}) for service in project_services: if service_volumes := service.get("volumes"): service_volume_names = [x["source"] for x in service_volumes] project_volumes.update( { service_volume_name: None for service_volume_name in service_volume_names if service_volume_name not in project_volumes }, ) if service_network_names := service.get("networks").keys(): project_networks.update( { service_network_name: None for service_network_name in service_network_names if service_network_name not in project_networks }, ) return replace(state, after=project) def write_compose(state: State) -> State: file = state.compose_filepath with Path(file).open(mode="w") as stream: yaml.dump(state.after, stream) return state