refactor: Introduce StackSelection dataclass for cleaner context passing

Instead of passing filter_host separately through multiple layers,
bundle the selection context into a StackSelection dataclass:

- stacks: list of selected stack names
- config: the loaded Config
- host_filter: optional host filter from -H flag

This provides:
1. Cleaner APIs - context travels together instead of being scattered
2. is_instance_level() method - encapsulates the check for whether
   this is an instance-level operation (host-filtered multi-host stack)
3. Future extensibility - can add more context (dry_run, verbose, etc.)

Updated all callers of get_stacks() to use the new return type.
This commit is contained in:
Bas Nijholt
2026-02-01 12:47:47 -08:00
parent 6b802106a9
commit e6e9eed93e
6 changed files with 103 additions and 59 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Annotated, TypeVar from typing import TYPE_CHECKING, Annotated, TypeVar
@@ -38,6 +39,28 @@ _T = TypeVar("_T")
_R = TypeVar("_R") _R = TypeVar("_R")
@dataclass
class StackSelection:
"""Result of stack selection with context for execution.
Bundles together the selected stacks, config, and any filters applied.
This context flows through to execution and state management.
"""
stacks: list[str]
config: Config
host_filter: str | None = None
def is_instance_level(self, stack: str) -> bool:
"""Check if this is an instance-level operation for a stack.
Instance-level means we're operating on just one host of a multi-host
stack (via --host filter). This affects state management - we shouldn't
remove multi-host stacks from state when only one instance was affected.
"""
return self.host_filter is not None and self.config.is_multi_host(stack)
# --- Shared CLI Options --- # --- Shared CLI Options ---
StacksArg = Annotated[ StacksArg = Annotated[
list[str] | None, list[str] | None,
@@ -154,7 +177,7 @@ def get_stacks(
*, *,
host: str | None = None, host: str | None = None,
default_all: bool = False, default_all: bool = False,
) -> tuple[list[str], Config]: ) -> StackSelection:
"""Resolve stack list and load config. """Resolve stack list and load config.
Handles three mutually exclusive selection methods: Handles three mutually exclusive selection methods:
@@ -171,6 +194,9 @@ def get_stacks(
Supports "." as shorthand for the current directory name. Supports "." as shorthand for the current directory name.
Returns:
StackSelection with stacks, config, and host_filter context.
""" """
validate_stack_selection(stacks, all_stacks, host) validate_stack_selection(stacks, all_stacks, host)
config = load_config_or_exit(config_path) config = load_config_or_exit(config_path)
@@ -181,14 +207,14 @@ def get_stacks(
if not stack_list: if not stack_list:
print_warning(f"No stacks configured for host [magenta]{host}[/]") print_warning(f"No stacks configured for host [magenta]{host}[/]")
raise typer.Exit(0) raise typer.Exit(0)
return stack_list, config return StackSelection(stack_list, config, host_filter=host)
if all_stacks: if all_stacks:
return list(config.stacks.keys()), config return StackSelection(list(config.stacks.keys()), config)
if not stacks: if not stacks:
if default_all: if default_all:
return list(config.stacks.keys()), config return StackSelection(list(config.stacks.keys()), config)
print_error("Specify stacks or use [bold]--all[/] / [bold]--host[/]") print_error("Specify stacks or use [bold]--all[/] / [bold]--host[/]")
raise typer.Exit(1) raise typer.Exit(1)
@@ -200,7 +226,7 @@ def get_stacks(
config, resolved, hint="Add the stack to compose-farm.yaml or use [bold]--all[/]" config, resolved, hint="Add the stack to compose-farm.yaml or use [bold]--all[/]"
) )
return resolved, config return StackSelection(resolved, config)
def run_async(coro: Coroutine[None, None, _T]) -> _T: def run_async(coro: Coroutine[None, None, _T]) -> _T:

View File

@@ -62,32 +62,37 @@ def up(
config: ConfigOption = None, config: ConfigOption = None,
) -> None: ) -> None:
"""Start stacks (docker compose up -d). Auto-migrates if host changed.""" """Start stacks (docker compose up -d). Auto-migrates if host changed."""
stack_list, cfg = get_stacks(stacks or [], all_stacks, config, host=host) selection = get_stacks(stacks or [], all_stacks, config, host=host)
if service: if service:
if len(stack_list) != 1: if len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
# For service-level up, use run_on_stacks directly (no migration logic) # For service-level up, use run_on_stacks directly (no migration logic)
results = run_async( results = run_async(
run_on_stacks( run_on_stacks(
cfg, stack_list, build_up_cmd(pull=pull, build=build, service=service), raw=True selection.config,
selection.stacks,
build_up_cmd(pull=pull, build=build, service=service),
raw=True,
) )
) )
elif host: elif selection.host_filter:
# For host-filtered up, use run_on_stacks to only affect that host # For host-filtered up, use run_on_stacks to only affect that host
# (skips migration logic, which is intended when explicitly specifying a host) # (skips migration logic, which is intended when explicitly specifying a host)
results = run_async( results = run_async(
run_on_stacks( run_on_stacks(
cfg, selection.config,
stack_list, selection.stacks,
build_up_cmd(pull=pull, build=build), build_up_cmd(pull=pull, build=build),
raw=True, raw=True,
filter_host=host, filter_host=selection.host_filter,
) )
) )
else: else:
results = run_async(up_stacks(cfg, stack_list, raw=True, pull=pull, build=build)) results = run_async(
maybe_regenerate_traefik(cfg, results) up_stacks(selection.config, selection.stacks, raw=True, pull=pull, build=build)
)
maybe_regenerate_traefik(selection.config, results)
report_results(results) report_results(results)
@@ -126,26 +131,33 @@ def down(
report_results(results) report_results(results)
return return
stack_list, cfg = get_stacks(stacks or [], all_stacks, config, host=host) selection = get_stacks(stacks or [], all_stacks, config, host=host)
raw = len(stack_list) == 1 raw = len(selection.stacks) == 1
results = run_async(run_on_stacks(cfg, stack_list, "down", raw=raw, filter_host=host)) results = run_async(
run_on_stacks(
selection.config,
selection.stacks,
"down",
raw=raw,
filter_host=selection.host_filter,
)
)
# Remove from state on success # Remove from state on success
# For multi-host stacks, result.stack is "stack@host", extract base name # For multi-host stacks, result.stack is "stack@host", extract base name
# Skip state removal for host-filtered multi-host stacks (only one instance was stopped)
removed_stacks: set[str] = set() removed_stacks: set[str] = set()
for result in results: for result in results:
if result.success: if result.success:
base_stack = result.stack.split("@")[0] base_stack = result.stack.split("@")[0]
if base_stack not in removed_stacks: if base_stack not in removed_stacks:
# Don't remove multi-host stacks from state when host-filtered # Skip state removal for instance-level operations (host-filtered multi-host)
# because only one instance was stopped, the stack is still running elsewhere # because only one instance was stopped, the stack is still running elsewhere
if host and cfg.is_multi_host(base_stack): if selection.is_instance_level(base_stack):
continue continue
remove_stack(cfg, base_stack) remove_stack(selection.config, base_stack)
removed_stacks.add(base_stack) removed_stacks.add(base_stack)
maybe_regenerate_traefik(cfg, results) maybe_regenerate_traefik(selection.config, results)
report_results(results) report_results(results)
@@ -157,13 +169,13 @@ def stop(
config: ConfigOption = None, config: ConfigOption = None,
) -> None: ) -> None:
"""Stop services without removing containers (docker compose stop).""" """Stop services without removing containers (docker compose stop)."""
stack_list, cfg = get_stacks(stacks or [], all_stacks, config) selection = get_stacks(stacks or [], all_stacks, config)
if service and len(stack_list) != 1: if service and len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
cmd = f"stop {service}" if service else "stop" cmd = f"stop {service}" if service else "stop"
raw = len(stack_list) == 1 raw = len(selection.stacks) == 1
results = run_async(run_on_stacks(cfg, stack_list, cmd, raw=raw)) results = run_async(run_on_stacks(selection.config, selection.stacks, cmd, raw=raw))
report_results(results) report_results(results)
@@ -175,13 +187,13 @@ def pull(
config: ConfigOption = None, config: ConfigOption = None,
) -> None: ) -> None:
"""Pull latest images (docker compose pull).""" """Pull latest images (docker compose pull)."""
stack_list, cfg = get_stacks(stacks or [], all_stacks, config) selection = get_stacks(stacks or [], all_stacks, config)
if service and len(stack_list) != 1: if service and len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
cmd = f"pull --ignore-buildable {service}" if service else "pull --ignore-buildable" cmd = f"pull --ignore-buildable {service}" if service else "pull --ignore-buildable"
raw = len(stack_list) == 1 raw = len(selection.stacks) == 1
results = run_async(run_on_stacks(cfg, stack_list, cmd, raw=raw)) results = run_async(run_on_stacks(selection.config, selection.stacks, cmd, raw=raw))
report_results(results) report_results(results)
@@ -193,16 +205,16 @@ def restart(
config: ConfigOption = None, config: ConfigOption = None,
) -> None: ) -> None:
"""Restart running containers (docker compose restart).""" """Restart running containers (docker compose restart)."""
stack_list, cfg = get_stacks(stacks or [], all_stacks, config) selection = get_stacks(stacks or [], all_stacks, config)
if service: if service:
if len(stack_list) != 1: if len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
cmd = f"restart {service}" cmd = f"restart {service}"
else: else:
cmd = "restart" cmd = "restart"
raw = len(stack_list) == 1 raw = len(selection.stacks) == 1
results = run_async(run_on_stacks(cfg, stack_list, cmd, raw=raw)) results = run_async(run_on_stacks(selection.config, selection.stacks, cmd, raw=raw))
report_results(results) report_results(results)

View File

@@ -453,9 +453,9 @@ def traefik_file(
render_traefik_config, render_traefik_config,
) )
stack_list, cfg = get_stacks(stacks or [], all_stacks, config) selection = get_stacks(stacks or [], all_stacks, config)
try: try:
dynamic, warnings = generate_traefik_config(cfg, stack_list) dynamic, warnings = generate_traefik_config(selection.config, selection.stacks)
except (FileNotFoundError, ValueError) as exc: except (FileNotFoundError, ValueError) as exc:
print_error(str(exc)) print_error(str(exc))
raise typer.Exit(1) from exc raise typer.Exit(1) from exc
@@ -495,22 +495,22 @@ def refresh(
Use 'cf apply' to make reality match your config (stop orphans, migrate). Use 'cf apply' to make reality match your config (stop orphans, migrate).
""" """
stack_list, cfg = get_stacks(stacks or [], all_stacks, config, default_all=True) selection = get_stacks(stacks or [], all_stacks, config, default_all=True)
# Partial refresh merges with existing state; full refresh replaces it # Partial refresh merges with existing state; full refresh replaces it
# Partial = specific stacks provided (not --all, not default) # Partial = specific stacks provided (not --all, not default)
partial_refresh = bool(stacks) and not all_stacks partial_refresh = bool(stacks) and not all_stacks
current_state = load_state(cfg) current_state = load_state(selection.config)
discovered, strays, duplicates = _discover_stacks_full(cfg, stack_list) discovered, strays, duplicates = _discover_stacks_full(selection.config, selection.stacks)
# Calculate changes (only for the stacks we're refreshing) # Calculate changes (only for the stacks we're refreshing)
added = [s for s in discovered if s not in current_state] added = [s for s in discovered if s not in current_state]
# Only mark as "removed" if we're doing a full refresh # Only mark as "removed" if we're doing a full refresh
if partial_refresh: if partial_refresh:
# In partial refresh, a stack not running is just "not found" # In partial refresh, a stack not running is just "not found"
removed = [s for s in stack_list if s in current_state and s not in discovered] removed = [s for s in selection.stacks if s in current_state and s not in discovered]
else: else:
removed = [s for s in current_state if s not in discovered] removed = [s for s in current_state if s not in discovered]
changed = [ changed = [
@@ -526,8 +526,8 @@ def refresh(
else: else:
print_success("State is already in sync.") print_success("State is already in sync.")
_report_stray_stacks(strays, cfg) _report_stray_stacks(strays, selection.config)
_report_duplicate_stacks(duplicates, cfg) _report_duplicate_stacks(duplicates, selection.config)
if dry_run: if dry_run:
console.print(f"\n{MSG_DRY_RUN}") console.print(f"\n{MSG_DRY_RUN}")
@@ -538,13 +538,13 @@ def refresh(
new_state = ( new_state = (
_merge_state(current_state, discovered, removed) if partial_refresh else discovered _merge_state(current_state, discovered, removed) if partial_refresh else discovered
) )
save_state(cfg, new_state) save_state(selection.config, new_state)
print_success(f"State updated: {len(new_state)} stacks tracked.") print_success(f"State updated: {len(new_state)} stacks tracked.")
# Capture image digests for running stacks (1 SSH call per host) # Capture image digests for running stacks (1 SSH call per host)
if discovered: if discovered:
try: try:
path = _snapshot_stacks(cfg, discovered, log_path) path = _snapshot_stacks(selection.config, discovered, log_path)
print_success(f"Digests written to {path}") print_success(f"Digests written to {path}")
except RuntimeError as exc: except RuntimeError as exc:
print_warning(str(exc)) print_warning(str(exc))

View File

@@ -235,20 +235,22 @@ def logs(
config: ConfigOption = None, config: ConfigOption = None,
) -> None: ) -> None:
"""Show stack logs. With --service, shows logs for just that service.""" """Show stack logs. With --service, shows logs for just that service."""
stack_list, cfg = get_stacks(stacks or [], all_stacks, config, host=host) selection = get_stacks(stacks or [], all_stacks, config, host=host)
if service and len(stack_list) != 1: if service and len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
# Default to fewer lines when showing multiple stacks # Default to fewer lines when showing multiple stacks
many_stacks = all_stacks or host is not None or len(stack_list) > 1 many_stacks = all_stacks or selection.host_filter is not None or len(selection.stacks) > 1
effective_tail = tail if tail is not None else (20 if many_stacks else 100) effective_tail = tail if tail is not None else (20 if many_stacks else 100)
cmd = f"logs --tail {effective_tail}" cmd = f"logs --tail {effective_tail}"
if follow: if follow:
cmd += " -f" cmd += " -f"
if service: if service:
cmd += f" {service}" cmd += f" {service}"
results = run_async(run_on_stacks(cfg, stack_list, cmd, filter_host=host)) results = run_async(
run_on_stacks(selection.config, selection.stacks, cmd, filter_host=selection.host_filter)
)
report_results(results) report_results(results)
@@ -267,12 +269,14 @@ def ps(
With --host: shows stacks on that host. With --host: shows stacks on that host.
With --service: filters to a specific service within the stack. With --service: filters to a specific service within the stack.
""" """
stack_list, cfg = get_stacks(stacks or [], all_stacks, config, host=host, default_all=True) selection = get_stacks(stacks or [], all_stacks, config, host=host, default_all=True)
if service and len(stack_list) != 1: if service and len(selection.stacks) != 1:
print_error("--service requires exactly one stack") print_error("--service requires exactly one stack")
raise typer.Exit(1) raise typer.Exit(1)
cmd = f"ps {service}" if service else "ps" cmd = f"ps {service}" if service else "ps"
results = run_async(run_on_stacks(cfg, stack_list, cmd, filter_host=host)) results = run_async(
run_on_stacks(selection.config, selection.stacks, cmd, filter_host=selection.host_filter)
)
report_results(results) report_results(results)

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch
import pytest import pytest
import typer import typer
from compose_farm.cli.common import StackSelection
from compose_farm.cli.lifecycle import apply, down from compose_farm.cli.lifecycle import apply, down
from compose_farm.config import Config, Host from compose_farm.config import Config, Host
from compose_farm.executor import CommandResult from compose_farm.executor import CommandResult
@@ -486,7 +487,7 @@ class TestHostFilterMultiHost:
patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"), patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"),
patch("compose_farm.cli.lifecycle.report_results"), patch("compose_farm.cli.lifecycle.report_results"),
): ):
mock_get_stacks.return_value = (["multi-host"], cfg) mock_get_stacks.return_value = StackSelection(["multi-host"], cfg, host_filter="host1")
down( down(
stacks=None, stacks=None,
@@ -521,7 +522,7 @@ class TestHostFilterMultiHost:
patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"), patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"),
patch("compose_farm.cli.lifecycle.report_results"), patch("compose_farm.cli.lifecycle.report_results"),
): ):
mock_get_stacks.return_value = (["multi-host"], cfg) mock_get_stacks.return_value = StackSelection(["multi-host"], cfg, host_filter="host1")
down( down(
stacks=None, stacks=None,
@@ -554,7 +555,7 @@ class TestHostFilterMultiHost:
patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"), patch("compose_farm.cli.lifecycle.maybe_regenerate_traefik"),
patch("compose_farm.cli.lifecycle.report_results"), patch("compose_farm.cli.lifecycle.report_results"),
): ):
mock_get_stacks.return_value = (["multi-host"], cfg) mock_get_stacks.return_value = StackSelection(["multi-host"], cfg, host_filter=None)
down( down(
stacks=None, stacks=None,

View File

@@ -8,6 +8,7 @@ import pytest
from compose_farm import executor as executor_module from compose_farm import executor as executor_module
from compose_farm import state as state_module from compose_farm import state as state_module
from compose_farm.cli import management as cli_management_module from compose_farm.cli import management as cli_management_module
from compose_farm.cli.common import StackSelection
from compose_farm.config import Config, Host from compose_farm.config import Config, Host
from compose_farm.executor import CommandResult, check_stack_running from compose_farm.executor import CommandResult, check_stack_running
@@ -204,7 +205,7 @@ class TestRefreshCommand:
with ( with (
patch( patch(
"compose_farm.cli.management.get_stacks", "compose_farm.cli.management.get_stacks",
return_value=(["plex"], mock_config), return_value=StackSelection(["plex"], mock_config),
), ),
patch( patch(
"compose_farm.cli.management.load_state", "compose_farm.cli.management.load_state",
@@ -240,7 +241,7 @@ class TestRefreshCommand:
with ( with (
patch( patch(
"compose_farm.cli.management.get_stacks", "compose_farm.cli.management.get_stacks",
return_value=(["plex", "jellyfin", "grafana"], mock_config), return_value=StackSelection(["plex", "jellyfin", "grafana"], mock_config),
), ),
patch( patch(
"compose_farm.cli.management.load_state", "compose_farm.cli.management.load_state",
@@ -278,7 +279,7 @@ class TestRefreshCommand:
with ( with (
patch( patch(
"compose_farm.cli.management.get_stacks", "compose_farm.cli.management.get_stacks",
return_value=(["plex", "jellyfin", "grafana"], mock_config), return_value=StackSelection(["plex", "jellyfin", "grafana"], mock_config),
), ),
patch( patch(
"compose_farm.cli.management.load_state", "compose_farm.cli.management.load_state",
@@ -312,7 +313,7 @@ class TestRefreshCommand:
with ( with (
patch( patch(
"compose_farm.cli.management.get_stacks", "compose_farm.cli.management.get_stacks",
return_value=(["plex", "jellyfin"], mock_config), return_value=StackSelection(["plex", "jellyfin"], mock_config),
), ),
patch( patch(
"compose_farm.cli.management.load_state", "compose_farm.cli.management.load_state",
@@ -347,7 +348,7 @@ class TestRefreshCommand:
with ( with (
patch( patch(
"compose_farm.cli.management.get_stacks", "compose_farm.cli.management.get_stacks",
return_value=(["plex"], mock_config), return_value=StackSelection(["plex"], mock_config),
), ),
patch( patch(
"compose_farm.cli.management.load_state", "compose_farm.cli.management.load_state",