Compare commits

...

34 Commits

Author SHA1 Message Date
Bas Nijholt
d8353dbb7e fix: Skip socket paths in preflight volume checks (#37)
Socket paths like SSH_AUTH_SOCK are machine-local and shouldn't be
validated on remote hosts during preflight checks.
2025-12-18 13:59:06 -08:00
Bas Nijholt
2e6146a94b feat(ps): Add service filtering to ps command (#33) 2025-12-18 13:31:18 -08:00
Bas Nijholt
87849a8161 fix(web): Run self-updates via SSH to survive container restart (#35) 2025-12-18 13:10:30 -08:00
Bas Nijholt
c8bf792a9a refactor: Store SSH keys in subdirectory for cleaner volume mounting (#36)
* refactor: Store SSH keys in subdirectory for cleaner volume mounting

Change SSH key location from ~/.ssh/compose-farm (file) to
~/.ssh/compose-farm/id_ed25519 (file in directory).

This allows docker-compose to mount just the compose-farm directory
to /root/.ssh without exposing all host SSH keys to the container.

Also make host path the default option in docker-compose.yml with
clearer comments about the two options.

* docs: Update README for new SSH key directory structure

* docs: Clarify cf ssh setup must run inside container
2025-12-18 13:07:41 -08:00
Bas Nijholt
d37295fbee feat(web): Add distinct color for Dashboard/Console in command palette (#34)
Give Dashboard and Console a purple accent to visually distinguish
them from service navigation items in the Command K palette.
2025-12-18 12:38:28 -08:00
Bas Nijholt
266f541d35 fix(web): Auto-scroll Command K palette when navigating with arrow keys (#32)
When using arrow keys to navigate through the command palette list,
items outside the visible area now scroll into view automatically.
2025-12-18 12:30:29 -08:00
Bas Nijholt
aabdd550ba feat(cli): Add progress bar to ssh status host connectivity check (#31)
Use run_parallel_with_progress for visual feedback during host checks.
Results are now sorted alphabetically for consistent output.

Also adds code style rule to CLAUDE.md about keeping imports at top level.
2025-12-18 12:21:47 -08:00
Bas Nijholt
8ff60a1e3e refactor(ssh): Unify ssh_status to use run_command like check command (#29) 2025-12-18 12:17:47 -08:00
Bas Nijholt
2497bd727a feat(web): Navigate to dashboard for Apply/Refresh from command palette (#28)
When triggering Apply or Refresh from the command palette on a non-dashboard
page, navigate to the dashboard first and then execute the action, opening
the terminal output.
2025-12-18 12:12:50 -08:00
Bas Nijholt
e37d9d87ba feat(web): Add icons to Command K palette items (#27) 2025-12-18 12:08:55 -08:00
Bas Nijholt
80a1906d90 fix(web): Fix console page not initializing on HTMX navigation (#26)
* fix(web): Fix console page not initializing on HTMX navigation

Move inline script from {% block scripts %} to inside {% block content %}
so it's included in HTMX swaps. The script block was outside #main-content,
so hx-select="#main-content" was discarding it during navigation.

Also wrap script in IIFE to prevent let re-declaration errors when
navigating back to the console page.

* refactor(web): Simplify console script using var instead of IIFE
2025-12-18 12:05:30 -08:00
Bas Nijholt
282de12336 feat(cli): Add ssh subcommand for SSH key management (#22) 2025-12-18 11:58:33 -08:00
Bas Nijholt
2c5308aea3 fix(web): Add Console navigation to Command K palette (#25)
The Command K menu was missing an option to navigate to the Console page,
even though it's available in the sidebar.
2025-12-18 11:55:30 -08:00
Bas Nijholt
5057202938 refactor: DRY cleanup and message consistency (#24) 2025-12-18 11:45:32 -08:00
Bas Nijholt
5e1b9987dd fix(web): Set PTY as controlling terminal for local shell sessions (#23)
Local shell sessions weren't receiving SIGINT (Ctrl+C) because the PTY
wasn't set as the controlling terminal. Add preexec_fn that calls
setsid() and TIOCSCTTY to properly set up the terminal.
2025-12-18 11:12:37 -08:00
Bas Nijholt
d9c26f7f2c Merge pull request #21 from basnijholt/refactor/dry-cleanup
refactor: DRY cleanup - consolidate duplicate code patterns
2025-12-18 11:12:24 -08:00
Bas Nijholt
adfcd4bb31 style: Capitalize "Hint:" consistently 2025-12-18 11:05:53 -08:00
Bas Nijholt
95f7d9c3cf style(cli): Unify "not found" message format with color highlighting
- Services use [cyan] highlighting consistently
- Hosts use [magenta] highlighting consistently
- All use the same "X not found in config" pattern
2025-12-18 11:05:05 -08:00
Bas Nijholt
4c1674cfd8 style(cli): Unify error message format with ✗ prefix
All CLI error messages now consistently use the [red]✗[/] prefix
pattern instead of wrapping the entire message in [red]...[/red].
2025-12-18 11:04:28 -08:00
Bas Nijholt
f65ca8420e fix(web): Filter empty hosts from services_by_host
Preserve original behavior where only hosts with running services are
shown in the dashboard, rather than all configured hosts.
2025-12-18 11:00:01 -08:00
Bas Nijholt
85aff2c271 refactor(state): Move group_services_by_host to state.py
Consolidate duplicate service grouping logic from monitoring.py and
pages.py into a shared function in state.py.
2025-12-18 10:55:53 -08:00
Bas Nijholt
61ca24bb8e refactor(cli): Remove unused get_description parameter
All callers used the same pattern (r[0]), so hardcode it in the helper
and remove the parameter entirely.
2025-12-18 10:54:12 -08:00
Bas Nijholt
ed36588358 refactor(cli): Add validate_host and validate_hosts helpers
Extract common host validation patterns into reusable helpers.
Also simplifies validate_host_for_service to use the new validate_host
helper internally.
2025-12-18 10:49:57 -08:00
Bas Nijholt
80c8079a8c refactor(executor): Add ssh_connect_kwargs helper
Extract common asyncssh.connect parameters into a reusable
ssh_connect_kwargs() function. Used by executor.py, api.py, and ws.py.

Lines: 2608 → 2601 (-7)
2025-12-18 10:48:29 -08:00
Bas Nijholt
763bedf9f6 refactor(cli): Extract config not found helpers
Consolidate repeated "config not found" and "path doesn't exist"
messages into _report_no_config_found() and _report_config_path_not_exists()
helper functions. Also unifies the UX to always show status of searched
paths.
2025-12-18 10:46:58 -08:00
Bas Nijholt
641f7e91a8 refactor(cli): Consolidate _report_*_errors() functions
Merge _report_mount_errors, _report_network_errors, and _report_device_errors
into a single _report_requirement_errors function that takes a category
parameter.

Lines: 2634 → 2608 (-26)
2025-12-18 10:43:49 -08:00
Bas Nijholt
4e8e925d59 refactor(cli): Add run_parallel_with_progress helper
Extract common async progress bar pattern into a reusable helper in
common.py. Updates _discover_services, _check_ssh_connectivity,
_check_service_requirements, _get_container_counts, and _snapshot_services
to use the new helper.

Lines: 2642 → 2634 (-8)
2025-12-18 10:42:45 -08:00
Bas Nijholt
d84858dcfb fix(docker): Add restart policy to web service (#19)
* fix(docker): Add restart policy to containers

* fix: Only add restart policy to web service
2025-12-18 10:39:09 -08:00
Bas Nijholt
3121ee04eb feat(web): Show ⌘K shortcut on command palette button (#20) 2025-12-18 10:38:57 -08:00
Bas Nijholt
a795132a04 refactor(cli): Move format_host to common.py
Consolidate duplicate _format_host() function from lifecycle.py and
management.py into a single format_host() function in common.py.

Lines: 2647 → 2642 (-5)
2025-12-18 10:38:52 -08:00
Bas Nijholt
a6e491575a feat(web): Add Console page with terminal and editor (#17) 2025-12-18 10:29:15 -08:00
Bas Nijholt
78bf90afd9 docs: Improve Releases section in CLAUDE.md 2025-12-18 10:04:56 -08:00
Bas Nijholt
76b60bdd96 feat(web): Add Console page with terminal and editor
Add a new Console page accessible from the sidebar that provides:
- Interactive terminal with full shell access to any configured host
- SSH agent forwarding for authentication to remote hosts
- Monaco editor for viewing/editing files on hosts
- Host selector dropdown with local host listed first
- Auto-loads compose-farm config file on page load

Changes:
- Add /console route and console.html template
- Add /ws/shell/{host} WebSocket endpoint for shell sessions
- Add /api/console/file GET/PUT endpoints for remote file operations
- Update sidebar to include Console navigation link
2025-12-18 10:02:54 -08:00
Bas Nijholt
98bfb1bf6d fix(executor): Disable SSH host key checking in raw mode (#18)
Add SSH options to match asyncssh behavior:
- StrictHostKeyChecking=no
- UserKnownHostsFile=/dev/null
- LogLevel=ERROR (suppress warnings)
- Use -tt to force TTY allocation without stdin TTY

Fixes "Host key verification failed" errors when running from web UI.
2025-12-18 09:59:22 -08:00
30 changed files with 2180 additions and 463 deletions

View File

@@ -43,6 +43,10 @@ Icons use [Lucide](https://lucide.dev/). Add new icons as macros in `web/templat
7. **State tracking**: Tracks where services are deployed for auto-migration
8. **Pre-flight checks**: Verifies NFS mounts and Docker networks exist before starting/migrating
## Code Style
- **Imports at top level**: Never add imports inside functions unless they are explicitly marked with `# noqa: PLC0415` and a comment explaining it speeds up CLI startup. Heavy modules like `pydantic`, `yaml`, and `rich.table` are lazily imported to keep `cf --help` fast.
## Communication Notes
- Clarify ambiguous wording (e.g., homophones like "right"/"write", "their"/"there").
@@ -53,6 +57,24 @@ Icons use [Lucide](https://lucide.dev/). Add new icons as macros in `web/templat
- **NEVER merge anything into main.** Always commit directly or use fast-forward/rebase.
- Never force push.
## Releases
Use `gh release create` to create releases. The tag is created automatically.
```bash
# Check current version
git tag --sort=-v:refname | head -1
# Create release (minor version bump: v0.21.1 -> v0.22.0)
gh release create v0.22.0 --title "v0.22.0" --notes "release notes here"
```
Versioning:
- **Patch** (v0.21.0 → v0.21.1): Bug fixes
- **Minor** (v0.21.1 → v0.22.0): New features, non-breaking changes
Write release notes manually describing what changed. Group by features and bug fixes.
## Commands Quick Reference
CLI available as `cf` or `compose-farm`.

View File

@@ -23,6 +23,9 @@ A minimal CLI tool to run Docker Compose commands across multiple hosts via SSH.
- [Best practices](#best-practices)
- [What Compose Farm doesn't do](#what-compose-farm-doesnt-do)
- [Installation](#installation)
- [SSH Authentication](#ssh-authentication)
- [SSH Agent (default)](#ssh-agent-default)
- [Dedicated SSH Key (recommended for Docker/Web UI)](#dedicated-ssh-key-recommended-for-dockerweb-ui)
- [Configuration](#configuration)
- [Multi-Host Services](#multi-host-services)
- [Config Command](#config-command)
@@ -159,6 +162,62 @@ docker run --rm \
</details>
## SSH Authentication
Compose Farm uses SSH to run commands on remote hosts. There are two authentication methods:
### SSH Agent (default)
Works out of the box if you have an SSH agent running with your keys loaded:
```bash
# Verify your agent has keys
ssh-add -l
# Run compose-farm commands
cf up --all
```
### Dedicated SSH Key (recommended for Docker/Web UI)
When running compose-farm in Docker, the SSH agent connection can be lost (e.g., after container restart). The `cf ssh` command sets up a dedicated key that persists:
```bash
# Generate key and copy to all configured hosts
cf ssh setup
# Check status
cf ssh status
```
This creates `~/.ssh/compose-farm/id_ed25519` (ED25519, no passphrase) and copies the public key to each host's `authorized_keys`. Compose Farm tries the SSH agent first, then falls back to this key.
<details><summary>🐳 Docker volume options for SSH keys</summary>
When running in Docker, mount a volume to persist the SSH keys. Choose ONE option and use it for both `cf` and `web` services:
**Option 1: Host path (default)** - keys at `~/.ssh/compose-farm/id_ed25519`
```yaml
volumes:
- ~/.ssh/compose-farm:/root/.ssh
```
**Option 2: Named volume** - managed by Docker
```yaml
volumes:
- cf-ssh:/root/.ssh
```
Run setup once after starting the container (while the SSH agent still works):
```bash
docker compose exec web cf ssh setup
```
The keys will persist across restarts.
</details>
## Configuration
Create `~/.config/compose-farm/compose-farm.yaml` (or `./compose-farm.yaml` in your working directory):
@@ -344,10 +403,11 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
│ check Validate configuration, traefik labels, mounts, and networks. │
│ init-network Create Docker network on hosts with consistent settings. │
│ config Manage compose-farm configuration files. │
│ ssh Manage SSH keys for passwordless authentication. │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Monitoring ─────────────────────────────────────────────────────────────────╮
│ logs Show service logs. │
│ ps Show status of all services. │
│ ps Show status of services.
│ stats Show overview statistics for hosts and services. │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Server ─────────────────────────────────────────────────────────────────────╮
@@ -773,6 +833,21 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
</details>
<details>
<summary>See the output of <code>cf ssh --help</code></summary>
<!-- CODE:BASH:START -->
<!-- echo '```yaml' -->
<!-- export NO_COLOR=1 -->
<!-- export TERM=dumb -->
<!-- export TERMINAL_WIDTH=90 -->
<!-- cf ssh --help -->
<!-- echo '```' -->
<!-- CODE:END -->
</details>
**Monitoring**
<details>
@@ -829,11 +904,19 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
<!-- ⚠️ This content is auto-generated by `markdown-code-runner`. -->
```yaml
Usage: cf ps [OPTIONS]
Usage: cf ps [OPTIONS] [SERVICES]...
Show status of all services.
Show status of services.
Without arguments: shows all services (same as --all). With service names:
shows only those services. With --host: shows services on that host.
╭─ Arguments ──────────────────────────────────────────────────────────────────╮
│ services [SERVICES]... Services to operate on │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Options ────────────────────────────────────────────────────────────────────╮
│ --all -a Run on all services │
│ --host -H TEXT Filter to services on this host │
│ --config -c PATH Path to config file │
│ --help -h Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────╯

View File

@@ -5,6 +5,12 @@ services:
- ${SSH_AUTH_SOCK}:/ssh-agent:ro
# Compose directory (contains compose files AND compose-farm.yaml config)
- ${CF_COMPOSE_DIR:-/opt/stacks}:${CF_COMPOSE_DIR:-/opt/stacks}
# SSH keys for passwordless auth (generated by `cf ssh setup`)
# Choose ONE option below (use the same option for both cf and web services):
# Option 1: Host path (default) - keys at ~/.ssh/compose-farm/id_ed25519
- ${CF_SSH_DIR:-~/.ssh/compose-farm}:/root/.ssh
# Option 2: Named volume - managed by Docker, shared between services
# - cf-ssh:/root/.ssh
environment:
- SSH_AUTH_SOCK=/ssh-agent
# Config file path (state stored alongside it)
@@ -12,13 +18,21 @@ services:
web:
image: ghcr.io/basnijholt/compose-farm:latest
restart: unless-stopped
command: web --host 0.0.0.0 --port 9000
volumes:
- ${SSH_AUTH_SOCK}:/ssh-agent:ro
- ${CF_COMPOSE_DIR:-/opt/stacks}:${CF_COMPOSE_DIR:-/opt/stacks}
# SSH keys - use the SAME option as cf service above
# Option 1: Host path (default)
- ${CF_SSH_DIR:-~/.ssh/compose-farm}:/root/.ssh
# Option 2: Named volume
# - cf-ssh:/root/.ssh
environment:
- SSH_AUTH_SOCK=/ssh-agent
- CF_CONFIG=${CF_COMPOSE_DIR:-/opt/stacks}/compose-farm.yaml
# Used to detect self-updates and run via SSH to survive container restart
- CF_WEB_SERVICE=compose-farm
labels:
- traefik.enable=true
- traefik.http.routers.compose-farm.rule=Host(`compose-farm.${DOMAIN}`)
@@ -32,3 +46,7 @@ services:
networks:
mynetwork:
external: true
volumes:
cf-ssh:
# Only used if Option 2 is selected above

View File

@@ -8,6 +8,7 @@ from compose_farm.cli import (
lifecycle, # noqa: F401
management, # noqa: F401
monitoring, # noqa: F401
ssh, # noqa: F401
web, # noqa: F401
)

View File

@@ -18,7 +18,15 @@ from rich.progress import (
TimeElapsedColumn,
)
from compose_farm.console import console, err_console
from compose_farm.console import (
MSG_HOST_NOT_FOUND,
MSG_SERVICE_NOT_FOUND,
console,
print_error,
print_hint,
print_success,
print_warning,
)
if TYPE_CHECKING:
from collections.abc import Callable, Coroutine, Generator
@@ -27,6 +35,7 @@ if TYPE_CHECKING:
from compose_farm.executor import CommandResult
_T = TypeVar("_T")
_R = TypeVar("_R")
# --- Shared CLI Options ---
@@ -56,6 +65,13 @@ _MISSING_PATH_PREVIEW_LIMIT = 2
_STATS_PREVIEW_LIMIT = 3 # Max number of pending migrations to show by name
def format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
@contextlib.contextmanager
def progress_bar(
label: str, total: int, *, initial_description: str = "[dim]connecting...[/]"
@@ -81,6 +97,37 @@ def progress_bar(
yield progress, task_id
def run_parallel_with_progress(
label: str,
items: list[_T],
async_fn: Callable[[_T], Coroutine[None, None, _R]],
) -> list[_R]:
"""Run async tasks in parallel with a progress bar.
Args:
label: Progress bar label (e.g., "Discovering", "Querying hosts")
items: List of items to process
async_fn: Async function to call for each item, returns tuple where
first element is used for progress description
Returns:
List of results from async_fn in completion order.
"""
async def gather() -> list[_R]:
with progress_bar(label, len(items)) as (progress, task_id):
tasks = [asyncio.create_task(async_fn(item)) for item in items]
results: list[_R] = []
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
progress.update(task_id, advance=1, description=f"[cyan]{result[0]}[/]") # type: ignore[index]
return results
return asyncio.run(gather())
def load_config_or_exit(config_path: Path | None) -> Config:
"""Load config or exit with a friendly error message."""
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
@@ -89,7 +136,7 @@ def load_config_or_exit(config_path: Path | None) -> Config:
try:
return load_config(config_path)
except FileNotFoundError as e:
err_console.print(f"[red]✗[/] {e}")
print_error(str(e))
raise typer.Exit(1) from e
@@ -97,29 +144,54 @@ def get_services(
services: list[str],
all_services: bool,
config_path: Path | None,
*,
host: str | None = None,
default_all: bool = False,
) -> tuple[list[str], Config]:
"""Resolve service list and load config.
Handles three mutually exclusive selection methods:
- Explicit service names
- --all flag
- --host filter
Args:
services: Explicit service names
all_services: Whether --all was specified
config_path: Path to config file
host: Filter to services on this host
default_all: If True, default to all services when nothing specified (for ps)
Supports "." as shorthand for the current directory name.
"""
validate_service_selection(services, all_services, host)
config = load_config_or_exit(config_path)
if host is not None:
validate_host(config, host)
svc_list = [s for s in config.services if host in config.get_hosts(s)]
if not svc_list:
print_warning(f"No services configured for host [magenta]{host}[/]")
raise typer.Exit(0)
return svc_list, config
if all_services:
return list(config.services.keys()), config
if not services:
err_console.print("[red]✗[/] Specify services or use --all")
if default_all:
return list(config.services.keys()), config
print_error("Specify services or use [bold]--all[/] / [bold]--host[/]")
raise typer.Exit(1)
# Resolve "." to current directory name
resolved = [Path.cwd().name if svc == "." else svc for svc in services]
# Validate all services exist in config
unknown = [svc for svc in resolved if svc not in config.services]
if unknown:
for svc in unknown:
err_console.print(f"[red]✗[/] Unknown service: [cyan]{svc}[/]")
err_console.print("[dim]Hint: Add the service to compose-farm.yaml or use --all[/]")
raise typer.Exit(1)
validate_services(
config, resolved, hint="Add the service to compose-farm.yaml or use [bold]--all[/]"
)
return resolved, config
@@ -143,21 +215,19 @@ def report_results(results: list[CommandResult]) -> None:
console.print() # Blank line before summary
if failed:
for r in failed:
err_console.print(
f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}"
)
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
console.print()
console.print(
f"[green]✓[/] {len(succeeded)}/{len(results)} services succeeded, "
f"[red]✗[/] {len(failed)} failed"
)
else:
console.print(f"[green]✓[/] All {len(results)} services succeeded")
print_success(f"All {len(results)} services succeeded")
elif failed:
# Single service failed
r = failed[0]
err_console.print(f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}")
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
if failed:
raise typer.Exit(1)
@@ -197,28 +267,69 @@ def maybe_regenerate_traefik(
cfg.traefik_file.parent.mkdir(parents=True, exist_ok=True)
cfg.traefik_file.write_text(new_content)
console.print() # Ensure we're on a new line after streaming output
console.print(f"[green]✓[/] Traefik config updated: {cfg.traefik_file}")
print_success(f"Traefik config updated: {cfg.traefik_file}")
for warning in warnings:
err_console.print(f"[yellow]![/] {warning}")
print_warning(warning)
except (FileNotFoundError, ValueError) as exc:
err_console.print(f"[yellow]![/] Failed to update traefik config: {exc}")
print_warning(f"Failed to update traefik config: {exc}")
def validate_services(cfg: Config, services: list[str], *, hint: str | None = None) -> None:
"""Validate that all services exist in config. Exits with error if any not found."""
invalid = [s for s in services if s not in cfg.services]
if invalid:
for svc in invalid:
print_error(MSG_SERVICE_NOT_FOUND.format(name=svc))
if hint:
print_hint(hint)
raise typer.Exit(1)
def validate_host(cfg: Config, host: str) -> None:
"""Validate that a host exists in config. Exits with error if not found."""
if host not in cfg.hosts:
print_error(MSG_HOST_NOT_FOUND.format(name=host))
raise typer.Exit(1)
def validate_hosts(cfg: Config, hosts: list[str]) -> None:
"""Validate that all hosts exist in config. Exits with error if any not found."""
invalid = [h for h in hosts if h not in cfg.hosts]
if invalid:
for h in invalid:
print_error(MSG_HOST_NOT_FOUND.format(name=h))
raise typer.Exit(1)
def validate_host_for_service(cfg: Config, service: str, host: str) -> None:
"""Validate that a host is valid for a service."""
if host not in cfg.hosts:
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
raise typer.Exit(1)
validate_host(cfg, host)
allowed_hosts = cfg.get_hosts(service)
if host not in allowed_hosts:
err_console.print(
f"[red]✗[/] Service '{service}' is not configured for host '{host}' "
print_error(
f"Service [cyan]{service}[/] is not configured for host [magenta]{host}[/] "
f"(configured: {', '.join(allowed_hosts)})"
)
raise typer.Exit(1)
def validate_service_selection(
services: list[str] | None,
all_services: bool,
host: str | None,
) -> None:
"""Validate that only one service selection method is used.
The three selection methods (explicit services, --all, --host) are mutually
exclusive. This ensures consistent behavior across all commands.
"""
methods = sum([bool(services), all_services, host is not None])
if methods > 1:
print_error("Use only one of: service names, [bold]--all[/], or [bold]--host[/]")
raise typer.Exit(1)
def run_host_operation(
cfg: Config,
svc_list: list[str],

View File

@@ -14,7 +14,7 @@ from typing import Annotated
import typer
from compose_farm.cli.app import app
from compose_farm.console import console, err_console
from compose_farm.console import MSG_CONFIG_NOT_FOUND, console, print_error, print_success
from compose_farm.paths import config_search_paths, default_config_path, find_config_path
config_app = typer.Typer(
@@ -66,8 +66,8 @@ def _generate_template() -> str:
template_file = resources.files("compose_farm") / "example-config.yaml"
return template_file.read_text(encoding="utf-8")
except FileNotFoundError as e:
err_console.print("[red]Example config template is missing from the package.[/red]")
err_console.print("Reinstall compose-farm or report this issue.")
print_error("Example config template is missing from the package")
console.print("Reinstall compose-farm or report this issue.")
raise typer.Exit(1) from e
@@ -80,6 +80,23 @@ def _get_config_file(path: Path | None) -> Path | None:
return config_path.resolve() if config_path else None
def _report_no_config_found() -> None:
"""Report that no config file was found in search paths."""
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
def _report_config_path_not_exists(config_file: Path) -> None:
"""Report that an explicit config path doesn't exist."""
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
@config_app.command("init")
def config_init(
path: _PathOption = None,
@@ -107,7 +124,7 @@ def config_init(
template_content = _generate_template()
target_path.write_text(template_content, encoding="utf-8")
console.print(f"[green]✓[/] Config file created at: {target_path}")
print_success(f"Config file created at: {target_path}")
console.print("\n[dim]Edit the file to customize your settings:[/dim]")
console.print(" [cyan]cf config edit[/cyan]")
@@ -123,17 +140,11 @@ def config_edit(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
console.print("\nSearched locations:")
for p in config_search_paths():
console.print(f" - {p}")
_report_no_config_found()
raise typer.Exit(1)
if not config_file.exists():
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_config_path_not_exists(config_file)
raise typer.Exit(1)
editor = _get_editor()
@@ -142,21 +153,21 @@ def config_edit(
try:
editor_cmd = shlex.split(editor, posix=os.name != "nt")
except ValueError as e:
err_console.print("[red]Invalid editor command. Check $EDITOR/$VISUAL.[/red]")
print_error("Invalid editor command. Check [bold]$EDITOR[/]/[bold]$VISUAL[/]")
raise typer.Exit(1) from e
if not editor_cmd:
err_console.print("[red]Editor command is empty.[/red]")
print_error("Editor command is empty")
raise typer.Exit(1)
try:
subprocess.run([*editor_cmd, str(config_file)], check=True)
except FileNotFoundError:
err_console.print(f"[red]Editor '{editor_cmd[0]}' not found.[/red]")
err_console.print("Set $EDITOR environment variable to your preferred editor.")
print_error(f"Editor [cyan]{editor_cmd[0]}[/] not found")
console.print("Set [bold]$EDITOR[/] environment variable to your preferred editor.")
raise typer.Exit(1) from None
except subprocess.CalledProcessError as e:
err_console.print(f"[red]Editor exited with error code {e.returncode}[/red]")
print_error(f"Editor exited with error code {e.returncode}")
raise typer.Exit(e.returncode) from None
@@ -169,18 +180,11 @@ def config_show(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_no_config_found()
raise typer.Exit(0)
if not config_file.exists():
console.print("[yellow]Config file not found.[/yellow]")
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
_report_config_path_not_exists(config_file)
raise typer.Exit(1)
content = config_file.read_text(encoding="utf-8")
@@ -207,11 +211,7 @@ def config_path(
config_file = _get_config_file(path)
if config_file is None:
console.print("[yellow]No config file found.[/yellow]")
console.print("\nSearched locations:")
for p in config_search_paths():
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
console.print(f" - {p} ({status})")
_report_no_config_found()
raise typer.Exit(1)
# Just print the path for easy piping
@@ -226,7 +226,7 @@ def config_validate(
config_file = _get_config_file(path)
if config_file is None:
err_console.print("[red]✗[/] No config file found")
print_error(MSG_CONFIG_NOT_FOUND)
raise typer.Exit(1)
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
@@ -235,13 +235,13 @@ def config_validate(
try:
cfg = load_config(config_file)
except FileNotFoundError as e:
err_console.print(f"[red]✗[/] {e}")
print_error(str(e))
raise typer.Exit(1) from e
except Exception as e:
err_console.print(f"[red]✗[/] Invalid config: {e}")
print_error(f"Invalid config: {e}")
raise typer.Exit(1) from e
console.print(f"[green]✓[/] Valid config: {config_file}")
print_success(f"Valid config: {config_file}")
console.print(f" Hosts: {len(cfg.hosts)}")
console.print(f" Services: {len(cfg.services)}")
@@ -268,11 +268,11 @@ def config_symlink(
target_path = (target or Path("compose-farm.yaml")).expanduser().resolve()
if not target_path.exists():
err_console.print(f"[red]✗[/] Target config file not found: {target_path}")
print_error(f"Target config file not found: {target_path}")
raise typer.Exit(1)
if not target_path.is_file():
err_console.print(f"[red]✗[/] Target is not a file: {target_path}")
print_error(f"Target is not a file: {target_path}")
raise typer.Exit(1)
symlink_path = default_config_path()
@@ -282,7 +282,7 @@ def config_symlink(
if symlink_path.is_symlink():
current_target = symlink_path.resolve() if symlink_path.exists() else None
if current_target == target_path:
console.print(f"[green]✓[/] Symlink already points to: {target_path}")
print_success(f"Symlink already points to: {target_path}")
return
# Update existing symlink
if not force:
@@ -294,8 +294,8 @@ def config_symlink(
symlink_path.unlink()
else:
# Regular file exists
err_console.print(f"[red]✗[/] A regular file exists at: {symlink_path}")
err_console.print(" Back it up or remove it first, then retry.")
print_error(f"A regular file exists at: {symlink_path}")
console.print(" Back it up or remove it first, then retry.")
raise typer.Exit(1)
# Create parent directories
@@ -304,7 +304,7 @@ def config_symlink(
# Create symlink with absolute path
symlink_path.symlink_to(target_path)
console.print("[green]✓[/] Created symlink:")
print_success("Created symlink:")
console.print(f" {symlink_path}")
console.print(f" -> {target_path}")

View File

@@ -15,24 +15,22 @@ from compose_farm.cli.common import (
ConfigOption,
HostOption,
ServicesArg,
format_host,
get_services,
load_config_or_exit,
maybe_regenerate_traefik,
report_results,
run_async,
run_host_operation,
)
from compose_farm.console import console, err_console
from compose_farm.console import MSG_DRY_RUN, console, print_error, print_success
from compose_farm.executor import run_on_services, run_sequential_on_services
from compose_farm.operations import stop_orphaned_services, up_services
from compose_farm.state import (
add_service_to_host,
get_orphaned_services,
get_service_host,
get_services_needing_migration,
get_services_not_in_state,
remove_service,
remove_service_from_host,
)
@@ -44,14 +42,7 @@ def up(
config: ConfigOption = None,
) -> None:
"""Start services (docker compose up -d). Auto-migrates if host changed."""
svc_list, cfg = get_services(services or [], all_services, config)
# Per-host operation: run on specific host only
if host:
run_host_operation(cfg, svc_list, host, "up -d", "Starting", add_service_to_host)
return
# Normal operation: use up_services with migration logic
svc_list, cfg = get_services(services or [], all_services, config, host=host)
results = run_async(up_services(cfg, svc_list, raw=True))
maybe_regenerate_traefik(cfg, results)
report_results(results)
@@ -71,17 +62,19 @@ def down(
config: ConfigOption = None,
) -> None:
"""Stop services (docker compose down)."""
# Handle --orphaned flag
# Handle --orphaned flag (mutually exclusive with other selection methods)
if orphaned:
if services or all_services or host:
err_console.print("[red]✗[/] Cannot use --orphaned with services, --all, or --host")
print_error(
"Cannot combine [bold]--orphaned[/] with services, [bold]--all[/], or [bold]--host[/]"
)
raise typer.Exit(1)
cfg = load_config_or_exit(config)
orphaned_services = get_orphaned_services(cfg)
if not orphaned_services:
console.print("[green]✓[/] No orphaned services to stop")
print_success("No orphaned services to stop")
return
console.print(
@@ -92,14 +85,7 @@ def down(
report_results(results)
return
svc_list, cfg = get_services(services or [], all_services, config)
# Per-host operation: run on specific host only
if host:
run_host_operation(cfg, svc_list, host, "down", "Stopping", remove_service_from_host)
return
# Normal operation
svc_list, cfg = get_services(services or [], all_services, config, host=host)
raw = len(svc_list) == 1
results = run_async(run_on_services(cfg, svc_list, "down", raw=raw))
@@ -162,13 +148,6 @@ def update(
report_results(results)
def _format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
def _report_pending_migrations(cfg: Config, migrations: list[str]) -> None:
"""Report services that need migration."""
console.print(f"[cyan]Services to migrate ({len(migrations)}):[/]")
@@ -182,14 +161,14 @@ def _report_pending_orphans(orphaned: dict[str, str | list[str]]) -> None:
"""Report orphaned services that will be stopped."""
console.print(f"[yellow]Orphaned services to stop ({len(orphaned)}):[/]")
for svc, hosts in orphaned.items():
console.print(f" [cyan]{svc}[/] on [magenta]{_format_host(hosts)}[/]")
console.print(f" [cyan]{svc}[/] on [magenta]{format_host(hosts)}[/]")
def _report_pending_starts(cfg: Config, missing: list[str]) -> None:
"""Report services that will be started."""
console.print(f"[green]Services to start ({len(missing)}):[/]")
for svc in missing:
target = _format_host(cfg.get_hosts(svc))
target = format_host(cfg.get_hosts(svc))
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
@@ -197,7 +176,7 @@ def _report_pending_refresh(cfg: Config, to_refresh: list[str]) -> None:
"""Report services that will be refreshed."""
console.print(f"[blue]Services to refresh ({len(to_refresh)}):[/]")
for svc in to_refresh:
target = _format_host(cfg.get_hosts(svc))
target = format_host(cfg.get_hosts(svc))
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
@@ -245,7 +224,7 @@ def apply(
has_refresh = bool(to_refresh)
if not has_orphans and not has_migrations and not has_missing and not has_refresh:
console.print("[green]✓[/] Nothing to apply - reality matches config")
print_success("Nothing to apply - reality matches config")
return
# Report what will be done
@@ -259,7 +238,7 @@ def apply(
_report_pending_refresh(cfg, to_refresh)
if dry_run:
console.print("\n[dim](dry-run: no changes made)[/]")
console.print(f"\n{MSG_DRY_RUN}")
return
# Execute changes

View File

@@ -8,7 +8,6 @@ from pathlib import Path # noqa: TC003
from typing import TYPE_CHECKING, Annotated
import typer
from rich.progress import Progress, TaskID # noqa: TC002
from compose_farm.cli.app import app
from compose_farm.cli.common import (
@@ -17,16 +16,25 @@ from compose_farm.cli.common import (
ConfigOption,
LogPathOption,
ServicesArg,
format_host,
get_services,
load_config_or_exit,
progress_bar,
run_async,
run_parallel_with_progress,
validate_hosts,
validate_services,
)
if TYPE_CHECKING:
from compose_farm.config import Config
from compose_farm.console import console, err_console
from compose_farm.console import (
MSG_DRY_RUN,
console,
print_error,
print_success,
print_warning,
)
from compose_farm.executor import (
CommandResult,
is_local,
@@ -54,21 +62,12 @@ from compose_farm.traefik import generate_traefik_config, render_traefik_config
def _discover_services(cfg: Config) -> dict[str, str | list[str]]:
"""Discover running services with a progress bar."""
async def gather_with_progress(
progress: Progress, task_id: TaskID
) -> dict[str, str | list[str]]:
tasks = [asyncio.create_task(discover_service_host(cfg, s)) for s in cfg.services]
discovered: dict[str, str | list[str]] = {}
for coro in asyncio.as_completed(tasks):
service, host = await coro
if host is not None:
discovered[service] = host
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
return discovered
with progress_bar("Discovering", len(cfg.services)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Discovering",
list(cfg.services),
lambda s: discover_service_host(cfg, s),
)
return {svc: host for svc, host in results if host is not None}
def _snapshot_services(
@@ -77,36 +76,22 @@ def _snapshot_services(
log_path: Path | None,
) -> Path:
"""Capture image digests with a progress bar."""
async def collect_service(service: str, now: datetime) -> list[SnapshotEntry]:
try:
return await collect_service_entries(cfg, service, now=now)
except RuntimeError:
return []
async def gather_with_progress(
progress: Progress, task_id: TaskID, now: datetime, svc_list: list[str]
) -> list[SnapshotEntry]:
# Map tasks to service names so we can update description
task_to_service = {asyncio.create_task(collect_service(s, now)): s for s in svc_list}
all_entries: list[SnapshotEntry] = []
for coro in asyncio.as_completed(list(task_to_service.keys())):
entries = await coro
all_entries.extend(entries)
# Find which service just completed (by checking done tasks)
for t, svc in task_to_service.items():
if t.done() and not hasattr(t, "_reported"):
t._reported = True # type: ignore[attr-defined]
progress.update(task_id, advance=1, description=f"[cyan]{svc}[/]")
break
return all_entries
effective_log_path = log_path or DEFAULT_LOG_PATH
now_dt = datetime.now(UTC)
now_iso = isoformat(now_dt)
with progress_bar("Capturing", len(services)) as (progress, task_id):
snapshot_entries = asyncio.run(gather_with_progress(progress, task_id, now_dt, services))
async def collect_service(service: str) -> tuple[str, list[SnapshotEntry]]:
try:
return service, await collect_service_entries(cfg, service, now=now_dt)
except RuntimeError:
return service, []
results = run_parallel_with_progress(
"Capturing",
services,
collect_service,
)
snapshot_entries = [entry for _, entries in results for entry in entries]
if not snapshot_entries:
msg = "No image digests were captured"
@@ -119,13 +104,6 @@ def _snapshot_services(
return effective_log_path
def _format_host(host: str | list[str]) -> str:
"""Format a host value for display."""
if isinstance(host, list):
return ", ".join(host)
return host
def _report_sync_changes(
added: list[str],
removed: list[str],
@@ -137,14 +115,14 @@ def _report_sync_changes(
if added:
console.print(f"\nNew services found ({len(added)}):")
for service in sorted(added):
host_str = _format_host(discovered[service])
host_str = format_host(discovered[service])
console.print(f" [green]+[/] [cyan]{service}[/] on [magenta]{host_str}[/]")
if changed:
console.print(f"\nServices on different hosts ({len(changed)}):")
for service, old_host, new_host in sorted(changed):
old_str = _format_host(old_host)
new_str = _format_host(new_host)
old_str = format_host(old_host)
new_str = format_host(new_host)
console.print(
f" [yellow]~[/] [cyan]{service}[/]: [magenta]{old_str}[/] → [magenta]{new_str}[/]"
)
@@ -152,7 +130,7 @@ def _report_sync_changes(
if removed:
console.print(f"\nServices no longer running ({len(removed)}):")
for service in sorted(removed):
host_str = _format_host(current_state[service])
host_str = format_host(current_state[service])
console.print(f" [red]-[/] [cyan]{service}[/] (was on [magenta]{host_str}[/])")
@@ -171,21 +149,21 @@ def _check_ssh_connectivity(cfg: Config) -> list[str]:
async def check_host(host_name: str) -> tuple[str, bool]:
host = cfg.hosts[host_name]
result = await run_command(host, "echo ok", host_name, stream=False)
return host_name, result.success
try:
result = await asyncio.wait_for(
run_command(host, "echo ok", host_name, stream=False),
timeout=5.0,
)
return host_name, result.success
except TimeoutError:
return host_name, False
async def gather_with_progress(progress: Progress, task_id: TaskID) -> list[str]:
tasks = [asyncio.create_task(check_host(h)) for h in remote_hosts]
unreachable: list[str] = []
for coro in asyncio.as_completed(tasks):
host_name, success = await coro
if not success:
unreachable.append(host_name)
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
return unreachable
with progress_bar("Checking SSH connectivity", len(remote_hosts)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Checking SSH connectivity",
remote_hosts,
check_host,
)
return [host for host, success in results if not success]
def _check_service_requirements(
@@ -222,27 +200,21 @@ def _check_service_requirements(
return service, mount_errors, network_errors, device_errors
async def gather_with_progress(
progress: Progress, task_id: TaskID
) -> tuple[list[tuple[str, str, str]], list[tuple[str, str, str]], list[tuple[str, str, str]]]:
tasks = [asyncio.create_task(check_service(s)) for s in services]
all_mount_errors: list[tuple[str, str, str]] = []
all_network_errors: list[tuple[str, str, str]] = []
all_device_errors: list[tuple[str, str, str]] = []
results = run_parallel_with_progress(
"Checking requirements",
services,
check_service,
)
for coro in asyncio.as_completed(tasks):
service, mount_errs, net_errs, dev_errs = await coro
all_mount_errors.extend(mount_errs)
all_network_errors.extend(net_errs)
all_device_errors.extend(dev_errs)
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
all_mount_errors: list[tuple[str, str, str]] = []
all_network_errors: list[tuple[str, str, str]] = []
all_device_errors: list[tuple[str, str, str]] = []
for _, mount_errs, net_errs, dev_errs in results:
all_mount_errors.extend(mount_errs)
all_network_errors.extend(net_errs)
all_device_errors.extend(dev_errs)
return all_mount_errors, all_network_errors, all_device_errors
with progress_bar(
"Checking requirements", len(services), initial_description="[dim]checking...[/]"
) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
return all_mount_errors, all_network_errors, all_device_errors
def _report_config_status(cfg: Config) -> bool:
@@ -263,7 +235,7 @@ def _report_config_status(cfg: Config) -> bool:
console.print(f" [red]-[/] [cyan]{name}[/]")
if not unmanaged and not missing_from_disk:
console.print("[green]✓[/] Config matches disk")
print_success("Config matches disk")
return bool(missing_from_disk)
@@ -275,11 +247,10 @@ def _report_orphaned_services(cfg: Config) -> bool:
if orphaned:
console.print("\n[yellow]Orphaned services[/] (in state but not in config):")
console.print(
"[dim]Run 'cf apply' to stop them, or 'cf down --orphaned' for just orphans.[/]"
"[dim]Run [bold]cf apply[/bold] to stop them, or [bold]cf down --orphaned[/bold] for just orphans.[/]"
)
for name, hosts in sorted(orphaned.items()):
host_str = ", ".join(hosts) if isinstance(hosts, list) else hosts
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{host_str}[/]")
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{format_host(hosts)}[/]")
return True
return False
@@ -295,54 +266,24 @@ def _report_traefik_status(cfg: Config, services: list[str]) -> None:
if warnings:
console.print(f"\n[yellow]Traefik issues[/] ({len(warnings)}):")
for warning in warnings:
console.print(f" [yellow]![/] {warning}")
print_warning(warning)
else:
console.print("[green]✓[/] Traefik labels valid")
print_success("Traefik labels valid")
def _report_mount_errors(mount_errors: list[tuple[str, str, str]]) -> None:
"""Report mount errors grouped by service."""
def _report_requirement_errors(errors: list[tuple[str, str, str]], category: str) -> None:
"""Report requirement errors (mounts, networks, devices) grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, path in mount_errors:
by_service.setdefault(svc, []).append((host, path))
for svc, host, item in errors:
by_service.setdefault(svc, []).append((host, item))
console.print(f"[red]Missing mounts[/] ({len(mount_errors)}):")
console.print(f"[red]Missing {category}[/] ({len(errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
paths = [p for _, p in items]
missing = [i for _, i in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for path in paths:
console.print(f" [red]✗[/] {path}")
def _report_network_errors(network_errors: list[tuple[str, str, str]]) -> None:
"""Report network errors grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, net in network_errors:
by_service.setdefault(svc, []).append((host, net))
console.print(f"[red]Missing networks[/] ({len(network_errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
networks = [n for _, n in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for net in networks:
console.print(f" [red]✗[/] {net}")
def _report_device_errors(device_errors: list[tuple[str, str, str]]) -> None:
"""Report device errors grouped by service."""
by_service: dict[str, list[tuple[str, str]]] = {}
for svc, host, dev in device_errors:
by_service.setdefault(svc, []).append((host, dev))
console.print(f"[red]Missing devices[/] ({len(device_errors)}):")
for svc, items in sorted(by_service.items()):
host = items[0][0]
devices = [d for _, d in items]
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
for dev in devices:
console.print(f" [red]✗[/] {dev}")
for item in missing:
console.print(f" [red]✗[/] {item}")
def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
@@ -350,9 +291,9 @@ def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
if unreachable_hosts:
console.print(f"[red]Unreachable hosts[/] ({len(unreachable_hosts)}):")
for host in sorted(unreachable_hosts):
console.print(f" [red]✗[/] [magenta]{host}[/]")
print_error(f"[magenta]{host}[/]")
return True
console.print("[green]✓[/] All hosts reachable")
print_success("All hosts reachable")
return False
@@ -394,16 +335,16 @@ def _run_remote_checks(cfg: Config, svc_list: list[str], *, show_host_compat: bo
mount_errors, network_errors, device_errors = _check_service_requirements(cfg, svc_list)
if mount_errors:
_report_mount_errors(mount_errors)
_report_requirement_errors(mount_errors, "mounts")
has_errors = True
if network_errors:
_report_network_errors(network_errors)
_report_requirement_errors(network_errors, "networks")
has_errors = True
if device_errors:
_report_device_errors(device_errors)
_report_requirement_errors(device_errors, "devices")
has_errors = True
if not mount_errors and not network_errors and not device_errors:
console.print("[green]✓[/] All mounts, networks, and devices exist")
print_success("All mounts, networks, and devices exist")
if show_host_compat:
for service in svc_list:
@@ -440,7 +381,7 @@ def traefik_file(
try:
dynamic, warnings = generate_traefik_config(cfg, svc_list)
except (FileNotFoundError, ValueError) as exc:
err_console.print(f"[red]✗[/] {exc}")
print_error(str(exc))
raise typer.Exit(1) from exc
rendered = render_traefik_config(dynamic)
@@ -448,12 +389,12 @@ def traefik_file(
if output:
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(rendered)
console.print(f"[green]✓[/] Traefik config written to {output}")
print_success(f"Traefik config written to {output}")
else:
console.print(rendered)
for warning in warnings:
err_console.print(f"[yellow]![/] {warning}")
print_warning(warning)
@app.command(rich_help_panel="Configuration")
@@ -492,24 +433,24 @@ def refresh(
if state_changed:
_report_sync_changes(added, removed, changed, discovered, current_state)
else:
console.print("[green]✓[/] State is already in sync.")
print_success("State is already in sync.")
if dry_run:
console.print("\n[dim](dry-run: no changes made)[/]")
console.print(f"\n{MSG_DRY_RUN}")
return
# Update state file
if state_changed:
save_state(cfg, discovered)
console.print(f"\n[green]✓[/] State updated: {len(discovered)} services tracked.")
print_success(f"State updated: {len(discovered)} services tracked.")
# Capture image digests for running services
if discovered:
try:
path = _snapshot_services(cfg, list(discovered.keys()), log_path)
console.print(f"[green]✓[/] Digests written to {path}")
print_success(f"Digests written to {path}")
except RuntimeError as exc:
err_console.print(f"[yellow]![/] {exc}")
print_warning(str(exc))
@app.command(rich_help_panel="Configuration")
@@ -533,11 +474,7 @@ def check(
# Determine which services to check and whether to show host compatibility
if services:
svc_list = list(services)
invalid = [s for s in svc_list if s not in cfg.services]
if invalid:
for svc in invalid:
err_console.print(f"[red]✗[/] Service '{svc}' not found in config")
raise typer.Exit(1)
validate_services(cfg, svc_list)
show_host_compat = True
else:
svc_list = list(cfg.services.keys())
@@ -587,11 +524,7 @@ def init_network(
cfg = load_config_or_exit(config)
target_hosts = list(hosts) if hosts else list(cfg.hosts.keys())
invalid = [h for h in target_hosts if h not in cfg.hosts]
if invalid:
for h in invalid:
err_console.print(f"[red]✗[/] Host '{h}' not found in config")
raise typer.Exit(1)
validate_hosts(cfg, target_hosts)
async def create_network_on_host(host_name: str) -> CommandResult:
host = cfg.hosts[host_name]
@@ -616,9 +549,8 @@ def init_network(
if result.success:
console.print(f"[cyan]\\[{host_name}][/] [green]✓[/] Created network '{network}'")
else:
err_console.print(
f"[cyan]\\[{host_name}][/] [red]✗[/] Failed to create network: "
f"{result.stderr.strip()}"
print_error(
f"[cyan]\\[{host_name}][/] Failed to create network: {result.stderr.strip()}"
)
return result

View File

@@ -2,12 +2,10 @@
from __future__ import annotations
import asyncio
import contextlib
from typing import TYPE_CHECKING, Annotated
import typer
from rich.progress import Progress, TaskID # noqa: TC002
from rich.table import Table
from compose_farm.cli.app import app
@@ -19,47 +17,18 @@ from compose_farm.cli.common import (
ServicesArg,
get_services,
load_config_or_exit,
progress_bar,
report_results,
run_async,
run_parallel_with_progress,
)
from compose_farm.console import console, err_console
from compose_farm.console import console
from compose_farm.executor import run_command, run_on_services
from compose_farm.state import get_services_needing_migration, load_state
from compose_farm.state import get_services_needing_migration, group_services_by_host, load_state
if TYPE_CHECKING:
from collections.abc import Mapping
from compose_farm.config import Config
def _group_services_by_host(
services: dict[str, str | list[str]],
hosts: Mapping[str, object],
all_hosts: list[str] | None = None,
) -> dict[str, list[str]]:
"""Group services by their assigned host(s).
For multi-host services (list or "all"), the service appears in multiple host lists.
"""
by_host: dict[str, list[str]] = {h: [] for h in hosts}
for service, host_value in services.items():
if isinstance(host_value, list):
# Explicit list of hosts
for host_name in host_value:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value == "all" and all_hosts:
# "all" keyword - add to all hosts
for host_name in all_hosts:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value in by_host:
# Single host
by_host[host_value].append(service)
return by_host
def _get_container_counts(cfg: Config) -> dict[str, int]:
"""Get container counts from all hosts with a progress bar."""
@@ -72,18 +41,12 @@ def _get_container_counts(cfg: Config) -> dict[str, int]:
count = int(result.stdout.strip())
return host_name, count
async def gather_with_progress(progress: Progress, task_id: TaskID) -> dict[str, int]:
hosts = list(cfg.hosts.keys())
tasks = [asyncio.create_task(get_count(h)) for h in hosts]
results: dict[str, int] = {}
for coro in asyncio.as_completed(tasks):
host_name, count = await coro
results[host_name] = count
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
return results
with progress_bar("Querying hosts", len(cfg.hosts)) as (progress, task_id):
return asyncio.run(gather_with_progress(progress, task_id))
results = run_parallel_with_progress(
"Querying hosts",
list(cfg.hosts.keys()),
get_count,
)
return dict(results)
def _build_host_table(
@@ -163,24 +126,7 @@ def logs(
config: ConfigOption = None,
) -> None:
"""Show service logs."""
if all_services and host is not None:
err_console.print("[red]✗[/] Cannot use --all and --host together")
raise typer.Exit(1)
cfg = load_config_or_exit(config)
# Determine service list based on options
if host is not None:
if host not in cfg.hosts:
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
raise typer.Exit(1)
# Include services where host is in the list of configured hosts
svc_list = [s for s in cfg.services if host in cfg.get_hosts(s)]
if not svc_list:
err_console.print(f"[yellow]![/] No services configured for host '{host}'")
return
else:
svc_list, cfg = get_services(services or [], all_services, config)
svc_list, cfg = get_services(services or [], all_services, config, host=host)
# Default to fewer lines when showing multiple services
many_services = all_services or host is not None or len(svc_list) > 1
@@ -194,11 +140,19 @@ def logs(
@app.command(rich_help_panel="Monitoring")
def ps(
services: ServicesArg = None,
all_services: AllOption = False,
host: HostOption = None,
config: ConfigOption = None,
) -> None:
"""Show status of all services."""
cfg = load_config_or_exit(config)
results = run_async(run_on_services(cfg, list(cfg.services.keys()), "ps"))
"""Show status of services.
Without arguments: shows all services (same as --all).
With service names: shows only those services.
With --host: shows services on that host.
"""
svc_list, cfg = get_services(services or [], all_services, config, host=host, default_all=True)
results = run_async(run_on_services(cfg, svc_list, "ps"))
report_results(results)
@@ -220,8 +174,8 @@ def stats(
pending = get_services_needing_migration(cfg)
all_hosts = list(cfg.hosts.keys())
services_by_host = _group_services_by_host(cfg.services, cfg.hosts, all_hosts)
running_by_host = _group_services_by_host(state, cfg.hosts, all_hosts)
services_by_host = group_services_by_host(cfg.services, cfg.hosts, all_hosts)
running_by_host = group_services_by_host(state, cfg.hosts, all_hosts)
container_counts: dict[str, int] = {}
if live:

282
src/compose_farm/cli/ssh.py Normal file
View File

@@ -0,0 +1,282 @@
"""SSH key management commands for compose-farm."""
from __future__ import annotations
import asyncio
import subprocess
from typing import TYPE_CHECKING, Annotated
import typer
from compose_farm.cli.app import app
from compose_farm.cli.common import ConfigOption, load_config_or_exit, run_parallel_with_progress
from compose_farm.console import console, err_console
from compose_farm.executor import run_command
if TYPE_CHECKING:
from compose_farm.config import Host
from compose_farm.ssh_keys import (
SSH_KEY_PATH,
SSH_PUBKEY_PATH,
get_pubkey_content,
get_ssh_env,
key_exists,
)
_DEFAULT_SSH_PORT = 22
_PUBKEY_DISPLAY_THRESHOLD = 60
ssh_app = typer.Typer(
name="ssh",
help="Manage SSH keys for passwordless authentication.",
no_args_is_help=True,
)
_ForceOption = Annotated[
bool,
typer.Option("--force", "-f", help="Regenerate key even if it exists."),
]
def _generate_key(*, force: bool = False) -> bool:
"""Generate an ED25519 SSH key with no passphrase.
Returns True if key was generated, False if skipped.
"""
if key_exists() and not force:
console.print(f"[yellow]![/] SSH key already exists: {SSH_KEY_PATH}")
console.print("[dim]Use --force to regenerate[/]")
return False
# Create .ssh directory if it doesn't exist
SSH_KEY_PATH.parent.mkdir(parents=True, exist_ok=True, mode=0o700)
# Remove existing key if forcing regeneration
if force:
SSH_KEY_PATH.unlink(missing_ok=True)
SSH_PUBKEY_PATH.unlink(missing_ok=True)
console.print(f"[dim]Generating SSH key at {SSH_KEY_PATH}...[/]")
try:
subprocess.run(
[ # noqa: S607
"ssh-keygen",
"-t",
"ed25519",
"-N",
"", # No passphrase
"-f",
str(SSH_KEY_PATH),
"-C",
"compose-farm",
],
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as e:
err_console.print(f"[red]Failed to generate SSH key:[/] {e.stderr.decode()}")
return False
except FileNotFoundError:
err_console.print("[red]ssh-keygen not found. Is OpenSSH installed?[/]")
return False
# Set correct permissions
SSH_KEY_PATH.chmod(0o600)
SSH_PUBKEY_PATH.chmod(0o644)
console.print(f"[green]Generated SSH key:[/] {SSH_KEY_PATH}")
return True
def _copy_key_to_host(host_name: str, address: str, user: str, port: int) -> bool:
"""Copy public key to a host's authorized_keys.
Uses ssh-copy-id which handles agent vs password fallback automatically.
Returns True on success, False on failure.
"""
target = f"{user}@{address}"
console.print(f"[dim]Copying key to {host_name} ({target})...[/]")
cmd = ["ssh-copy-id"]
# Disable strict host key checking (consistent with executor.py)
cmd.extend(["-o", "StrictHostKeyChecking=no"])
cmd.extend(["-o", "UserKnownHostsFile=/dev/null"])
if port != _DEFAULT_SSH_PORT:
cmd.extend(["-p", str(port)])
cmd.extend(["-i", str(SSH_PUBKEY_PATH), target])
try:
# Don't capture output so user can see password prompt
result = subprocess.run(cmd, check=False, env=get_ssh_env())
if result.returncode == 0:
console.print(f"[green]Key copied to {host_name}[/]")
return True
err_console.print(f"[red]Failed to copy key to {host_name}[/]")
return False
except FileNotFoundError:
err_console.print("[red]ssh-copy-id not found. Is OpenSSH installed?[/]")
return False
@ssh_app.command("keygen")
def ssh_keygen(
force: _ForceOption = False,
) -> None:
"""Generate SSH key (does not distribute to hosts).
Creates an ED25519 key at ~/.ssh/compose-farm/id_ed25519 with no passphrase.
Use 'cf ssh setup' to also distribute the key to all configured hosts.
"""
success = _generate_key(force=force)
if not success and not key_exists():
raise typer.Exit(1)
@ssh_app.command("setup")
def ssh_setup(
config: ConfigOption = None,
force: _ForceOption = False,
) -> None:
"""Generate SSH key and distribute to all configured hosts.
Creates an ED25519 key at ~/.ssh/compose-farm/id_ed25519 (no passphrase)
and copies the public key to authorized_keys on each host.
For each host, tries SSH agent first. If agent is unavailable,
prompts for password.
"""
cfg = load_config_or_exit(config)
# Skip localhost hosts
remote_hosts = {
name: host
for name, host in cfg.hosts.items()
if host.address.lower() not in ("localhost", "127.0.0.1")
}
if not remote_hosts:
console.print("[yellow]No remote hosts configured.[/]")
raise typer.Exit(0)
# Generate key if needed
if not key_exists() or force:
if not _generate_key(force=force):
raise typer.Exit(1)
else:
console.print(f"[dim]Using existing key: {SSH_KEY_PATH}[/]")
console.print()
console.print(f"[bold]Distributing key to {len(remote_hosts)} host(s)...[/]")
console.print()
# Copy key to each host
succeeded = 0
failed = 0
for host_name, host in remote_hosts.items():
if _copy_key_to_host(host_name, host.address, host.user, host.port):
succeeded += 1
else:
failed += 1
console.print()
if failed == 0:
console.print(
f"[green]Setup complete.[/] {succeeded}/{len(remote_hosts)} hosts configured."
)
else:
console.print(
f"[yellow]Setup partially complete.[/] {succeeded}/{len(remote_hosts)} hosts configured, "
f"[red]{failed} failed[/]."
)
raise typer.Exit(1)
@ssh_app.command("status")
def ssh_status(
config: ConfigOption = None,
) -> None:
"""Show SSH key status and host connectivity."""
from rich.table import Table # noqa: PLC0415
cfg = load_config_or_exit(config)
# Key status
console.print("[bold]SSH Key Status[/]")
console.print()
if key_exists():
console.print(f" [green]Key exists:[/] {SSH_KEY_PATH}")
pubkey = get_pubkey_content()
if pubkey:
# Show truncated public key
if len(pubkey) > _PUBKEY_DISPLAY_THRESHOLD:
console.print(f" [dim]Public key:[/] {pubkey[:30]}...{pubkey[-20:]}")
else:
console.print(f" [dim]Public key:[/] {pubkey}")
else:
console.print(f" [yellow]No key found:[/] {SSH_KEY_PATH}")
console.print(" [dim]Run 'cf ssh setup' to generate and distribute a key[/]")
console.print()
console.print("[bold]Host Connectivity[/]")
console.print()
# Skip localhost hosts
remote_hosts = {
name: host
for name, host in cfg.hosts.items()
if host.address.lower() not in ("localhost", "127.0.0.1")
}
if not remote_hosts:
console.print(" [dim]No remote hosts configured[/]")
return
async def check_host(item: tuple[str, Host]) -> tuple[str, str, str]:
"""Check connectivity to a single host."""
host_name, host = item
target = f"{host.user}@{host.address}"
if host.port != _DEFAULT_SSH_PORT:
target += f":{host.port}"
try:
result = await asyncio.wait_for(
run_command(host, "echo ok", host_name, stream=False),
timeout=5.0,
)
status = "[green]OK[/]" if result.success else "[red]Auth failed[/]"
except TimeoutError:
status = "[red]Timeout (5s)[/]"
except Exception as e:
status = f"[red]Error: {e}[/]"
return host_name, target, status
# Check connectivity in parallel with progress bar
results = run_parallel_with_progress(
"Checking hosts",
list(remote_hosts.items()),
check_host,
)
# Build table from results
table = Table(show_header=True, header_style="bold")
table.add_column("Host")
table.add_column("Address")
table.add_column("Status")
# Sort by host name for consistent order
for host_name, target, status in sorted(results, key=lambda r: r[0]):
table.add_row(host_name, target, status)
console.print(table)
# Register ssh subcommand on the shared app
app.add_typer(ssh_app, name="ssh", rich_help_panel="Configuration")

View File

@@ -7,14 +7,14 @@ from __future__ import annotations
import os
import re
import stat
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
import yaml
if TYPE_CHECKING:
from pathlib import Path
from .config import Config
# Port parsing constants
@@ -141,23 +141,42 @@ def _resolve_host_path(host_path: str, compose_dir: Path) -> str | None:
return None # Named volume
def _is_socket(path: str) -> bool:
"""Check if a path is a socket (e.g., SSH agent socket)."""
try:
return stat.S_ISSOCK(Path(path).stat().st_mode)
except (FileNotFoundError, PermissionError, OSError):
return False
def _parse_volume_item(
item: str | dict[str, Any],
env: dict[str, str],
compose_dir: Path,
) -> str | None:
"""Parse a single volume item and return host path if it's a bind mount."""
"""Parse a single volume item and return host path if it's a bind mount.
Skips socket paths (e.g., SSH_AUTH_SOCK) since they're machine-local
and shouldn't be validated on remote hosts.
"""
host_path: str | None = None
if isinstance(item, str):
interpolated = _interpolate(item, env)
parts = interpolated.split(":")
if len(parts) >= _MIN_VOLUME_PARTS:
return _resolve_host_path(parts[0], compose_dir)
host_path = _resolve_host_path(parts[0], compose_dir)
elif isinstance(item, dict) and item.get("type") == "bind":
source = item.get("source")
if source:
interpolated = _interpolate(str(source), env)
return _resolve_host_path(interpolated, compose_dir)
return None
host_path = _resolve_host_path(interpolated, compose_dir)
# Skip sockets - they're machine-local (e.g., SSH agent)
if host_path and _is_socket(host_path):
return None
return host_path
def parse_host_volumes(config: Config, service: str) -> list[str]:

View File

@@ -4,3 +4,35 @@ from rich.console import Console
console = Console(highlight=False)
err_console = Console(stderr=True, highlight=False)
# --- Message Constants ---
# Standardized message templates for consistent user-facing output
MSG_SERVICE_NOT_FOUND = "Service [cyan]{name}[/] not found in config"
MSG_HOST_NOT_FOUND = "Host [magenta]{name}[/] not found in config"
MSG_CONFIG_NOT_FOUND = "Config file not found"
MSG_DRY_RUN = "[dim](dry-run: no changes made)[/]"
# --- Message Helper Functions ---
def print_error(msg: str) -> None:
"""Print error message with ✗ prefix to stderr."""
err_console.print(f"[red]✗[/] {msg}")
def print_success(msg: str) -> None:
"""Print success message with ✓ prefix to stdout."""
console.print(f"[green]✓[/] {msg}")
def print_warning(msg: str) -> None:
"""Print warning message with ! prefix to stderr."""
err_console.print(f"[yellow]![/] {msg}")
def print_hint(msg: str) -> None:
"""Print hint message in dim style to stdout."""
console.print(f"[dim]Hint: {msg}[/]")

View File

@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any
from rich.markup import escape
from .console import console, err_console
from .ssh_keys import get_key_path, get_ssh_auth_sock, get_ssh_env
if TYPE_CHECKING:
from collections.abc import Callable
@@ -22,6 +23,43 @@ LOCAL_ADDRESSES = frozenset({"local", "localhost", "127.0.0.1", "::1"})
_DEFAULT_SSH_PORT = 22
def build_ssh_command(host: Host, command: str, *, tty: bool = False) -> list[str]:
"""Build SSH command args for executing a command on a remote host.
Args:
host: Host configuration with address, port, user
command: Command to run on the remote host
tty: Whether to allocate a TTY (for interactive/progress bar commands)
Returns:
List of command args suitable for subprocess
"""
ssh_args = [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"LogLevel=ERROR",
]
if tty:
ssh_args.insert(1, "-tt") # Force TTY allocation
key_path = get_key_path()
if key_path:
ssh_args.extend(["-i", str(key_path)])
if host.port != _DEFAULT_SSH_PORT:
ssh_args.extend(["-p", str(host.port)])
ssh_args.append(f"{host.user}@{host.address}")
ssh_args.append(command)
return ssh_args
@lru_cache(maxsize=1)
def _get_local_ips() -> frozenset[str]:
"""Get all IP addresses of the current machine."""
@@ -71,6 +109,25 @@ def is_local(host: Host) -> bool:
return addr in _get_local_ips()
def ssh_connect_kwargs(host: Host) -> dict[str, Any]:
"""Get kwargs for asyncssh.connect() from a Host config."""
kwargs: dict[str, Any] = {
"host": host.address,
"port": host.port,
"username": host.user,
"known_hosts": None,
}
# Add SSH agent path (auto-detect forwarded agent if needed)
agent_path = get_ssh_auth_sock()
if agent_path:
kwargs["agent_path"] = agent_path
# Add key file fallback for when SSH agent is unavailable
key_path = get_key_path()
if key_path:
kwargs["client_keys"] = [str(key_path)]
return kwargs
async def _run_local_command(
command: str,
service: str,
@@ -152,12 +209,10 @@ async def _run_ssh_command(
"""Run a command on a remote host via SSH with streaming output."""
if raw:
# Use native ssh with TTY for proper progress bar rendering
ssh_args = ["ssh", "-t"]
if host.port != _DEFAULT_SSH_PORT:
ssh_args.extend(["-p", str(host.port)])
ssh_args.extend([f"{host.user}@{host.address}", command])
ssh_args = build_ssh_command(host, command, tty=True)
# Run in thread to avoid blocking the event loop
result = await asyncio.to_thread(subprocess.run, ssh_args, check=False)
# Use get_ssh_env() to auto-detect SSH agent socket
result = await asyncio.to_thread(subprocess.run, ssh_args, check=False, env=get_ssh_env())
return CommandResult(
service=service,
exit_code=result.returncode,
@@ -168,12 +223,7 @@ async def _run_ssh_command(
proc: asyncssh.SSHClientProcess[Any]
try:
async with asyncssh.connect( # noqa: SIM117 - conn needed before create_process
host.address,
port=host.port,
username=host.user,
known_hosts=None,
) as conn:
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn: # noqa: SIM117
async with conn.create_process(command) as proc:
if stream:

View File

@@ -10,7 +10,7 @@ import asyncio
from typing import TYPE_CHECKING, NamedTuple
from .compose import parse_devices, parse_external_networks, parse_host_volumes
from .console import console, err_console
from .console import console, err_console, print_error, print_success, print_warning
from .executor import (
CommandResult,
check_networks_exist,
@@ -145,9 +145,7 @@ async def _cleanup_and_rollback(
raw: bool = False,
) -> None:
"""Clean up failed start and attempt rollback to old host if it was running."""
err_console.print(
f"{prefix} [yellow]![/] Cleaning up failed start on [magenta]{target_host}[/]"
)
print_warning(f"{prefix} Cleaning up failed start on [magenta]{target_host}[/]")
await run_compose(cfg, service, "down", raw=raw)
if not was_running:
@@ -156,12 +154,12 @@ async def _cleanup_and_rollback(
)
return
err_console.print(f"{prefix} [yellow]![/] Rolling back to [magenta]{current_host}[/]...")
print_warning(f"{prefix} Rolling back to [magenta]{current_host}[/]...")
rollback_result = await run_compose_on_host(cfg, service, current_host, "up -d", raw=raw)
if rollback_result.success:
console.print(f"{prefix} [green]✓[/] Rollback succeeded on [magenta]{current_host}[/]")
print_success(f"{prefix} Rollback succeeded on [magenta]{current_host}[/]")
else:
err_console.print(f"{prefix} [red]✗[/] Rollback failed - service is down")
print_error(f"{prefix} Rollback failed - service is down")
def _report_preflight_failures(
@@ -170,17 +168,15 @@ def _report_preflight_failures(
preflight: PreflightResult,
) -> None:
"""Report pre-flight check failures."""
err_console.print(
f"[cyan]\\[{service}][/] [red]✗[/] Cannot start on [magenta]{target_host}[/]:"
)
print_error(f"[cyan]\\[{service}][/] Cannot start on [magenta]{target_host}[/]:")
for path in preflight.missing_paths:
err_console.print(f" [red]✗[/] missing path: {path}")
print_error(f" missing path: {path}")
for net in preflight.missing_networks:
err_console.print(f" [red]✗[/] missing network: {net}")
print_error(f" missing network: {net}")
if preflight.missing_networks:
err_console.print(f" [dim]hint: cf init-network {target_host}[/]")
err_console.print(f" [dim]Hint: cf init-network {target_host}[/]")
for dev in preflight.missing_devices:
err_console.print(f" [red]✗[/] missing device: {dev}")
print_error(f" missing device: {dev}")
async def _up_multi_host_service(
@@ -252,8 +248,8 @@ async def _migrate_service(
for cmd, label in [("pull --ignore-buildable", "Pull"), ("build", "Build")]:
result = await _run_compose_step(cfg, service, cmd, raw=raw)
if not result.success:
err_console.print(
f"{prefix} [red]✗[/] {label} failed on [magenta]{target_host}[/], "
print_error(
f"{prefix} {label} failed on [magenta]{target_host}[/], "
"leaving service on current host"
)
return result
@@ -293,9 +289,8 @@ async def _up_single_service(
return failure
did_migration = True
else:
err_console.print(
f"{prefix} [yellow]![/] was on "
f"[magenta]{current_host}[/] (not in config), skipping down"
print_warning(
f"{prefix} was on [magenta]{current_host}[/] (not in config), skipping down"
)
# Start on target host
@@ -391,9 +386,7 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
for host in host_list:
# Skip hosts no longer in config
if host not in cfg.hosts:
console.print(
f" [yellow]![/] {service}@{host}: host no longer in config, skipping"
)
print_warning(f"{service}@{host}: host no longer in config, skipping")
results.append(
CommandResult(
service=f"{service}@{host}",
@@ -413,11 +406,11 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
result = await task
results.append(result)
if result.success:
console.print(f" [green]✓[/] {service}@{host}: stopped")
print_success(f"{service}@{host}: stopped")
else:
console.print(f" [red]✗[/] {service}@{host}: {result.stderr or 'failed'}")
print_error(f"{service}@{host}: {result.stderr or 'failed'}")
except Exception as e:
console.print(f" [red]✗[/] {service}@{host}: {e}")
print_error(f"{service}@{host}: {e}")
results.append(
CommandResult(
service=f"{service}@{host}",

View File

@@ -0,0 +1,67 @@
"""SSH key utilities for compose-farm."""
from __future__ import annotations
import os
from pathlib import Path
# Default key paths for compose-farm SSH key
# Keys are stored in a subdirectory for cleaner docker volume mounting
SSH_KEY_DIR = Path.home() / ".ssh" / "compose-farm"
SSH_KEY_PATH = SSH_KEY_DIR / "id_ed25519"
SSH_PUBKEY_PATH = SSH_KEY_PATH.with_suffix(".pub")
def get_ssh_auth_sock() -> str | None:
"""Get SSH_AUTH_SOCK, auto-detecting forwarded agent if needed.
Checks in order:
1. SSH_AUTH_SOCK environment variable (if socket exists)
2. Forwarded agent sockets in ~/.ssh/agent/ (most recent first)
Returns the socket path or None if no valid socket found.
"""
sock = os.environ.get("SSH_AUTH_SOCK")
if sock and Path(sock).is_socket():
return sock
# Try to find a forwarded SSH agent socket
agent_dir = Path.home() / ".ssh" / "agent"
if agent_dir.is_dir():
sockets = sorted(
agent_dir.glob("s.*.sshd.*"), key=lambda p: p.stat().st_mtime, reverse=True
)
for s in sockets:
if s.is_socket():
return str(s)
return None
def get_ssh_env() -> dict[str, str]:
"""Get environment dict for SSH subprocess with auto-detected agent.
Returns a copy of the current environment with SSH_AUTH_SOCK set
to the auto-detected agent socket (if found).
"""
env = os.environ.copy()
sock = get_ssh_auth_sock()
if sock:
env["SSH_AUTH_SOCK"] = sock
return env
def key_exists() -> bool:
"""Check if the compose-farm SSH key pair exists."""
return SSH_KEY_PATH.exists() and SSH_PUBKEY_PATH.exists()
def get_key_path() -> Path | None:
"""Get the SSH key path if it exists, None otherwise."""
return SSH_KEY_PATH if key_exists() else None
def get_pubkey_content() -> str | None:
"""Get the public key content if it exists, None otherwise."""
if not SSH_PUBKEY_PATH.exists():
return None
return SSH_PUBKEY_PATH.read_text().strip()

View File

@@ -8,11 +8,44 @@ from typing import TYPE_CHECKING, Any
import yaml
if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Generator, Mapping
from .config import Config
def group_services_by_host(
services: dict[str, str | list[str]],
hosts: Mapping[str, object],
all_hosts: list[str] | None = None,
) -> dict[str, list[str]]:
"""Group services by their assigned host(s).
For multi-host services (list or "all"), the service appears in multiple host lists.
"""
by_host: dict[str, list[str]] = {h: [] for h in hosts}
for service, host_value in services.items():
if isinstance(host_value, list):
for host_name in host_value:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value == "all" and all_hosts:
for host_name in all_hosts:
if host_name in by_host:
by_host[host_name].append(service)
elif host_value in by_host:
by_host[host_value].append(service)
return by_host
def group_running_services_by_host(
state: dict[str, str | list[str]],
hosts: Mapping[str, object],
) -> dict[str, list[str]]:
"""Group running services by host, filtering out hosts with no services."""
by_host = group_services_by_host(state, hosts)
return {h: svcs for h, svcs in by_host.items() if svcs}
def load_state(config: Config) -> dict[str, str | list[str]]:
"""Load the current deployment state.

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError
if TYPE_CHECKING:
from compose_farm.config import Config
@@ -30,3 +31,10 @@ def get_config() -> Config:
def get_templates() -> Jinja2Templates:
"""Get Jinja2 templates instance."""
return Jinja2Templates(directory=str(TEMPLATES_DIR))
def extract_config_error(exc: Exception) -> str:
"""Extract a user-friendly error message from a config exception."""
if isinstance(exc, ValidationError):
return "; ".join(err.get("msg", str(err)) for err in exc.errors())
return str(exc)

View File

@@ -2,19 +2,20 @@
from __future__ import annotations
import asyncio
import contextlib
import json
from typing import TYPE_CHECKING, Annotated, Any
import shlex
from datetime import UTC, datetime
from pathlib import Path
from typing import Annotated, Any
import asyncssh
import yaml
if TYPE_CHECKING:
from pathlib import Path
from fastapi import APIRouter, Body, HTTPException
from fastapi import APIRouter, Body, HTTPException, Query
from fastapi.responses import HTMLResponse
from compose_farm.executor import run_compose_on_host
from compose_farm.executor import is_local, run_compose_on_host, ssh_connect_kwargs
from compose_farm.paths import find_config_path
from compose_farm.state import load_state
from compose_farm.web.deps import get_config, get_templates
@@ -30,6 +31,51 @@ def _validate_yaml(content: str) -> None:
raise HTTPException(status_code=400, detail=f"Invalid YAML: {e}") from e
def _backup_file(file_path: Path) -> Path | None:
"""Create a timestamped backup of a file if it exists and content differs.
Backups are stored in a .backups directory alongside the file.
Returns the backup path if created, None if no backup was needed.
"""
if not file_path.exists():
return None
# Create backup directory
backup_dir = file_path.parent / ".backups"
backup_dir.mkdir(exist_ok=True)
# Generate timestamped backup filename
timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S")
backup_name = f"{file_path.name}.{timestamp}"
backup_path = backup_dir / backup_name
# Copy current content to backup
backup_path.write_text(file_path.read_text())
# Clean up old backups (keep last 200)
backups = sorted(backup_dir.glob(f"{file_path.name}.*"), reverse=True)
for old_backup in backups[200:]:
old_backup.unlink()
return backup_path
def _save_with_backup(file_path: Path, content: str) -> bool:
"""Save content to file, creating a backup first if content changed.
Returns True if file was saved, False if content was unchanged.
"""
# Check if content actually changed
if file_path.exists():
current_content = file_path.read_text()
if current_content == content:
return False # No change, skip save
_backup_file(file_path)
file_path.write_text(content)
return True
def _get_service_compose_path(name: str) -> Path:
"""Get compose path for service, raising HTTPException if not found."""
config = get_config()
@@ -183,8 +229,9 @@ async def save_compose(
"""Save compose file content."""
compose_path = _get_service_compose_path(name)
_validate_yaml(content)
compose_path.write_text(content)
return {"success": True, "message": "Compose file saved"}
saved = _save_with_backup(compose_path, content)
msg = "Compose file saved" if saved else "No changes to save"
return {"success": True, "message": msg}
@router.put("/service/{name}/env")
@@ -193,8 +240,9 @@ async def save_env(
) -> dict[str, Any]:
"""Save .env file content."""
env_path = _get_service_compose_path(name).parent / ".env"
env_path.write_text(content)
return {"success": True, "message": ".env file saved"}
saved = _save_with_backup(env_path, content)
msg = ".env file saved" if saved else "No changes to save"
return {"success": True, "message": msg}
@router.put("/config")
@@ -207,6 +255,106 @@ async def save_config(
raise HTTPException(status_code=404, detail="Config file not found")
_validate_yaml(content)
config_path.write_text(content)
saved = _save_with_backup(config_path, content)
msg = "Config saved" if saved else "No changes to save"
return {"success": True, "message": msg}
return {"success": True, "message": "Config saved"}
async def _read_file_local(path: str) -> str:
"""Read a file from the local filesystem."""
expanded = Path(path).expanduser()
return await asyncio.to_thread(expanded.read_text, encoding="utf-8")
async def _write_file_local(path: str, content: str) -> bool:
"""Write content to a file on the local filesystem with backup.
Returns True if file was saved, False if content was unchanged.
"""
expanded = Path(path).expanduser()
return await asyncio.to_thread(_save_with_backup, expanded, content)
async def _read_file_remote(host: Any, path: str) -> str:
"""Read a file from a remote host via SSH."""
# Expand ~ on remote by using shell
cmd = f"cat {shlex.quote(path)}"
if path.startswith("~/"):
cmd = f"cat ~/{shlex.quote(path[2:])}"
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
result = await conn.run(cmd, check=True)
stdout = result.stdout or ""
return stdout.decode() if isinstance(stdout, bytes) else stdout
async def _write_file_remote(host: Any, path: str, content: str) -> None:
"""Write content to a file on a remote host via SSH."""
# Expand ~ on remote by using shell
target_path = f"~/{path[2:]}" if path.startswith("~/") else path
cmd = f"cat > {shlex.quote(target_path)}"
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
result = await conn.run(cmd, input=content, check=True)
if result.returncode != 0:
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
msg = f"Failed to write file: {stderr}"
raise RuntimeError(msg)
def _get_console_host(host: str, path: str) -> Any:
"""Validate and return host config for console file operations."""
config = get_config()
host_config = config.hosts.get(host)
if not host_config:
raise HTTPException(status_code=404, detail=f"Host '{host}' not found")
if not path:
raise HTTPException(status_code=400, detail="Path is required")
return host_config
@router.get("/console/file")
async def read_console_file(
host: Annotated[str, Query(description="Host name")],
path: Annotated[str, Query(description="File path")],
) -> dict[str, Any]:
"""Read a file from a host for the console editor."""
host_config = _get_console_host(host, path)
try:
if is_local(host_config):
content = await _read_file_local(path)
else:
content = await _read_file_remote(host_config, path)
return {"success": True, "content": content}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {path}") from None
except PermissionError:
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.put("/console/file")
async def write_console_file(
host: Annotated[str, Query(description="Host name")],
path: Annotated[str, Query(description="File path")],
content: Annotated[str, Body(media_type="text/plain")],
) -> dict[str, Any]:
"""Write a file to a host from the console editor."""
host_config = _get_console_host(host, path)
try:
if is_local(host_config):
saved = await _write_file_local(path, content)
msg = f"Saved: {path}" if saved else "No changes to save"
else:
await _write_file_remote(host_config, path, content)
msg = f"Saved: {path}" # Remote doesn't track changes
return {"success": True, "message": msg}
except PermissionError:
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@@ -7,19 +7,57 @@ from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from pydantic import ValidationError
from compose_farm.executor import is_local
from compose_farm.paths import find_config_path
from compose_farm.state import (
get_orphaned_services,
get_service_host,
get_services_needing_migration,
get_services_not_in_state,
group_running_services_by_host,
load_state,
)
from compose_farm.web.deps import get_config, get_templates
from compose_farm.web.deps import (
extract_config_error,
get_config,
get_templates,
)
router = APIRouter()
@router.get("/console", response_class=HTMLResponse)
async def console(request: Request) -> HTMLResponse:
"""Console page with terminal and editor."""
config = get_config()
templates = get_templates()
# Find local host and sort it first
local_host = None
for name, host in config.hosts.items():
if is_local(host):
local_host = name
break
# Sort hosts with local first
hosts = sorted(config.hosts.keys())
if local_host:
hosts = [local_host] + [h for h in hosts if h != local_host]
# Get config path for default editor file
config_path = str(config.config_path) if config.config_path else ""
return templates.TemplateResponse(
"console.html",
{
"request": request,
"hosts": hosts,
"local_host": local_host,
"config_path": config_path,
},
)
@router.get("/", response_class=HTMLResponse)
async def index(request: Request) -> HTMLResponse:
"""Dashboard page - combined view of all cluster info."""
@@ -30,11 +68,7 @@ async def index(request: Request) -> HTMLResponse:
try:
config = get_config()
except (ValidationError, FileNotFoundError) as e:
# Extract error message
if isinstance(e, ValidationError):
config_error = "; ".join(err.get("msg", str(err)) for err in e.errors())
else:
config_error = str(e)
config_error = extract_config_error(e)
# Read raw config content for the editor
config_path = find_config_path()
@@ -70,14 +104,8 @@ async def index(request: Request) -> HTMLResponse:
migrations = get_services_needing_migration(config)
not_started = get_services_not_in_state(config)
# Group services by host
services_by_host: dict[str, list[str]] = {}
for svc, host in deployed.items():
if isinstance(host, list):
for h in host:
services_by_host.setdefault(h, []).append(svc)
else:
services_by_host.setdefault(host, []).append(svc)
# Group services by host (filter out hosts with no running services)
services_by_host = group_running_services_by_host(deployed, config.hosts)
# Config file content
config_content = ""
@@ -186,10 +214,7 @@ async def config_error_partial(request: Request) -> HTMLResponse:
get_config()
return HTMLResponse("") # No error
except (ValidationError, FileNotFoundError) as e:
if isinstance(e, ValidationError):
error = "; ".join(err.get("msg", str(err)) for err in e.errors())
else:
error = str(e)
error = extract_config_error(e)
return templates.TemplateResponse(
"partials/config_error.html", {"request": request, "config_error": error}
)
@@ -246,15 +271,7 @@ async def services_by_host_partial(request: Request, expanded: bool = True) -> H
templates = get_templates()
deployed = load_state(config)
# Group services by host
services_by_host: dict[str, list[str]] = {}
for svc, host in deployed.items():
if isinstance(host, list):
for h in host:
services_by_host.setdefault(h, []).append(svc)
else:
services_by_host.setdefault(host, []).append(svc)
services_by_host = group_running_services_by_host(deployed, config.hosts)
return templates.TemplateResponse(
"partials/services_by_host.html",

View File

@@ -17,6 +17,35 @@ const editors = {};
let monacoLoaded = false;
let monacoLoading = false;
// Language detection from file path
const LANGUAGE_MAP = {
'yaml': 'yaml', 'yml': 'yaml',
'json': 'json',
'js': 'javascript', 'mjs': 'javascript',
'ts': 'typescript', 'tsx': 'typescript',
'py': 'python',
'sh': 'shell', 'bash': 'shell',
'md': 'markdown',
'html': 'html', 'htm': 'html',
'css': 'css',
'sql': 'sql',
'toml': 'toml',
'ini': 'ini', 'conf': 'ini',
'dockerfile': 'dockerfile',
'env': 'plaintext'
};
/**
* Get Monaco language from file path
* @param {string} path - File path
* @returns {string} Monaco language identifier
*/
function getLanguageFromPath(path) {
const ext = path.split('.').pop().toLowerCase();
return LANGUAGE_MAP[ext] || 'plaintext';
}
window.getLanguageFromPath = getLanguageFromPath;
// Terminal color theme (dark mode matching PicoCSS)
const TERMINAL_THEME = {
background: '#1a1a2e',
@@ -87,6 +116,7 @@ function createWebSocket(path) {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
return new WebSocket(`${protocol}//${window.location.host}${path}`);
}
window.createWebSocket = createWebSocket;
/**
* Initialize a terminal and connect to WebSocket for streaming
@@ -223,10 +253,16 @@ function loadMonaco(callback) {
* @param {HTMLElement} container - Container element
* @param {string} content - Initial content
* @param {string} language - Editor language (yaml, plaintext, etc.)
* @param {boolean} readonly - Whether editor is read-only
* @param {object} opts - Options: { readonly, onSave }
* @returns {object} Monaco editor instance
*/
function createEditor(container, content, language, readonly = false) {
function createEditor(container, content, language, opts = {}) {
// Support legacy boolean readonly parameter
if (typeof opts === 'boolean') {
opts = { readonly: opts };
}
const { readonly = false, onSave = null } = opts;
const options = {
value: content,
language: language,
@@ -249,12 +285,17 @@ function createEditor(container, content, language, readonly = false) {
// Add Command+S / Ctrl+S handler for editable editors
if (!readonly) {
editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, function() {
saveAllEditors();
if (onSave) {
onSave(editor);
} else {
saveAllEditors();
}
});
}
return editor;
}
window.createEditor = createEditor;
/**
* Initialize all Monaco editors on the page
@@ -370,6 +411,16 @@ function initPage() {
document.addEventListener('DOMContentLoaded', function() {
initPage();
initKeyboardShortcuts();
// Handle ?action= parameter (from command palette navigation)
const params = new URLSearchParams(window.location.search);
const action = params.get('action');
if (action && window.location.pathname === '/') {
// Clear the URL parameter
history.replaceState({}, '', '/');
// Trigger the action
htmx.ajax('POST', `/api/${action}`, {swap: 'none'});
}
});
// Re-initialize after HTMX swaps main content
@@ -441,41 +492,59 @@ document.body.addEventListener('htmx:afterRequest', function(evt) {
const fab = document.getElementById('cmd-fab');
if (!dialog || !input || !list) return;
const colors = { service: '#22c55e', action: '#eab308', nav: '#3b82f6' };
// Load icons from template (rendered server-side from icons.html)
const iconTemplate = document.getElementById('cmd-icons');
const icons = {};
if (iconTemplate) {
iconTemplate.content.querySelectorAll('[data-icon]').forEach(el => {
icons[el.dataset.icon] = el.innerHTML;
});
}
const colors = { service: '#22c55e', action: '#eab308', nav: '#3b82f6', app: '#a855f7' };
let commands = [];
let filtered = [];
let selected = 0;
const post = (url) => () => htmx.ajax('POST', url, {swap: 'none'});
const nav = (url) => () => window.location.href = url;
const cmd = (type, name, desc, action) => ({ type, name, desc, action });
// Navigate to dashboard and trigger action (or just POST if already on dashboard)
const dashboardAction = (endpoint) => () => {
if (window.location.pathname === '/') {
htmx.ajax('POST', `/api/${endpoint}`, {swap: 'none'});
} else {
window.location.href = `/?action=${endpoint}`;
}
};
const cmd = (type, name, desc, action, icon = null) => ({ type, name, desc, action, icon });
function buildCommands() {
const actions = [
cmd('action', 'Apply', 'Make reality match config', post('/api/apply')),
cmd('action', 'Refresh', 'Update state from reality', post('/api/refresh')),
cmd('nav', 'Dashboard', 'Go to dashboard', nav('/')),
cmd('action', 'Apply', 'Make reality match config', dashboardAction('apply'), icons.check),
cmd('action', 'Refresh', 'Update state from reality', dashboardAction('refresh'), icons.refresh_cw),
cmd('app', 'Dashboard', 'Go to dashboard', nav('/'), icons.home),
cmd('app', 'Console', 'Go to console', nav('/console'), icons.terminal),
];
// Add service-specific actions if on a service page
const match = window.location.pathname.match(/^\/service\/(.+)$/);
if (match) {
const svc = decodeURIComponent(match[1]);
const svcCmd = (name, desc, endpoint) => cmd('service', name, `${desc} ${svc}`, post(`/api/service/${svc}/${endpoint}`));
const svcCmd = (name, desc, endpoint, icon) => cmd('service', name, `${desc} ${svc}`, post(`/api/service/${svc}/${endpoint}`), icon);
actions.unshift(
svcCmd('Up', 'Start', 'up'),
svcCmd('Down', 'Stop', 'down'),
svcCmd('Restart', 'Restart', 'restart'),
svcCmd('Pull', 'Pull', 'pull'),
svcCmd('Update', 'Pull + restart', 'update'),
svcCmd('Logs', 'View logs for', 'logs'),
svcCmd('Up', 'Start', 'up', icons.play),
svcCmd('Down', 'Stop', 'down', icons.square),
svcCmd('Restart', 'Restart', 'restart', icons.rotate_cw),
svcCmd('Pull', 'Pull', 'pull', icons.cloud_download),
svcCmd('Update', 'Pull + restart', 'update', icons.refresh_cw),
svcCmd('Logs', 'View logs for', 'logs', icons.file_text),
);
}
// Add nav commands for all services from sidebar
const services = [...document.querySelectorAll('#sidebar-services li[data-svc] a[href]')].map(a => {
const name = a.getAttribute('href').replace('/service/', '');
return cmd('nav', name, 'Go to service', nav(`/service/${name}`));
return cmd('nav', name, 'Go to service', nav(`/service/${name}`), icons.box);
});
commands = [...actions, ...services];
@@ -490,10 +559,13 @@ document.body.addEventListener('htmx:afterRequest', function(evt) {
function render() {
list.innerHTML = filtered.map((c, i) => `
<a class="flex justify-between items-center px-3 py-2 rounded-r cursor-pointer hover:bg-base-200 border-l-4 ${i === selected ? 'bg-base-300' : ''}" style="border-left-color: ${colors[c.type] || '#666'}" data-idx="${i}">
<span><span class="opacity-50 text-xs mr-2">${c.type}</span>${c.name}</span>
<span class="flex items-center gap-2">${c.icon || ''}<span>${c.name}</span></span>
<span class="opacity-40 text-xs">${c.desc}</span>
</a>
`).join('') || '<div class="opacity-50 p-2">No matches</div>';
// Scroll selected item into view
const sel = list.querySelector(`[data-idx="${selected}"]`);
if (sel) sel.scrollIntoView({ block: 'nearest' });
}
function open() {

View File

@@ -4,12 +4,17 @@ from __future__ import annotations
import asyncio
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any
from compose_farm.executor import build_ssh_command
from compose_farm.ssh_keys import get_ssh_auth_sock
if TYPE_CHECKING:
from compose_farm.config import Config
# Environment variable to identify the web service (for self-update detection)
CF_WEB_SERVICE = os.environ.get("CF_WEB_SERVICE", "")
# ANSI escape codes for terminal output
RED = "\x1b[31m"
GREEN = "\x1b[32m"
@@ -17,25 +22,6 @@ DIM = "\x1b[2m"
RESET = "\x1b[0m"
CRLF = "\r\n"
def _get_ssh_auth_sock() -> str | None:
"""Get SSH_AUTH_SOCK, auto-detecting forwarded agent if needed."""
sock = os.environ.get("SSH_AUTH_SOCK")
if sock and Path(sock).is_socket():
return sock
# Try to find a forwarded SSH agent socket
agent_dir = Path.home() / ".ssh" / "agent"
if agent_dir.is_dir():
sockets = sorted(
agent_dir.glob("s.*.sshd.*"), key=lambda p: p.stat().st_mtime, reverse=True
)
for s in sockets:
if s.is_socket():
return str(s)
return None
# In-memory task registry
tasks: dict[str, dict[str, Any]] = {}
@@ -69,7 +55,7 @@ async def run_cli_streaming(
env = {"FORCE_COLOR": "1", "TERM": "xterm-256color", "COLUMNS": "120"}
# Ensure SSH agent is available (auto-detect if needed)
ssh_sock = _get_ssh_auth_sock()
ssh_sock = get_ssh_auth_sock()
if ssh_sock:
env["SSH_AUTH_SOCK"] = ssh_sock
@@ -97,6 +83,76 @@ async def run_cli_streaming(
tasks[task_id]["status"] = "failed"
def _is_self_update(service: str, command: str) -> bool:
"""Check if this is a self-update (updating the web service itself).
Self-updates need special handling because running 'down' on the container
we're running in would kill the process before 'up' can execute.
"""
if not CF_WEB_SERVICE or service != CF_WEB_SERVICE:
return False
# Commands that involve 'down' need SSH: update, restart, down
return command in ("update", "restart", "down")
async def _run_cli_via_ssh(
config: Config,
args: list[str],
task_id: str,
) -> None:
"""Run a cf CLI command via SSH to the host.
Used for self-updates to ensure the command survives container restart.
"""
try:
# Get the host for the web service
host = config.get_host(CF_WEB_SERVICE)
# Build the remote command
remote_cmd = f"cf {' '.join(args)} --config={config.config_path}"
# Show what we're doing
await stream_to_task(
task_id,
f"{DIM}$ ssh {host.user}@{host.address} {remote_cmd}{RESET}{CRLF}",
)
await stream_to_task(
task_id,
f"{GREEN}Running via SSH (self-update protection){RESET}{CRLF}",
)
# Build SSH command using shared helper
ssh_args = build_ssh_command(host, remote_cmd)
# Set up environment with SSH agent
env = {**os.environ, "FORCE_COLOR": "1", "TERM": "xterm-256color"}
ssh_sock = get_ssh_auth_sock()
if ssh_sock:
env["SSH_AUTH_SOCK"] = ssh_sock
process = await asyncio.create_subprocess_exec(
*ssh_args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
env=env,
)
# Stream output
if process.stdout:
async for line in process.stdout:
text = line.decode("utf-8", errors="replace")
if text.endswith("\n") and not text.endswith("\r\n"):
text = text[:-1] + "\r\n"
await stream_to_task(task_id, text)
exit_code = await process.wait()
tasks[task_id]["status"] = "completed" if exit_code == 0 else "failed"
except Exception as e:
await stream_to_task(task_id, f"{RED}Error: {e}{RESET}{CRLF}")
tasks[task_id]["status"] = "failed"
async def run_compose_streaming(
config: Config,
service: str,
@@ -111,4 +167,9 @@ async def run_compose_streaming(
# Build CLI args
cli_args = [cli_cmd, service, *extra_args]
await run_cli_streaming(config, cli_args, task_id)
# Use SSH for self-updates to survive container restart
if _is_self_update(service, cli_cmd):
await _run_cli_via_ssh(config, cli_args, task_id)
else:
await run_cli_streaming(config, cli_args, task_id)

View File

@@ -0,0 +1,241 @@
{% extends "base.html" %}
{% from "partials/components.html" import page_header %}
{% from "partials/icons.html" import terminal, save %}
{% block title %}Console - Compose Farm{% endblock %}
{% block content %}
<div class="max-w-6xl">
{{ page_header("Console", "Terminal and editor access") }}
<!-- Host Selector -->
<div class="flex items-center gap-4 mb-4">
<label class="font-semibold">Host:</label>
<select id="console-host-select" class="select select-sm select-bordered">
{% for name in hosts %}
<option value="{{ name }}">{{ name }}{% if name == local_host %} (local){% endif %}</option>
{% endfor %}
</select>
<button id="console-connect-btn" class="btn btn-sm btn-primary" onclick="connectConsole()">Connect</button>
<span id="console-status" class="text-sm opacity-60"></span>
</div>
<!-- Terminal -->
<div class="mb-6">
<div class="flex items-center gap-2 mb-2">
<h3 class="font-semibold flex items-center gap-2">{{ terminal() }} Terminal</h3>
<span class="text-xs opacity-50">Full shell access to selected host</span>
</div>
<div id="console-terminal" class="w-full bg-base-300 rounded-lg overflow-hidden resize-y" style="height: 384px; min-height: 200px;"></div>
</div>
<!-- Editor -->
<div class="mb-6">
<div class="flex items-center justify-between mb-2">
<div class="flex items-center gap-4">
<h3 class="font-semibold">Editor</h3>
<input type="text" id="console-file-path" class="input input-sm input-bordered w-96" placeholder="Enter file path (e.g., ~/docker-compose.yaml)" value="{{ config_path }}">
<button class="btn btn-sm btn-outline" onclick="loadFile()">Open</button>
</div>
<div class="flex items-center gap-2">
<span id="editor-status" class="text-sm opacity-60"></span>
<button id="console-save-btn" class="btn btn-sm btn-primary" onclick="saveFile()">{{ save() }} Save</button>
</div>
</div>
<div id="console-editor" class="resize-y overflow-hidden rounded-lg" style="height: 512px; min-height: 200px;"></div>
</div>
</div>
<script>
// Use var to allow re-declaration on HTMX navigation
var consoleTerminal = null;
var consoleWs = null;
var consoleEditor = null;
var currentFilePath = null;
var currentHost = null;
function connectConsole() {
const hostSelect = document.getElementById('console-host-select');
const host = hostSelect.value;
const statusEl = document.getElementById('console-status');
const terminalEl = document.getElementById('console-terminal');
if (!host) {
statusEl.textContent = 'Please select a host';
return;
}
currentHost = host;
// Clean up existing connection
if (consoleWs) {
consoleWs.close();
consoleWs = null;
}
if (consoleTerminal) {
consoleTerminal.dispose();
consoleTerminal = null;
}
statusEl.textContent = 'Connecting...';
// Create WebSocket
consoleWs = createWebSocket(`/ws/shell/${host}`);
// Resize callback - createTerminal's ResizeObserver calls this on container resize
const sendSize = (cols, rows) => {
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
consoleWs.send(JSON.stringify({ type: 'resize', cols, rows }));
}
};
// Create terminal with resize callback
const { term } = createTerminal(terminalEl, { cursorBlink: true }, sendSize);
consoleTerminal = term;
consoleWs.onopen = () => {
statusEl.textContent = `Connected to ${host}`;
sendSize(term.cols, term.rows);
term.focus();
// Auto-load the default file once editor is ready
const pathInput = document.getElementById('console-file-path');
if (pathInput && pathInput.value) {
const tryLoad = () => {
if (consoleEditor) {
loadFile();
} else {
setTimeout(tryLoad, 100);
}
};
tryLoad();
}
};
consoleWs.onmessage = (event) => term.write(event.data);
consoleWs.onclose = () => {
statusEl.textContent = 'Disconnected';
term.write(`${ANSI.CRLF}${ANSI.DIM}[Connection closed]${ANSI.RESET}${ANSI.CRLF}`);
};
consoleWs.onerror = (error) => {
statusEl.textContent = 'Connection error';
term.write(`${ANSI.RED}[WebSocket Error]${ANSI.RESET}${ANSI.CRLF}`);
console.error('Console WebSocket error:', error);
};
// Send input to WebSocket
term.onData((data) => {
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
consoleWs.send(data);
}
});
}
function initConsoleEditor() {
const editorEl = document.getElementById('console-editor');
if (!editorEl || consoleEditor) return;
loadMonaco(() => {
consoleEditor = createEditor(editorEl, '', 'plaintext', { onSave: saveFile });
});
}
async function loadFile() {
const pathInput = document.getElementById('console-file-path');
const path = pathInput.value.trim();
const statusEl = document.getElementById('editor-status');
if (!path) {
statusEl.textContent = 'Enter a file path';
return;
}
if (!currentHost) {
statusEl.textContent = 'Connect to a host first';
return;
}
statusEl.textContent = `Loading ${path}...`;
try {
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(path)}`);
const data = await response.json();
if (!response.ok || !data.success) {
statusEl.textContent = data.detail || 'Failed to load file';
return;
}
const language = getLanguageFromPath(path);
if (consoleEditor) {
consoleEditor.setValue(data.content);
monaco.editor.setModelLanguage(consoleEditor.getModel(), language);
currentFilePath = path; // Only set after content is loaded
statusEl.textContent = `Loaded: ${path}`;
} else {
statusEl.textContent = 'Editor not ready';
}
} catch (e) {
statusEl.textContent = `Error: ${e.message}`;
}
}
async function saveFile() {
const statusEl = document.getElementById('editor-status');
if (!currentFilePath) {
statusEl.textContent = 'No file loaded';
return;
}
if (!currentHost) {
statusEl.textContent = 'Not connected to a host';
return;
}
if (!consoleEditor) {
statusEl.textContent = 'Editor not ready';
return;
}
statusEl.textContent = `Saving ${currentFilePath}...`;
try {
const content = consoleEditor.getValue();
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(currentFilePath)}`, {
method: 'PUT',
headers: { 'Content-Type': 'text/plain' },
body: content
});
const data = await response.json();
if (!response.ok || !data.success) {
statusEl.textContent = data.detail || 'Failed to save file';
return;
}
statusEl.textContent = `Saved: ${currentFilePath}`;
} catch (e) {
statusEl.textContent = `Error: ${e.message}`;
}
}
// Initialize editor and auto-connect to first host
function init() {
initConsoleEditor();
const hostSelect = document.getElementById('console-host-select');
if (hostSelect && hostSelect.options.length > 0) {
connectConsole();
}
}
// On HTMX navigation, dependencies (app.js) are already loaded.
// On hard refresh, this script runs before app.js, so wait for DOMContentLoaded.
if (typeof createTerminal === 'function') {
init();
} else {
document.addEventListener('DOMContentLoaded', init);
}
</script>
{% endblock content %}

View File

@@ -1,4 +1,18 @@
{% from "partials/icons.html" import search, command %}
{% from "partials/icons.html" import search, play, square, rotate_cw, cloud_download, refresh_cw, file_text, check, home, terminal, box %}
<!-- Icons for command palette (referenced by JS) -->
<template id="cmd-icons">
<span data-icon="play">{{ play() }}</span>
<span data-icon="square">{{ square() }}</span>
<span data-icon="rotate_cw">{{ rotate_cw() }}</span>
<span data-icon="cloud_download">{{ cloud_download() }}</span>
<span data-icon="refresh_cw">{{ refresh_cw() }}</span>
<span data-icon="file_text">{{ file_text() }}</span>
<span data-icon="check">{{ check() }}</span>
<span data-icon="home">{{ home() }}</span>
<span data-icon="terminal">{{ terminal() }}</span>
<span data-icon="box">{{ box() }}</span>
</template>
<dialog id="cmd-palette" class="modal">
<div class="modal-box max-w-lg p-0">
<label class="input input-lg bg-base-100 border-0 border-b border-base-300 w-full rounded-none rounded-t-box sticky top-0 z-10 focus-within:outline-none">
@@ -15,5 +29,7 @@
<!-- Floating button to open command palette -->
<button id="cmd-fab" class="btn btn-circle glass shadow-lg fixed bottom-6 right-6 z-50 hover:ring hover:ring-base-content/50" title="Command Palette (⌘K)">
{{ command(24) }}
<span class="flex items-center gap-0.5 text-sm font-semibold">
<span class="opacity-70"></span><span>K</span>
</span>
</button>

View File

@@ -1,8 +1,9 @@
{% from "partials/icons.html" import home, search %}
<!-- Dashboard Link -->
{% from "partials/icons.html" import home, search, terminal %}
<!-- Navigation Links -->
<div class="mb-4">
<ul class="menu" hx-boost="true" hx-target="#main-content" hx-select="#main-content" hx-swap="outerHTML">
<li><a href="/" class="font-semibold">{{ home() }} Dashboard</a></li>
<li><a href="/console" class="font-semibold">{{ terminal() }} Console</a></li>
</ul>
</div>

View File

@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any
import asyncssh
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from compose_farm.executor import is_local
from compose_farm.executor import is_local, ssh_connect_kwargs
from compose_farm.web.deps import get_config
from compose_farm.web.streaming import CRLF, DIM, GREEN, RED, RESET, tasks
@@ -121,6 +121,14 @@ async def _bridge_websocket_to_ssh(
proc.terminate()
def _make_controlling_tty(slave_fd: int) -> None:
"""Set up the slave PTY as the controlling terminal for the child process."""
# Create a new session
os.setsid()
# Make the slave fd the controlling terminal
fcntl.ioctl(slave_fd, termios.TIOCSCTTY, 0)
async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
"""Run docker exec locally with PTY."""
master_fd, slave_fd = pty.openpty()
@@ -131,6 +139,8 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
stdout=slave_fd,
stderr=slave_fd,
close_fds=True,
preexec_fn=lambda: _make_controlling_tty(slave_fd),
start_new_session=False, # We handle setsid in preexec_fn
)
os.close(slave_fd)
@@ -141,13 +151,14 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
await _bridge_websocket_to_fd(websocket, master_fd, proc)
async def _run_remote_exec(websocket: WebSocket, host: Host, exec_cmd: str) -> None:
async def _run_remote_exec(
websocket: WebSocket, host: Host, exec_cmd: str, *, agent_forwarding: bool = False
) -> None:
"""Run docker exec on remote host via SSH with PTY."""
# ssh_connect_kwargs includes agent_path and client_keys fallback
async with asyncssh.connect(
host.address,
port=host.port,
username=host.user,
known_hosts=None,
**ssh_connect_kwargs(host),
agent_forwarding=agent_forwarding,
) as conn:
proc: asyncssh.SSHClientProcess[Any] = await conn.create_process(
exec_cmd,
@@ -202,6 +213,48 @@ async def exec_websocket(
await websocket.close()
async def _run_shell_session(
websocket: WebSocket,
host_name: str,
) -> None:
"""Run an interactive shell session on a host over WebSocket."""
config = get_config()
host = config.hosts.get(host_name)
if not host:
await websocket.send_text(f"{RED}Host '{host_name}' not found{RESET}{CRLF}")
return
# Start interactive shell in home directory (avoid login shell to prevent job control warnings)
shell_cmd = "cd ~ && exec bash -i 2>/dev/null || exec sh -i"
if is_local(host):
await _run_local_exec(websocket, shell_cmd)
else:
await _run_remote_exec(websocket, host, shell_cmd, agent_forwarding=True)
@router.websocket("/ws/shell/{host}")
async def shell_websocket(
websocket: WebSocket,
host: str,
) -> None:
"""WebSocket endpoint for interactive host shell access."""
await websocket.accept()
try:
await websocket.send_text(f"{DIM}[Connecting to {host}...]{RESET}{CRLF}")
await _run_shell_session(websocket, host)
await websocket.send_text(f"{CRLF}{DIM}[Disconnected]{RESET}{CRLF}")
except WebSocketDisconnect:
pass
except Exception as e:
with contextlib.suppress(Exception):
await websocket.send_text(f"{RED}Error: {e}{RESET}{CRLF}")
finally:
with contextlib.suppress(Exception):
await websocket.close()
@router.websocket("/ws/terminal/{task_id}")
async def terminal_websocket(websocket: WebSocket, task_id: str) -> None:
"""WebSocket endpoint for terminal streaming."""

View File

@@ -150,7 +150,7 @@ class TestLogsHostFilter:
mock_run_async, _ = _mock_run_async_factory(["svc1", "svc2"])
with (
patch("compose_farm.cli.monitoring.load_config_or_exit", return_value=cfg),
patch("compose_farm.cli.common.load_config_or_exit", return_value=cfg),
patch("compose_farm.cli.monitoring.run_async", side_effect=mock_run_async),
patch("compose_farm.cli.monitoring.run_on_services") as mock_run,
):
@@ -174,7 +174,7 @@ class TestLogsHostFilter:
mock_run_async, _ = _mock_run_async_factory(["svc1", "svc2"])
with (
patch("compose_farm.cli.monitoring.load_config_or_exit", return_value=cfg),
patch("compose_farm.cli.common.load_config_or_exit", return_value=cfg),
patch("compose_farm.cli.monitoring.run_async", side_effect=mock_run_async),
patch("compose_farm.cli.monitoring.run_on_services") as mock_run,
):

114
tests/test_cli_ssh.py Normal file
View File

@@ -0,0 +1,114 @@
"""Tests for CLI ssh commands."""
from pathlib import Path
from unittest.mock import patch
from typer.testing import CliRunner
from compose_farm.cli.app import app
runner = CliRunner()
class TestSshKeygen:
"""Tests for cf ssh keygen command."""
def test_keygen_generates_key(self, tmp_path: Path) -> None:
"""Generate SSH key when none exists."""
key_path = tmp_path / "compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
with (
patch("compose_farm.cli.ssh.SSH_KEY_PATH", key_path),
patch("compose_farm.cli.ssh.SSH_PUBKEY_PATH", pubkey_path),
patch("compose_farm.cli.ssh.key_exists", return_value=False),
):
result = runner.invoke(app, ["ssh", "keygen"])
# Command runs (may fail if ssh-keygen not available in test env)
assert result.exit_code in (0, 1)
def test_keygen_skips_if_exists(self, tmp_path: Path) -> None:
"""Skip key generation if key already exists."""
key_path = tmp_path / "compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
with (
patch("compose_farm.cli.ssh.SSH_KEY_PATH", key_path),
patch("compose_farm.cli.ssh.SSH_PUBKEY_PATH", pubkey_path),
patch("compose_farm.cli.ssh.key_exists", return_value=True),
):
result = runner.invoke(app, ["ssh", "keygen"])
assert "already exists" in result.output
class TestSshStatus:
"""Tests for cf ssh status command."""
def test_status_shows_no_key(self, tmp_path: Path) -> None:
"""Show message when no key exists."""
config_file = tmp_path / "compose-farm.yaml"
config_file.write_text("""
hosts:
local:
address: localhost
services:
test: local
""")
with patch("compose_farm.cli.ssh.key_exists", return_value=False):
result = runner.invoke(app, ["ssh", "status", f"--config={config_file}"])
assert "No key found" in result.output
def test_status_shows_key_exists(self, tmp_path: Path) -> None:
"""Show key info when key exists."""
config_file = tmp_path / "compose-farm.yaml"
config_file.write_text("""
hosts:
local:
address: localhost
services:
test: local
""")
with (
patch("compose_farm.cli.ssh.key_exists", return_value=True),
patch("compose_farm.cli.ssh.get_pubkey_content", return_value="ssh-ed25519 AAAA..."),
):
result = runner.invoke(app, ["ssh", "status", f"--config={config_file}"])
assert "Key exists" in result.output
class TestSshSetup:
"""Tests for cf ssh setup command."""
def test_setup_no_remote_hosts(self, tmp_path: Path) -> None:
"""Show message when no remote hosts configured."""
config_file = tmp_path / "compose-farm.yaml"
config_file.write_text("""
hosts:
local:
address: localhost
services:
test: local
""")
result = runner.invoke(app, ["ssh", "setup", f"--config={config_file}"])
assert "No remote hosts" in result.output
class TestSshHelp:
"""Tests for cf ssh help."""
def test_ssh_help(self) -> None:
"""Show help for ssh command."""
result = runner.invoke(app, ["ssh", "--help"])
assert result.exit_code == 0
assert "setup" in result.output
assert "status" in result.output
assert "keygen" in result.output

245
tests/test_ssh_keys.py Normal file
View File

@@ -0,0 +1,245 @@
"""Tests for ssh_keys module."""
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
from compose_farm.config import Host
from compose_farm.executor import ssh_connect_kwargs
from compose_farm.ssh_keys import (
SSH_KEY_PATH,
get_key_path,
get_pubkey_content,
get_ssh_auth_sock,
get_ssh_env,
key_exists,
)
class TestGetSshAuthSock:
"""Tests for get_ssh_auth_sock function."""
def test_returns_env_var_when_socket_exists(self) -> None:
"""Return SSH_AUTH_SOCK env var if the socket exists."""
mock_path = MagicMock()
mock_path.is_socket.return_value = True
with (
patch.dict(os.environ, {"SSH_AUTH_SOCK": "/tmp/agent.sock"}),
patch("compose_farm.ssh_keys.Path", return_value=mock_path),
):
result = get_ssh_auth_sock()
assert result == "/tmp/agent.sock"
def test_returns_none_when_env_var_not_socket(self, tmp_path: Path) -> None:
"""Return None if SSH_AUTH_SOCK points to non-socket."""
regular_file = tmp_path / "not_a_socket"
regular_file.touch()
with (
patch.dict(os.environ, {"SSH_AUTH_SOCK": str(regular_file)}),
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
):
# Should fall through to agent dir check, which won't exist
result = get_ssh_auth_sock()
assert result is None
def test_finds_agent_in_ssh_agent_dir(self, tmp_path: Path) -> None:
"""Find agent socket in ~/.ssh/agent/ directory."""
# Create agent directory structure with a regular file
agent_dir = tmp_path / ".ssh" / "agent"
agent_dir.mkdir(parents=True)
sock_path = agent_dir / "s.12345.sshd.67890"
sock_path.touch() # Create as regular file
with (
patch.dict(os.environ, {}, clear=False),
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
patch.object(Path, "is_socket", return_value=True),
):
os.environ.pop("SSH_AUTH_SOCK", None)
result = get_ssh_auth_sock()
assert result == str(sock_path)
def test_returns_none_when_no_agent_found(self, tmp_path: Path) -> None:
"""Return None when no SSH agent socket is found."""
with (
patch.dict(os.environ, {}, clear=False),
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
):
os.environ.pop("SSH_AUTH_SOCK", None)
result = get_ssh_auth_sock()
assert result is None
class TestGetSshEnv:
"""Tests for get_ssh_env function."""
def test_returns_env_with_ssh_auth_sock(self) -> None:
"""Return env dict with SSH_AUTH_SOCK set."""
with patch("compose_farm.ssh_keys.get_ssh_auth_sock", return_value="/tmp/agent.sock"):
result = get_ssh_env()
assert result["SSH_AUTH_SOCK"] == "/tmp/agent.sock"
# Should include other env vars too
assert "PATH" in result or len(result) > 1
def test_returns_env_without_ssh_auth_sock_when_none(self, tmp_path: Path) -> None:
"""Return env without SSH_AUTH_SOCK when no agent found."""
with (
patch.dict(os.environ, {}, clear=False),
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
):
os.environ.pop("SSH_AUTH_SOCK", None)
result = get_ssh_env()
# SSH_AUTH_SOCK should not be set if no agent found
assert result.get("SSH_AUTH_SOCK") is None
class TestKeyExists:
"""Tests for key_exists function."""
def test_returns_true_when_both_keys_exist(self, tmp_path: Path) -> None:
"""Return True when both private and public keys exist."""
key_path = tmp_path / "compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
key_path.touch()
pubkey_path.touch()
with (
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
):
assert key_exists() is True
def test_returns_false_when_private_key_missing(self, tmp_path: Path) -> None:
"""Return False when private key doesn't exist."""
key_path = tmp_path / "compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
pubkey_path.touch() # Only public key exists
with (
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
):
assert key_exists() is False
def test_returns_false_when_public_key_missing(self, tmp_path: Path) -> None:
"""Return False when public key doesn't exist."""
key_path = tmp_path / "compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
key_path.touch() # Only private key exists
with (
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
):
assert key_exists() is False
class TestGetKeyPath:
"""Tests for get_key_path function."""
def test_returns_path_when_key_exists(self) -> None:
"""Return key path when key exists."""
with patch("compose_farm.ssh_keys.key_exists", return_value=True):
result = get_key_path()
assert result == SSH_KEY_PATH
def test_returns_none_when_key_missing(self) -> None:
"""Return None when key doesn't exist."""
with patch("compose_farm.ssh_keys.key_exists", return_value=False):
result = get_key_path()
assert result is None
class TestGetPubkeyContent:
"""Tests for get_pubkey_content function."""
def test_returns_content_when_exists(self, tmp_path: Path) -> None:
"""Return public key content when file exists."""
pubkey_content = "ssh-ed25519 AAAA... compose-farm"
pubkey_path = tmp_path / "compose-farm.pub"
pubkey_path.write_text(pubkey_content + "\n")
with patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path):
result = get_pubkey_content()
assert result == pubkey_content
def test_returns_none_when_missing(self, tmp_path: Path) -> None:
"""Return None when public key doesn't exist."""
pubkey_path = tmp_path / "compose-farm.pub" # Doesn't exist
with patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path):
result = get_pubkey_content()
assert result is None
class TestSshConnectKwargs:
"""Tests for ssh_connect_kwargs function."""
def test_basic_kwargs(self) -> None:
"""Return basic connection kwargs."""
host = Host(address="example.com", port=22, user="testuser")
with (
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
patch("compose_farm.executor.get_key_path", return_value=None),
):
result = ssh_connect_kwargs(host)
assert result["host"] == "example.com"
assert result["port"] == 22
assert result["username"] == "testuser"
assert result["known_hosts"] is None
assert "agent_path" not in result
assert "client_keys" not in result
def test_includes_agent_path_when_available(self) -> None:
"""Include agent_path when SSH agent is available."""
host = Host(address="example.com")
with (
patch("compose_farm.executor.get_ssh_auth_sock", return_value="/tmp/agent.sock"),
patch("compose_farm.executor.get_key_path", return_value=None),
):
result = ssh_connect_kwargs(host)
assert result["agent_path"] == "/tmp/agent.sock"
def test_includes_client_keys_when_key_exists(self, tmp_path: Path) -> None:
"""Include client_keys when compose-farm key exists."""
host = Host(address="example.com")
key_path = tmp_path / "compose-farm"
with (
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
patch("compose_farm.executor.get_key_path", return_value=key_path),
):
result = ssh_connect_kwargs(host)
assert result["client_keys"] == [str(key_path)]
def test_includes_both_agent_and_key(self, tmp_path: Path) -> None:
"""Include both agent_path and client_keys when both available."""
host = Host(address="example.com")
key_path = tmp_path / "compose-farm"
with (
patch("compose_farm.executor.get_ssh_auth_sock", return_value="/tmp/agent.sock"),
patch("compose_farm.executor.get_key_path", return_value=key_path),
):
result = ssh_connect_kwargs(host)
assert result["agent_path"] == "/tmp/agent.sock"
assert result["client_keys"] == [str(key_path)]
def test_custom_port(self) -> None:
"""Handle custom SSH port."""
host = Host(address="example.com", port=2222)
with (
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
patch("compose_farm.executor.get_key_path", return_value=None),
):
result = ssh_connect_kwargs(host)
assert result["port"] == 2222

54
tests/web/test_backup.py Normal file
View File

@@ -0,0 +1,54 @@
"""Tests for file backup functionality."""
from pathlib import Path
from compose_farm.web.routes.api import _backup_file, _save_with_backup
def test_backup_creates_timestamped_file(tmp_path: Path) -> None:
"""Test that backup creates file in .backups with correct content."""
test_file = tmp_path / "test.yaml"
test_file.write_text("original content")
backup_path = _backup_file(test_file)
assert backup_path is not None
assert backup_path.parent.name == ".backups"
assert backup_path.name.startswith("test.yaml.")
assert backup_path.read_text() == "original content"
def test_backup_returns_none_for_nonexistent_file(tmp_path: Path) -> None:
"""Test that backup returns None if file doesn't exist."""
assert _backup_file(tmp_path / "nonexistent.yaml") is None
def test_save_creates_new_file(tmp_path: Path) -> None:
"""Test that save creates new file without backup."""
test_file = tmp_path / "new.yaml"
assert _save_with_backup(test_file, "content") is True
assert test_file.read_text() == "content"
assert not (tmp_path / ".backups").exists()
def test_save_skips_unchanged_content(tmp_path: Path) -> None:
"""Test that save returns False and creates no backup if unchanged."""
test_file = tmp_path / "test.yaml"
test_file.write_text("same")
assert _save_with_backup(test_file, "same") is False
assert not (tmp_path / ".backups").exists()
def test_save_creates_backup_before_overwrite(tmp_path: Path) -> None:
"""Test that save backs up original before overwriting."""
test_file = tmp_path / "test.yaml"
test_file.write_text("original")
assert _save_with_backup(test_file, "new") is True
assert test_file.read_text() == "new"
backups = list((tmp_path / ".backups").glob("test.yaml.*"))
assert len(backups) == 1
assert backups[0].read_text() == "original"

View File

@@ -0,0 +1,111 @@
"""Tests to verify template context variables match what templates expect.
Uses runtime validation by actually rendering templates and catching errors.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from fastapi.testclient import TestClient
if TYPE_CHECKING:
from compose_farm.config import Config
@pytest.fixture
def mock_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Config:
"""Create a minimal mock config for template testing."""
compose_dir = tmp_path / "compose"
compose_dir.mkdir()
# Create minimal service directory
svc_dir = compose_dir / "test-service"
svc_dir.mkdir()
(svc_dir / "compose.yaml").write_text("services:\n app:\n image: nginx\n")
config_path = tmp_path / "compose-farm.yaml"
config_path.write_text(f"""
compose_dir: {compose_dir}
hosts:
local-host:
address: localhost
services:
test-service: local-host
""")
state_path = tmp_path / "compose-farm-state.yaml"
state_path.write_text("deployed:\n test-service: local-host\n")
from compose_farm.config import load_config
config = load_config(config_path)
# Patch get_config in all relevant modules
from compose_farm.web import deps
from compose_farm.web.routes import api, pages
monkeypatch.setattr(deps, "get_config", lambda: config)
monkeypatch.setattr(api, "get_config", lambda: config)
monkeypatch.setattr(pages, "get_config", lambda: config)
return config
@pytest.fixture
def client(mock_config: Config) -> TestClient:
"""Create a test client with mocked config."""
from compose_farm.web.app import create_app
return TestClient(create_app())
class TestPageTemplatesRender:
"""Test that page templates render without missing variables."""
def test_index_renders(self, client: TestClient) -> None:
"""Test index page renders without errors."""
response = client.get("/")
assert response.status_code == 200
assert "Compose Farm" in response.text
def test_console_renders(self, client: TestClient) -> None:
"""Test console page renders without errors."""
response = client.get("/console")
assert response.status_code == 200
assert "Console" in response.text
assert "Terminal" in response.text
def test_service_detail_renders(self, client: TestClient) -> None:
"""Test service detail page renders without errors."""
response = client.get("/service/test-service")
assert response.status_code == 200
assert "test-service" in response.text
class TestPartialTemplatesRender:
"""Test that partial templates render without missing variables."""
def test_sidebar_renders(self, client: TestClient) -> None:
"""Test sidebar partial renders without errors."""
response = client.get("/partials/sidebar")
assert response.status_code == 200
assert "Dashboard" in response.text
assert "Console" in response.text
def test_stats_renders(self, client: TestClient) -> None:
"""Test stats partial renders without errors."""
response = client.get("/partials/stats")
assert response.status_code == 200
def test_pending_renders(self, client: TestClient) -> None:
"""Test pending partial renders without errors."""
response = client.get("/partials/pending")
assert response.status_code == 200
def test_services_by_host_renders(self, client: TestClient) -> None:
"""Test services_by_host partial renders without errors."""
response = client.get("/partials/services-by-host")
assert response.status_code == 200