Compare commits

...

20 Commits

Author SHA1 Message Date
Bas Nijholt
5057202938 refactor: DRY cleanup and message consistency (#24) 2025-12-18 11:45:32 -08:00
Bas Nijholt
5e1b9987dd fix(web): Set PTY as controlling terminal for local shell sessions (#23)
Local shell sessions weren't receiving SIGINT (Ctrl+C) because the PTY
wasn't set as the controlling terminal. Add preexec_fn that calls
setsid() and TIOCSCTTY to properly set up the terminal.
2025-12-18 11:12:37 -08:00
Bas Nijholt
d9c26f7f2c Merge pull request #21 from basnijholt/refactor/dry-cleanup
refactor: DRY cleanup - consolidate duplicate code patterns
2025-12-18 11:12:24 -08:00
Bas Nijholt
adfcd4bb31 style: Capitalize "Hint:" consistently 2025-12-18 11:05:53 -08:00
Bas Nijholt
95f7d9c3cf style(cli): Unify "not found" message format with color highlighting
- Services use [cyan] highlighting consistently
- Hosts use [magenta] highlighting consistently
- All use the same "X not found in config" pattern
2025-12-18 11:05:05 -08:00
Bas Nijholt
4c1674cfd8 style(cli): Unify error message format with ✗ prefix
All CLI error messages now consistently use the [red]✗[/] prefix
pattern instead of wrapping the entire message in [red]...[/red].
2025-12-18 11:04:28 -08:00
Bas Nijholt
f65ca8420e fix(web): Filter empty hosts from services_by_host
Preserve original behavior where only hosts with running services are
shown in the dashboard, rather than all configured hosts.
2025-12-18 11:00:01 -08:00
Bas Nijholt
85aff2c271 refactor(state): Move group_services_by_host to state.py
Consolidate duplicate service grouping logic from monitoring.py and
pages.py into a shared function in state.py.
2025-12-18 10:55:53 -08:00
Bas Nijholt
61ca24bb8e refactor(cli): Remove unused get_description parameter
All callers used the same pattern (r[0]), so hardcode it in the helper
and remove the parameter entirely.
2025-12-18 10:54:12 -08:00
Bas Nijholt
ed36588358 refactor(cli): Add validate_host and validate_hosts helpers
Extract common host validation patterns into reusable helpers.
Also simplifies validate_host_for_service to use the new validate_host
helper internally.
2025-12-18 10:49:57 -08:00
Bas Nijholt
80c8079a8c refactor(executor): Add ssh_connect_kwargs helper
Extract common asyncssh.connect parameters into a reusable
ssh_connect_kwargs() function. Used by executor.py, api.py, and ws.py.

Lines: 2608 → 2601 (-7)
2025-12-18 10:48:29 -08:00
Bas Nijholt
763bedf9f6 refactor(cli): Extract config not found helpers
Consolidate repeated "config not found" and "path doesn't exist"
messages into _report_no_config_found() and _report_config_path_not_exists()
helper functions. Also unifies the UX to always show status of searched
paths.
2025-12-18 10:46:58 -08:00
Bas Nijholt
641f7e91a8 refactor(cli): Consolidate _report_*_errors() functions
Merge _report_mount_errors, _report_network_errors, and _report_device_errors
into a single _report_requirement_errors function that takes a category
parameter.

Lines: 2634 → 2608 (-26)
2025-12-18 10:43:49 -08:00
Bas Nijholt
4e8e925d59 refactor(cli): Add run_parallel_with_progress helper
Extract common async progress bar pattern into a reusable helper in
common.py. Updates _discover_services, _check_ssh_connectivity,
_check_service_requirements, _get_container_counts, and _snapshot_services
to use the new helper.

Lines: 2642 → 2634 (-8)
2025-12-18 10:42:45 -08:00
Bas Nijholt
d84858dcfb fix(docker): Add restart policy to web service (#19)
* fix(docker): Add restart policy to containers

* fix: Only add restart policy to web service
2025-12-18 10:39:09 -08:00
Bas Nijholt
3121ee04eb feat(web): Show ⌘K shortcut on command palette button (#20) 2025-12-18 10:38:57 -08:00
Bas Nijholt
a795132a04 refactor(cli): Move format_host to common.py
Consolidate duplicate _format_host() function from lifecycle.py and
management.py into a single format_host() function in common.py.

Lines: 2647 → 2642 (-5)
2025-12-18 10:38:52 -08:00
Bas Nijholt
a6e491575a feat(web): Add Console page with terminal and editor (#17) 2025-12-18 10:29:15 -08:00
Bas Nijholt
78bf90afd9 docs: Improve Releases section in CLAUDE.md 2025-12-18 10:04:56 -08:00
Bas Nijholt
76b60bdd96 feat(web): Add Console page with terminal and editor
Add a new Console page accessible from the sidebar that provides:
- Interactive terminal with full shell access to any configured host
- SSH agent forwarding for authentication to remote hosts
- Monaco editor for viewing/editing files on hosts
- Host selector dropdown with local host listed first
- Auto-loads compose-farm config file on page load

Changes:
- Add /console route and console.html template
- Add /ws/shell/{host} WebSocket endpoint for shell sessions
- Add /api/console/file GET/PUT endpoints for remote file operations
- Update sidebar to include Console navigation link
2025-12-18 10:02:54 -08:00
21 changed files with 1091 additions and 374 deletions

View File

@@ -53,6 +53,24 @@ Icons use [Lucide](https://lucide.dev/). Add new icons as macros in `web/templat
- **NEVER merge anything into main.** Always commit directly or use fast-forward/rebase.
- Never force push.
## Releases
Use `gh release create` to create releases. The tag is created automatically.
```bash
# Check current version
git tag --sort=-v:refname | head -1
# Create release (minor version bump: v0.21.1 -> v0.22.0)
gh release create v0.22.0 --title "v0.22.0" --notes "release notes here"
```
Versioning:
- **Patch** (v0.21.0 → v0.21.1): Bug fixes
- **Minor** (v0.21.1 → v0.22.0): New features, non-breaking changes
Write release notes manually describing what changed. Group by features and bug fixes.
## Commands Quick Reference
CLI available as `cf` or `compose-farm`.

View File

@@ -12,6 +12,7 @@ services:
web:
image: ghcr.io/basnijholt/compose-farm:latest
restart: unless-stopped
command: web --host 0.0.0.0 --port 9000
volumes:
- ${SSH_AUTH_SOCK}:/ssh-agent:ro

View File

@@ -18,7 +18,15 @@ from rich.progress import (
TimeElapsedColumn,
)
from compose_farm.console import console, err_console
from compose_farm.console import (
MSG_HOST_NOT_FOUND,
MSG_SERVICE_NOT_FOUND,
console,
print_error,
print_hint,
print_success,
print_warning,
)
if TYPE_CHECKING:
from collections.abc import Callable, Coroutine, Generator
@@ -27,6 +35,7 @@ if TYPE_CHECKING:
from compose_farm.executor import CommandResult
_T = TypeVar("_T")
_R = TypeVar("_R")
# --- Shared CLI Options ---
@@ -56,6 +65,13 @@ _MISSING_PATH_PREVIEW_LIMIT = 2
_STATS_PREVIEW_LIMIT = 3 # Max number of pending migrations to show by name
def format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
@contextlib.contextmanager
def progress_bar(
label: str, total: int, *, initial_description: str = "[dim]connecting...[/]"
@@ -81,6 +97,37 @@ def progress_bar(
yield progress, task_id
def run_parallel_with_progress(
label: str,
items: list[_T],
async_fn: Callable[[_T], Coroutine[None, None, _R]],
) -> list[_R]:
"""Run async tasks in parallel with a progress bar.
Args:
label: Progress bar label (e.g., "Discovering", "Querying hosts")
items: List of items to process
async_fn: Async function to call for each item, returns tuple where
first element is used for progress description
Returns:
List of results from async_fn in completion order.
"""
async def gather() -> list[_R]:
with progress_bar(label, len(items)) as (progress, task_id):
tasks = [asyncio.create_task(async_fn(item)) for item in items]
results: list[_R] = []
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
progress.update(task_id, advance=1, description=f"[cyan]{result[0]}[/]") # type: ignore[index]
return results
return asyncio.run(gather())
def load_config_or_exit(config_path: Path | None) -> Config:
"""Load config or exit with a friendly error message."""
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
@@ -89,7 +136,7 @@ def load_config_or_exit(config_path: Path | None) -> Config:
try:
return load_config(config_path)
except FileNotFoundError as e:
err_console.print(f"[red]✗[/] {e}")
print_error(str(e))
raise typer.Exit(1) from e
@@ -107,19 +154,16 @@ def get_services(
if all_services:
return list(config.services.keys()), config
if not services:
err_console.print("[red]✗[/] Specify services or use --all")
print_error("Specify services or use [bold]--all[/]")
raise typer.Exit(1)
# Resolve "." to current directory name
resolved = [Path.cwd().name if svc == "." else svc for svc in services]
# Validate all services exist in config
unknown = [svc for svc in resolved if svc not in config.services]
if unknown:
for svc in unknown:
err_console.print(f"[red]✗[/] Unknown service: [cyan]{svc}[/]")
err_console.print("[dim]Hint: Add the service to compose-farm.yaml or use --all[/]")
raise typer.Exit(1)
validate_services(
config, resolved, hint="Add the service to compose-farm.yaml or use [bold]--all[/]"
)
return resolved, config
@@ -143,21 +187,19 @@ def report_results(results: list[CommandResult]) -> None:
console.print() # Blank line before summary
if failed:
for r in failed:
err_console.print(
f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}"
)
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
console.print()
console.print(
f"[green]✓[/] {len(succeeded)}/{len(results)} services succeeded, "
f"[red]✗[/] {len(failed)} failed"
)
else:
console.print(f"[green]✓[/] All {len(results)} services succeeded")
print_success(f"All {len(results)} services succeeded")
elif failed:
# Single service failed
r = failed[0]
err_console.print(f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}")
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
if failed:
raise typer.Exit(1)
@@ -197,23 +239,48 @@ def maybe_regenerate_traefik(
cfg.traefik_file.parent.mkdir(parents=True, exist_ok=True)
cfg.traefik_file.write_text(new_content)
console.print() # Ensure we're on a new line after streaming output
console.print(f"[green]✓[/] Traefik config updated: {cfg.traefik_file}")
print_success(f"Traefik config updated: {cfg.traefik_file}")
for warning in warnings:
err_console.print(f"[yellow]![/] {warning}")
print_warning(warning)
except (FileNotFoundError, ValueError) as exc:
err_console.print(f"[yellow]![/] Failed to update traefik config: {exc}")
print_warning(f"Failed to update traefik config: {exc}")
def validate_services(cfg: Config, services: list[str], *, hint: str | None = None) -> None:
"""Validate that all services exist in config. Exits with error if any not found."""
invalid = [s for s in services if s not in cfg.services]
if invalid:
for svc in invalid:
print_error(MSG_SERVICE_NOT_FOUND.format(name=svc))
if hint:
print_hint(hint)
raise typer.Exit(1)
def validate_host(cfg: Config, host: str) -> None:
"""Validate that a host exists in config. Exits with error if not found."""
if host not in cfg.hosts:
print_error(MSG_HOST_NOT_FOUND.format(name=host))
raise typer.Exit(1)
def validate_hosts(cfg: Config, hosts: list[str]) -> None:
"""Validate that all hosts exist in config. Exits with error if any not found."""
invalid = [h for h in hosts if h not in cfg.hosts]
if invalid:
for h in invalid:
print_error(MSG_HOST_NOT_FOUND.format(name=h))
raise typer.Exit(1)
def validate_host_for_service(cfg: Config, service: str, host: str) -> None:
"""Validate that a host is valid for a service."""
if host not in cfg.hosts:
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
raise typer.Exit(1)
validate_host(cfg, host)
allowed_hosts = cfg.get_hosts(service)
if host not in allowed_hosts:
err_console.print(
f"[red]✗[/] Service '{service}' is not configured for host '{host}' "
print_error(
f"Service [cyan]{service}[/] is not configured for host [magenta]{host}[/] "
f"(configured: {', '.join(allowed_hosts)})"
)
raise typer.Exit(1)

View File

@@ -14,7 +14,7 @@ from typing import Annotated
import typer
from compose_farm.cli.app import app
from compose_farm.console import console, err_console
from compose_farm.console import MSG_CONFIG_NOT_FOUND, console, print_error, print_success
from compose_farm.paths import config_search_paths, default_config_path, find_config_path
config_app = typer.Typer(
@@ -66,8 +66,8 @@ def _generate_template() -> str:
template_file = resources.files("compose_farm") / "example-config.yaml"
return template_file.read_text(encoding="utf-8")
except FileNotFoundError as e:
err_console.print("[red]Example config template is missing from the package.[/red]")
err_console.print("Reinstall compose-farm or report this issue.")
print_error("Example config template is missing from the package")
console.print("Reinstall compose-farm or report this issue.")
raise typer.Exit(1) from e
@@ -80,6 +80,23 @@ def _get_config_file(path: Path | None) -> Path | None:
return config_path.resolve() if config_path else None
def _report_no_config_found() -> None:
"""Report that no config file was found in search paths."""
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
def _report_config_path_not_exists(config_file: Path) -> None:
"""Report that an explicit config path doesn't exist."""
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
@config_app.command("init")
def config_init(
path: _PathOption = None,
@@ -107,7 +124,7 @@ def config_init(
template_content = _generate_template()
target_path.write_text(template_content, encoding="utf-8")
console.print(f"[green]✓[/] Config file created at: {target_path}")
print_success(f"Config file created at: {target_path}")
console.print("\n[dim]Edit the file to customize your settings:[/dim]")
console.print(" [cyan]cf config edit[/cyan]")
@@ -123,17 +140,11 @@ def config_edit(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
console.print("\nSearched locations:")
for p in config_search_paths():
console.print(f" - {p}")
_report_no_config_found()
raise typer.Exit(1)
if not config_file.exists():
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_config_path_not_exists(config_file)
raise typer.Exit(1)
editor = _get_editor()
@@ -142,21 +153,21 @@ def config_edit(
try:
editor_cmd = shlex.split(editor, posix=os.name != "nt")
except ValueError as e:
err_console.print("[red]Invalid editor command. Check $EDITOR/$VISUAL.[/red]")
print_error("Invalid editor command. Check [bold]$EDITOR[/]/[bold]$VISUAL[/]")
raise typer.Exit(1) from e
if not editor_cmd:
err_console.print("[red]Editor command is empty.[/red]")
print_error("Editor command is empty")
raise typer.Exit(1)
try:
subprocess.run([*editor_cmd, str(config_file)], check=True)
except FileNotFoundError:
err_console.print(f"[red]Editor '{editor_cmd[0]}' not found.[/red]")
err_console.print("Set $EDITOR environment variable to your preferred editor.")
print_error(f"Editor [cyan]{editor_cmd[0]}[/] not found")
console.print("Set [bold]$EDITOR[/] environment variable to your preferred editor.")
raise typer.Exit(1) from None
except subprocess.CalledProcessError as e:
err_console.print(f"[red]Editor exited with error code {e.returncode}[/red]")
print_error(f"Editor exited with error code {e.returncode}")
raise typer.Exit(e.returncode) from None
@@ -169,18 +180,11 @@ def config_show(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_no_config_found()
raise typer.Exit(0)
if not config_file.exists():
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_config_path_not_exists(config_file)
raise typer.Exit(1)
content = config_file.read_text(encoding="utf-8")
@@ -207,11 +211,7 @@ def config_path(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
_report_no_config_found()
raise typer.Exit(1)
# Just print the path for easy piping
@@ -226,7 +226,7 @@ def config_validate(
config_file = _get_config_file(path)
if config_file is None:
err_console.print("[red]✗[/] No config file found")
print_error(MSG_CONFIG_NOT_FOUND)
raise typer.Exit(1)
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
@@ -235,13 +235,13 @@ def config_validate(
try:
cfg = load_config(config_file)
except FileNotFoundError as e:
err_console.print(f"[red]✗[/] {e}")
print_error(str(e))
raise typer.Exit(1) from e
except Exception as e:
err_console.print(f"[red]✗[/] Invalid config: {e}")
print_error(f"Invalid config: {e}")
raise typer.Exit(1) from e
console.print(f"[green]✓[/] Valid config: {config_file}")
print_success(f"Valid config: {config_file}")
console.print(f" Hosts: {len(cfg.hosts)}")
console.print(f" Services: {len(cfg.services)}")
@@ -268,11 +268,11 @@ def config_symlink(
target_path = (target or Path("compose-farm.yaml")).expanduser().resolve()
if not target_path.exists():
err_console.print(f"[red]✗[/] Target config file not found: {target_path}")
print_error(f"Target config file not found: {target_path}")
raise typer.Exit(1)
if not target_path.is_file():
err_console.print(f"[red]✗[/] Target is not a file: {target_path}")
print_error(f"Target is not a file: {target_path}")
raise typer.Exit(1)
symlink_path = default_config_path()
@@ -282,7 +282,7 @@ def config_symlink(
if symlink_path.is_symlink():
current_target = symlink_path.resolve() if symlink_path.exists() else None
if current_target == target_path:
console.print(f"[green]✓[/] Symlink already points to: {target_path}")
print_success(f"Symlink already points to: {target_path}")
return
# Update existing symlink
if not force:
@@ -294,8 +294,8 @@ def config_symlink(
symlink_path.unlink()
else:
# Regular file exists
err_console.print(f"[red]✗[/] A regular file exists at: {symlink_path}")
err_console.print(" Back it up or remove it first, then retry.")
print_error(f"A regular file exists at: {symlink_path}")
console.print(" Back it up or remove it first, then retry.")
raise typer.Exit(1)
# Create parent directories
@@ -304,7 +304,7 @@ def config_symlink(
# Create symlink with absolute path
symlink_path.symlink_to(target_path)
console.print("[green]✓[/] Created symlink:")
print_success("Created symlink:")
console.print(f" {symlink_path}")
console.print(f" -> {target_path}")

View File

@@ -15,6 +15,7 @@ from compose_farm.cli.common import (
ConfigOption,
HostOption,
ServicesArg,
format_host,
get_services,
load_config_or_exit,
maybe_regenerate_traefik,
@@ -22,7 +23,7 @@ from compose_farm.cli.common import (
run_async,
run_host_operation,
)
from compose_farm.console import console, err_console
from compose_farm.console import MSG_DRY_RUN, console, print_error, print_success
from compose_farm.executor import run_on_services, run_sequential_on_services
from compose_farm.operations import stop_orphaned_services, up_services
from compose_farm.state import (
@@ -74,14 +75,16 @@ def down(
# Handle --orphaned flag
if orphaned:
if services or all_services or host:
err_console.print("[red]✗[/] Cannot use --orphaned with services, --all, or --host")
print_error(
"Cannot combine [bold]--orphaned[/] with services, [bold]--all[/], or [bold]--host[/]"
)
raise typer.Exit(1)
cfg = load_config_or_exit(config)
orphaned_services = get_orphaned_services(cfg)
if not orphaned_services:
console.print("[green]✓[/] No orphaned services to stop")
print_success("No orphaned services to stop")
return
console.print(
@@ -162,13 +165,6 @@ def update(
report_results(results)
def _format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
def _report_pending_migrations(cfg: Config, migrations: list[str]) -> None:
"""Report services that need migration."""
console.print(f"[cyan]Services to migrate ({len(migrations)}):[/]")
@@ -182,14 +178,14 @@ def _report_pending_orphans(orphaned: dict[str, str | list[str]]) -> None:
"""Report orphaned services that will be stopped."""
console.print(f"[yellow]Orphaned services to stop ({len(orphaned)}):[/]")
for svc, hosts in orphaned.items():
console.print(f" [cyan]{svc}[/] on [magenta]{_format_host(hosts)}[/]")
console.print(f" [cyan]{svc}[/] on [magenta]{format_host(hosts)}[/]")
def _report_pending_starts(cfg: Config, missing: list[str]) -> None:
"""Report services that will be started."""
console.print(f"[green]Services to start ({len(missing)}):[/]")
for svc in missing:
target = _format_host(cfg.get_hosts(svc))
target = format_host(cfg.get_hosts(svc))
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
@@ -197,7 +193,7 @@ def _report_pending_refresh(cfg: Config, to_refresh: list[str]) -> None:
"""Report services that will be refreshed."""
console.print(f"[blue]Services to refresh ({len(to_refresh)}):[/]")
for svc in to_refresh:
target = _format_host(cfg.get_hosts(svc))
target = format_host(cfg.get_hosts(svc))
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
@@ -245,7 +241,7 @@ def apply(
has_refresh = bool(to_refresh)
if not has_orphans and not has_migrations and not has_missing and not has_refresh:
console.print("[green]✓[/] Nothing to apply - reality matches config")
print_success("Nothing to apply - reality matches config")
return
# Report what will be done
@@ -259,7 +255,7 @@ def apply(
_report_pending_refresh(cfg, to_refresh)
if dry_run:
console.print("\n[dim](dry-run: no changes made)[/]")
console.print(f"\n{MSG_DRY_RUN}")
return
# Execute changes

View File

@@ -8,7 +8,6 @@ from pathlib import Path # noqa: TC003
from typing import TYPE_CHECKING, Annotated
import typer
from rich.progress import Progress, TaskID # noqa: TC002
from compose_farm.cli.app import app
from compose_farm.cli.common import (
@@ -17,16 +16,25 @@ from compose_farm.cli.common import (
ConfigOption,
LogPathOption,
ServicesArg,
format_host,
get_services,
load_config_or_exit,
progress_bar,
run_async,
run_parallel_with_progress,
validate_hosts,
validate_services,
)
if TYPE_CHECKING:
from compose_farm.config import Config
from compose_farm.console import console, err_console
from compose_farm.console import (
MSG_DRY_RUN,
console,
print_error,
print_success,
print_warning,
)
from compose_farm.executor import (
CommandResult,
is_local,
@@ -54,21 +62,12 @@ from compose_farm.traefik import generate_traefik_config, render_traefik_config
def _discover_services(cfg: Config) -> dict[str, str | list[str]]:
"""Discover running services with a progress bar."""
async def gather_with_progress(
progress: Progress, task_id: TaskID
) -> dict[str, str | list[str]]:
tasks = [asyncio.create_task(discover_service_host(cfg, s)) for s in cfg.services]
discovered: dict[str, str | list[str]] = {}
for coro in asyncio.as_completed(tasks):
service, host = await coro
if host is not None:
discovered[service] = host
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
return discovered
with progress_bar("Discovering", len(cfg.services)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Discovering",
list(cfg.services),
lambda s: discover_service_host(cfg, s),
)
return {svc: host for svc, host in results if host is not None}
def _snapshot_services(
@@ -77,36 +76,22 @@ def _snapshot_services(
log_path: Path | None,
) -> Path:
"""Capture image digests with a progress bar."""
async def collect_service(service: str, now: datetime) -> list[SnapshotEntry]:
try:
return await collect_service_entries(cfg, service, now=now)
except RuntimeError:
return []
async def gather_with_progress(
progress: Progress, task_id: TaskID, now: datetime, svc_list: list[str]
) -> list[SnapshotEntry]:
# Map tasks to service names so we can update description
task_to_service = {asyncio.create_task(collect_service(s, now)): s for s in svc_list}
all_entries: list[SnapshotEntry] = []
for coro in asyncio.as_completed(list(task_to_service.keys())):
entries = await coro
all_entries.extend(entries)
# Find which service just completed (by checking done tasks)
for t, svc in task_to_service.items():
if t.done() and not hasattr(t, "_reported"):
t._reported = True # type: ignore[attr-defined]
progress.update(task_id, advance=1, description=f"[cyan]{svc}[/]")
break
return all_entries
effective_log_path = log_path or DEFAULT_LOG_PATH
now_dt = datetime.now(UTC)
now_iso = isoformat(now_dt)
with progress_bar("Capturing", len(services)) as (progress, task_id):
snapshot_entries = asyncio.run(gather_with_progress(progress, task_id, now_dt, services))
async def collect_service(service: str) -> tuple[str, list[SnapshotEntry]]:
try:
return service, await collect_service_entries(cfg, service, now=now_dt)
except RuntimeError:
return service, []
results = run_parallel_with_progress(
"Capturing",
services,
collect_service,
)
snapshot_entries = [entry for _, entries in results for entry in entries]
if not snapshot_entries:
msg = "No image digests were captured"
@@ -119,13 +104,6 @@ def _snapshot_services(
return effective_log_path
def _format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
def _report_sync_changes(
added: list[str],
removed: list[str],
@@ -137,14 +115,14 @@ def _report_sync_changes(
if added:
console.print(f"\nNew services found ({len(added)}):")
for service in sorted(added):
host_str = _format_host(discovered[service])
host_str = format_host(discovered[service])
console.print(f" [green]+[/] [cyan]{service}[/] on [magenta]{host_str}[/]")
if changed:
console.print(f"\nServices on different hosts ({len(changed)}):")
for service, old_host, new_host in sorted(changed):
old_str = _format_host(old_host)
new_str = _format_host(new_host)
old_str = format_host(old_host)
new_str = format_host(new_host)
console.print(
f" [yellow]~[/] [cyan]{service}[/]: [magenta]{old_str}[/] → [magenta]{new_str}[/]"
)
@@ -152,7 +130,7 @@ def _report_sync_changes(
if removed:
console.print(f"\nServices no longer running ({len(removed)}):")
for service in sorted(removed):
host_str = _format_host(current_state[service])
host_str = format_host(current_state[service])
console.print(f" [red]-[/] [cyan]{service}[/] (was on [magenta]{host_str}[/])")
@@ -174,18 +152,12 @@ def _check_ssh_connectivity(cfg: Config) -> list[str]:
result = await run_command(host, "echo ok", host_name, stream=False)
return host_name, result.success
async def gather_with_progress(progress: Progress, task_id: TaskID) -> list[str]:
tasks = [asyncio.create_task(check_host(h)) for h in remote_hosts]
unreachable: list[str] = []
for coro in asyncio.as_completed(tasks):
host_name, success = await coro
if not success:
unreachable.append(host_name)
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
return unreachable
with progress_bar("Checking SSH connectivity", len(remote_hosts)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Checking SSH connectivity",
remote_hosts,
check_host,
)
return [host for host, success in results if not success]
def _check_service_requirements(
@@ -222,27 +194,21 @@ def _check_service_requirements(
return service, mount_errors, network_errors, device_errors
async def gather_with_progress(
progress: Progress, task_id: TaskID
) -> tuple[list[tuple[str, str, str]], list[tuple[str, str, str]], list[tuple[str, str, str]]]:
tasks = [asyncio.create_task(check_service(s)) for s in services]
all_mount_errors: list[tuple[str, str, str]] = []
all_network_errors: list[tuple[str, str, str]] = []
all_device_errors: list[tuple[str, str, str]] = []
results = run_parallel_with_progress(
"Checking requirements",
services,
check_service,
)
for coro in asyncio.as_completed(tasks):
service, mount_errs, net_errs, dev_errs = await coro
all_mount_errors.extend(mount_errs)
all_network_errors.extend(net_errs)
all_device_errors.extend(dev_errs)
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
all_mount_errors: list[tuple[str, str, str]] = []
all_network_errors: list[tuple[str, str, str]] = []
all_device_errors: list[tuple[str, str, str]] = []
for _, mount_errs, net_errs, dev_errs in results:
all_mount_errors.extend(mount_errs)
all_network_errors.extend(net_errs)
all_device_errors.extend(dev_errs)
return all_mount_errors, all_network_errors, all_device_errors
with progress_bar(
"Checking requirements", len(services), initial_description="[dim]checking...[/]"
) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
return all_mount_errors, all_network_errors, all_device_errors
def _report_config_status(cfg: Config) -> bool:
@@ -263,7 +229,7 @@ def _report_config_status(cfg: Config) -> bool:
console.print(f" [red]-[/] [cyan]{name}[/]")
if not unmanaged and not missing_from_disk:
console.print("[green]✓[/] Config matches disk")
print_success("Config matches disk")
return bool(missing_from_disk)
@@ -275,11 +241,10 @@ def _report_orphaned_services(cfg: Config) -> bool:
if orphaned:
console.print("\n[yellow]Orphaned services[/] (in state but not in config):")
console.print(
"[dim]Run 'cf apply' to stop them, or 'cf down --orphaned' for just orphans.[/]"
"[dim]Run [bold]cf apply[/bold] to stop them, or [bold]cf down --orphaned[/bold] for just orphans.[/]"
)
for name, hosts in sorted(orphaned.items()):
host_str = ", ".join(hosts) if isinstance(hosts, list) else hosts
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{host_str}[/]")
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{format_host(hosts)}[/]")
return True
return False
@@ -295,54 +260,24 @@ def _report_traefik_status(cfg: Config, services: list[str]) -> None:
if warnings:
console.print(f"\n[yellow]Traefik issues[/] ({len(warnings)}):")
for warning in warnings:
console.print(f" [yellow]![/] {warning}")
print_warning(warning)
else:
console.print("[green]✓[/] Traefik labels valid")
print_success("Traefik labels valid")
def _report_mount_errors(mount_errors: list[tuple[str, str, str]]) -> None:
"""Report mount errors grouped by service."""
def _report_requirement_errors(errors: list[tuple[str, str, str]], category: str) -> None:
"""Report requirement errors (mounts, networks, devices) grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, path in mount_errors:
by_service.setdefault(svc, []).append((host, path))
for svc, host, item in errors:
by_service.setdefault(svc, []).append((host, item))
console.print(f"[red]Missing mounts[/] ({len(mount_errors)}):")
console.print(f"[red]Missing {category}[/] ({len(errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
paths = [p for _, p in items]
missing = [i for _, i in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for path in paths:
console.print(f" [red]✗[/] {path}")
def _report_network_errors(network_errors: list[tuple[str, str, str]]) -> None:
"""Report network errors grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, net in network_errors:
by_service.setdefault(svc, []).append((host, net))
console.print(f"[red]Missing networks[/] ({len(network_errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
networks = [n for _, n in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for net in networks:
console.print(f" [red]✗[/] {net}")
def _report_device_errors(device_errors: list[tuple[str, str, str]]) -> None:
"""Report device errors grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, dev in device_errors:
by_service.setdefault(svc, []).append((host, dev))
console.print(f"[red]Missing devices[/] ({len(device_errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
devices = [d for _, d in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for dev in devices:
console.print(f" [red]✗[/] {dev}")
for item in missing:
console.print(f" [red]✗[/] {item}")
def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
@@ -350,9 +285,9 @@ def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
if unreachable_hosts:
console.print(f"[red]Unreachable hosts[/] ({len(unreachable_hosts)}):")
for host in sorted(unreachable_hosts):
console.print(f" [red]✗[/] [magenta]{host}[/]")
print_error(f"[magenta]{host}[/]")
return True
console.print("[green]✓[/] All hosts reachable")
print_success("All hosts reachable")
return False
@@ -394,16 +329,16 @@ def _run_remote_checks(cfg: Config, svc_list: list[str], *, show_host_compat: bo
mount_errors, network_errors, device_errors = _check_service_requirements(cfg, svc_list)
if mount_errors:
_report_mount_errors(mount_errors)
_report_requirement_errors(mount_errors, "mounts")
has_errors = True
if network_errors:
_report_network_errors(network_errors)
_report_requirement_errors(network_errors, "networks")
has_errors = True
if device_errors:
_report_device_errors(device_errors)
_report_requirement_errors(device_errors, "devices")
has_errors = True
if not mount_errors and not network_errors and not device_errors:
console.print("[green]✓[/] All mounts, networks, and devices exist")
print_success("All mounts, networks, and devices exist")
if show_host_compat:
for service in svc_list:
@@ -440,7 +375,7 @@ def traefik_file(
try:
dynamic, warnings = generate_traefik_config(cfg, svc_list)
except (FileNotFoundError, ValueError) as exc:
err_console.print(f"[red]✗[/] {exc}")
print_error(str(exc))
raise typer.Exit(1) from exc
rendered = render_traefik_config(dynamic)
@@ -448,12 +383,12 @@ def traefik_file(
if output:
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(rendered)
console.print(f"[green]✓[/] Traefik config written to {output}")
print_success(f"Traefik config written to {output}")
else:
console.print(rendered)
for warning in warnings:
err_console.print(f"[yellow]![/] {warning}")
print_warning(warning)
@app.command(rich_help_panel="Configuration")
@@ -492,24 +427,24 @@ def refresh(
if state_changed:
_report_sync_changes(added, removed, changed, discovered, current_state)
else:
console.print("[green]✓[/] State is already in sync.")
print_success("State is already in sync.")
if dry_run:
console.print("\n[dim](dry-run: no changes made)[/]")
console.print(f"\n{MSG_DRY_RUN}")
return
# Update state file
if state_changed:
save_state(cfg, discovered)
console.print(f"\n[green]✓[/] State updated: {len(discovered)} services tracked.")
print_success(f"State updated: {len(discovered)} services tracked.")
# Capture image digests for running services
if discovered:
try:
path = _snapshot_services(cfg, list(discovered.keys()), log_path)
console.print(f"[green]✓[/] Digests written to {path}")
print_success(f"Digests written to {path}")
except RuntimeError as exc:
err_console.print(f"[yellow]![/] {exc}")
print_warning(str(exc))
@app.command(rich_help_panel="Configuration")
@@ -533,11 +468,7 @@ def check(
# Determine which services to check and whether to show host compatibility
if services:
svc_list = list(services)
invalid = [s for s in svc_list if s not in cfg.services]
if invalid:
for svc in invalid:
err_console.print(f"[red]✗[/] Service '{svc}' not found in config")
raise typer.Exit(1)
validate_services(cfg, svc_list)
show_host_compat = True
else:
svc_list = list(cfg.services.keys())
@@ -587,11 +518,7 @@ def init_network(
cfg = load_config_or_exit(config)
target_hosts = list(hosts) if hosts else list(cfg.hosts.keys())
invalid = [h for h in target_hosts if h not in cfg.hosts]
if invalid:
for h in invalid:
err_console.print(f"[red]✗[/] Host '{h}' not found in config")
raise typer.Exit(1)
validate_hosts(cfg, target_hosts)
async def create_network_on_host(host_name: str) -> CommandResult:
host = cfg.hosts[host_name]
@@ -616,9 +543,8 @@ def init_network(
if result.success:
console.print(f"[cyan]\\[{host_name}][/] [green]✓[/] Created network '{network}'")
else:
err_console.print(
f"[cyan]\\[{host_name}][/] [red]✗[/] Failed to create network: "
f"{result.stderr.strip()}"
print_error(
f"[cyan]\\[{host_name}][/] Failed to create network: {result.stderr.strip()}"
)
return result

View File

@@ -2,12 +2,10 @@
from __future__ import annotations
import asyncio
import contextlib
from typing import TYPE_CHECKING, Annotated
import typer
from rich.progress import Progress, TaskID # noqa: TC002
from rich.table import Table
from compose_farm.cli.app import app
@@ -19,47 +17,19 @@ from compose_farm.cli.common import (
ServicesArg,
get_services,
load_config_or_exit,
progress_bar,
report_results,
run_async,
run_parallel_with_progress,
validate_host,
)
from compose_farm.console import console, err_console
from compose_farm.console import console, print_error, print_warning
from compose_farm.executor import run_command, run_on_services
from compose_farm.state import get_services_needing_migration, load_state
from compose_farm.state import get_services_needing_migration, group_services_by_host, load_state
if TYPE_CHECKING:
from collections.abc import Mapping
from compose_farm.config import Config
def _group_services_by_host(
services: dict[str, str | list[str]],
hosts: Mapping[str, object],
all_hosts: list[str] | None = None,
) -> dict[str, list[str]]:
"""Group services by their assigned host(s).
For multi-host services (list or "all"), the service appears in multiple host lists.
"""
by_host: dict[str, list[str]] = {h: [] for h in hosts}
for service, host_value in services.items():
if isinstance(host_value, list):
# Explicit list of hosts
for host_name in host_value:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value == "all" and all_hosts:
# "all" keyword - add to all hosts
for host_name in all_hosts:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value in by_host:
# Single host
by_host[host_value].append(service)
return by_host
def _get_container_counts(cfg: Config) -> dict[str, int]:
"""Get container counts from all hosts with a progress bar."""
@@ -72,18 +42,12 @@ def _get_container_counts(cfg: Config) -> dict[str, int]:
count = int(result.stdout.strip())
return host_name, count
async def gather_with_progress(progress: Progress, task_id: TaskID) -> dict[str, int]:
hosts = list(cfg.hosts.keys())
tasks = [asyncio.create_task(get_count(h)) for h in hosts]
results: dict[str, int] = {}
for coro in asyncio.as_completed(tasks):
host_name, count = await coro
results[host_name] = count
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
return results
with progress_bar("Querying hosts", len(cfg.hosts)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Querying hosts",
list(cfg.hosts.keys()),
get_count,
)
return dict(results)
def _build_host_table(
@@ -164,20 +128,18 @@ def logs(
) -> None:
"""Show service logs."""
if all_services and host is not None:
err_console.print("[red]✗[/] Cannot use --all and --host together")
print_error("Cannot combine [bold]--all[/] and [bold]--host[/]")
raise typer.Exit(1)
cfg = load_config_or_exit(config)
# Determine service list based on options
if host is not None:
if host not in cfg.hosts:
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
raise typer.Exit(1)
validate_host(cfg, host)
# Include services where host is in the list of configured hosts
svc_list = [s for s in cfg.services if host in cfg.get_hosts(s)]
if not svc_list:
err_console.print(f"[yellow]![/] No services configured for host '{host}'")
print_warning(f"No services configured for host [magenta]{host}[/]")
return
else:
svc_list, cfg = get_services(services or [], all_services, config)
@@ -220,8 +182,8 @@ def stats(
pending = get_services_needing_migration(cfg)
all_hosts = list(cfg.hosts.keys())
services_by_host = _group_services_by_host(cfg.services, cfg.hosts, all_hosts)
running_by_host = _group_services_by_host(state, cfg.hosts, all_hosts)
services_by_host = group_services_by_host(cfg.services, cfg.hosts, all_hosts)
running_by_host = group_services_by_host(state, cfg.hosts, all_hosts)
container_counts: dict[str, int] = {}
if live:

View File

@@ -4,3 +4,35 @@ from rich.console import Console
console = Console(highlight=False)
err_console = Console(stderr=True, highlight=False)
# --- Message Constants ---
# Standardized message templates for consistent user-facing output
MSG_SERVICE_NOT_FOUND = "Service [cyan]{name}[/] not found in config"
MSG_HOST_NOT_FOUND = "Host [magenta]{name}[/] not found in config"
MSG_CONFIG_NOT_FOUND = "Config file not found"
MSG_DRY_RUN = "[dim](dry-run: no changes made)[/]"
# --- Message Helper Functions ---
def print_error(msg: str) -> None:
"""Print error message with ✗ prefix to stderr."""
err_console.print(f"[red]✗[/] {msg}")
def print_success(msg: str) -> None:
"""Print success message with ✓ prefix to stdout."""
console.print(f"[green]✓[/] {msg}")
def print_warning(msg: str) -> None:
"""Print warning message with ! prefix to stderr."""
err_console.print(f"[yellow]![/] {msg}")
def print_hint(msg: str) -> None:
"""Print hint message in dim style to stdout."""
console.print(f"[dim]Hint: {msg}[/]")

View File

@@ -71,6 +71,16 @@ def is_local(host: Host) -> bool:
return addr in _get_local_ips()
def ssh_connect_kwargs(host: Host) -> dict[str, Any]:
"""Get kwargs for asyncssh.connect() from a Host config."""
return {
"host": host.address,
"port": host.port,
"username": host.user,
"known_hosts": None,
}
async def _run_local_command(
command: str,
service: str,
@@ -177,12 +187,7 @@ async def _run_ssh_command(
proc: asyncssh.SSHClientProcess[Any]
try:
async with asyncssh.connect( # noqa: SIM117 - conn needed before create_process
host.address,
port=host.port,
username=host.user,
known_hosts=None,
) as conn:
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn: # noqa: SIM117
async with conn.create_process(command) as proc:
if stream:

View File

@@ -10,7 +10,7 @@ import asyncio
from typing import TYPE_CHECKING, NamedTuple
from .compose import parse_devices, parse_external_networks, parse_host_volumes
from .console import console, err_console
from .console import console, err_console, print_error, print_success, print_warning
from .executor import (
CommandResult,
check_networks_exist,
@@ -145,9 +145,7 @@ async def _cleanup_and_rollback(
raw: bool = False,
) -> None:
"""Clean up failed start and attempt rollback to old host if it was running."""
err_console.print(
f"{prefix} [yellow]![/] Cleaning up failed start on [magenta]{target_host}[/]"
)
print_warning(f"{prefix} Cleaning up failed start on [magenta]{target_host}[/]")
await run_compose(cfg, service, "down", raw=raw)
if not was_running:
@@ -156,12 +154,12 @@ async def _cleanup_and_rollback(
)
return
err_console.print(f"{prefix} [yellow]![/] Rolling back to [magenta]{current_host}[/]...")
print_warning(f"{prefix} Rolling back to [magenta]{current_host}[/]...")
rollback_result = await run_compose_on_host(cfg, service, current_host, "up -d", raw=raw)
if rollback_result.success:
console.print(f"{prefix} [green]✓[/] Rollback succeeded on [magenta]{current_host}[/]")
print_success(f"{prefix} Rollback succeeded on [magenta]{current_host}[/]")
else:
err_console.print(f"{prefix} [red]✗[/] Rollback failed - service is down")
print_error(f"{prefix} Rollback failed - service is down")
def _report_preflight_failures(
@@ -170,17 +168,15 @@ def _report_preflight_failures(
preflight: PreflightResult,
) -> None:
"""Report pre-flight check failures."""
err_console.print(
f"[cyan]\\[{service}][/] [red]✗[/] Cannot start on [magenta]{target_host}[/]:"
)
print_error(f"[cyan]\\[{service}][/] Cannot start on [magenta]{target_host}[/]:")
for path in preflight.missing_paths:
err_console.print(f" [red]✗[/] missing path: {path}")
print_error(f" missing path: {path}")
for net in preflight.missing_networks:
err_console.print(f" [red]✗[/] missing network: {net}")
print_error(f" missing network: {net}")
if preflight.missing_networks:
err_console.print(f" [dim]hint: cf init-network {target_host}[/]")
err_console.print(f" [dim]Hint: cf init-network {target_host}[/]")
for dev in preflight.missing_devices:
err_console.print(f" [red]✗[/] missing device: {dev}")
print_error(f" missing device: {dev}")
async def _up_multi_host_service(
@@ -252,8 +248,8 @@ async def _migrate_service(
for cmd, label in [("pull --ignore-buildable", "Pull"), ("build", "Build")]:
result = await _run_compose_step(cfg, service, cmd, raw=raw)
if not result.success:
err_console.print(
f"{prefix} [red]✗[/] {label} failed on [magenta]{target_host}[/], "
print_error(
f"{prefix} {label} failed on [magenta]{target_host}[/], "
"leaving service on current host"
)
return result
@@ -293,9 +289,8 @@ async def _up_single_service(
return failure
did_migration = True
else:
err_console.print(
f"{prefix} [yellow]![/] was on "
f"[magenta]{current_host}[/] (not in config), skipping down"
print_warning(
f"{prefix} was on [magenta]{current_host}[/] (not in config), skipping down"
)
# Start on target host
@@ -391,9 +386,7 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
for host in host_list:
# Skip hosts no longer in config
if host not in cfg.hosts:
console.print(
f" [yellow]![/] {service}@{host}: host no longer in config, skipping"
)
print_warning(f"{service}@{host}: host no longer in config, skipping")
results.append(
CommandResult(
service=f"{service}@{host}",
@@ -413,11 +406,11 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
result = await task
results.append(result)
if result.success:
console.print(f" [green]✓[/] {service}@{host}: stopped")
print_success(f"{service}@{host}: stopped")
else:
console.print(f" [red]✗[/] {service}@{host}: {result.stderr or 'failed'}")
print_error(f"{service}@{host}: {result.stderr or 'failed'}")
except Exception as e:
console.print(f" [red]✗[/] {service}@{host}: {e}")
print_error(f"{service}@{host}: {e}")
results.append(
CommandResult(
service=f"{service}@{host}",

View File

@@ -8,11 +8,44 @@ from typing import TYPE_CHECKING, Any
import yaml
if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Generator, Mapping
from .config import Config
def group_services_by_host(
services: dict[str, str | list[str]],
hosts: Mapping[str, object],
all_hosts: list[str] | None = None,
) -> dict[str, list[str]]:
"""Group services by their assigned host(s).
For multi-host services (list or "all"), the service appears in multiple host lists.
"""
by_host: dict[str, list[str]] = {h: [] for h in hosts}
for service, host_value in services.items():
if isinstance(host_value, list):
for host_name in host_value:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value == "all" and all_hosts:
for host_name in all_hosts:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value in by_host:
by_host[host_value].append(service)
return by_host
def group_running_services_by_host(
state: dict[str, str | list[str]],
hosts: Mapping[str, object],
) -> dict[str, list[str]]:
"""Group running services by host, filtering out hosts with no services."""
by_host = group_services_by_host(state, hosts)
return {h: svcs for h, svcs in by_host.items() if svcs}
def load_state(config: Config) -> dict[str, str | list[str]]:
"""Load the current deployment state.

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError
if TYPE_CHECKING:
from compose_farm.config import Config
@@ -30,3 +31,10 @@ def get_config() -> Config:
def get_templates() -> Jinja2Templates:
"""Get Jinja2 templates instance."""
return Jinja2Templates(directory=str(TEMPLATES_DIR))
def extract_config_error(exc: Exception) -> str:
"""Extract a user-friendly error message from a config exception."""
if isinstance(exc, ValidationError):
return "; ".join(err.get("msg", str(err)) for err in exc.errors())
return str(exc)

View File

@@ -2,19 +2,20 @@
from __future__ import annotations
import asyncio
import contextlib
import json
from typing import TYPE_CHECKING, Annotated, Any
import shlex
from datetime import UTC, datetime
from pathlib import Path
from typing import Annotated, Any
import asyncssh
import yaml
if TYPE_CHECKING:
from pathlib import Path
from fastapi import APIRouter, Body, HTTPException
from fastapi import APIRouter, Body, HTTPException, Query
from fastapi.responses import HTMLResponse
from compose_farm.executor import run_compose_on_host
from compose_farm.executor import is_local, run_compose_on_host, ssh_connect_kwargs
from compose_farm.paths import find_config_path
from compose_farm.state import load_state
from compose_farm.web.deps import get_config, get_templates
@@ -30,6 +31,51 @@ def _validate_yaml(content: str) -> None:
raise HTTPException(status_code=400, detail=f"Invalid YAML: {e}") from e
def _backup_file(file_path: Path) -> Path | None:
"""Create a timestamped backup of a file if it exists and content differs.
Backups are stored in a .backups directory alongside the file.
Returns the backup path if created, None if no backup was needed.
"""
if not file_path.exists():
return None
# Create backup directory
backup_dir = file_path.parent / ".backups"
backup_dir.mkdir(exist_ok=True)
# Generate timestamped backup filename
timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S")
backup_name = f"{file_path.name}.{timestamp}"
backup_path = backup_dir / backup_name
# Copy current content to backup
backup_path.write_text(file_path.read_text())
# Clean up old backups (keep last 200)
backups = sorted(backup_dir.glob(f"{file_path.name}.*"), reverse=True)
for old_backup in backups[200:]:
old_backup.unlink()
return backup_path
def _save_with_backup(file_path: Path, content: str) -> bool:
"""Save content to file, creating a backup first if content changed.
Returns True if file was saved, False if content was unchanged.
"""
# Check if content actually changed
if file_path.exists():
current_content = file_path.read_text()
if current_content == content:
return False # No change, skip save
_backup_file(file_path)
file_path.write_text(content)
return True
def _get_service_compose_path(name: str) -> Path:
"""Get compose path for service, raising HTTPException if not found."""
config = get_config()
@@ -183,8 +229,9 @@ async def save_compose(
"""Save compose file content."""
compose_path = _get_service_compose_path(name)
_validate_yaml(content)
compose_path.write_text(content)
return {"success": True, "message": "Compose file saved"}
saved = _save_with_backup(compose_path, content)
msg = "Compose file saved" if saved else "No changes to save"
return {"success": True, "message": msg}
@router.put("/service/{name}/env")
@@ -193,8 +240,9 @@ async def save_env(
) -> dict[str, Any]:
"""Save .env file content."""
env_path = _get_service_compose_path(name).parent / ".env"
env_path.write_text(content)
return {"success": True, "message": ".env file saved"}
saved = _save_with_backup(env_path, content)
msg = ".env file saved" if saved else "No changes to save"
return {"success": True, "message": msg}
@router.put("/config")
@@ -207,6 +255,106 @@ async def save_config(
raise HTTPException(status_code=404, detail="Config file not found")
_validate_yaml(content)
config_path.write_text(content)
saved = _save_with_backup(config_path, content)
msg = "Config saved" if saved else "No changes to save"
return {"success": True, "message": msg}
return {"success": True, "message": "Config saved"}
async def _read_file_local(path: str) -> str:
"""Read a file from the local filesystem."""
expanded = Path(path).expanduser()
return await asyncio.to_thread(expanded.read_text, encoding="utf-8")
async def _write_file_local(path: str, content: str) -> bool:
"""Write content to a file on the local filesystem with backup.
Returns True if file was saved, False if content was unchanged.
"""
expanded = Path(path).expanduser()
return await asyncio.to_thread(_save_with_backup, expanded, content)
async def _read_file_remote(host: Any, path: str) -> str:
"""Read a file from a remote host via SSH."""
# Expand ~ on remote by using shell
cmd = f"cat {shlex.quote(path)}"
if path.startswith("~/"):
cmd = f"cat ~/{shlex.quote(path[2:])}"
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
result = await conn.run(cmd, check=True)
stdout = result.stdout or ""
return stdout.decode() if isinstance(stdout, bytes) else stdout
async def _write_file_remote(host: Any, path: str, content: str) -> None:
"""Write content to a file on a remote host via SSH."""
# Expand ~ on remote by using shell
target_path = f"~/{path[2:]}" if path.startswith("~/") else path
cmd = f"cat > {shlex.quote(target_path)}"
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
result = await conn.run(cmd, input=content, check=True)
if result.returncode != 0:
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
msg = f"Failed to write file: {stderr}"
raise RuntimeError(msg)
def _get_console_host(host: str, path: str) -> Any:
"""Validate and return host config for console file operations."""
config = get_config()
host_config = config.hosts.get(host)
if not host_config:
raise HTTPException(status_code=404, detail=f"Host '{host}' not found")
if not path:
raise HTTPException(status_code=400, detail="Path is required")
return host_config
@router.get("/console/file")
async def read_console_file(
host: Annotated[str, Query(description="Host name")],
path: Annotated[str, Query(description="File path")],
) -> dict[str, Any]:
"""Read a file from a host for the console editor."""
host_config = _get_console_host(host, path)
try:
if is_local(host_config):
content = await _read_file_local(path)
else:
content = await _read_file_remote(host_config, path)
return {"success": True, "content": content}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {path}") from None
except PermissionError:
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.put("/console/file")
async def write_console_file(
host: Annotated[str, Query(description="Host name")],
path: Annotated[str, Query(description="File path")],
content: Annotated[str, Body(media_type="text/plain")],
) -> dict[str, Any]:
"""Write a file to a host from the console editor."""
host_config = _get_console_host(host, path)
try:
if is_local(host_config):
saved = await _write_file_local(path, content)
msg = f"Saved: {path}" if saved else "No changes to save"
else:
await _write_file_remote(host_config, path, content)
msg = f"Saved: {path}" # Remote doesn't track changes
return {"success": True, "message": msg}
except PermissionError:
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@@ -7,19 +7,57 @@ from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from pydantic import ValidationError
from compose_farm.executor import is_local
from compose_farm.paths import find_config_path
from compose_farm.state import (
get_orphaned_services,
get_service_host,
get_services_needing_migration,
get_services_not_in_state,
group_running_services_by_host,
load_state,
)
from compose_farm.web.deps import get_config, get_templates
from compose_farm.web.deps import (
extract_config_error,
get_config,
get_templates,
)
router = APIRouter()
@router.get("/console", response_class=HTMLResponse)
async def console(request: Request) -> HTMLResponse:
"""Console page with terminal and editor."""
config = get_config()
templates = get_templates()
# Find local host and sort it first
local_host = None
for name, host in config.hosts.items():
if is_local(host):
local_host = name
break
# Sort hosts with local first
hosts = sorted(config.hosts.keys())
if local_host:
hosts = [local_host] + [h for h in hosts if h != local_host]
# Get config path for default editor file
config_path = str(config.config_path) if config.config_path else ""
return templates.TemplateResponse(
"console.html",
{
"request": request,
"hosts": hosts,
"local_host": local_host,
"config_path": config_path,
},
)
@router.get("/", response_class=HTMLResponse)
async def index(request: Request) -> HTMLResponse:
"""Dashboard page - combined view of all cluster info."""
@@ -30,11 +68,7 @@ async def index(request: Request) -> HTMLResponse:
try:
config = get_config()
except (ValidationError, FileNotFoundError) as e:
# Extract error message
if isinstance(e, ValidationError):
config_error = "; ".join(err.get("msg", str(err)) for err in e.errors())
else:
config_error = str(e)
config_error = extract_config_error(e)
# Read raw config content for the editor
config_path = find_config_path()
@@ -70,14 +104,8 @@ async def index(request: Request) -> HTMLResponse:
migrations = get_services_needing_migration(config)
not_started = get_services_not_in_state(config)
# Group services by host
services_by_host: dict[str, list[str]] = {}
for svc, host in deployed.items():
if isinstance(host, list):
for h in host:
services_by_host.setdefault(h, []).append(svc)
else:
services_by_host.setdefault(host, []).append(svc)
# Group services by host (filter out hosts with no running services)
services_by_host = group_running_services_by_host(deployed, config.hosts)
# Config file content
config_content = ""
@@ -186,10 +214,7 @@ async def config_error_partial(request: Request) -> HTMLResponse:
get_config()
return HTMLResponse("") # No error
except (ValidationError, FileNotFoundError) as e:
if isinstance(e, ValidationError):
error = "; ".join(err.get("msg", str(err)) for err in e.errors())
else:
error = str(e)
error = extract_config_error(e)
return templates.TemplateResponse(
"partials/config_error.html", {"request": request, "config_error": error}
)
@@ -246,15 +271,7 @@ async def services_by_host_partial(request: Request, expanded: bool = True) -> H
templates = get_templates()
deployed = load_state(config)
# Group services by host
services_by_host: dict[str, list[str]] = {}
for svc, host in deployed.items():
if isinstance(host, list):
for h in host:
services_by_host.setdefault(h, []).append(svc)
else:
services_by_host.setdefault(host, []).append(svc)
services_by_host = group_running_services_by_host(deployed, config.hosts)
return templates.TemplateResponse(
"partials/services_by_host.html",

View File

@@ -17,6 +17,35 @@ const editors = {};
let monacoLoaded = false;
let monacoLoading = false;
// Language detection from file path
const LANGUAGE_MAP = {
'yaml': 'yaml', 'yml': 'yaml',
'json': 'json',
'js': 'javascript', 'mjs': 'javascript',
'ts': 'typescript', 'tsx': 'typescript',
'py': 'python',
'sh': 'shell', 'bash': 'shell',
'md': 'markdown',
'html': 'html', 'htm': 'html',
'css': 'css',
'sql': 'sql',
'toml': 'toml',
'ini': 'ini', 'conf': 'ini',
'dockerfile': 'dockerfile',
'env': 'plaintext'
};
/**
* Get Monaco language from file path
* @param {string} path - File path
* @returns {string} Monaco language identifier
*/
function getLanguageFromPath(path) {
const ext = path.split('.').pop().toLowerCase();
return LANGUAGE_MAP[ext] || 'plaintext';
}
window.getLanguageFromPath = getLanguageFromPath;
// Terminal color theme (dark mode matching PicoCSS)
const TERMINAL_THEME = {
background: '#1a1a2e',
@@ -87,6 +116,7 @@ function createWebSocket(path) {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
return new WebSocket(`${protocol}//${window.location.host}${path}`);
}
window.createWebSocket = createWebSocket;
/**
* Initialize a terminal and connect to WebSocket for streaming
@@ -223,10 +253,16 @@ function loadMonaco(callback) {
* @param {HTMLElement} container - Container element
* @param {string} content - Initial content
* @param {string} language - Editor language (yaml, plaintext, etc.)
* @param {boolean} readonly - Whether editor is read-only
* @param {object} opts - Options: { readonly, onSave }
* @returns {object} Monaco editor instance
*/
function createEditor(container, content, language, readonly = false) {
function createEditor(container, content, language, opts = {}) {
// Support legacy boolean readonly parameter
if (typeof opts === 'boolean') {
opts = { readonly: opts };
}
const { readonly = false, onSave = null } = opts;
const options = {
value: content,
language: language,
@@ -249,12 +285,17 @@ function createEditor(container, content, language, readonly = false) {
// Add Command+S / Ctrl+S handler for editable editors
if (!readonly) {
editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, function() {
saveAllEditors();
if (onSave) {
onSave(editor);
} else {
saveAllEditors();
}
});
}
return editor;
}
window.createEditor = createEditor;
/**
* Initialize all Monaco editors on the page

View File

@@ -0,0 +1,246 @@
{% extends "base.html" %}
{% from "partials/components.html" import page_header %}
{% from "partials/icons.html" import terminal, save %}
{% block title %}Console - Compose Farm{% endblock %}
{% block content %}
<div class="max-w-6xl">
{{ page_header("Console", "Terminal and editor access") }}
<!-- Host Selector -->
<div class="flex items-center gap-4 mb-4">
<label class="font-semibold">Host:</label>
<select id="console-host-select" class="select select-sm select-bordered">
{% for name in hosts %}
<option value="{{ name }}">{{ name }}{% if name == local_host %} (local){% endif %}</option>
{% endfor %}
</select>
<button id="console-connect-btn" class="btn btn-sm btn-primary" onclick="connectConsole()">Connect</button>
<span id="console-status" class="text-sm opacity-60"></span>
</div>
<!-- Terminal -->
<div class="mb-6">
<div class="flex items-center gap-2 mb-2">
<h3 class="font-semibold flex items-center gap-2">{{ terminal() }} Terminal</h3>
<span class="text-xs opacity-50">Full shell access to selected host</span>
</div>
<div id="console-terminal" class="w-full bg-base-300 rounded-lg overflow-hidden resize-y" style="height: 384px; min-height: 200px;"></div>
</div>
<!-- Editor -->
<div class="mb-6">
<div class="flex items-center justify-between mb-2">
<div class="flex items-center gap-4">
<h3 class="font-semibold">Editor</h3>
<input type="text" id="console-file-path" class="input input-sm input-bordered w-96" placeholder="Enter file path (e.g., ~/docker-compose.yaml)" value="{{ config_path }}">
<button class="btn btn-sm btn-outline" onclick="loadFile()">Open</button>
</div>
<div class="flex items-center gap-2">
<span id="editor-status" class="text-sm opacity-60"></span>
<button id="console-save-btn" class="btn btn-sm btn-primary" onclick="saveFile()">{{ save() }} Save</button>
</div>
</div>
<div id="console-editor" class="resize-y overflow-hidden rounded-lg" style="height: 512px; min-height: 200px;"></div>
</div>
</div>
{% endblock %}
{% block scripts %}
<script>
let consoleTerminal = null;
let consoleWs = null;
let consoleEditor = null;
let currentFilePath = null;
let currentHost = null;
function connectConsole() {
const hostSelect = document.getElementById('console-host-select');
const host = hostSelect.value;
const statusEl = document.getElementById('console-status');
const terminalEl = document.getElementById('console-terminal');
if (!host) {
statusEl.textContent = 'Please select a host';
return;
}
currentHost = host;
// Clean up existing connection
if (consoleWs) {
consoleWs.close();
consoleWs = null;
}
if (consoleTerminal) {
consoleTerminal.dispose();
consoleTerminal = null;
}
statusEl.textContent = 'Connecting...';
// Create WebSocket
consoleWs = createWebSocket(`/ws/shell/${host}`);
// Resize callback - createTerminal's ResizeObserver calls this on container resize
const sendSize = (cols, rows) => {
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
consoleWs.send(JSON.stringify({ type: 'resize', cols, rows }));
}
};
// Create terminal with resize callback
const { term } = createTerminal(terminalEl, { cursorBlink: true }, sendSize);
consoleTerminal = term;
consoleWs.onopen = () => {
statusEl.textContent = `Connected to ${host}`;
sendSize(term.cols, term.rows);
term.focus();
// Auto-load the default file once editor is ready
const pathInput = document.getElementById('console-file-path');
if (pathInput && pathInput.value) {
const tryLoad = () => {
if (consoleEditor) {
loadFile();
} else {
setTimeout(tryLoad, 100);
}
};
tryLoad();
}
};
consoleWs.onmessage = (event) => term.write(event.data);
consoleWs.onclose = () => {
statusEl.textContent = 'Disconnected';
term.write(`${ANSI.CRLF}${ANSI.DIM}[Connection closed]${ANSI.RESET}${ANSI.CRLF}`);
};
consoleWs.onerror = (error) => {
statusEl.textContent = 'Connection error';
term.write(`${ANSI.RED}[WebSocket Error]${ANSI.RESET}${ANSI.CRLF}`);
console.error('Console WebSocket error:', error);
};
// Send input to WebSocket
term.onData((data) => {
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
consoleWs.send(data);
}
});
}
function initConsoleEditor() {
const editorEl = document.getElementById('console-editor');
if (!editorEl || consoleEditor) return;
loadMonaco(() => {
consoleEditor = createEditor(editorEl, '', 'plaintext', { onSave: saveFile });
});
}
async function loadFile() {
const pathInput = document.getElementById('console-file-path');
const path = pathInput.value.trim();
const statusEl = document.getElementById('editor-status');
if (!path) {
statusEl.textContent = 'Enter a file path';
return;
}
if (!currentHost) {
statusEl.textContent = 'Connect to a host first';
return;
}
statusEl.textContent = `Loading ${path}...`;
try {
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(path)}`);
const data = await response.json();
if (!response.ok || !data.success) {
statusEl.textContent = data.detail || 'Failed to load file';
return;
}
const language = getLanguageFromPath(path);
if (consoleEditor) {
consoleEditor.setValue(data.content);
monaco.editor.setModelLanguage(consoleEditor.getModel(), language);
currentFilePath = path; // Only set after content is loaded
statusEl.textContent = `Loaded: ${path}`;
} else {
statusEl.textContent = 'Editor not ready';
}
} catch (e) {
statusEl.textContent = `Error: ${e.message}`;
}
}
async function saveFile() {
const statusEl = document.getElementById('editor-status');
if (!currentFilePath) {
statusEl.textContent = 'No file loaded';
return;
}
if (!currentHost) {
statusEl.textContent = 'Not connected to a host';
return;
}
if (!consoleEditor) {
statusEl.textContent = 'Editor not ready';
return;
}
statusEl.textContent = `Saving ${currentFilePath}...`;
try {
const content = consoleEditor.getValue();
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(currentFilePath)}`, {
method: 'PUT',
headers: { 'Content-Type': 'text/plain' },
body: content
});
const data = await response.json();
if (!response.ok || !data.success) {
statusEl.textContent = data.detail || 'Failed to save file';
return;
}
statusEl.textContent = `Saved: ${currentFilePath}`;
} catch (e) {
statusEl.textContent = `Error: ${e.message}`;
}
}
// Initialize on page load
document.addEventListener('DOMContentLoaded', () => {
initConsoleEditor();
// Auto-connect to first host if available
const hostSelect = document.getElementById('console-host-select');
if (hostSelect && hostSelect.options.length > 0) {
connectConsole();
}
});
// Re-init after HTMX swap
document.body.addEventListener('htmx:afterSwap', (evt) => {
if (evt.detail.target.id === 'main-content') {
// Re-init if we're on console page
if (window.location.pathname === '/console') {
consoleEditor = null;
initConsoleEditor();
}
}
});
</script>
{% endblock %}

View File

@@ -1,4 +1,4 @@
{% from "partials/icons.html" import search, command %}
{% from "partials/icons.html" import search %}
<dialog id="cmd-palette" class="modal">
<div class="modal-box max-w-lg p-0">
<label class="input input-lg bg-base-100 border-0 border-b border-base-300 w-full rounded-none rounded-t-box sticky top-0 z-10 focus-within:outline-none">
@@ -15,5 +15,7 @@
<!-- Floating button to open command palette -->
<button id="cmd-fab" class="btn btn-circle glass shadow-lg fixed bottom-6 right-6 z-50 hover:ring hover:ring-base-content/50" title="Command Palette (⌘K)">
{{ command(24) }}
<span class="flex items-center gap-0.5 text-sm font-semibold">
<span class="opacity-70"></span><span>K</span>
</span>
</button>

View File

@@ -1,8 +1,9 @@
{% from "partials/icons.html" import home, search %}
<!-- Dashboard Link -->
{% from "partials/icons.html" import home, search, terminal %}
<!-- Navigation Links -->
<div class="mb-4">
<ul class="menu" hx-boost="true" hx-target="#main-content" hx-select="#main-content" hx-swap="outerHTML">
<li><a href="/" class="font-semibold">{{ home() }} Dashboard</a></li>
<li><a href="/console" class="font-semibold">{{ terminal() }} Console</a></li>
</ul>
</div>

View File

@@ -16,9 +16,9 @@ from typing import TYPE_CHECKING, Any
import asyncssh
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from compose_farm.executor import is_local
from compose_farm.executor import is_local, ssh_connect_kwargs
from compose_farm.web.deps import get_config
from compose_farm.web.streaming import CRLF, DIM, GREEN, RED, RESET, tasks
from compose_farm.web.streaming import CRLF, DIM, GREEN, RED, RESET, _get_ssh_auth_sock, tasks
if TYPE_CHECKING:
from compose_farm.config import Host
@@ -121,6 +121,14 @@ async def _bridge_websocket_to_ssh(
proc.terminate()
def _make_controlling_tty(slave_fd: int) -> None:
"""Set up the slave PTY as the controlling terminal for the child process."""
# Create a new session
os.setsid()
# Make the slave fd the controlling terminal
fcntl.ioctl(slave_fd, termios.TIOCSCTTY, 0)
async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
"""Run docker exec locally with PTY."""
master_fd, slave_fd = pty.openpty()
@@ -131,6 +139,8 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
stdout=slave_fd,
stderr=slave_fd,
close_fds=True,
preexec_fn=lambda: _make_controlling_tty(slave_fd),
start_new_session=False, # We handle setsid in preexec_fn
)
os.close(slave_fd)
@@ -141,13 +151,17 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
await _bridge_websocket_to_fd(websocket, master_fd, proc)
async def _run_remote_exec(websocket: WebSocket, host: Host, exec_cmd: str) -> None:
async def _run_remote_exec(
websocket: WebSocket, host: Host, exec_cmd: str, *, agent_forwarding: bool = False
) -> None:
"""Run docker exec on remote host via SSH with PTY."""
# Get SSH agent socket for authentication
agent_path = _get_ssh_auth_sock()
async with asyncssh.connect(
host.address,
port=host.port,
username=host.user,
known_hosts=None,
**ssh_connect_kwargs(host),
agent_forwarding=agent_forwarding,
agent_path=agent_path,
) as conn:
proc: asyncssh.SSHClientProcess[Any] = await conn.create_process(
exec_cmd,
@@ -202,6 +216,48 @@ async def exec_websocket(
await websocket.close()
async def _run_shell_session(
websocket: WebSocket,
host_name: str,
) -> None:
"""Run an interactive shell session on a host over WebSocket."""
config = get_config()
host = config.hosts.get(host_name)
if not host:
await websocket.send_text(f"{RED}Host '{host_name}' not found{RESET}{CRLF}")
return
# Start interactive shell in home directory (avoid login shell to prevent job control warnings)
shell_cmd = "cd ~ && exec bash -i 2>/dev/null || exec sh -i"
if is_local(host):
await _run_local_exec(websocket, shell_cmd)
else:
await _run_remote_exec(websocket, host, shell_cmd, agent_forwarding=True)
@router.websocket("/ws/shell/{host}")
async def shell_websocket(
websocket: WebSocket,
host: str,
) -> None:
"""WebSocket endpoint for interactive host shell access."""
await websocket.accept()
try:
await websocket.send_text(f"{DIM}[Connecting to {host}...]{RESET}{CRLF}")
await _run_shell_session(websocket, host)
await websocket.send_text(f"{CRLF}{DIM}[Disconnected]{RESET}{CRLF}")
except WebSocketDisconnect:
pass
except Exception as e:
with contextlib.suppress(Exception):
await websocket.send_text(f"{RED}Error: {e}{RESET}{CRLF}")
finally:
with contextlib.suppress(Exception):
await websocket.close()
@router.websocket("/ws/terminal/{task_id}")
async def terminal_websocket(websocket: WebSocket, task_id: str) -> None:
"""WebSocket endpoint for terminal streaming."""

54
tests/web/test_backup.py Normal file
View File

@@ -0,0 +1,54 @@
"""Tests for file backup functionality."""
from pathlib import Path
from compose_farm.web.routes.api import _backup_file, _save_with_backup
def test_backup_creates_timestamped_file(tmp_path: Path) -> None:
"""Test that backup creates file in .backups with correct content."""
test_file = tmp_path / "test.yaml"
test_file.write_text("original content")
backup_path = _backup_file(test_file)
assert backup_path is not None
assert backup_path.parent.name == ".backups"
assert backup_path.name.startswith("test.yaml.")
assert backup_path.read_text() == "original content"
def test_backup_returns_none_for_nonexistent_file(tmp_path: Path) -> None:
"""Test that backup returns None if file doesn't exist."""
assert _backup_file(tmp_path / "nonexistent.yaml") is None
def test_save_creates_new_file(tmp_path: Path) -> None:
"""Test that save creates new file without backup."""
test_file = tmp_path / "new.yaml"
assert _save_with_backup(test_file, "content") is True
assert test_file.read_text() == "content"
assert not (tmp_path / ".backups").exists()
def test_save_skips_unchanged_content(tmp_path: Path) -> None:
"""Test that save returns False and creates no backup if unchanged."""
test_file = tmp_path / "test.yaml"
test_file.write_text("same")
assert _save_with_backup(test_file, "same") is False
assert not (tmp_path / ".backups").exists()
def test_save_creates_backup_before_overwrite(tmp_path: Path) -> None:
"""Test that save backs up original before overwriting."""
test_file = tmp_path / "test.yaml"
test_file.write_text("original")
assert _save_with_backup(test_file, "new") is True
assert test_file.read_text() == "new"
backups = list((tmp_path / ".backups").glob("test.yaml.*"))
assert len(backups) == 1
assert backups[0].read_text() == "original"

View File

@@ -0,0 +1,111 @@
"""Tests to verify template context variables match what templates expect.
Uses runtime validation by actually rendering templates and catching errors.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from fastapi.testclient import TestClient
if TYPE_CHECKING:
from compose_farm.config import Config
@pytest.fixture
def mock_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Config:
"""Create a minimal mock config for template testing."""
compose_dir = tmp_path / "compose"
compose_dir.mkdir()
# Create minimal service directory
svc_dir = compose_dir / "test-service"
svc_dir.mkdir()
(svc_dir / "compose.yaml").write_text("services:\n app:\n image: nginx\n")
config_path = tmp_path / "compose-farm.yaml"
config_path.write_text(f"""
compose_dir: {compose_dir}
hosts:
local-host:
address: localhost
services:
test-service: local-host
""")
state_path = tmp_path / "compose-farm-state.yaml"
state_path.write_text("deployed:\n test-service: local-host\n")
from compose_farm.config import load_config
config = load_config(config_path)
# Patch get_config in all relevant modules
from compose_farm.web import deps
from compose_farm.web.routes import api, pages
monkeypatch.setattr(deps, "get_config", lambda: config)
monkeypatch.setattr(api, "get_config", lambda: config)
monkeypatch.setattr(pages, "get_config", lambda: config)
return config
@pytest.fixture
def client(mock_config: Config) -> TestClient:
"""Create a test client with mocked config."""
from compose_farm.web.app import create_app
return TestClient(create_app())
class TestPageTemplatesRender:
"""Test that page templates render without missing variables."""
def test_index_renders(self, client: TestClient) -> None:
"""Test index page renders without errors."""
response = client.get("/")
assert response.status_code == 200
assert "Compose Farm" in response.text
def test_console_renders(self, client: TestClient) -> None:
"""Test console page renders without errors."""
response = client.get("/console")
assert response.status_code == 200
assert "Console" in response.text
assert "Terminal" in response.text
def test_service_detail_renders(self, client: TestClient) -> None:
"""Test service detail page renders without errors."""
response = client.get("/service/test-service")
assert response.status_code == 200
assert "test-service" in response.text
class TestPartialTemplatesRender:
"""Test that partial templates render without missing variables."""
def test_sidebar_renders(self, client: TestClient) -> None:
"""Test sidebar partial renders without errors."""
response = client.get("/partials/sidebar")
assert response.status_code == 200
assert "Dashboard" in response.text
assert "Console" in response.text
def test_stats_renders(self, client: TestClient) -> None:
"""Test stats partial renders without errors."""
response = client.get("/partials/stats")
assert response.status_code == 200
def test_pending_renders(self, client: TestClient) -> None:
"""Test pending partial renders without errors."""
response = client.get("/partials/pending")
assert response.status_code == 200
def test_services_by_host_renders(self, client: TestClient) -> None:
"""Test services_by_host partial renders without errors."""
response = client.get("/partials/services-by-host")
assert response.status_code == 200