perf: Batch snapshot collection to 1 SSH call per host (#130)

## Summary

Optimize `cf refresh` SSH calls from O(stacks) to O(hosts):
- Discovery: 1 SSH call per host (unchanged)
- Snapshots: 1 SSH call per host (was 1 per stack)

For 50 stacks across 4 hosts: 54 → 8 SSH calls.

## Changes

**Performance:**
- Use `docker ps` + `docker image inspect` instead of `docker compose images` per stack
- Batch snapshot collection by host in `collect_stacks_entries_on_host()`

**Architecture:**
- Add `build_discovery_results()` to `operations.py` (business logic)
- Keep progress bar wrapper in `cli/management.py` (presentation)
- Remove dead code: `discover_all_stacks_on_all_hosts()`, `collect_all_stacks_entries()`
This commit is contained in:
Bas Nijholt
2025-12-22 22:19:32 -08:00
committed by GitHub
parent 6fbc7430cb
commit 5f2e081298
6 changed files with 482 additions and 327 deletions

View File

@@ -21,7 +21,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.9 rev: v0.14.9
hooks: hooks:
- id: ruff - id: ruff-check
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format

View File

@@ -37,24 +37,23 @@ from compose_farm.console import (
) )
from compose_farm.executor import ( from compose_farm.executor import (
CommandResult, CommandResult,
get_running_stacks_on_host,
is_local, is_local,
run_command, run_command,
) )
from compose_farm.logs import ( from compose_farm.logs import (
DEFAULT_LOG_PATH, DEFAULT_LOG_PATH,
SnapshotEntry, SnapshotEntry,
collect_stack_entries, collect_stacks_entries_on_host,
isoformat, isoformat,
load_existing_entries, load_existing_entries,
merge_entries, merge_entries,
write_toml, write_toml,
) )
from compose_farm.operations import ( from compose_farm.operations import (
StackDiscoveryResult, build_discovery_results,
check_host_compatibility, check_host_compatibility,
check_stack_requirements, check_stack_requirements,
discover_all_stacks_on_all_hosts,
discover_stack_host,
) )
from compose_farm.state import get_orphaned_stacks, load_state, save_state from compose_farm.state import get_orphaned_stacks, load_state, save_state
from compose_farm.traefik import generate_traefik_config, render_traefik_config from compose_farm.traefik import generate_traefik_config, render_traefik_config
@@ -62,38 +61,39 @@ from compose_farm.traefik import generate_traefik_config, render_traefik_config
# --- Sync helpers --- # --- Sync helpers ---
def _discover_stacks(cfg: Config, stacks: list[str] | None = None) -> dict[str, str | list[str]]:
"""Discover running stacks with a progress bar."""
stack_list = stacks if stacks is not None else list(cfg.stacks)
results = run_parallel_with_progress(
"Discovering",
stack_list,
lambda s: discover_stack_host(cfg, s),
)
return {svc: host for svc, host in results if host is not None}
def _snapshot_stacks( def _snapshot_stacks(
cfg: Config, cfg: Config,
stacks: list[str], discovered: dict[str, str | list[str]],
log_path: Path | None, log_path: Path | None,
) -> Path: ) -> Path:
"""Capture image digests with a progress bar.""" """Capture image digests using batched SSH calls (1 per host).
Args:
cfg: Configuration
discovered: Dict mapping stack -> host(s) where it's running
log_path: Optional path to write the log file
Returns:
Path to the written log file.
"""
effective_log_path = log_path or DEFAULT_LOG_PATH effective_log_path = log_path or DEFAULT_LOG_PATH
now_dt = datetime.now(UTC) now_dt = datetime.now(UTC)
now_iso = isoformat(now_dt) now_iso = isoformat(now_dt)
async def collect_stack(stack: str) -> tuple[str, list[SnapshotEntry]]: # Group stacks by host for batched SSH calls
try: stacks_by_host: dict[str, set[str]] = {}
return stack, await collect_stack_entries(cfg, stack, now=now_dt) for stack, hosts in discovered.items():
except RuntimeError: # Use first host for multi-host stacks (they use the same images)
return stack, [] host = hosts[0] if isinstance(hosts, list) else hosts
stacks_by_host.setdefault(host, set()).add(stack)
results = run_parallel_with_progress( # Collect entries with 1 SSH call per host (with progress bar)
"Capturing", async def collect_on_host(host: str) -> tuple[str, list[SnapshotEntry]]:
stacks, entries = await collect_stacks_entries_on_host(cfg, host, stacks_by_host[host], now=now_dt)
collect_stack, return host, entries
)
results = run_parallel_with_progress("Capturing", list(stacks_by_host.keys()), collect_on_host)
snapshot_entries = [entry for _, entries in results for entry in entries] snapshot_entries = [entry for _, entries in results for entry in entries]
if not snapshot_entries: if not snapshot_entries:
@@ -155,39 +155,20 @@ def _discover_stacks_full(
) -> tuple[dict[str, str | list[str]], dict[str, list[str]], dict[str, list[str]]]: ) -> tuple[dict[str, str | list[str]], dict[str, list[str]], dict[str, list[str]]]:
"""Discover running stacks with full host scanning for stray detection. """Discover running stacks with full host scanning for stray detection.
Uses an optimized approach that queries each host once for all running stacks, Queries each host once for all running stacks (with progress bar),
instead of checking each stack on each host individually. This reduces SSH then delegates to build_discovery_results for categorization.
calls from (stacks * hosts) to just (hosts).
Returns:
Tuple of (discovered, strays, duplicates):
- discovered: stack -> host(s) where running correctly
- strays: stack -> list of unauthorized hosts
- duplicates: stack -> list of all hosts (for single-host stacks on multiple)
""" """
# Use the efficient batch discovery (1 SSH call per host instead of per stack) all_hosts = list(cfg.hosts.keys())
results: list[StackDiscoveryResult] = asyncio.run(discover_all_stacks_on_all_hosts(cfg, stacks))
discovered: dict[str, str | list[str]] = {} # Query each host for running stacks (with progress bar)
strays: dict[str, list[str]] = {} async def get_stacks_on_host(host: str) -> tuple[str, set[str]]:
duplicates: dict[str, list[str]] = {} running = await get_running_stacks_on_host(cfg, host)
return host, running
for result in results: host_results = run_parallel_with_progress("Discovering", all_hosts, get_stacks_on_host)
correct_hosts = [h for h in result.running_hosts if h in result.configured_hosts] running_on_host: dict[str, set[str]] = dict(host_results)
if correct_hosts:
if result.is_multi_host:
discovered[result.stack] = correct_hosts
else:
discovered[result.stack] = correct_hosts[0]
if result.is_stray: return build_discovery_results(cfg, running_on_host, stacks)
strays[result.stack] = result.stray_hosts
if result.is_duplicate:
duplicates[result.stack] = result.running_hosts
return discovered, strays, duplicates
def _report_stray_stacks( def _report_stray_stacks(
@@ -554,10 +535,10 @@ def refresh(
save_state(cfg, new_state) save_state(cfg, new_state)
print_success(f"State updated: {len(new_state)} stacks tracked.") print_success(f"State updated: {len(new_state)} stacks tracked.")
# Capture image digests for running stacks # Capture image digests for running stacks (1 SSH call per host)
if discovered: if discovered:
try: try:
path = _snapshot_stacks(cfg, list(discovered.keys()), log_path) path = _snapshot_stacks(cfg, discovered, log_path)
print_success(f"Digests written to {path}") print_success(f"Digests written to {path}")
except RuntimeError as exc: except RuntimeError as exc:
print_warning(str(exc)) print_warning(str(exc))

View File

@@ -6,21 +6,22 @@ import json
import tomllib import tomllib
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from .executor import run_compose from .executor import run_command
from .paths import xdg_config_home from .paths import xdg_config_home
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from .config import Config from .config import Config
from .executor import CommandResult
# Separator used to split output sections
_SECTION_SEPARATOR = "---CF-SEP---"
DEFAULT_LOG_PATH = xdg_config_home() / "compose-farm" / "dockerfarm-log.toml" DEFAULT_LOG_PATH = xdg_config_home() / "compose-farm" / "dockerfarm-log.toml"
_DIGEST_HEX_LENGTH = 64
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -56,87 +57,97 @@ def _escape(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"') return value.replace("\\", "\\\\").replace('"', '\\"')
def _parse_images_output(raw: str) -> list[dict[str, Any]]: def _parse_image_digests(image_json: str) -> dict[str, str]:
"""Parse `docker compose images --format json` output. """Parse docker image inspect JSON to build image tag -> digest map."""
if not image_json:
Handles both a JSON array and newline-separated JSON objects for robustness. return {}
"""
raw = raw.strip()
if not raw:
return []
try: try:
parsed = json.loads(raw) image_data = json.loads(image_json)
except json.JSONDecodeError: except json.JSONDecodeError:
objects = [] return {}
for line in raw.splitlines():
if not line.strip():
continue
objects.append(json.loads(line))
return objects
if isinstance(parsed, list): image_digests: dict[str, str] = {}
return parsed for img in image_data:
if isinstance(parsed, dict): tags = img.get("RepoTags") or []
return [parsed] digests = img.get("RepoDigests") or []
return [] digest = digests[0].split("@")[-1] if digests else img.get("Id", "")
for tag in tags:
image_digests[tag] = digest
if img.get("Id"):
image_digests[img["Id"]] = digest
return image_digests
def _extract_image_fields(record: dict[str, Any]) -> tuple[str, str]: async def collect_stacks_entries_on_host(
"""Extract image name and digest with fallbacks."""
image = record.get("Image") or record.get("Repository") or record.get("Name") or ""
tag = record.get("Tag") or record.get("Version")
if tag and ":" not in image.rsplit("/", 1)[-1]:
image = f"{image}:{tag}"
digest = (
record.get("Digest")
or record.get("Image ID")
or record.get("ImageID")
or record.get("ID")
or ""
)
if digest and not digest.startswith("sha256:") and len(digest) == _DIGEST_HEX_LENGTH:
digest = f"sha256:{digest}"
return image, digest
async def collect_stack_entries(
config: Config, config: Config,
stack: str, host_name: str,
stacks: set[str],
*, *,
now: datetime, now: datetime,
run_compose_fn: Callable[..., Awaitable[CommandResult]] = run_compose,
) -> list[SnapshotEntry]: ) -> list[SnapshotEntry]:
"""Run `docker compose images` for a stack and normalize results.""" """Collect image entries for stacks on one host using 2 docker commands.
result = await run_compose_fn(config, stack, "images --format json", stream=False)
Uses `docker ps` to get running containers + their compose project labels,
then `docker image inspect` to get digests for all unique images.
Much faster than running N `docker compose images` commands.
"""
if not stacks:
return []
host = config.hosts[host_name]
# Single SSH call with 2 docker commands:
# 1. Get project|image pairs from running containers
# 2. Get image info (including digests) for all unique images
command = (
f"docker ps --format '{{{{.Label \"com.docker.compose.project\"}}}}|{{{{.Image}}}}' && "
f"echo '{_SECTION_SEPARATOR}' && "
"docker image inspect $(docker ps --format '{{.Image}}' | sort -u) 2>/dev/null || true"
)
result = await run_command(host, command, host_name, stream=False, prefix="")
if not result.success: if not result.success:
msg = result.stderr or f"compose images exited with {result.exit_code}" return []
error = f"[{stack}] Unable to read images: {msg}"
raise RuntimeError(error)
records = _parse_images_output(result.stdout) # Split output into two sections
# Use first host for snapshots (multi-host stacks use same images on all hosts) parts = result.stdout.split(_SECTION_SEPARATOR)
host_name = config.get_hosts(stack)[0] if len(parts) != 2: # noqa: PLR2004
compose_path = config.get_compose_path(stack) return []
entries: list[SnapshotEntry] = [] container_lines, image_json = parts[0].strip(), parts[1].strip()
for record in records:
image, digest = _extract_image_fields(record) # Parse project|image pairs, filtering to only stacks we care about
if not digest: stack_images: dict[str, set[str]] = {}
for line in container_lines.splitlines():
if "|" not in line:
continue continue
project, image = line.split("|", 1)
if project in stacks:
stack_images.setdefault(project, set()).add(image)
if not stack_images:
return []
# Parse image inspect JSON to build image -> digest map
image_digests = _parse_image_digests(image_json)
# Build entries
entries: list[SnapshotEntry] = []
for stack, images in stack_images.items():
for image in images:
digest = image_digests.get(image, "")
if digest:
entries.append( entries.append(
SnapshotEntry( SnapshotEntry(
stack=stack, stack=stack,
host=host_name, host=host_name,
compose_file=compose_path, compose_file=config.get_compose_path(stack),
image=image, image=image,
digest=digest, digest=digest,
captured_at=now, captured_at=now,
) )
) )
return entries return entries

View File

@@ -16,7 +16,6 @@ from .executor import (
check_networks_exist, check_networks_exist,
check_paths_exist, check_paths_exist,
check_stack_running, check_stack_running,
get_running_stacks_on_host,
run_command, run_command,
run_compose, run_compose,
run_compose_on_host, run_compose_on_host,
@@ -77,31 +76,6 @@ def get_stack_paths(cfg: Config, stack: str) -> list[str]:
return paths return paths
async def discover_stack_host(cfg: Config, stack: str) -> tuple[str, str | list[str] | None]:
"""Discover where a stack is running.
For multi-host stacks, checks all assigned hosts in parallel.
For single-host, checks assigned host first, then others.
Returns (stack_name, host_or_hosts_or_none).
"""
assigned_hosts = cfg.get_hosts(stack)
if cfg.is_multi_host(stack):
# Check all assigned hosts in parallel
checks = await asyncio.gather(*[check_stack_running(cfg, stack, h) for h in assigned_hosts])
running = [h for h, is_running in zip(assigned_hosts, checks, strict=True) if is_running]
return stack, running if running else None
# Single-host: check assigned host first, then others
if await check_stack_running(cfg, stack, assigned_hosts[0]):
return stack, assigned_hosts[0]
for host in cfg.hosts:
if host != assigned_hosts[0] and await check_stack_running(cfg, stack, host):
return stack, host
return stack, None
class StackDiscoveryResult(NamedTuple): class StackDiscoveryResult(NamedTuple):
"""Result of discovering where a stack is running across all hosts.""" """Result of discovering where a stack is running across all hosts."""
@@ -135,50 +109,6 @@ class StackDiscoveryResult(NamedTuple):
return not self.is_multi_host and len(self.running_hosts) > 1 return not self.is_multi_host and len(self.running_hosts) > 1
async def discover_all_stacks_on_all_hosts(
cfg: Config,
stacks: list[str] | None = None,
) -> list[StackDiscoveryResult]:
"""Discover where stacks are running with minimal SSH calls.
Instead of checking each stack on each host individually (stacks * hosts calls),
this queries each host once for all running stacks (hosts calls total).
Args:
cfg: Configuration
stacks: Optional list of stacks to check. If None, checks all stacks in config.
Returns:
List of StackDiscoveryResult for each stack.
"""
stack_list = stacks if stacks is not None else list(cfg.stacks)
all_hosts = list(cfg.hosts.keys())
# Query each host once to get all running stacks (N SSH calls, where N = number of hosts)
host_stacks = await asyncio.gather(
*[get_running_stacks_on_host(cfg, host) for host in all_hosts]
)
# Build a map of host -> running stacks
running_on_host: dict[str, set[str]] = dict(zip(all_hosts, host_stacks, strict=True))
# Build results for each stack
results = []
for stack in stack_list:
configured_hosts = cfg.get_hosts(stack)
running_hosts = [host for host in all_hosts if stack in running_on_host[host]]
results.append(
StackDiscoveryResult(
stack=stack,
configured_hosts=configured_hosts,
running_hosts=running_hosts,
)
)
return results
async def check_stack_requirements( async def check_stack_requirements(
cfg: Config, cfg: Config,
stack: str, stack: str,
@@ -544,3 +474,60 @@ async def stop_stray_stacks(
""" """
return await _stop_stacks_on_hosts(cfg, strays, label="stray") return await _stop_stacks_on_hosts(cfg, strays, label="stray")
def build_discovery_results(
cfg: Config,
running_on_host: dict[str, set[str]],
stacks: list[str] | None = None,
) -> tuple[dict[str, str | list[str]], dict[str, list[str]], dict[str, list[str]]]:
"""Build discovery results from per-host running stacks.
Takes the raw data of which stacks are running on which hosts and
categorizes them into discovered (running correctly), strays (wrong host),
and duplicates (single-host stack on multiple hosts).
Args:
cfg: Config object.
running_on_host: Dict mapping host -> set of running stack names.
stacks: Optional list of stacks to check. Defaults to all configured stacks.
Returns:
Tuple of (discovered, strays, duplicates):
- discovered: stack -> host(s) where running correctly
- strays: stack -> list of unauthorized hosts
- duplicates: stack -> list of all hosts (for single-host stacks on multiple)
"""
stack_list = stacks if stacks is not None else list(cfg.stacks)
all_hosts = list(running_on_host.keys())
# Build StackDiscoveryResult for each stack
results: list[StackDiscoveryResult] = [
StackDiscoveryResult(
stack=stack,
configured_hosts=cfg.get_hosts(stack),
running_hosts=[h for h in all_hosts if stack in running_on_host[h]],
)
for stack in stack_list
]
discovered: dict[str, str | list[str]] = {}
strays: dict[str, list[str]] = {}
duplicates: dict[str, list[str]] = {}
for result in results:
correct_hosts = [h for h in result.running_hosts if h in result.configured_hosts]
if correct_hosts:
if result.is_multi_host:
discovered[result.stack] = correct_hosts
else:
discovered[result.stack] = correct_hosts[0]
if result.is_stray:
strays[result.stack] = result.stray_hosts
if result.is_duplicate:
duplicates[result.stack] = result.running_hosts
return discovered, strays, duplicates

View File

@@ -10,8 +10,8 @@ import pytest
from compose_farm.config import Config, Host from compose_farm.config import Config, Host
from compose_farm.executor import CommandResult from compose_farm.executor import CommandResult
from compose_farm.logs import ( from compose_farm.logs import (
_parse_images_output, _SECTION_SEPARATOR,
collect_stack_entries, collect_stacks_entries_on_host,
isoformat, isoformat,
load_existing_entries, load_existing_entries,
merge_entries, merge_entries,
@@ -19,53 +19,231 @@ from compose_farm.logs import (
) )
def test_parse_images_output_handles_list_and_lines() -> None: def _make_mock_output(
data = [ project_images: dict[str, list[str]], image_info: list[dict[str, object]]
{"Service": "svc", "Image": "redis", "Digest": "sha256:abc"}, ) -> str:
{"Service": "svc", "Image": "db", "Digest": "sha256:def"}, """Build mock output matching the 2-docker-command format."""
# Section 1: project|image pairs from docker ps
ps_lines = [
f"{project}|{image}" for project, images in project_images.items() for image in images
] ]
as_array = _parse_images_output(json.dumps(data))
assert len(as_array) == 2
as_lines = _parse_images_output("\n".join(json.dumps(item) for item in data)) # Section 2: JSON array from docker image inspect
assert len(as_lines) == 2 image_json = json.dumps(image_info)
return f"{chr(10).join(ps_lines)}\n{_SECTION_SEPARATOR}\n{image_json}"
@pytest.mark.asyncio class TestCollectStacksEntriesOnHost:
async def test_snapshot_preserves_first_seen(tmp_path: Path) -> None: """Tests for collect_stacks_entries_on_host (2 docker commands per host)."""
@pytest.fixture
def config_with_stacks(self, tmp_path: Path) -> Config:
"""Create a config with multiple stacks."""
compose_dir = tmp_path / "compose"
compose_dir.mkdir()
for stack in ["plex", "jellyfin", "sonarr"]:
stack_dir = compose_dir / stack
stack_dir.mkdir()
(stack_dir / "docker-compose.yml").write_text("services: {}\n")
return Config(
compose_dir=compose_dir,
hosts={"host1": Host(address="localhost"), "host2": Host(address="localhost")},
stacks={"plex": "host1", "jellyfin": "host1", "sonarr": "host2"},
)
@pytest.mark.asyncio
async def test_single_ssh_call(
self, config_with_stacks: Config, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Verify only 1 SSH call is made regardless of stack count."""
call_count = {"count": 0}
async def mock_run_command(
host: Host, command: str, stack: str, *, stream: bool, prefix: str
) -> CommandResult:
call_count["count"] += 1
output = _make_mock_output(
{"plex": ["plex:latest"], "jellyfin": ["jellyfin:latest"]},
[
{
"RepoTags": ["plex:latest"],
"Id": "sha256:aaa",
"RepoDigests": ["plex@sha256:aaa"],
},
{
"RepoTags": ["jellyfin:latest"],
"Id": "sha256:bbb",
"RepoDigests": ["jellyfin@sha256:bbb"],
},
],
)
return CommandResult(stack=stack, exit_code=0, success=True, stdout=output)
monkeypatch.setattr("compose_farm.logs.run_command", mock_run_command)
now = datetime(2025, 1, 1, tzinfo=UTC)
entries = await collect_stacks_entries_on_host(
config_with_stacks, "host1", {"plex", "jellyfin"}, now=now
)
assert call_count["count"] == 1
assert len(entries) == 2
@pytest.mark.asyncio
async def test_filters_to_requested_stacks(
self, config_with_stacks: Config, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Only return entries for stacks we asked for, even if others are running."""
async def mock_run_command(
host: Host, command: str, stack: str, *, stream: bool, prefix: str
) -> CommandResult:
# Docker ps shows 3 stacks, but we only want plex
output = _make_mock_output(
{
"plex": ["plex:latest"],
"jellyfin": ["jellyfin:latest"],
"other": ["other:latest"],
},
[
{
"RepoTags": ["plex:latest"],
"Id": "sha256:aaa",
"RepoDigests": ["plex@sha256:aaa"],
},
{
"RepoTags": ["jellyfin:latest"],
"Id": "sha256:bbb",
"RepoDigests": ["j@sha256:bbb"],
},
{
"RepoTags": ["other:latest"],
"Id": "sha256:ccc",
"RepoDigests": ["o@sha256:ccc"],
},
],
)
return CommandResult(stack=stack, exit_code=0, success=True, stdout=output)
monkeypatch.setattr("compose_farm.logs.run_command", mock_run_command)
now = datetime(2025, 1, 1, tzinfo=UTC)
entries = await collect_stacks_entries_on_host(
config_with_stacks, "host1", {"plex"}, now=now
)
assert len(entries) == 1
assert entries[0].stack == "plex"
@pytest.mark.asyncio
async def test_multiple_images_per_stack(
self, config_with_stacks: Config, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Stack with multiple containers/images returns multiple entries."""
async def mock_run_command(
host: Host, command: str, stack: str, *, stream: bool, prefix: str
) -> CommandResult:
output = _make_mock_output(
{"plex": ["plex:latest", "redis:7"]},
[
{
"RepoTags": ["plex:latest"],
"Id": "sha256:aaa",
"RepoDigests": ["p@sha256:aaa"],
},
{"RepoTags": ["redis:7"], "Id": "sha256:bbb", "RepoDigests": ["r@sha256:bbb"]},
],
)
return CommandResult(stack=stack, exit_code=0, success=True, stdout=output)
monkeypatch.setattr("compose_farm.logs.run_command", mock_run_command)
now = datetime(2025, 1, 1, tzinfo=UTC)
entries = await collect_stacks_entries_on_host(
config_with_stacks, "host1", {"plex"}, now=now
)
assert len(entries) == 2
images = {e.image for e in entries}
assert images == {"plex:latest", "redis:7"}
@pytest.mark.asyncio
async def test_empty_stacks_returns_empty(self, config_with_stacks: Config) -> None:
"""Empty stack set returns empty entries without making SSH call."""
now = datetime(2025, 1, 1, tzinfo=UTC)
entries = await collect_stacks_entries_on_host(config_with_stacks, "host1", set(), now=now)
assert entries == []
@pytest.mark.asyncio
async def test_ssh_failure_returns_empty(
self, config_with_stacks: Config, monkeypatch: pytest.MonkeyPatch
) -> None:
"""SSH failure returns empty list instead of raising."""
async def mock_run_command(
host: Host, command: str, stack: str, *, stream: bool, prefix: str
) -> CommandResult:
return CommandResult(stack=stack, exit_code=1, success=False, stdout="", stderr="error")
monkeypatch.setattr("compose_farm.logs.run_command", mock_run_command)
now = datetime(2025, 1, 1, tzinfo=UTC)
entries = await collect_stacks_entries_on_host(
config_with_stacks, "host1", {"plex"}, now=now
)
assert entries == []
class TestSnapshotMerging:
"""Tests for merge_entries preserving first_seen."""
@pytest.fixture
def config(self, tmp_path: Path) -> Config:
compose_dir = tmp_path / "compose" compose_dir = tmp_path / "compose"
compose_dir.mkdir() compose_dir.mkdir()
stack_dir = compose_dir / "svc" stack_dir = compose_dir / "svc"
stack_dir.mkdir() stack_dir.mkdir()
(stack_dir / "docker-compose.yml").write_text("services: {}\n") (stack_dir / "docker-compose.yml").write_text("services: {}\n")
config = Config( return Config(
compose_dir=compose_dir, compose_dir=compose_dir,
hosts={"local": Host(address="localhost")}, hosts={"local": Host(address="localhost")},
stacks={"svc": "local"}, stacks={"svc": "local"},
) )
sample_output = json.dumps([{"Service": "svc", "Image": "redis", "Digest": "sha256:abc"}]) @pytest.mark.asyncio
async def test_preserves_first_seen(
self, tmp_path: Path, config: Config, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Repeated snapshots preserve first_seen timestamp."""
async def fake_run_compose( async def mock_run_command(
_cfg: Config, stack: str, compose_cmd: str, *, stream: bool = True host: Host, command: str, stack: str, *, stream: bool, prefix: str
) -> CommandResult: ) -> CommandResult:
assert compose_cmd == "images --format json" output = _make_mock_output(
assert stream is False or stream is True {"svc": ["redis:latest"]},
return CommandResult( [
stack=stack, {
exit_code=0, "RepoTags": ["redis:latest"],
success=True, "Id": "sha256:abc",
stdout=sample_output, "RepoDigests": ["r@sha256:abc"],
stderr="", }
],
) )
return CommandResult(stack=stack, exit_code=0, success=True, stdout=output)
monkeypatch.setattr("compose_farm.logs.run_command", mock_run_command)
log_path = tmp_path / "dockerfarm-log.toml" log_path = tmp_path / "dockerfarm-log.toml"
# First snapshot # First snapshot
first_time = datetime(2025, 1, 1, tzinfo=UTC) first_time = datetime(2025, 1, 1, tzinfo=UTC)
first_entries = await collect_stack_entries( first_entries = await collect_stacks_entries_on_host(
config, "svc", now=first_time, run_compose_fn=fake_run_compose config, "local", {"svc"}, now=first_time
) )
first_iso = isoformat(first_time) first_iso = isoformat(first_time)
merged = merge_entries([], first_entries, now_iso=first_iso) merged = merge_entries([], first_entries, now_iso=first_iso)
@@ -77,8 +255,8 @@ async def test_snapshot_preserves_first_seen(tmp_path: Path) -> None:
# Second snapshot # Second snapshot
second_time = datetime(2025, 2, 1, tzinfo=UTC) second_time = datetime(2025, 2, 1, tzinfo=UTC)
second_entries = await collect_stack_entries( second_entries = await collect_stacks_entries_on_host(
config, "svc", now=second_time, run_compose_fn=fake_run_compose config, "local", {"svc"}, now=second_time
) )
second_iso = isoformat(second_time) second_iso = isoformat(second_time)
existing = load_existing_entries(log_path) existing = load_existing_entries(log_path)

View File

@@ -12,9 +12,8 @@ from compose_farm.cli import lifecycle
from compose_farm.config import Config, Host from compose_farm.config import Config, Host
from compose_farm.executor import CommandResult from compose_farm.executor import CommandResult
from compose_farm.operations import ( from compose_farm.operations import (
StackDiscoveryResult,
_migrate_stack, _migrate_stack,
discover_all_stacks_on_all_hosts, build_discovery_results,
) )
@@ -115,63 +114,18 @@ class TestUpdateCommandSequence:
assert "up -d" in source assert "up -d" in source
class TestDiscoverAllStacksOnAllHosts: class TestBuildDiscoveryResults:
"""Tests for discover_all_stacks_on_all_hosts function.""" """Tests for build_discovery_results function."""
async def test_returns_discovery_results_for_all_stacks(self, basic_config: Config) -> None: @pytest.fixture
"""Function returns StackDiscoveryResult for each stack.""" def config(self, tmp_path: Path) -> Config:
with patch( """Create a test config with multiple stacks."""
"compose_farm.operations.get_running_stacks_on_host",
return_value={"test-service"},
):
results = await discover_all_stacks_on_all_hosts(basic_config)
assert len(results) == 1
assert isinstance(results[0], StackDiscoveryResult)
assert results[0].stack == "test-service"
async def test_detects_stray_stacks(self, tmp_path: Path) -> None:
"""Function detects stacks running on wrong hosts."""
compose_dir = tmp_path / "compose"
(compose_dir / "plex").mkdir(parents=True)
(compose_dir / "plex" / "docker-compose.yml").write_text("services: {}")
config = Config(
compose_dir=compose_dir,
hosts={
"host1": Host(address="localhost"),
"host2": Host(address="localhost"),
},
stacks={"plex": "host1"}, # Should run on host1
)
# Mock: plex is running on host2 (wrong host)
async def mock_get_running(cfg: Config, host: str) -> set[str]:
if host == "host2":
return {"plex"}
return set()
with patch(
"compose_farm.operations.get_running_stacks_on_host",
side_effect=mock_get_running,
):
results = await discover_all_stacks_on_all_hosts(config)
assert len(results) == 1
assert results[0].stack == "plex"
assert results[0].running_hosts == ["host2"]
assert results[0].configured_hosts == ["host1"]
assert results[0].is_stray is True
assert results[0].stray_hosts == ["host2"]
async def test_queries_each_host_once(self, tmp_path: Path) -> None:
"""Function makes exactly one call per host, not per stack."""
compose_dir = tmp_path / "compose" compose_dir = tmp_path / "compose"
for stack in ["plex", "jellyfin", "sonarr"]: for stack in ["plex", "jellyfin", "sonarr"]:
(compose_dir / stack).mkdir(parents=True) (compose_dir / stack).mkdir(parents=True)
(compose_dir / stack / "docker-compose.yml").write_text("services: {}") (compose_dir / stack / "docker-compose.yml").write_text("services: {}")
config = Config( return Config(
compose_dir=compose_dir, compose_dir=compose_dir,
hosts={ hosts={
"host1": Host(address="localhost"), "host1": Host(address="localhost"),
@@ -180,17 +134,61 @@ class TestDiscoverAllStacksOnAllHosts:
stacks={"plex": "host1", "jellyfin": "host1", "sonarr": "host2"}, stacks={"plex": "host1", "jellyfin": "host1", "sonarr": "host2"},
) )
call_count = {"count": 0} def test_discovers_correctly_running_stacks(self, config: Config) -> None:
"""Stacks running on correct hosts are discovered."""
running_on_host = {
"host1": {"plex", "jellyfin"},
"host2": {"sonarr"},
}
async def mock_get_running(cfg: Config, host: str) -> set[str]: discovered, strays, duplicates = build_discovery_results(config, running_on_host)
call_count["count"] += 1
return set()
with patch( assert discovered == {"plex": "host1", "jellyfin": "host1", "sonarr": "host2"}
"compose_farm.operations.get_running_stacks_on_host", assert strays == {}
side_effect=mock_get_running, assert duplicates == {}
):
await discover_all_stacks_on_all_hosts(config)
# Should call once per host (2), not once per stack (3) def test_detects_stray_stacks(self, config: Config) -> None:
assert call_count["count"] == 2 """Stacks running on wrong hosts are marked as strays."""
running_on_host = {
"host1": set(),
"host2": {"plex"}, # plex should be on host1
}
discovered, strays, _duplicates = build_discovery_results(config, running_on_host)
assert "plex" not in discovered
assert strays == {"plex": ["host2"]}
def test_detects_duplicates(self, config: Config) -> None:
"""Single-host stacks running on multiple hosts are duplicates."""
running_on_host = {
"host1": {"plex"},
"host2": {"plex"}, # plex running on both hosts
}
discovered, strays, duplicates = build_discovery_results(
config, running_on_host, stacks=["plex"]
)
# plex is correctly running on host1
assert discovered == {"plex": "host1"}
# plex is also a stray on host2
assert strays == {"plex": ["host2"]}
# plex is a duplicate (single-host stack on multiple hosts)
assert duplicates == {"plex": ["host1", "host2"]}
def test_filters_to_requested_stacks(self, config: Config) -> None:
"""Only returns results for requested stacks."""
running_on_host = {
"host1": {"plex", "jellyfin"},
"host2": {"sonarr"},
}
discovered, _strays, _duplicates = build_discovery_results(
config, running_on_host, stacks=["plex"]
)
# Only plex should be in results
assert discovered == {"plex": "host1"}
assert "jellyfin" not in discovered
assert "sonarr" not in discovered