# Copyright: (c) 2024, Luca Bilke <luca@bil.ke> # MIT License (see LICENSE) from __future__ import annotations import copy from dataclasses import asdict, dataclass, field, replace from typing import Any, Callable import yaml from ansible.module_utils.basic import AnsibleModule PROJECTS_DIR = "/var/lib/ez_compose" BASE_SERVICE_ARGS = { "project_name": { "type": "str", "required": True, }, "name": { "type": "str", "required": True, }, "image": { "type": "str", }, "state": { "type": "str", "default": "present", "choices": ["present", "absent"], }, "defaults": { "type": "dict", }, } BASE_LABEL_ARGS = { "project_name": { "type": "str", "required": True, }, "name": { "type": "str", "required": True, }, "state": { "type": "str", "default": "present", "choices": ["present", "absent"], }, } @dataclass(frozen=True) class Result: changed: bool = False diff: dict[str, Any] = field(default_factory=dict) @dataclass(frozen=True) class Settings: projects_dir: str = "/usr/local/share/ez_compose/" @dataclass(frozen=True) class State: module: Any # Replace Any with the actual type of AnsibleModule if available result: Result compose_filepath: str before: dict[str, Any] after: dict[str, Any] def _recursive_update(default: dict[Any, Any], update: dict[Any, Any]) -> dict[Any, Any]: for key in update: if isinstance(update[key], dict) and isinstance(default.get(key), dict): default[key] = _recursive_update(default[key], update[key]) elif isinstance(update[key], list) and isinstance(default.get(key), list): default_set = set(default[key]) custom_set = set(update[key]) default[key] = list(default_set.union(custom_set)) else: default[key] = update[key] return default def get_state(module: AnsibleModule) -> State: compose_filepath = f"{PROJECTS_DIR}/{module.params['project_name']}/docker-compose.yml" try: with open(compose_filepath, "r") as fp: before = yaml.safe_load(fp) except FileNotFoundError: before: dict[str, Any] = {} return State( module=module, result=Result(), compose_filepath=compose_filepath, before=before, after=before, ) def apply_service_base(state: State) -> State: service_name = state.module.params["name"] project_name = state.module.params["project_name"] image = state.module.params["image"] update: dict[str, Any] = { "service_name": f"{project_name}_{service_name}", "hostname": f"{project_name}_{service_name}", "image": image, "restart": "unless-stopped", "environment": {}, "labels": {}, "volumes": [], "networks": { f"{project_name}_internal": None, }, } return update_service(state, update) def set_service_defaults(state: State) -> State: container_name = state.module.params["name"] defaults = state.module.params["defaults"] project = copy.deepcopy(state.after) services = project["services"] service = services[container_name] _ = _recursive_update(service, defaults) services.update({container_name: service}) return replace(state, after=project) def update_service(state: State, update: dict[str, Any]) -> State: project = copy.deepcopy(state.after) service_name = state.module.params["name"] _ = _recursive_update(project["services"][service_name], update) return replace(state, after=project) def remove_service(state: State) -> State: project = copy.deepcopy(state.after) service_name = state.module.params["name"] del project["services"][service_name] return replace(state, after=project) def remove_labels(state: State, label_names: list[str]) -> State: project = copy.deepcopy(state.after) service_name = state.module.params["name"] service = project["services"].get(service_name, {}) labels = service.get("labels", {}) if labels: for label in labels: if label in label_names: try: del service["labels"][label] except KeyError: pass service["labels"] = labels else: try: del service["labels"] except KeyError: pass project["services"][service_name] = service return replace(state, after=project) def update_project(state: State) -> State: project = copy.deepcopy(state.after) project_services = project.get("services", {}) project_networks = project.get("networks", {}) project_volumes = project.get("volumes", {}) for service in project_services: if service_volumes := service.get("volumes"): service_volume_names = [x["source"] for x in service_volumes] project_volumes.update( { service_volume_name: None for service_volume_name in service_volume_names if service_volume_name not in project_volumes } ) if service_network_names := service.get("networks").keys(): project_networks.update( { service_network_name: None for service_network_name in service_network_names if service_network_name not in project_networks } ) return replace(state, after=project) def write_compose(state: State) -> State: file = state.compose_filepath with open(file, mode="w") as stream: yaml.dump(state.after, stream) return state def run_service( extra_args: dict[str, Any] = {}, helper: Callable[[State], State] = lambda _: _, ) -> None: module = AnsibleModule( argument_spec={**BASE_SERVICE_ARGS, **extra_args}, supports_check_mode=True, ) state = get_state(module) if module.params["state"] == "absent": state = remove_service(state) else: for f in [apply_service_base, set_service_defaults, helper, update_project]: state = f(state) exit(state) def run_label( extra_args: dict[str, Any], helper: Callable[[State], State], label_names: list[str] ) -> None: module = AnsibleModule( argument_spec={**BASE_LABEL_ARGS, **extra_args}, supports_check_mode=True, ) state = get_state(module) if module.params["state"] == "absent": state = remove_labels(state, label_names) else: state = helper(state) exit(state) def exit(state: State) -> None: # TODO: Check diff and set changed variable if state.module.check_mode: state.module.exit_json(**asdict(state.result)) # type: ignore[reportUnkownMemberType] _ = write_compose(state) state.module.exit_json(**asdict(state.result)) # type: ignore[reportUnkownMemberType]