diff --git a/src/compose_farm/cli/management.py b/src/compose_farm/cli/management.py index 9d1b9ee..e4eaf04 100644 --- a/src/compose_farm/cli/management.py +++ b/src/compose_farm/cli/management.py @@ -53,8 +53,8 @@ from compose_farm.operations import ( StackDiscoveryResult, check_host_compatibility, check_stack_requirements, + discover_all_stacks_on_all_hosts, discover_stack_host, - discover_stack_on_all_hosts, ) from compose_farm.state import get_orphaned_stacks, load_state, save_state from compose_farm.traefik import generate_traefik_config, render_traefik_config @@ -155,6 +155,10 @@ def _discover_stacks_full( ) -> tuple[dict[str, str | list[str]], dict[str, list[str]], dict[str, list[str]]]: """Discover running stacks with full host scanning for stray detection. + Uses an optimized approach that queries each host once for all running stacks, + instead of checking each stack on each host individually. This reduces SSH + calls from (stacks * hosts) to just (hosts). + Returns: Tuple of (discovered, strays, duplicates): - discovered: stack -> host(s) where running correctly @@ -162,12 +166,8 @@ def _discover_stacks_full( - duplicates: stack -> list of all hosts (for single-host stacks on multiple) """ - stack_list = stacks if stacks is not None else list(cfg.stacks) - results: list[StackDiscoveryResult] = run_parallel_with_progress( - "Discovering", - stack_list, - lambda s: discover_stack_on_all_hosts(cfg, s), - ) + # Use the efficient batch discovery (1 SSH call per host instead of per stack) + results: list[StackDiscoveryResult] = asyncio.run(discover_all_stacks_on_all_hosts(cfg, stacks)) discovered: dict[str, str | list[str]] = {} strays: dict[str, list[str]] = {} diff --git a/src/compose_farm/executor.py b/src/compose_farm/executor.py index 4b25da6..feeffc5 100644 --- a/src/compose_farm/executor.py +++ b/src/compose_farm/executor.py @@ -497,6 +497,28 @@ async def check_stack_running( return result.success and bool(result.stdout.strip()) +async def get_running_stacks_on_host( + config: Config, + host_name: str, +) -> set[str]: + """Get all running compose stacks on a host in a single SSH call. + + Uses docker ps with the compose.project label to identify running stacks. + Much more efficient than checking each stack individually. + """ + host = config.hosts[host_name] + + # Get unique project names from running containers + command = "docker ps --format '{{.Label \"com.docker.compose.project\"}}' | sort -u" + result = await run_command(host, command, stack=host_name, stream=False, prefix="") + + if not result.success: + return set() + + # Filter out empty lines and return as set + return {line.strip() for line in result.stdout.splitlines() if line.strip()} + + async def _batch_check_existence( config: Config, host_name: str, diff --git a/src/compose_farm/operations.py b/src/compose_farm/operations.py index 4856427..1b0ff91 100644 --- a/src/compose_farm/operations.py +++ b/src/compose_farm/operations.py @@ -16,6 +16,7 @@ from .executor import ( check_networks_exist, check_paths_exist, check_stack_running, + get_running_stacks_on_host, run_command, run_compose, run_compose_on_host, @@ -134,24 +135,49 @@ class StackDiscoveryResult(NamedTuple): return not self.is_multi_host and len(self.running_hosts) > 1 -async def discover_stack_on_all_hosts(cfg: Config, stack: str) -> StackDiscoveryResult: - """Discover where a stack is running across ALL hosts. +async def discover_all_stacks_on_all_hosts( + cfg: Config, + stacks: list[str] | None = None, +) -> list[StackDiscoveryResult]: + """Discover where stacks are running with minimal SSH calls. + + Instead of checking each stack on each host individually (stacks * hosts calls), + this queries each host once for all running stacks (hosts calls total). + + Args: + cfg: Configuration + stacks: Optional list of stacks to check. If None, checks all stacks in config. + + Returns: + List of StackDiscoveryResult for each stack. - Unlike discover_stack_host(), this checks every host in parallel - to detect strays and duplicates. """ - configured_hosts = cfg.get_hosts(stack) + stack_list = stacks if stacks is not None else list(cfg.stacks) all_hosts = list(cfg.hosts.keys()) - checks = await asyncio.gather(*[check_stack_running(cfg, stack, h) for h in all_hosts]) - running_hosts = [h for h, is_running in zip(all_hosts, checks, strict=True) if is_running] - - return StackDiscoveryResult( - stack=stack, - configured_hosts=configured_hosts, - running_hosts=running_hosts, + # Query each host once to get all running stacks (N SSH calls, where N = number of hosts) + host_stacks = await asyncio.gather( + *[get_running_stacks_on_host(cfg, host) for host in all_hosts] ) + # Build a map of host -> running stacks + running_on_host: dict[str, set[str]] = dict(zip(all_hosts, host_stacks, strict=True)) + + # Build results for each stack + results = [] + for stack in stack_list: + configured_hosts = cfg.get_hosts(stack) + running_hosts = [host for host in all_hosts if stack in running_on_host[host]] + results.append( + StackDiscoveryResult( + stack=stack, + configured_hosts=configured_hosts, + running_hosts=running_hosts, + ) + ) + + return results + async def check_stack_requirements( cfg: Config, diff --git a/tests/test_executor.py b/tests/test_executor.py index f0f53bd..c638d3a 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -11,6 +11,7 @@ from compose_farm.executor import ( _run_local_command, check_networks_exist, check_paths_exist, + get_running_stacks_on_host, is_local, run_command, run_compose, @@ -239,3 +240,31 @@ class TestCheckNetworksExist: result = await check_networks_exist(config, "local", []) assert result == {} + + +@linux_only +class TestGetRunningStacksOnHost: + """Tests for get_running_stacks_on_host function (requires Docker).""" + + async def test_returns_set_of_stacks(self, tmp_path: Path) -> None: + """Function returns a set of stack names.""" + config = Config( + compose_dir=tmp_path, + hosts={"local": Host(address="localhost")}, + stacks={}, + ) + + result = await get_running_stacks_on_host(config, "local") + assert isinstance(result, set) + + async def test_filters_empty_lines(self, tmp_path: Path) -> None: + """Empty project names are filtered out.""" + config = Config( + compose_dir=tmp_path, + hosts={"local": Host(address="localhost")}, + stacks={}, + ) + + # Result should not contain empty strings + result = await get_running_stacks_on_host(config, "local") + assert "" not in result diff --git a/tests/test_operations.py b/tests/test_operations.py index f7b6150..d2da95e 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -11,7 +11,11 @@ import pytest from compose_farm.cli import lifecycle from compose_farm.config import Config, Host from compose_farm.executor import CommandResult -from compose_farm.operations import _migrate_stack +from compose_farm.operations import ( + StackDiscoveryResult, + _migrate_stack, + discover_all_stacks_on_all_hosts, +) @pytest.fixture @@ -109,3 +113,84 @@ class TestUpdateCommandSequence: # Verify the sequence is pull, build, down, up assert "down" in source assert "up -d" in source + + +class TestDiscoverAllStacksOnAllHosts: + """Tests for discover_all_stacks_on_all_hosts function.""" + + async def test_returns_discovery_results_for_all_stacks(self, basic_config: Config) -> None: + """Function returns StackDiscoveryResult for each stack.""" + with patch( + "compose_farm.operations.get_running_stacks_on_host", + return_value={"test-service"}, + ): + results = await discover_all_stacks_on_all_hosts(basic_config) + + assert len(results) == 1 + assert isinstance(results[0], StackDiscoveryResult) + assert results[0].stack == "test-service" + + async def test_detects_stray_stacks(self, tmp_path: Path) -> None: + """Function detects stacks running on wrong hosts.""" + compose_dir = tmp_path / "compose" + (compose_dir / "plex").mkdir(parents=True) + (compose_dir / "plex" / "docker-compose.yml").write_text("services: {}") + + config = Config( + compose_dir=compose_dir, + hosts={ + "host1": Host(address="localhost"), + "host2": Host(address="localhost"), + }, + stacks={"plex": "host1"}, # Should run on host1 + ) + + # Mock: plex is running on host2 (wrong host) + async def mock_get_running(cfg: Config, host: str) -> set[str]: + if host == "host2": + return {"plex"} + return set() + + with patch( + "compose_farm.operations.get_running_stacks_on_host", + side_effect=mock_get_running, + ): + results = await discover_all_stacks_on_all_hosts(config) + + assert len(results) == 1 + assert results[0].stack == "plex" + assert results[0].running_hosts == ["host2"] + assert results[0].configured_hosts == ["host1"] + assert results[0].is_stray is True + assert results[0].stray_hosts == ["host2"] + + async def test_queries_each_host_once(self, tmp_path: Path) -> None: + """Function makes exactly one call per host, not per stack.""" + compose_dir = tmp_path / "compose" + for stack in ["plex", "jellyfin", "sonarr"]: + (compose_dir / stack).mkdir(parents=True) + (compose_dir / stack / "docker-compose.yml").write_text("services: {}") + + config = Config( + compose_dir=compose_dir, + hosts={ + "host1": Host(address="localhost"), + "host2": Host(address="localhost"), + }, + stacks={"plex": "host1", "jellyfin": "host1", "sonarr": "host2"}, + ) + + call_count = {"count": 0} + + async def mock_get_running(cfg: Config, host: str) -> set[str]: + call_count["count"] += 1 + return set() + + with patch( + "compose_farm.operations.get_running_stacks_on_host", + side_effect=mock_get_running, + ): + await discover_all_stacks_on_all_hosts(config) + + # Should call once per host (2), not once per stack (3) + assert call_count["count"] == 2