mirror of
https://github.com/basnijholt/compose-farm.git
synced 2026-02-03 14:13:26 +00:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8353dbb7e | ||
|
|
2e6146a94b | ||
|
|
87849a8161 | ||
|
|
c8bf792a9a | ||
|
|
d37295fbee | ||
|
|
266f541d35 | ||
|
|
aabdd550ba | ||
|
|
8ff60a1e3e | ||
|
|
2497bd727a | ||
|
|
e37d9d87ba | ||
|
|
80a1906d90 | ||
|
|
282de12336 | ||
|
|
2c5308aea3 | ||
|
|
5057202938 | ||
|
|
5e1b9987dd | ||
|
|
d9c26f7f2c | ||
|
|
adfcd4bb31 | ||
|
|
95f7d9c3cf | ||
|
|
4c1674cfd8 | ||
|
|
f65ca8420e | ||
|
|
85aff2c271 | ||
|
|
61ca24bb8e | ||
|
|
ed36588358 | ||
|
|
80c8079a8c | ||
|
|
763bedf9f6 | ||
|
|
641f7e91a8 | ||
|
|
4e8e925d59 | ||
|
|
d84858dcfb | ||
|
|
3121ee04eb | ||
|
|
a795132a04 | ||
|
|
a6e491575a | ||
|
|
78bf90afd9 | ||
|
|
76b60bdd96 | ||
|
|
98bfb1bf6d |
22
CLAUDE.md
22
CLAUDE.md
@@ -43,6 +43,10 @@ Icons use [Lucide](https://lucide.dev/). Add new icons as macros in `web/templat
|
||||
7. **State tracking**: Tracks where services are deployed for auto-migration
|
||||
8. **Pre-flight checks**: Verifies NFS mounts and Docker networks exist before starting/migrating
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Imports at top level**: Never add imports inside functions unless they are explicitly marked with `# noqa: PLC0415` and a comment explaining it speeds up CLI startup. Heavy modules like `pydantic`, `yaml`, and `rich.table` are lazily imported to keep `cf --help` fast.
|
||||
|
||||
## Communication Notes
|
||||
|
||||
- Clarify ambiguous wording (e.g., homophones like "right"/"write", "their"/"there").
|
||||
@@ -53,6 +57,24 @@ Icons use [Lucide](https://lucide.dev/). Add new icons as macros in `web/templat
|
||||
- **NEVER merge anything into main.** Always commit directly or use fast-forward/rebase.
|
||||
- Never force push.
|
||||
|
||||
## Releases
|
||||
|
||||
Use `gh release create` to create releases. The tag is created automatically.
|
||||
|
||||
```bash
|
||||
# Check current version
|
||||
git tag --sort=-v:refname | head -1
|
||||
|
||||
# Create release (minor version bump: v0.21.1 -> v0.22.0)
|
||||
gh release create v0.22.0 --title "v0.22.0" --notes "release notes here"
|
||||
```
|
||||
|
||||
Versioning:
|
||||
- **Patch** (v0.21.0 → v0.21.1): Bug fixes
|
||||
- **Minor** (v0.21.1 → v0.22.0): New features, non-breaking changes
|
||||
|
||||
Write release notes manually describing what changed. Group by features and bug fixes.
|
||||
|
||||
## Commands Quick Reference
|
||||
|
||||
CLI available as `cf` or `compose-farm`.
|
||||
|
||||
89
README.md
89
README.md
@@ -23,6 +23,9 @@ A minimal CLI tool to run Docker Compose commands across multiple hosts via SSH.
|
||||
- [Best practices](#best-practices)
|
||||
- [What Compose Farm doesn't do](#what-compose-farm-doesnt-do)
|
||||
- [Installation](#installation)
|
||||
- [SSH Authentication](#ssh-authentication)
|
||||
- [SSH Agent (default)](#ssh-agent-default)
|
||||
- [Dedicated SSH Key (recommended for Docker/Web UI)](#dedicated-ssh-key-recommended-for-dockerweb-ui)
|
||||
- [Configuration](#configuration)
|
||||
- [Multi-Host Services](#multi-host-services)
|
||||
- [Config Command](#config-command)
|
||||
@@ -159,6 +162,62 @@ docker run --rm \
|
||||
|
||||
</details>
|
||||
|
||||
## SSH Authentication
|
||||
|
||||
Compose Farm uses SSH to run commands on remote hosts. There are two authentication methods:
|
||||
|
||||
### SSH Agent (default)
|
||||
|
||||
Works out of the box if you have an SSH agent running with your keys loaded:
|
||||
|
||||
```bash
|
||||
# Verify your agent has keys
|
||||
ssh-add -l
|
||||
|
||||
# Run compose-farm commands
|
||||
cf up --all
|
||||
```
|
||||
|
||||
### Dedicated SSH Key (recommended for Docker/Web UI)
|
||||
|
||||
When running compose-farm in Docker, the SSH agent connection can be lost (e.g., after container restart). The `cf ssh` command sets up a dedicated key that persists:
|
||||
|
||||
```bash
|
||||
# Generate key and copy to all configured hosts
|
||||
cf ssh setup
|
||||
|
||||
# Check status
|
||||
cf ssh status
|
||||
```
|
||||
|
||||
This creates `~/.ssh/compose-farm/id_ed25519` (ED25519, no passphrase) and copies the public key to each host's `authorized_keys`. Compose Farm tries the SSH agent first, then falls back to this key.
|
||||
|
||||
<details><summary>🐳 Docker volume options for SSH keys</summary>
|
||||
|
||||
When running in Docker, mount a volume to persist the SSH keys. Choose ONE option and use it for both `cf` and `web` services:
|
||||
|
||||
**Option 1: Host path (default)** - keys at `~/.ssh/compose-farm/id_ed25519`
|
||||
```yaml
|
||||
volumes:
|
||||
- ~/.ssh/compose-farm:/root/.ssh
|
||||
```
|
||||
|
||||
**Option 2: Named volume** - managed by Docker
|
||||
```yaml
|
||||
volumes:
|
||||
- cf-ssh:/root/.ssh
|
||||
```
|
||||
|
||||
Run setup once after starting the container (while the SSH agent still works):
|
||||
|
||||
```bash
|
||||
docker compose exec web cf ssh setup
|
||||
```
|
||||
|
||||
The keys will persist across restarts.
|
||||
|
||||
</details>
|
||||
|
||||
## Configuration
|
||||
|
||||
Create `~/.config/compose-farm/compose-farm.yaml` (or `./compose-farm.yaml` in your working directory):
|
||||
@@ -344,10 +403,11 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
|
||||
│ check Validate configuration, traefik labels, mounts, and networks. │
|
||||
│ init-network Create Docker network on hosts with consistent settings. │
|
||||
│ config Manage compose-farm configuration files. │
|
||||
│ ssh Manage SSH keys for passwordless authentication. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Monitoring ─────────────────────────────────────────────────────────────────╮
|
||||
│ logs Show service logs. │
|
||||
│ ps Show status of all services. │
|
||||
│ ps Show status of services. │
|
||||
│ stats Show overview statistics for hosts and services. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Server ─────────────────────────────────────────────────────────────────────╮
|
||||
@@ -773,6 +833,21 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>See the output of <code>cf ssh --help</code></summary>
|
||||
|
||||
<!-- CODE:BASH:START -->
|
||||
<!-- echo '```yaml' -->
|
||||
<!-- export NO_COLOR=1 -->
|
||||
<!-- export TERM=dumb -->
|
||||
<!-- export TERMINAL_WIDTH=90 -->
|
||||
<!-- cf ssh --help -->
|
||||
<!-- echo '```' -->
|
||||
<!-- CODE:END -->
|
||||
|
||||
</details>
|
||||
|
||||
**Monitoring**
|
||||
|
||||
<details>
|
||||
@@ -829,11 +904,19 @@ Full `--help` output for each command. See the [Usage](#usage) table above for a
|
||||
<!-- ⚠️ This content is auto-generated by `markdown-code-runner`. -->
|
||||
```yaml
|
||||
|
||||
Usage: cf ps [OPTIONS]
|
||||
Usage: cf ps [OPTIONS] [SERVICES]...
|
||||
|
||||
Show status of all services.
|
||||
Show status of services.
|
||||
|
||||
Without arguments: shows all services (same as --all). With service names:
|
||||
shows only those services. With --host: shows services on that host.
|
||||
|
||||
╭─ Arguments ──────────────────────────────────────────────────────────────────╮
|
||||
│ services [SERVICES]... Services to operate on │
|
||||
╰──────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Options ────────────────────────────────────────────────────────────────────╮
|
||||
│ --all -a Run on all services │
|
||||
│ --host -H TEXT Filter to services on this host │
|
||||
│ --config -c PATH Path to config file │
|
||||
│ --help -h Show this message and exit. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────╯
|
||||
|
||||
@@ -5,6 +5,12 @@ services:
|
||||
- ${SSH_AUTH_SOCK}:/ssh-agent:ro
|
||||
# Compose directory (contains compose files AND compose-farm.yaml config)
|
||||
- ${CF_COMPOSE_DIR:-/opt/stacks}:${CF_COMPOSE_DIR:-/opt/stacks}
|
||||
# SSH keys for passwordless auth (generated by `cf ssh setup`)
|
||||
# Choose ONE option below (use the same option for both cf and web services):
|
||||
# Option 1: Host path (default) - keys at ~/.ssh/compose-farm/id_ed25519
|
||||
- ${CF_SSH_DIR:-~/.ssh/compose-farm}:/root/.ssh
|
||||
# Option 2: Named volume - managed by Docker, shared between services
|
||||
# - cf-ssh:/root/.ssh
|
||||
environment:
|
||||
- SSH_AUTH_SOCK=/ssh-agent
|
||||
# Config file path (state stored alongside it)
|
||||
@@ -12,13 +18,21 @@ services:
|
||||
|
||||
web:
|
||||
image: ghcr.io/basnijholt/compose-farm:latest
|
||||
restart: unless-stopped
|
||||
command: web --host 0.0.0.0 --port 9000
|
||||
volumes:
|
||||
- ${SSH_AUTH_SOCK}:/ssh-agent:ro
|
||||
- ${CF_COMPOSE_DIR:-/opt/stacks}:${CF_COMPOSE_DIR:-/opt/stacks}
|
||||
# SSH keys - use the SAME option as cf service above
|
||||
# Option 1: Host path (default)
|
||||
- ${CF_SSH_DIR:-~/.ssh/compose-farm}:/root/.ssh
|
||||
# Option 2: Named volume
|
||||
# - cf-ssh:/root/.ssh
|
||||
environment:
|
||||
- SSH_AUTH_SOCK=/ssh-agent
|
||||
- CF_CONFIG=${CF_COMPOSE_DIR:-/opt/stacks}/compose-farm.yaml
|
||||
# Used to detect self-updates and run via SSH to survive container restart
|
||||
- CF_WEB_SERVICE=compose-farm
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.compose-farm.rule=Host(`compose-farm.${DOMAIN}`)
|
||||
@@ -32,3 +46,7 @@ services:
|
||||
networks:
|
||||
mynetwork:
|
||||
external: true
|
||||
|
||||
volumes:
|
||||
cf-ssh:
|
||||
# Only used if Option 2 is selected above
|
||||
|
||||
@@ -8,6 +8,7 @@ from compose_farm.cli import (
|
||||
lifecycle, # noqa: F401
|
||||
management, # noqa: F401
|
||||
monitoring, # noqa: F401
|
||||
ssh, # noqa: F401
|
||||
web, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,15 @@ from rich.progress import (
|
||||
TimeElapsedColumn,
|
||||
)
|
||||
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.console import (
|
||||
MSG_HOST_NOT_FOUND,
|
||||
MSG_SERVICE_NOT_FOUND,
|
||||
console,
|
||||
print_error,
|
||||
print_hint,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Coroutine, Generator
|
||||
@@ -27,6 +35,7 @@ if TYPE_CHECKING:
|
||||
from compose_farm.executor import CommandResult
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
# --- Shared CLI Options ---
|
||||
@@ -56,6 +65,13 @@ _MISSING_PATH_PREVIEW_LIMIT = 2
|
||||
_STATS_PREVIEW_LIMIT = 3 # Max number of pending migrations to show by name
|
||||
|
||||
|
||||
def format_host(host: str | list[str]) -> str:
|
||||
"""Format a host value for display."""
|
||||
if isinstance(host, list):
|
||||
return ", ".join(host)
|
||||
return host
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def progress_bar(
|
||||
label: str, total: int, *, initial_description: str = "[dim]connecting...[/]"
|
||||
@@ -81,6 +97,37 @@ def progress_bar(
|
||||
yield progress, task_id
|
||||
|
||||
|
||||
def run_parallel_with_progress(
|
||||
label: str,
|
||||
items: list[_T],
|
||||
async_fn: Callable[[_T], Coroutine[None, None, _R]],
|
||||
) -> list[_R]:
|
||||
"""Run async tasks in parallel with a progress bar.
|
||||
|
||||
Args:
|
||||
label: Progress bar label (e.g., "Discovering", "Querying hosts")
|
||||
items: List of items to process
|
||||
async_fn: Async function to call for each item, returns tuple where
|
||||
first element is used for progress description
|
||||
|
||||
Returns:
|
||||
List of results from async_fn in completion order.
|
||||
|
||||
"""
|
||||
|
||||
async def gather() -> list[_R]:
|
||||
with progress_bar(label, len(items)) as (progress, task_id):
|
||||
tasks = [asyncio.create_task(async_fn(item)) for item in items]
|
||||
results: list[_R] = []
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{result[0]}[/]") # type: ignore[index]
|
||||
return results
|
||||
|
||||
return asyncio.run(gather())
|
||||
|
||||
|
||||
def load_config_or_exit(config_path: Path | None) -> Config:
|
||||
"""Load config or exit with a friendly error message."""
|
||||
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
|
||||
@@ -89,7 +136,7 @@ def load_config_or_exit(config_path: Path | None) -> Config:
|
||||
try:
|
||||
return load_config(config_path)
|
||||
except FileNotFoundError as e:
|
||||
err_console.print(f"[red]✗[/] {e}")
|
||||
print_error(str(e))
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
|
||||
@@ -97,29 +144,54 @@ def get_services(
|
||||
services: list[str],
|
||||
all_services: bool,
|
||||
config_path: Path | None,
|
||||
*,
|
||||
host: str | None = None,
|
||||
default_all: bool = False,
|
||||
) -> tuple[list[str], Config]:
|
||||
"""Resolve service list and load config.
|
||||
|
||||
Handles three mutually exclusive selection methods:
|
||||
- Explicit service names
|
||||
- --all flag
|
||||
- --host filter
|
||||
|
||||
Args:
|
||||
services: Explicit service names
|
||||
all_services: Whether --all was specified
|
||||
config_path: Path to config file
|
||||
host: Filter to services on this host
|
||||
default_all: If True, default to all services when nothing specified (for ps)
|
||||
|
||||
Supports "." as shorthand for the current directory name.
|
||||
|
||||
"""
|
||||
validate_service_selection(services, all_services, host)
|
||||
config = load_config_or_exit(config_path)
|
||||
|
||||
if host is not None:
|
||||
validate_host(config, host)
|
||||
svc_list = [s for s in config.services if host in config.get_hosts(s)]
|
||||
if not svc_list:
|
||||
print_warning(f"No services configured for host [magenta]{host}[/]")
|
||||
raise typer.Exit(0)
|
||||
return svc_list, config
|
||||
|
||||
if all_services:
|
||||
return list(config.services.keys()), config
|
||||
|
||||
if not services:
|
||||
err_console.print("[red]✗[/] Specify services or use --all")
|
||||
if default_all:
|
||||
return list(config.services.keys()), config
|
||||
print_error("Specify services or use [bold]--all[/] / [bold]--host[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Resolve "." to current directory name
|
||||
resolved = [Path.cwd().name if svc == "." else svc for svc in services]
|
||||
|
||||
# Validate all services exist in config
|
||||
unknown = [svc for svc in resolved if svc not in config.services]
|
||||
if unknown:
|
||||
for svc in unknown:
|
||||
err_console.print(f"[red]✗[/] Unknown service: [cyan]{svc}[/]")
|
||||
err_console.print("[dim]Hint: Add the service to compose-farm.yaml or use --all[/]")
|
||||
raise typer.Exit(1)
|
||||
validate_services(
|
||||
config, resolved, hint="Add the service to compose-farm.yaml or use [bold]--all[/]"
|
||||
)
|
||||
|
||||
return resolved, config
|
||||
|
||||
@@ -143,21 +215,19 @@ def report_results(results: list[CommandResult]) -> None:
|
||||
console.print() # Blank line before summary
|
||||
if failed:
|
||||
for r in failed:
|
||||
err_console.print(
|
||||
f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}"
|
||||
)
|
||||
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
|
||||
console.print()
|
||||
console.print(
|
||||
f"[green]✓[/] {len(succeeded)}/{len(results)} services succeeded, "
|
||||
f"[red]✗[/] {len(failed)} failed"
|
||||
)
|
||||
else:
|
||||
console.print(f"[green]✓[/] All {len(results)} services succeeded")
|
||||
print_success(f"All {len(results)} services succeeded")
|
||||
|
||||
elif failed:
|
||||
# Single service failed
|
||||
r = failed[0]
|
||||
err_console.print(f"[red]✗[/] [cyan]{r.service}[/] failed with exit code {r.exit_code}")
|
||||
print_error(f"[cyan]{r.service}[/] failed with exit code {r.exit_code}")
|
||||
|
||||
if failed:
|
||||
raise typer.Exit(1)
|
||||
@@ -197,28 +267,69 @@ def maybe_regenerate_traefik(
|
||||
cfg.traefik_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cfg.traefik_file.write_text(new_content)
|
||||
console.print() # Ensure we're on a new line after streaming output
|
||||
console.print(f"[green]✓[/] Traefik config updated: {cfg.traefik_file}")
|
||||
print_success(f"Traefik config updated: {cfg.traefik_file}")
|
||||
|
||||
for warning in warnings:
|
||||
err_console.print(f"[yellow]![/] {warning}")
|
||||
print_warning(warning)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
err_console.print(f"[yellow]![/] Failed to update traefik config: {exc}")
|
||||
print_warning(f"Failed to update traefik config: {exc}")
|
||||
|
||||
|
||||
def validate_services(cfg: Config, services: list[str], *, hint: str | None = None) -> None:
|
||||
"""Validate that all services exist in config. Exits with error if any not found."""
|
||||
invalid = [s for s in services if s not in cfg.services]
|
||||
if invalid:
|
||||
for svc in invalid:
|
||||
print_error(MSG_SERVICE_NOT_FOUND.format(name=svc))
|
||||
if hint:
|
||||
print_hint(hint)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def validate_host(cfg: Config, host: str) -> None:
|
||||
"""Validate that a host exists in config. Exits with error if not found."""
|
||||
if host not in cfg.hosts:
|
||||
print_error(MSG_HOST_NOT_FOUND.format(name=host))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def validate_hosts(cfg: Config, hosts: list[str]) -> None:
|
||||
"""Validate that all hosts exist in config. Exits with error if any not found."""
|
||||
invalid = [h for h in hosts if h not in cfg.hosts]
|
||||
if invalid:
|
||||
for h in invalid:
|
||||
print_error(MSG_HOST_NOT_FOUND.format(name=h))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def validate_host_for_service(cfg: Config, service: str, host: str) -> None:
|
||||
"""Validate that a host is valid for a service."""
|
||||
if host not in cfg.hosts:
|
||||
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
|
||||
raise typer.Exit(1)
|
||||
validate_host(cfg, host)
|
||||
allowed_hosts = cfg.get_hosts(service)
|
||||
if host not in allowed_hosts:
|
||||
err_console.print(
|
||||
f"[red]✗[/] Service '{service}' is not configured for host '{host}' "
|
||||
print_error(
|
||||
f"Service [cyan]{service}[/] is not configured for host [magenta]{host}[/] "
|
||||
f"(configured: {', '.join(allowed_hosts)})"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def validate_service_selection(
|
||||
services: list[str] | None,
|
||||
all_services: bool,
|
||||
host: str | None,
|
||||
) -> None:
|
||||
"""Validate that only one service selection method is used.
|
||||
|
||||
The three selection methods (explicit services, --all, --host) are mutually
|
||||
exclusive. This ensures consistent behavior across all commands.
|
||||
"""
|
||||
methods = sum([bool(services), all_services, host is not None])
|
||||
if methods > 1:
|
||||
print_error("Use only one of: service names, [bold]--all[/], or [bold]--host[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def run_host_operation(
|
||||
cfg: Config,
|
||||
svc_list: list[str],
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Annotated
|
||||
import typer
|
||||
|
||||
from compose_farm.cli.app import app
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.console import MSG_CONFIG_NOT_FOUND, console, print_error, print_success
|
||||
from compose_farm.paths import config_search_paths, default_config_path, find_config_path
|
||||
|
||||
config_app = typer.Typer(
|
||||
@@ -66,8 +66,8 @@ def _generate_template() -> str:
|
||||
template_file = resources.files("compose_farm") / "example-config.yaml"
|
||||
return template_file.read_text(encoding="utf-8")
|
||||
except FileNotFoundError as e:
|
||||
err_console.print("[red]Example config template is missing from the package.[/red]")
|
||||
err_console.print("Reinstall compose-farm or report this issue.")
|
||||
print_error("Example config template is missing from the package")
|
||||
console.print("Reinstall compose-farm or report this issue.")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
|
||||
@@ -80,6 +80,23 @@ def _get_config_file(path: Path | None) -> Path | None:
|
||||
return config_path.resolve() if config_path else None
|
||||
|
||||
|
||||
def _report_no_config_found() -> None:
|
||||
"""Report that no config file was found in search paths."""
|
||||
console.print("[yellow]No config file found.[/yellow]")
|
||||
console.print("\nSearched locations:")
|
||||
for p in config_search_paths():
|
||||
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
|
||||
console.print(f" - {p} ({status})")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
|
||||
|
||||
def _report_config_path_not_exists(config_file: Path) -> None:
|
||||
"""Report that an explicit config path doesn't exist."""
|
||||
console.print("[yellow]Config file not found.[/yellow]")
|
||||
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
|
||||
|
||||
@config_app.command("init")
|
||||
def config_init(
|
||||
path: _PathOption = None,
|
||||
@@ -107,7 +124,7 @@ def config_init(
|
||||
template_content = _generate_template()
|
||||
target_path.write_text(template_content, encoding="utf-8")
|
||||
|
||||
console.print(f"[green]✓[/] Config file created at: {target_path}")
|
||||
print_success(f"Config file created at: {target_path}")
|
||||
console.print("\n[dim]Edit the file to customize your settings:[/dim]")
|
||||
console.print(" [cyan]cf config edit[/cyan]")
|
||||
|
||||
@@ -123,17 +140,11 @@ def config_edit(
|
||||
config_file = _get_config_file(path)
|
||||
|
||||
if config_file is None:
|
||||
console.print("[yellow]No config file found.[/yellow]")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
console.print("\nSearched locations:")
|
||||
for p in config_search_paths():
|
||||
console.print(f" - {p}")
|
||||
_report_no_config_found()
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not config_file.exists():
|
||||
console.print("[yellow]Config file not found.[/yellow]")
|
||||
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
_report_config_path_not_exists(config_file)
|
||||
raise typer.Exit(1)
|
||||
|
||||
editor = _get_editor()
|
||||
@@ -142,21 +153,21 @@ def config_edit(
|
||||
try:
|
||||
editor_cmd = shlex.split(editor, posix=os.name != "nt")
|
||||
except ValueError as e:
|
||||
err_console.print("[red]Invalid editor command. Check $EDITOR/$VISUAL.[/red]")
|
||||
print_error("Invalid editor command. Check [bold]$EDITOR[/]/[bold]$VISUAL[/]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if not editor_cmd:
|
||||
err_console.print("[red]Editor command is empty.[/red]")
|
||||
print_error("Editor command is empty")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
subprocess.run([*editor_cmd, str(config_file)], check=True)
|
||||
except FileNotFoundError:
|
||||
err_console.print(f"[red]Editor '{editor_cmd[0]}' not found.[/red]")
|
||||
err_console.print("Set $EDITOR environment variable to your preferred editor.")
|
||||
print_error(f"Editor [cyan]{editor_cmd[0]}[/] not found")
|
||||
console.print("Set [bold]$EDITOR[/] environment variable to your preferred editor.")
|
||||
raise typer.Exit(1) from None
|
||||
except subprocess.CalledProcessError as e:
|
||||
err_console.print(f"[red]Editor exited with error code {e.returncode}[/red]")
|
||||
print_error(f"Editor exited with error code {e.returncode}")
|
||||
raise typer.Exit(e.returncode) from None
|
||||
|
||||
|
||||
@@ -169,18 +180,11 @@ def config_show(
|
||||
config_file = _get_config_file(path)
|
||||
|
||||
if config_file is None:
|
||||
console.print("[yellow]No config file found.[/yellow]")
|
||||
console.print("\nSearched locations:")
|
||||
for p in config_search_paths():
|
||||
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
|
||||
console.print(f" - {p} ({status})")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
_report_no_config_found()
|
||||
raise typer.Exit(0)
|
||||
|
||||
if not config_file.exists():
|
||||
console.print("[yellow]Config file not found.[/yellow]")
|
||||
console.print(f"\nProvided path does not exist: [cyan]{config_file}[/cyan]")
|
||||
console.print("\nRun [bold cyan]cf config init[/bold cyan] to create one.")
|
||||
_report_config_path_not_exists(config_file)
|
||||
raise typer.Exit(1)
|
||||
|
||||
content = config_file.read_text(encoding="utf-8")
|
||||
@@ -207,11 +211,7 @@ def config_path(
|
||||
config_file = _get_config_file(path)
|
||||
|
||||
if config_file is None:
|
||||
console.print("[yellow]No config file found.[/yellow]")
|
||||
console.print("\nSearched locations:")
|
||||
for p in config_search_paths():
|
||||
status = "[green]exists[/green]" if p.exists() else "[dim]not found[/dim]"
|
||||
console.print(f" - {p} ({status})")
|
||||
_report_no_config_found()
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Just print the path for easy piping
|
||||
@@ -226,7 +226,7 @@ def config_validate(
|
||||
config_file = _get_config_file(path)
|
||||
|
||||
if config_file is None:
|
||||
err_console.print("[red]✗[/] No config file found")
|
||||
print_error(MSG_CONFIG_NOT_FOUND)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Lazy import: pydantic adds ~50ms to startup, only load when actually needed
|
||||
@@ -235,13 +235,13 @@ def config_validate(
|
||||
try:
|
||||
cfg = load_config(config_file)
|
||||
except FileNotFoundError as e:
|
||||
err_console.print(f"[red]✗[/] {e}")
|
||||
print_error(str(e))
|
||||
raise typer.Exit(1) from e
|
||||
except Exception as e:
|
||||
err_console.print(f"[red]✗[/] Invalid config: {e}")
|
||||
print_error(f"Invalid config: {e}")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]✓[/] Valid config: {config_file}")
|
||||
print_success(f"Valid config: {config_file}")
|
||||
console.print(f" Hosts: {len(cfg.hosts)}")
|
||||
console.print(f" Services: {len(cfg.services)}")
|
||||
|
||||
@@ -268,11 +268,11 @@ def config_symlink(
|
||||
target_path = (target or Path("compose-farm.yaml")).expanduser().resolve()
|
||||
|
||||
if not target_path.exists():
|
||||
err_console.print(f"[red]✗[/] Target config file not found: {target_path}")
|
||||
print_error(f"Target config file not found: {target_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not target_path.is_file():
|
||||
err_console.print(f"[red]✗[/] Target is not a file: {target_path}")
|
||||
print_error(f"Target is not a file: {target_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
symlink_path = default_config_path()
|
||||
@@ -282,7 +282,7 @@ def config_symlink(
|
||||
if symlink_path.is_symlink():
|
||||
current_target = symlink_path.resolve() if symlink_path.exists() else None
|
||||
if current_target == target_path:
|
||||
console.print(f"[green]✓[/] Symlink already points to: {target_path}")
|
||||
print_success(f"Symlink already points to: {target_path}")
|
||||
return
|
||||
# Update existing symlink
|
||||
if not force:
|
||||
@@ -294,8 +294,8 @@ def config_symlink(
|
||||
symlink_path.unlink()
|
||||
else:
|
||||
# Regular file exists
|
||||
err_console.print(f"[red]✗[/] A regular file exists at: {symlink_path}")
|
||||
err_console.print(" Back it up or remove it first, then retry.")
|
||||
print_error(f"A regular file exists at: {symlink_path}")
|
||||
console.print(" Back it up or remove it first, then retry.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Create parent directories
|
||||
@@ -304,7 +304,7 @@ def config_symlink(
|
||||
# Create symlink with absolute path
|
||||
symlink_path.symlink_to(target_path)
|
||||
|
||||
console.print("[green]✓[/] Created symlink:")
|
||||
print_success("Created symlink:")
|
||||
console.print(f" {symlink_path}")
|
||||
console.print(f" -> {target_path}")
|
||||
|
||||
|
||||
@@ -15,24 +15,22 @@ from compose_farm.cli.common import (
|
||||
ConfigOption,
|
||||
HostOption,
|
||||
ServicesArg,
|
||||
format_host,
|
||||
get_services,
|
||||
load_config_or_exit,
|
||||
maybe_regenerate_traefik,
|
||||
report_results,
|
||||
run_async,
|
||||
run_host_operation,
|
||||
)
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.console import MSG_DRY_RUN, console, print_error, print_success
|
||||
from compose_farm.executor import run_on_services, run_sequential_on_services
|
||||
from compose_farm.operations import stop_orphaned_services, up_services
|
||||
from compose_farm.state import (
|
||||
add_service_to_host,
|
||||
get_orphaned_services,
|
||||
get_service_host,
|
||||
get_services_needing_migration,
|
||||
get_services_not_in_state,
|
||||
remove_service,
|
||||
remove_service_from_host,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,14 +42,7 @@ def up(
|
||||
config: ConfigOption = None,
|
||||
) -> None:
|
||||
"""Start services (docker compose up -d). Auto-migrates if host changed."""
|
||||
svc_list, cfg = get_services(services or [], all_services, config)
|
||||
|
||||
# Per-host operation: run on specific host only
|
||||
if host:
|
||||
run_host_operation(cfg, svc_list, host, "up -d", "Starting", add_service_to_host)
|
||||
return
|
||||
|
||||
# Normal operation: use up_services with migration logic
|
||||
svc_list, cfg = get_services(services or [], all_services, config, host=host)
|
||||
results = run_async(up_services(cfg, svc_list, raw=True))
|
||||
maybe_regenerate_traefik(cfg, results)
|
||||
report_results(results)
|
||||
@@ -71,17 +62,19 @@ def down(
|
||||
config: ConfigOption = None,
|
||||
) -> None:
|
||||
"""Stop services (docker compose down)."""
|
||||
# Handle --orphaned flag
|
||||
# Handle --orphaned flag (mutually exclusive with other selection methods)
|
||||
if orphaned:
|
||||
if services or all_services or host:
|
||||
err_console.print("[red]✗[/] Cannot use --orphaned with services, --all, or --host")
|
||||
print_error(
|
||||
"Cannot combine [bold]--orphaned[/] with services, [bold]--all[/], or [bold]--host[/]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
cfg = load_config_or_exit(config)
|
||||
orphaned_services = get_orphaned_services(cfg)
|
||||
|
||||
if not orphaned_services:
|
||||
console.print("[green]✓[/] No orphaned services to stop")
|
||||
print_success("No orphaned services to stop")
|
||||
return
|
||||
|
||||
console.print(
|
||||
@@ -92,14 +85,7 @@ def down(
|
||||
report_results(results)
|
||||
return
|
||||
|
||||
svc_list, cfg = get_services(services or [], all_services, config)
|
||||
|
||||
# Per-host operation: run on specific host only
|
||||
if host:
|
||||
run_host_operation(cfg, svc_list, host, "down", "Stopping", remove_service_from_host)
|
||||
return
|
||||
|
||||
# Normal operation
|
||||
svc_list, cfg = get_services(services or [], all_services, config, host=host)
|
||||
raw = len(svc_list) == 1
|
||||
results = run_async(run_on_services(cfg, svc_list, "down", raw=raw))
|
||||
|
||||
@@ -162,13 +148,6 @@ def update(
|
||||
report_results(results)
|
||||
|
||||
|
||||
def _format_host(host: str | list[str]) -> str:
|
||||
"""Format a host value for display."""
|
||||
if isinstance(host, list):
|
||||
return ", ".join(host)
|
||||
return host
|
||||
|
||||
|
||||
def _report_pending_migrations(cfg: Config, migrations: list[str]) -> None:
|
||||
"""Report services that need migration."""
|
||||
console.print(f"[cyan]Services to migrate ({len(migrations)}):[/]")
|
||||
@@ -182,14 +161,14 @@ def _report_pending_orphans(orphaned: dict[str, str | list[str]]) -> None:
|
||||
"""Report orphaned services that will be stopped."""
|
||||
console.print(f"[yellow]Orphaned services to stop ({len(orphaned)}):[/]")
|
||||
for svc, hosts in orphaned.items():
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{_format_host(hosts)}[/]")
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{format_host(hosts)}[/]")
|
||||
|
||||
|
||||
def _report_pending_starts(cfg: Config, missing: list[str]) -> None:
|
||||
"""Report services that will be started."""
|
||||
console.print(f"[green]Services to start ({len(missing)}):[/]")
|
||||
for svc in missing:
|
||||
target = _format_host(cfg.get_hosts(svc))
|
||||
target = format_host(cfg.get_hosts(svc))
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
|
||||
|
||||
|
||||
@@ -197,7 +176,7 @@ def _report_pending_refresh(cfg: Config, to_refresh: list[str]) -> None:
|
||||
"""Report services that will be refreshed."""
|
||||
console.print(f"[blue]Services to refresh ({len(to_refresh)}):[/]")
|
||||
for svc in to_refresh:
|
||||
target = _format_host(cfg.get_hosts(svc))
|
||||
target = format_host(cfg.get_hosts(svc))
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{target}[/]")
|
||||
|
||||
|
||||
@@ -245,7 +224,7 @@ def apply(
|
||||
has_refresh = bool(to_refresh)
|
||||
|
||||
if not has_orphans and not has_migrations and not has_missing and not has_refresh:
|
||||
console.print("[green]✓[/] Nothing to apply - reality matches config")
|
||||
print_success("Nothing to apply - reality matches config")
|
||||
return
|
||||
|
||||
# Report what will be done
|
||||
@@ -259,7 +238,7 @@ def apply(
|
||||
_report_pending_refresh(cfg, to_refresh)
|
||||
|
||||
if dry_run:
|
||||
console.print("\n[dim](dry-run: no changes made)[/]")
|
||||
console.print(f"\n{MSG_DRY_RUN}")
|
||||
return
|
||||
|
||||
# Execute changes
|
||||
|
||||
@@ -8,7 +8,6 @@ from pathlib import Path # noqa: TC003
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
from rich.progress import Progress, TaskID # noqa: TC002
|
||||
|
||||
from compose_farm.cli.app import app
|
||||
from compose_farm.cli.common import (
|
||||
@@ -17,16 +16,25 @@ from compose_farm.cli.common import (
|
||||
ConfigOption,
|
||||
LogPathOption,
|
||||
ServicesArg,
|
||||
format_host,
|
||||
get_services,
|
||||
load_config_or_exit,
|
||||
progress_bar,
|
||||
run_async,
|
||||
run_parallel_with_progress,
|
||||
validate_hosts,
|
||||
validate_services,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from compose_farm.config import Config
|
||||
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.console import (
|
||||
MSG_DRY_RUN,
|
||||
console,
|
||||
print_error,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
from compose_farm.executor import (
|
||||
CommandResult,
|
||||
is_local,
|
||||
@@ -54,21 +62,12 @@ from compose_farm.traefik import generate_traefik_config, render_traefik_config
|
||||
|
||||
def _discover_services(cfg: Config) -> dict[str, str | list[str]]:
|
||||
"""Discover running services with a progress bar."""
|
||||
|
||||
async def gather_with_progress(
|
||||
progress: Progress, task_id: TaskID
|
||||
) -> dict[str, str | list[str]]:
|
||||
tasks = [asyncio.create_task(discover_service_host(cfg, s)) for s in cfg.services]
|
||||
discovered: dict[str, str | list[str]] = {}
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
service, host = await coro
|
||||
if host is not None:
|
||||
discovered[service] = host
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
|
||||
return discovered
|
||||
|
||||
with progress_bar("Discovering", len(cfg.services)) as (progress, task_id):
|
||||
return asyncio.run(gather_with_progress(progress, task_id))
|
||||
results = run_parallel_with_progress(
|
||||
"Discovering",
|
||||
list(cfg.services),
|
||||
lambda s: discover_service_host(cfg, s),
|
||||
)
|
||||
return {svc: host for svc, host in results if host is not None}
|
||||
|
||||
|
||||
def _snapshot_services(
|
||||
@@ -77,36 +76,22 @@ def _snapshot_services(
|
||||
log_path: Path | None,
|
||||
) -> Path:
|
||||
"""Capture image digests with a progress bar."""
|
||||
|
||||
async def collect_service(service: str, now: datetime) -> list[SnapshotEntry]:
|
||||
try:
|
||||
return await collect_service_entries(cfg, service, now=now)
|
||||
except RuntimeError:
|
||||
return []
|
||||
|
||||
async def gather_with_progress(
|
||||
progress: Progress, task_id: TaskID, now: datetime, svc_list: list[str]
|
||||
) -> list[SnapshotEntry]:
|
||||
# Map tasks to service names so we can update description
|
||||
task_to_service = {asyncio.create_task(collect_service(s, now)): s for s in svc_list}
|
||||
all_entries: list[SnapshotEntry] = []
|
||||
for coro in asyncio.as_completed(list(task_to_service.keys())):
|
||||
entries = await coro
|
||||
all_entries.extend(entries)
|
||||
# Find which service just completed (by checking done tasks)
|
||||
for t, svc in task_to_service.items():
|
||||
if t.done() and not hasattr(t, "_reported"):
|
||||
t._reported = True # type: ignore[attr-defined]
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{svc}[/]")
|
||||
break
|
||||
return all_entries
|
||||
|
||||
effective_log_path = log_path or DEFAULT_LOG_PATH
|
||||
now_dt = datetime.now(UTC)
|
||||
now_iso = isoformat(now_dt)
|
||||
|
||||
with progress_bar("Capturing", len(services)) as (progress, task_id):
|
||||
snapshot_entries = asyncio.run(gather_with_progress(progress, task_id, now_dt, services))
|
||||
async def collect_service(service: str) -> tuple[str, list[SnapshotEntry]]:
|
||||
try:
|
||||
return service, await collect_service_entries(cfg, service, now=now_dt)
|
||||
except RuntimeError:
|
||||
return service, []
|
||||
|
||||
results = run_parallel_with_progress(
|
||||
"Capturing",
|
||||
services,
|
||||
collect_service,
|
||||
)
|
||||
snapshot_entries = [entry for _, entries in results for entry in entries]
|
||||
|
||||
if not snapshot_entries:
|
||||
msg = "No image digests were captured"
|
||||
@@ -119,13 +104,6 @@ def _snapshot_services(
|
||||
return effective_log_path
|
||||
|
||||
|
||||
def _format_host(host: str | list[str]) -> str:
|
||||
"""Format a host value for display."""
|
||||
if isinstance(host, list):
|
||||
return ", ".join(host)
|
||||
return host
|
||||
|
||||
|
||||
def _report_sync_changes(
|
||||
added: list[str],
|
||||
removed: list[str],
|
||||
@@ -137,14 +115,14 @@ def _report_sync_changes(
|
||||
if added:
|
||||
console.print(f"\nNew services found ({len(added)}):")
|
||||
for service in sorted(added):
|
||||
host_str = _format_host(discovered[service])
|
||||
host_str = format_host(discovered[service])
|
||||
console.print(f" [green]+[/] [cyan]{service}[/] on [magenta]{host_str}[/]")
|
||||
|
||||
if changed:
|
||||
console.print(f"\nServices on different hosts ({len(changed)}):")
|
||||
for service, old_host, new_host in sorted(changed):
|
||||
old_str = _format_host(old_host)
|
||||
new_str = _format_host(new_host)
|
||||
old_str = format_host(old_host)
|
||||
new_str = format_host(new_host)
|
||||
console.print(
|
||||
f" [yellow]~[/] [cyan]{service}[/]: [magenta]{old_str}[/] → [magenta]{new_str}[/]"
|
||||
)
|
||||
@@ -152,7 +130,7 @@ def _report_sync_changes(
|
||||
if removed:
|
||||
console.print(f"\nServices no longer running ({len(removed)}):")
|
||||
for service in sorted(removed):
|
||||
host_str = _format_host(current_state[service])
|
||||
host_str = format_host(current_state[service])
|
||||
console.print(f" [red]-[/] [cyan]{service}[/] (was on [magenta]{host_str}[/])")
|
||||
|
||||
|
||||
@@ -171,21 +149,21 @@ def _check_ssh_connectivity(cfg: Config) -> list[str]:
|
||||
|
||||
async def check_host(host_name: str) -> tuple[str, bool]:
|
||||
host = cfg.hosts[host_name]
|
||||
result = await run_command(host, "echo ok", host_name, stream=False)
|
||||
return host_name, result.success
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
run_command(host, "echo ok", host_name, stream=False),
|
||||
timeout=5.0,
|
||||
)
|
||||
return host_name, result.success
|
||||
except TimeoutError:
|
||||
return host_name, False
|
||||
|
||||
async def gather_with_progress(progress: Progress, task_id: TaskID) -> list[str]:
|
||||
tasks = [asyncio.create_task(check_host(h)) for h in remote_hosts]
|
||||
unreachable: list[str] = []
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
host_name, success = await coro
|
||||
if not success:
|
||||
unreachable.append(host_name)
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
|
||||
return unreachable
|
||||
|
||||
with progress_bar("Checking SSH connectivity", len(remote_hosts)) as (progress, task_id):
|
||||
return asyncio.run(gather_with_progress(progress, task_id))
|
||||
results = run_parallel_with_progress(
|
||||
"Checking SSH connectivity",
|
||||
remote_hosts,
|
||||
check_host,
|
||||
)
|
||||
return [host for host, success in results if not success]
|
||||
|
||||
|
||||
def _check_service_requirements(
|
||||
@@ -222,27 +200,21 @@ def _check_service_requirements(
|
||||
|
||||
return service, mount_errors, network_errors, device_errors
|
||||
|
||||
async def gather_with_progress(
|
||||
progress: Progress, task_id: TaskID
|
||||
) -> tuple[list[tuple[str, str, str]], list[tuple[str, str, str]], list[tuple[str, str, str]]]:
|
||||
tasks = [asyncio.create_task(check_service(s)) for s in services]
|
||||
all_mount_errors: list[tuple[str, str, str]] = []
|
||||
all_network_errors: list[tuple[str, str, str]] = []
|
||||
all_device_errors: list[tuple[str, str, str]] = []
|
||||
results = run_parallel_with_progress(
|
||||
"Checking requirements",
|
||||
services,
|
||||
check_service,
|
||||
)
|
||||
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
service, mount_errs, net_errs, dev_errs = await coro
|
||||
all_mount_errors.extend(mount_errs)
|
||||
all_network_errors.extend(net_errs)
|
||||
all_device_errors.extend(dev_errs)
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{service}[/]")
|
||||
all_mount_errors: list[tuple[str, str, str]] = []
|
||||
all_network_errors: list[tuple[str, str, str]] = []
|
||||
all_device_errors: list[tuple[str, str, str]] = []
|
||||
for _, mount_errs, net_errs, dev_errs in results:
|
||||
all_mount_errors.extend(mount_errs)
|
||||
all_network_errors.extend(net_errs)
|
||||
all_device_errors.extend(dev_errs)
|
||||
|
||||
return all_mount_errors, all_network_errors, all_device_errors
|
||||
|
||||
with progress_bar(
|
||||
"Checking requirements", len(services), initial_description="[dim]checking...[/]"
|
||||
) as (progress, task_id):
|
||||
return asyncio.run(gather_with_progress(progress, task_id))
|
||||
return all_mount_errors, all_network_errors, all_device_errors
|
||||
|
||||
|
||||
def _report_config_status(cfg: Config) -> bool:
|
||||
@@ -263,7 +235,7 @@ def _report_config_status(cfg: Config) -> bool:
|
||||
console.print(f" [red]-[/] [cyan]{name}[/]")
|
||||
|
||||
if not unmanaged and not missing_from_disk:
|
||||
console.print("[green]✓[/] Config matches disk")
|
||||
print_success("Config matches disk")
|
||||
|
||||
return bool(missing_from_disk)
|
||||
|
||||
@@ -275,11 +247,10 @@ def _report_orphaned_services(cfg: Config) -> bool:
|
||||
if orphaned:
|
||||
console.print("\n[yellow]Orphaned services[/] (in state but not in config):")
|
||||
console.print(
|
||||
"[dim]Run 'cf apply' to stop them, or 'cf down --orphaned' for just orphans.[/]"
|
||||
"[dim]Run [bold]cf apply[/bold] to stop them, or [bold]cf down --orphaned[/bold] for just orphans.[/]"
|
||||
)
|
||||
for name, hosts in sorted(orphaned.items()):
|
||||
host_str = ", ".join(hosts) if isinstance(hosts, list) else hosts
|
||||
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{host_str}[/]")
|
||||
console.print(f" [yellow]![/] [cyan]{name}[/] on [magenta]{format_host(hosts)}[/]")
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -295,54 +266,24 @@ def _report_traefik_status(cfg: Config, services: list[str]) -> None:
|
||||
if warnings:
|
||||
console.print(f"\n[yellow]Traefik issues[/] ({len(warnings)}):")
|
||||
for warning in warnings:
|
||||
console.print(f" [yellow]![/] {warning}")
|
||||
print_warning(warning)
|
||||
else:
|
||||
console.print("[green]✓[/] Traefik labels valid")
|
||||
print_success("Traefik labels valid")
|
||||
|
||||
|
||||
def _report_mount_errors(mount_errors: list[tuple[str, str, str]]) -> None:
|
||||
"""Report mount errors grouped by service."""
|
||||
def _report_requirement_errors(errors: list[tuple[str, str, str]], category: str) -> None:
|
||||
"""Report requirement errors (mounts, networks, devices) grouped by service."""
|
||||
by_service: dict[str, list[tuple[str, str]]] = {}
|
||||
for svc, host, path in mount_errors:
|
||||
by_service.setdefault(svc, []).append((host, path))
|
||||
for svc, host, item in errors:
|
||||
by_service.setdefault(svc, []).append((host, item))
|
||||
|
||||
console.print(f"[red]Missing mounts[/] ({len(mount_errors)}):")
|
||||
console.print(f"[red]Missing {category}[/] ({len(errors)}):")
|
||||
for svc, items in sorted(by_service.items()):
|
||||
host = items[0][0]
|
||||
paths = [p for _, p in items]
|
||||
missing = [i for _, i in items]
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
|
||||
for path in paths:
|
||||
console.print(f" [red]✗[/] {path}")
|
||||
|
||||
|
||||
def _report_network_errors(network_errors: list[tuple[str, str, str]]) -> None:
|
||||
"""Report network errors grouped by service."""
|
||||
by_service: dict[str, list[tuple[str, str]]] = {}
|
||||
for svc, host, net in network_errors:
|
||||
by_service.setdefault(svc, []).append((host, net))
|
||||
|
||||
console.print(f"[red]Missing networks[/] ({len(network_errors)}):")
|
||||
for svc, items in sorted(by_service.items()):
|
||||
host = items[0][0]
|
||||
networks = [n for _, n in items]
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
|
||||
for net in networks:
|
||||
console.print(f" [red]✗[/] {net}")
|
||||
|
||||
|
||||
def _report_device_errors(device_errors: list[tuple[str, str, str]]) -> None:
|
||||
"""Report device errors grouped by service."""
|
||||
by_service: dict[str, list[tuple[str, str]]] = {}
|
||||
for svc, host, dev in device_errors:
|
||||
by_service.setdefault(svc, []).append((host, dev))
|
||||
|
||||
console.print(f"[red]Missing devices[/] ({len(device_errors)}):")
|
||||
for svc, items in sorted(by_service.items()):
|
||||
host = items[0][0]
|
||||
devices = [d for _, d in items]
|
||||
console.print(f" [cyan]{svc}[/] on [magenta]{host}[/]:")
|
||||
for dev in devices:
|
||||
console.print(f" [red]✗[/] {dev}")
|
||||
for item in missing:
|
||||
console.print(f" [red]✗[/] {item}")
|
||||
|
||||
|
||||
def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
|
||||
@@ -350,9 +291,9 @@ def _report_ssh_status(unreachable_hosts: list[str]) -> bool:
|
||||
if unreachable_hosts:
|
||||
console.print(f"[red]Unreachable hosts[/] ({len(unreachable_hosts)}):")
|
||||
for host in sorted(unreachable_hosts):
|
||||
console.print(f" [red]✗[/] [magenta]{host}[/]")
|
||||
print_error(f"[magenta]{host}[/]")
|
||||
return True
|
||||
console.print("[green]✓[/] All hosts reachable")
|
||||
print_success("All hosts reachable")
|
||||
return False
|
||||
|
||||
|
||||
@@ -394,16 +335,16 @@ def _run_remote_checks(cfg: Config, svc_list: list[str], *, show_host_compat: bo
|
||||
mount_errors, network_errors, device_errors = _check_service_requirements(cfg, svc_list)
|
||||
|
||||
if mount_errors:
|
||||
_report_mount_errors(mount_errors)
|
||||
_report_requirement_errors(mount_errors, "mounts")
|
||||
has_errors = True
|
||||
if network_errors:
|
||||
_report_network_errors(network_errors)
|
||||
_report_requirement_errors(network_errors, "networks")
|
||||
has_errors = True
|
||||
if device_errors:
|
||||
_report_device_errors(device_errors)
|
||||
_report_requirement_errors(device_errors, "devices")
|
||||
has_errors = True
|
||||
if not mount_errors and not network_errors and not device_errors:
|
||||
console.print("[green]✓[/] All mounts, networks, and devices exist")
|
||||
print_success("All mounts, networks, and devices exist")
|
||||
|
||||
if show_host_compat:
|
||||
for service in svc_list:
|
||||
@@ -440,7 +381,7 @@ def traefik_file(
|
||||
try:
|
||||
dynamic, warnings = generate_traefik_config(cfg, svc_list)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
err_console.print(f"[red]✗[/] {exc}")
|
||||
print_error(str(exc))
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
rendered = render_traefik_config(dynamic)
|
||||
@@ -448,12 +389,12 @@ def traefik_file(
|
||||
if output:
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
output.write_text(rendered)
|
||||
console.print(f"[green]✓[/] Traefik config written to {output}")
|
||||
print_success(f"Traefik config written to {output}")
|
||||
else:
|
||||
console.print(rendered)
|
||||
|
||||
for warning in warnings:
|
||||
err_console.print(f"[yellow]![/] {warning}")
|
||||
print_warning(warning)
|
||||
|
||||
|
||||
@app.command(rich_help_panel="Configuration")
|
||||
@@ -492,24 +433,24 @@ def refresh(
|
||||
if state_changed:
|
||||
_report_sync_changes(added, removed, changed, discovered, current_state)
|
||||
else:
|
||||
console.print("[green]✓[/] State is already in sync.")
|
||||
print_success("State is already in sync.")
|
||||
|
||||
if dry_run:
|
||||
console.print("\n[dim](dry-run: no changes made)[/]")
|
||||
console.print(f"\n{MSG_DRY_RUN}")
|
||||
return
|
||||
|
||||
# Update state file
|
||||
if state_changed:
|
||||
save_state(cfg, discovered)
|
||||
console.print(f"\n[green]✓[/] State updated: {len(discovered)} services tracked.")
|
||||
print_success(f"State updated: {len(discovered)} services tracked.")
|
||||
|
||||
# Capture image digests for running services
|
||||
if discovered:
|
||||
try:
|
||||
path = _snapshot_services(cfg, list(discovered.keys()), log_path)
|
||||
console.print(f"[green]✓[/] Digests written to {path}")
|
||||
print_success(f"Digests written to {path}")
|
||||
except RuntimeError as exc:
|
||||
err_console.print(f"[yellow]![/] {exc}")
|
||||
print_warning(str(exc))
|
||||
|
||||
|
||||
@app.command(rich_help_panel="Configuration")
|
||||
@@ -533,11 +474,7 @@ def check(
|
||||
# Determine which services to check and whether to show host compatibility
|
||||
if services:
|
||||
svc_list = list(services)
|
||||
invalid = [s for s in svc_list if s not in cfg.services]
|
||||
if invalid:
|
||||
for svc in invalid:
|
||||
err_console.print(f"[red]✗[/] Service '{svc}' not found in config")
|
||||
raise typer.Exit(1)
|
||||
validate_services(cfg, svc_list)
|
||||
show_host_compat = True
|
||||
else:
|
||||
svc_list = list(cfg.services.keys())
|
||||
@@ -587,11 +524,7 @@ def init_network(
|
||||
cfg = load_config_or_exit(config)
|
||||
|
||||
target_hosts = list(hosts) if hosts else list(cfg.hosts.keys())
|
||||
invalid = [h for h in target_hosts if h not in cfg.hosts]
|
||||
if invalid:
|
||||
for h in invalid:
|
||||
err_console.print(f"[red]✗[/] Host '{h}' not found in config")
|
||||
raise typer.Exit(1)
|
||||
validate_hosts(cfg, target_hosts)
|
||||
|
||||
async def create_network_on_host(host_name: str) -> CommandResult:
|
||||
host = cfg.hosts[host_name]
|
||||
@@ -616,9 +549,8 @@ def init_network(
|
||||
if result.success:
|
||||
console.print(f"[cyan]\\[{host_name}][/] [green]✓[/] Created network '{network}'")
|
||||
else:
|
||||
err_console.print(
|
||||
f"[cyan]\\[{host_name}][/] [red]✗[/] Failed to create network: "
|
||||
f"{result.stderr.strip()}"
|
||||
print_error(
|
||||
f"[cyan]\\[{host_name}][/] Failed to create network: {result.stderr.strip()}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
from rich.progress import Progress, TaskID # noqa: TC002
|
||||
from rich.table import Table
|
||||
|
||||
from compose_farm.cli.app import app
|
||||
@@ -19,47 +17,18 @@ from compose_farm.cli.common import (
|
||||
ServicesArg,
|
||||
get_services,
|
||||
load_config_or_exit,
|
||||
progress_bar,
|
||||
report_results,
|
||||
run_async,
|
||||
run_parallel_with_progress,
|
||||
)
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.console import console
|
||||
from compose_farm.executor import run_command, run_on_services
|
||||
from compose_farm.state import get_services_needing_migration, load_state
|
||||
from compose_farm.state import get_services_needing_migration, group_services_by_host, load_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from compose_farm.config import Config
|
||||
|
||||
|
||||
def _group_services_by_host(
|
||||
services: dict[str, str | list[str]],
|
||||
hosts: Mapping[str, object],
|
||||
all_hosts: list[str] | None = None,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Group services by their assigned host(s).
|
||||
|
||||
For multi-host services (list or "all"), the service appears in multiple host lists.
|
||||
"""
|
||||
by_host: dict[str, list[str]] = {h: [] for h in hosts}
|
||||
for service, host_value in services.items():
|
||||
if isinstance(host_value, list):
|
||||
# Explicit list of hosts
|
||||
for host_name in host_value:
|
||||
if host_name in by_host:
|
||||
by_host[host_name].append(service)
|
||||
elif host_value == "all" and all_hosts:
|
||||
# "all" keyword - add to all hosts
|
||||
for host_name in all_hosts:
|
||||
if host_name in by_host:
|
||||
by_host[host_name].append(service)
|
||||
elif host_value in by_host:
|
||||
# Single host
|
||||
by_host[host_value].append(service)
|
||||
return by_host
|
||||
|
||||
|
||||
def _get_container_counts(cfg: Config) -> dict[str, int]:
|
||||
"""Get container counts from all hosts with a progress bar."""
|
||||
|
||||
@@ -72,18 +41,12 @@ def _get_container_counts(cfg: Config) -> dict[str, int]:
|
||||
count = int(result.stdout.strip())
|
||||
return host_name, count
|
||||
|
||||
async def gather_with_progress(progress: Progress, task_id: TaskID) -> dict[str, int]:
|
||||
hosts = list(cfg.hosts.keys())
|
||||
tasks = [asyncio.create_task(get_count(h)) for h in hosts]
|
||||
results: dict[str, int] = {}
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
host_name, count = await coro
|
||||
results[host_name] = count
|
||||
progress.update(task_id, advance=1, description=f"[cyan]{host_name}[/]")
|
||||
return results
|
||||
|
||||
with progress_bar("Querying hosts", len(cfg.hosts)) as (progress, task_id):
|
||||
return asyncio.run(gather_with_progress(progress, task_id))
|
||||
results = run_parallel_with_progress(
|
||||
"Querying hosts",
|
||||
list(cfg.hosts.keys()),
|
||||
get_count,
|
||||
)
|
||||
return dict(results)
|
||||
|
||||
|
||||
def _build_host_table(
|
||||
@@ -163,24 +126,7 @@ def logs(
|
||||
config: ConfigOption = None,
|
||||
) -> None:
|
||||
"""Show service logs."""
|
||||
if all_services and host is not None:
|
||||
err_console.print("[red]✗[/] Cannot use --all and --host together")
|
||||
raise typer.Exit(1)
|
||||
|
||||
cfg = load_config_or_exit(config)
|
||||
|
||||
# Determine service list based on options
|
||||
if host is not None:
|
||||
if host not in cfg.hosts:
|
||||
err_console.print(f"[red]✗[/] Host '{host}' not found in config")
|
||||
raise typer.Exit(1)
|
||||
# Include services where host is in the list of configured hosts
|
||||
svc_list = [s for s in cfg.services if host in cfg.get_hosts(s)]
|
||||
if not svc_list:
|
||||
err_console.print(f"[yellow]![/] No services configured for host '{host}'")
|
||||
return
|
||||
else:
|
||||
svc_list, cfg = get_services(services or [], all_services, config)
|
||||
svc_list, cfg = get_services(services or [], all_services, config, host=host)
|
||||
|
||||
# Default to fewer lines when showing multiple services
|
||||
many_services = all_services or host is not None or len(svc_list) > 1
|
||||
@@ -194,11 +140,19 @@ def logs(
|
||||
|
||||
@app.command(rich_help_panel="Monitoring")
|
||||
def ps(
|
||||
services: ServicesArg = None,
|
||||
all_services: AllOption = False,
|
||||
host: HostOption = None,
|
||||
config: ConfigOption = None,
|
||||
) -> None:
|
||||
"""Show status of all services."""
|
||||
cfg = load_config_or_exit(config)
|
||||
results = run_async(run_on_services(cfg, list(cfg.services.keys()), "ps"))
|
||||
"""Show status of services.
|
||||
|
||||
Without arguments: shows all services (same as --all).
|
||||
With service names: shows only those services.
|
||||
With --host: shows services on that host.
|
||||
"""
|
||||
svc_list, cfg = get_services(services or [], all_services, config, host=host, default_all=True)
|
||||
results = run_async(run_on_services(cfg, svc_list, "ps"))
|
||||
report_results(results)
|
||||
|
||||
|
||||
@@ -220,8 +174,8 @@ def stats(
|
||||
pending = get_services_needing_migration(cfg)
|
||||
|
||||
all_hosts = list(cfg.hosts.keys())
|
||||
services_by_host = _group_services_by_host(cfg.services, cfg.hosts, all_hosts)
|
||||
running_by_host = _group_services_by_host(state, cfg.hosts, all_hosts)
|
||||
services_by_host = group_services_by_host(cfg.services, cfg.hosts, all_hosts)
|
||||
running_by_host = group_services_by_host(state, cfg.hosts, all_hosts)
|
||||
|
||||
container_counts: dict[str, int] = {}
|
||||
if live:
|
||||
|
||||
282
src/compose_farm/cli/ssh.py
Normal file
282
src/compose_farm/cli/ssh.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""SSH key management commands for compose-farm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from compose_farm.cli.app import app
|
||||
from compose_farm.cli.common import ConfigOption, load_config_or_exit, run_parallel_with_progress
|
||||
from compose_farm.console import console, err_console
|
||||
from compose_farm.executor import run_command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from compose_farm.config import Host
|
||||
|
||||
from compose_farm.ssh_keys import (
|
||||
SSH_KEY_PATH,
|
||||
SSH_PUBKEY_PATH,
|
||||
get_pubkey_content,
|
||||
get_ssh_env,
|
||||
key_exists,
|
||||
)
|
||||
|
||||
_DEFAULT_SSH_PORT = 22
|
||||
_PUBKEY_DISPLAY_THRESHOLD = 60
|
||||
|
||||
ssh_app = typer.Typer(
|
||||
name="ssh",
|
||||
help="Manage SSH keys for passwordless authentication.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
_ForceOption = Annotated[
|
||||
bool,
|
||||
typer.Option("--force", "-f", help="Regenerate key even if it exists."),
|
||||
]
|
||||
|
||||
|
||||
def _generate_key(*, force: bool = False) -> bool:
|
||||
"""Generate an ED25519 SSH key with no passphrase.
|
||||
|
||||
Returns True if key was generated, False if skipped.
|
||||
"""
|
||||
if key_exists() and not force:
|
||||
console.print(f"[yellow]![/] SSH key already exists: {SSH_KEY_PATH}")
|
||||
console.print("[dim]Use --force to regenerate[/]")
|
||||
return False
|
||||
|
||||
# Create .ssh directory if it doesn't exist
|
||||
SSH_KEY_PATH.parent.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
# Remove existing key if forcing regeneration
|
||||
if force:
|
||||
SSH_KEY_PATH.unlink(missing_ok=True)
|
||||
SSH_PUBKEY_PATH.unlink(missing_ok=True)
|
||||
|
||||
console.print(f"[dim]Generating SSH key at {SSH_KEY_PATH}...[/]")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[ # noqa: S607
|
||||
"ssh-keygen",
|
||||
"-t",
|
||||
"ed25519",
|
||||
"-N",
|
||||
"", # No passphrase
|
||||
"-f",
|
||||
str(SSH_KEY_PATH),
|
||||
"-C",
|
||||
"compose-farm",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
err_console.print(f"[red]Failed to generate SSH key:[/] {e.stderr.decode()}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
err_console.print("[red]ssh-keygen not found. Is OpenSSH installed?[/]")
|
||||
return False
|
||||
|
||||
# Set correct permissions
|
||||
SSH_KEY_PATH.chmod(0o600)
|
||||
SSH_PUBKEY_PATH.chmod(0o644)
|
||||
|
||||
console.print(f"[green]Generated SSH key:[/] {SSH_KEY_PATH}")
|
||||
return True
|
||||
|
||||
|
||||
def _copy_key_to_host(host_name: str, address: str, user: str, port: int) -> bool:
|
||||
"""Copy public key to a host's authorized_keys.
|
||||
|
||||
Uses ssh-copy-id which handles agent vs password fallback automatically.
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
target = f"{user}@{address}"
|
||||
console.print(f"[dim]Copying key to {host_name} ({target})...[/]")
|
||||
|
||||
cmd = ["ssh-copy-id"]
|
||||
|
||||
# Disable strict host key checking (consistent with executor.py)
|
||||
cmd.extend(["-o", "StrictHostKeyChecking=no"])
|
||||
cmd.extend(["-o", "UserKnownHostsFile=/dev/null"])
|
||||
|
||||
if port != _DEFAULT_SSH_PORT:
|
||||
cmd.extend(["-p", str(port)])
|
||||
|
||||
cmd.extend(["-i", str(SSH_PUBKEY_PATH), target])
|
||||
|
||||
try:
|
||||
# Don't capture output so user can see password prompt
|
||||
result = subprocess.run(cmd, check=False, env=get_ssh_env())
|
||||
if result.returncode == 0:
|
||||
console.print(f"[green]Key copied to {host_name}[/]")
|
||||
return True
|
||||
err_console.print(f"[red]Failed to copy key to {host_name}[/]")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
err_console.print("[red]ssh-copy-id not found. Is OpenSSH installed?[/]")
|
||||
return False
|
||||
|
||||
|
||||
@ssh_app.command("keygen")
|
||||
def ssh_keygen(
|
||||
force: _ForceOption = False,
|
||||
) -> None:
|
||||
"""Generate SSH key (does not distribute to hosts).
|
||||
|
||||
Creates an ED25519 key at ~/.ssh/compose-farm/id_ed25519 with no passphrase.
|
||||
Use 'cf ssh setup' to also distribute the key to all configured hosts.
|
||||
"""
|
||||
success = _generate_key(force=force)
|
||||
if not success and not key_exists():
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@ssh_app.command("setup")
|
||||
def ssh_setup(
|
||||
config: ConfigOption = None,
|
||||
force: _ForceOption = False,
|
||||
) -> None:
|
||||
"""Generate SSH key and distribute to all configured hosts.
|
||||
|
||||
Creates an ED25519 key at ~/.ssh/compose-farm/id_ed25519 (no passphrase)
|
||||
and copies the public key to authorized_keys on each host.
|
||||
|
||||
For each host, tries SSH agent first. If agent is unavailable,
|
||||
prompts for password.
|
||||
"""
|
||||
cfg = load_config_or_exit(config)
|
||||
|
||||
# Skip localhost hosts
|
||||
remote_hosts = {
|
||||
name: host
|
||||
for name, host in cfg.hosts.items()
|
||||
if host.address.lower() not in ("localhost", "127.0.0.1")
|
||||
}
|
||||
|
||||
if not remote_hosts:
|
||||
console.print("[yellow]No remote hosts configured.[/]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Generate key if needed
|
||||
if not key_exists() or force:
|
||||
if not _generate_key(force=force):
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(f"[dim]Using existing key: {SSH_KEY_PATH}[/]")
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Distributing key to {len(remote_hosts)} host(s)...[/]")
|
||||
console.print()
|
||||
|
||||
# Copy key to each host
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
for host_name, host in remote_hosts.items():
|
||||
if _copy_key_to_host(host_name, host.address, host.user, host.port):
|
||||
succeeded += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
console.print()
|
||||
if failed == 0:
|
||||
console.print(
|
||||
f"[green]Setup complete.[/] {succeeded}/{len(remote_hosts)} hosts configured."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Setup partially complete.[/] {succeeded}/{len(remote_hosts)} hosts configured, "
|
||||
f"[red]{failed} failed[/]."
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@ssh_app.command("status")
|
||||
def ssh_status(
|
||||
config: ConfigOption = None,
|
||||
) -> None:
|
||||
"""Show SSH key status and host connectivity."""
|
||||
from rich.table import Table # noqa: PLC0415
|
||||
|
||||
cfg = load_config_or_exit(config)
|
||||
|
||||
# Key status
|
||||
console.print("[bold]SSH Key Status[/]")
|
||||
console.print()
|
||||
|
||||
if key_exists():
|
||||
console.print(f" [green]Key exists:[/] {SSH_KEY_PATH}")
|
||||
pubkey = get_pubkey_content()
|
||||
if pubkey:
|
||||
# Show truncated public key
|
||||
if len(pubkey) > _PUBKEY_DISPLAY_THRESHOLD:
|
||||
console.print(f" [dim]Public key:[/] {pubkey[:30]}...{pubkey[-20:]}")
|
||||
else:
|
||||
console.print(f" [dim]Public key:[/] {pubkey}")
|
||||
else:
|
||||
console.print(f" [yellow]No key found:[/] {SSH_KEY_PATH}")
|
||||
console.print(" [dim]Run 'cf ssh setup' to generate and distribute a key[/]")
|
||||
|
||||
console.print()
|
||||
console.print("[bold]Host Connectivity[/]")
|
||||
console.print()
|
||||
|
||||
# Skip localhost hosts
|
||||
remote_hosts = {
|
||||
name: host
|
||||
for name, host in cfg.hosts.items()
|
||||
if host.address.lower() not in ("localhost", "127.0.0.1")
|
||||
}
|
||||
|
||||
if not remote_hosts:
|
||||
console.print(" [dim]No remote hosts configured[/]")
|
||||
return
|
||||
|
||||
async def check_host(item: tuple[str, Host]) -> tuple[str, str, str]:
|
||||
"""Check connectivity to a single host."""
|
||||
host_name, host = item
|
||||
target = f"{host.user}@{host.address}"
|
||||
if host.port != _DEFAULT_SSH_PORT:
|
||||
target += f":{host.port}"
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
run_command(host, "echo ok", host_name, stream=False),
|
||||
timeout=5.0,
|
||||
)
|
||||
status = "[green]OK[/]" if result.success else "[red]Auth failed[/]"
|
||||
except TimeoutError:
|
||||
status = "[red]Timeout (5s)[/]"
|
||||
except Exception as e:
|
||||
status = f"[red]Error: {e}[/]"
|
||||
|
||||
return host_name, target, status
|
||||
|
||||
# Check connectivity in parallel with progress bar
|
||||
results = run_parallel_with_progress(
|
||||
"Checking hosts",
|
||||
list(remote_hosts.items()),
|
||||
check_host,
|
||||
)
|
||||
|
||||
# Build table from results
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Host")
|
||||
table.add_column("Address")
|
||||
table.add_column("Status")
|
||||
|
||||
# Sort by host name for consistent order
|
||||
for host_name, target, status in sorted(results, key=lambda r: r[0]):
|
||||
table.add_row(host_name, target, status)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# Register ssh subcommand on the shared app
|
||||
app.add_typer(ssh_app, name="ssh", rich_help_panel="Configuration")
|
||||
@@ -7,14 +7,14 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import yaml
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from .config import Config
|
||||
|
||||
# Port parsing constants
|
||||
@@ -141,23 +141,42 @@ def _resolve_host_path(host_path: str, compose_dir: Path) -> str | None:
|
||||
return None # Named volume
|
||||
|
||||
|
||||
def _is_socket(path: str) -> bool:
|
||||
"""Check if a path is a socket (e.g., SSH agent socket)."""
|
||||
try:
|
||||
return stat.S_ISSOCK(Path(path).stat().st_mode)
|
||||
except (FileNotFoundError, PermissionError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _parse_volume_item(
|
||||
item: str | dict[str, Any],
|
||||
env: dict[str, str],
|
||||
compose_dir: Path,
|
||||
) -> str | None:
|
||||
"""Parse a single volume item and return host path if it's a bind mount."""
|
||||
"""Parse a single volume item and return host path if it's a bind mount.
|
||||
|
||||
Skips socket paths (e.g., SSH_AUTH_SOCK) since they're machine-local
|
||||
and shouldn't be validated on remote hosts.
|
||||
"""
|
||||
host_path: str | None = None
|
||||
|
||||
if isinstance(item, str):
|
||||
interpolated = _interpolate(item, env)
|
||||
parts = interpolated.split(":")
|
||||
if len(parts) >= _MIN_VOLUME_PARTS:
|
||||
return _resolve_host_path(parts[0], compose_dir)
|
||||
host_path = _resolve_host_path(parts[0], compose_dir)
|
||||
elif isinstance(item, dict) and item.get("type") == "bind":
|
||||
source = item.get("source")
|
||||
if source:
|
||||
interpolated = _interpolate(str(source), env)
|
||||
return _resolve_host_path(interpolated, compose_dir)
|
||||
return None
|
||||
host_path = _resolve_host_path(interpolated, compose_dir)
|
||||
|
||||
# Skip sockets - they're machine-local (e.g., SSH agent)
|
||||
if host_path and _is_socket(host_path):
|
||||
return None
|
||||
|
||||
return host_path
|
||||
|
||||
|
||||
def parse_host_volumes(config: Config, service: str) -> list[str]:
|
||||
|
||||
@@ -4,3 +4,35 @@ from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
err_console = Console(stderr=True, highlight=False)
|
||||
|
||||
|
||||
# --- Message Constants ---
|
||||
# Standardized message templates for consistent user-facing output
|
||||
|
||||
MSG_SERVICE_NOT_FOUND = "Service [cyan]{name}[/] not found in config"
|
||||
MSG_HOST_NOT_FOUND = "Host [magenta]{name}[/] not found in config"
|
||||
MSG_CONFIG_NOT_FOUND = "Config file not found"
|
||||
MSG_DRY_RUN = "[dim](dry-run: no changes made)[/]"
|
||||
|
||||
|
||||
# --- Message Helper Functions ---
|
||||
|
||||
|
||||
def print_error(msg: str) -> None:
|
||||
"""Print error message with ✗ prefix to stderr."""
|
||||
err_console.print(f"[red]✗[/] {msg}")
|
||||
|
||||
|
||||
def print_success(msg: str) -> None:
|
||||
"""Print success message with ✓ prefix to stdout."""
|
||||
console.print(f"[green]✓[/] {msg}")
|
||||
|
||||
|
||||
def print_warning(msg: str) -> None:
|
||||
"""Print warning message with ! prefix to stderr."""
|
||||
err_console.print(f"[yellow]![/] {msg}")
|
||||
|
||||
|
||||
def print_hint(msg: str) -> None:
|
||||
"""Print hint message in dim style to stdout."""
|
||||
console.print(f"[dim]Hint: {msg}[/]")
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from rich.markup import escape
|
||||
|
||||
from .console import console, err_console
|
||||
from .ssh_keys import get_key_path, get_ssh_auth_sock, get_ssh_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
@@ -22,6 +23,43 @@ LOCAL_ADDRESSES = frozenset({"local", "localhost", "127.0.0.1", "::1"})
|
||||
_DEFAULT_SSH_PORT = 22
|
||||
|
||||
|
||||
def build_ssh_command(host: Host, command: str, *, tty: bool = False) -> list[str]:
|
||||
"""Build SSH command args for executing a command on a remote host.
|
||||
|
||||
Args:
|
||||
host: Host configuration with address, port, user
|
||||
command: Command to run on the remote host
|
||||
tty: Whether to allocate a TTY (for interactive/progress bar commands)
|
||||
|
||||
Returns:
|
||||
List of command args suitable for subprocess
|
||||
|
||||
"""
|
||||
ssh_args = [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-o",
|
||||
"UserKnownHostsFile=/dev/null",
|
||||
"-o",
|
||||
"LogLevel=ERROR",
|
||||
]
|
||||
if tty:
|
||||
ssh_args.insert(1, "-tt") # Force TTY allocation
|
||||
|
||||
key_path = get_key_path()
|
||||
if key_path:
|
||||
ssh_args.extend(["-i", str(key_path)])
|
||||
|
||||
if host.port != _DEFAULT_SSH_PORT:
|
||||
ssh_args.extend(["-p", str(host.port)])
|
||||
|
||||
ssh_args.append(f"{host.user}@{host.address}")
|
||||
ssh_args.append(command)
|
||||
|
||||
return ssh_args
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_local_ips() -> frozenset[str]:
|
||||
"""Get all IP addresses of the current machine."""
|
||||
@@ -71,6 +109,25 @@ def is_local(host: Host) -> bool:
|
||||
return addr in _get_local_ips()
|
||||
|
||||
|
||||
def ssh_connect_kwargs(host: Host) -> dict[str, Any]:
|
||||
"""Get kwargs for asyncssh.connect() from a Host config."""
|
||||
kwargs: dict[str, Any] = {
|
||||
"host": host.address,
|
||||
"port": host.port,
|
||||
"username": host.user,
|
||||
"known_hosts": None,
|
||||
}
|
||||
# Add SSH agent path (auto-detect forwarded agent if needed)
|
||||
agent_path = get_ssh_auth_sock()
|
||||
if agent_path:
|
||||
kwargs["agent_path"] = agent_path
|
||||
# Add key file fallback for when SSH agent is unavailable
|
||||
key_path = get_key_path()
|
||||
if key_path:
|
||||
kwargs["client_keys"] = [str(key_path)]
|
||||
return kwargs
|
||||
|
||||
|
||||
async def _run_local_command(
|
||||
command: str,
|
||||
service: str,
|
||||
@@ -152,12 +209,10 @@ async def _run_ssh_command(
|
||||
"""Run a command on a remote host via SSH with streaming output."""
|
||||
if raw:
|
||||
# Use native ssh with TTY for proper progress bar rendering
|
||||
ssh_args = ["ssh", "-t"]
|
||||
if host.port != _DEFAULT_SSH_PORT:
|
||||
ssh_args.extend(["-p", str(host.port)])
|
||||
ssh_args.extend([f"{host.user}@{host.address}", command])
|
||||
ssh_args = build_ssh_command(host, command, tty=True)
|
||||
# Run in thread to avoid blocking the event loop
|
||||
result = await asyncio.to_thread(subprocess.run, ssh_args, check=False)
|
||||
# Use get_ssh_env() to auto-detect SSH agent socket
|
||||
result = await asyncio.to_thread(subprocess.run, ssh_args, check=False, env=get_ssh_env())
|
||||
return CommandResult(
|
||||
service=service,
|
||||
exit_code=result.returncode,
|
||||
@@ -168,12 +223,7 @@ async def _run_ssh_command(
|
||||
|
||||
proc: asyncssh.SSHClientProcess[Any]
|
||||
try:
|
||||
async with asyncssh.connect( # noqa: SIM117 - conn needed before create_process
|
||||
host.address,
|
||||
port=host.port,
|
||||
username=host.user,
|
||||
known_hosts=None,
|
||||
) as conn:
|
||||
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn: # noqa: SIM117
|
||||
async with conn.create_process(command) as proc:
|
||||
if stream:
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import asyncio
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
from .compose import parse_devices, parse_external_networks, parse_host_volumes
|
||||
from .console import console, err_console
|
||||
from .console import console, err_console, print_error, print_success, print_warning
|
||||
from .executor import (
|
||||
CommandResult,
|
||||
check_networks_exist,
|
||||
@@ -145,9 +145,7 @@ async def _cleanup_and_rollback(
|
||||
raw: bool = False,
|
||||
) -> None:
|
||||
"""Clean up failed start and attempt rollback to old host if it was running."""
|
||||
err_console.print(
|
||||
f"{prefix} [yellow]![/] Cleaning up failed start on [magenta]{target_host}[/]"
|
||||
)
|
||||
print_warning(f"{prefix} Cleaning up failed start on [magenta]{target_host}[/]")
|
||||
await run_compose(cfg, service, "down", raw=raw)
|
||||
|
||||
if not was_running:
|
||||
@@ -156,12 +154,12 @@ async def _cleanup_and_rollback(
|
||||
)
|
||||
return
|
||||
|
||||
err_console.print(f"{prefix} [yellow]![/] Rolling back to [magenta]{current_host}[/]...")
|
||||
print_warning(f"{prefix} Rolling back to [magenta]{current_host}[/]...")
|
||||
rollback_result = await run_compose_on_host(cfg, service, current_host, "up -d", raw=raw)
|
||||
if rollback_result.success:
|
||||
console.print(f"{prefix} [green]✓[/] Rollback succeeded on [magenta]{current_host}[/]")
|
||||
print_success(f"{prefix} Rollback succeeded on [magenta]{current_host}[/]")
|
||||
else:
|
||||
err_console.print(f"{prefix} [red]✗[/] Rollback failed - service is down")
|
||||
print_error(f"{prefix} Rollback failed - service is down")
|
||||
|
||||
|
||||
def _report_preflight_failures(
|
||||
@@ -170,17 +168,15 @@ def _report_preflight_failures(
|
||||
preflight: PreflightResult,
|
||||
) -> None:
|
||||
"""Report pre-flight check failures."""
|
||||
err_console.print(
|
||||
f"[cyan]\\[{service}][/] [red]✗[/] Cannot start on [magenta]{target_host}[/]:"
|
||||
)
|
||||
print_error(f"[cyan]\\[{service}][/] Cannot start on [magenta]{target_host}[/]:")
|
||||
for path in preflight.missing_paths:
|
||||
err_console.print(f" [red]✗[/] missing path: {path}")
|
||||
print_error(f" missing path: {path}")
|
||||
for net in preflight.missing_networks:
|
||||
err_console.print(f" [red]✗[/] missing network: {net}")
|
||||
print_error(f" missing network: {net}")
|
||||
if preflight.missing_networks:
|
||||
err_console.print(f" [dim]hint: cf init-network {target_host}[/]")
|
||||
err_console.print(f" [dim]Hint: cf init-network {target_host}[/]")
|
||||
for dev in preflight.missing_devices:
|
||||
err_console.print(f" [red]✗[/] missing device: {dev}")
|
||||
print_error(f" missing device: {dev}")
|
||||
|
||||
|
||||
async def _up_multi_host_service(
|
||||
@@ -252,8 +248,8 @@ async def _migrate_service(
|
||||
for cmd, label in [("pull --ignore-buildable", "Pull"), ("build", "Build")]:
|
||||
result = await _run_compose_step(cfg, service, cmd, raw=raw)
|
||||
if not result.success:
|
||||
err_console.print(
|
||||
f"{prefix} [red]✗[/] {label} failed on [magenta]{target_host}[/], "
|
||||
print_error(
|
||||
f"{prefix} {label} failed on [magenta]{target_host}[/], "
|
||||
"leaving service on current host"
|
||||
)
|
||||
return result
|
||||
@@ -293,9 +289,8 @@ async def _up_single_service(
|
||||
return failure
|
||||
did_migration = True
|
||||
else:
|
||||
err_console.print(
|
||||
f"{prefix} [yellow]![/] was on "
|
||||
f"[magenta]{current_host}[/] (not in config), skipping down"
|
||||
print_warning(
|
||||
f"{prefix} was on [magenta]{current_host}[/] (not in config), skipping down"
|
||||
)
|
||||
|
||||
# Start on target host
|
||||
@@ -391,9 +386,7 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
|
||||
for host in host_list:
|
||||
# Skip hosts no longer in config
|
||||
if host not in cfg.hosts:
|
||||
console.print(
|
||||
f" [yellow]![/] {service}@{host}: host no longer in config, skipping"
|
||||
)
|
||||
print_warning(f"{service}@{host}: host no longer in config, skipping")
|
||||
results.append(
|
||||
CommandResult(
|
||||
service=f"{service}@{host}",
|
||||
@@ -413,11 +406,11 @@ async def stop_orphaned_services(cfg: Config) -> list[CommandResult]:
|
||||
result = await task
|
||||
results.append(result)
|
||||
if result.success:
|
||||
console.print(f" [green]✓[/] {service}@{host}: stopped")
|
||||
print_success(f"{service}@{host}: stopped")
|
||||
else:
|
||||
console.print(f" [red]✗[/] {service}@{host}: {result.stderr or 'failed'}")
|
||||
print_error(f"{service}@{host}: {result.stderr or 'failed'}")
|
||||
except Exception as e:
|
||||
console.print(f" [red]✗[/] {service}@{host}: {e}")
|
||||
print_error(f"{service}@{host}: {e}")
|
||||
results.append(
|
||||
CommandResult(
|
||||
service=f"{service}@{host}",
|
||||
|
||||
67
src/compose_farm/ssh_keys.py
Normal file
67
src/compose_farm/ssh_keys.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""SSH key utilities for compose-farm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Default key paths for compose-farm SSH key
|
||||
# Keys are stored in a subdirectory for cleaner docker volume mounting
|
||||
SSH_KEY_DIR = Path.home() / ".ssh" / "compose-farm"
|
||||
SSH_KEY_PATH = SSH_KEY_DIR / "id_ed25519"
|
||||
SSH_PUBKEY_PATH = SSH_KEY_PATH.with_suffix(".pub")
|
||||
|
||||
|
||||
def get_ssh_auth_sock() -> str | None:
|
||||
"""Get SSH_AUTH_SOCK, auto-detecting forwarded agent if needed.
|
||||
|
||||
Checks in order:
|
||||
1. SSH_AUTH_SOCK environment variable (if socket exists)
|
||||
2. Forwarded agent sockets in ~/.ssh/agent/ (most recent first)
|
||||
|
||||
Returns the socket path or None if no valid socket found.
|
||||
"""
|
||||
sock = os.environ.get("SSH_AUTH_SOCK")
|
||||
if sock and Path(sock).is_socket():
|
||||
return sock
|
||||
|
||||
# Try to find a forwarded SSH agent socket
|
||||
agent_dir = Path.home() / ".ssh" / "agent"
|
||||
if agent_dir.is_dir():
|
||||
sockets = sorted(
|
||||
agent_dir.glob("s.*.sshd.*"), key=lambda p: p.stat().st_mtime, reverse=True
|
||||
)
|
||||
for s in sockets:
|
||||
if s.is_socket():
|
||||
return str(s)
|
||||
return None
|
||||
|
||||
|
||||
def get_ssh_env() -> dict[str, str]:
|
||||
"""Get environment dict for SSH subprocess with auto-detected agent.
|
||||
|
||||
Returns a copy of the current environment with SSH_AUTH_SOCK set
|
||||
to the auto-detected agent socket (if found).
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
sock = get_ssh_auth_sock()
|
||||
if sock:
|
||||
env["SSH_AUTH_SOCK"] = sock
|
||||
return env
|
||||
|
||||
|
||||
def key_exists() -> bool:
|
||||
"""Check if the compose-farm SSH key pair exists."""
|
||||
return SSH_KEY_PATH.exists() and SSH_PUBKEY_PATH.exists()
|
||||
|
||||
|
||||
def get_key_path() -> Path | None:
|
||||
"""Get the SSH key path if it exists, None otherwise."""
|
||||
return SSH_KEY_PATH if key_exists() else None
|
||||
|
||||
|
||||
def get_pubkey_content() -> str | None:
|
||||
"""Get the public key content if it exists, None otherwise."""
|
||||
if not SSH_PUBKEY_PATH.exists():
|
||||
return None
|
||||
return SSH_PUBKEY_PATH.read_text().strip()
|
||||
@@ -8,11 +8,44 @@ from typing import TYPE_CHECKING, Any
|
||||
import yaml
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
def group_services_by_host(
|
||||
services: dict[str, str | list[str]],
|
||||
hosts: Mapping[str, object],
|
||||
all_hosts: list[str] | None = None,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Group services by their assigned host(s).
|
||||
|
||||
For multi-host services (list or "all"), the service appears in multiple host lists.
|
||||
"""
|
||||
by_host: dict[str, list[str]] = {h: [] for h in hosts}
|
||||
for service, host_value in services.items():
|
||||
if isinstance(host_value, list):
|
||||
for host_name in host_value:
|
||||
if host_name in by_host:
|
||||
by_host[host_name].append(service)
|
||||
elif host_value == "all" and all_hosts:
|
||||
for host_name in all_hosts:
|
||||
if host_name in by_host:
|
||||
by_host[host_name].append(service)
|
||||
elif host_value in by_host:
|
||||
by_host[host_value].append(service)
|
||||
return by_host
|
||||
|
||||
|
||||
def group_running_services_by_host(
|
||||
state: dict[str, str | list[str]],
|
||||
hosts: Mapping[str, object],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Group running services by host, filtering out hosts with no services."""
|
||||
by_host = group_services_by_host(state, hosts)
|
||||
return {h: svcs for h, svcs in by_host.items() if svcs}
|
||||
|
||||
|
||||
def load_state(config: Config) -> dict[str, str | list[str]]:
|
||||
"""Load the current deployment state.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from compose_farm.config import Config
|
||||
@@ -30,3 +31,10 @@ def get_config() -> Config:
|
||||
def get_templates() -> Jinja2Templates:
|
||||
"""Get Jinja2 templates instance."""
|
||||
return Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
|
||||
def extract_config_error(exc: Exception) -> str:
|
||||
"""Extract a user-friendly error message from a config exception."""
|
||||
if isinstance(exc, ValidationError):
|
||||
return "; ".join(err.get("msg", str(err)) for err in exc.errors())
|
||||
return str(exc)
|
||||
|
||||
@@ -2,19 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
import shlex
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import asyncssh
|
||||
import yaml
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from fastapi import APIRouter, Body, HTTPException, Query
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from compose_farm.executor import run_compose_on_host
|
||||
from compose_farm.executor import is_local, run_compose_on_host, ssh_connect_kwargs
|
||||
from compose_farm.paths import find_config_path
|
||||
from compose_farm.state import load_state
|
||||
from compose_farm.web.deps import get_config, get_templates
|
||||
@@ -30,6 +31,51 @@ def _validate_yaml(content: str) -> None:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid YAML: {e}") from e
|
||||
|
||||
|
||||
def _backup_file(file_path: Path) -> Path | None:
|
||||
"""Create a timestamped backup of a file if it exists and content differs.
|
||||
|
||||
Backups are stored in a .backups directory alongside the file.
|
||||
Returns the backup path if created, None if no backup was needed.
|
||||
"""
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Create backup directory
|
||||
backup_dir = file_path.parent / ".backups"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate timestamped backup filename
|
||||
timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S")
|
||||
backup_name = f"{file_path.name}.{timestamp}"
|
||||
backup_path = backup_dir / backup_name
|
||||
|
||||
# Copy current content to backup
|
||||
backup_path.write_text(file_path.read_text())
|
||||
|
||||
# Clean up old backups (keep last 200)
|
||||
backups = sorted(backup_dir.glob(f"{file_path.name}.*"), reverse=True)
|
||||
for old_backup in backups[200:]:
|
||||
old_backup.unlink()
|
||||
|
||||
return backup_path
|
||||
|
||||
|
||||
def _save_with_backup(file_path: Path, content: str) -> bool:
|
||||
"""Save content to file, creating a backup first if content changed.
|
||||
|
||||
Returns True if file was saved, False if content was unchanged.
|
||||
"""
|
||||
# Check if content actually changed
|
||||
if file_path.exists():
|
||||
current_content = file_path.read_text()
|
||||
if current_content == content:
|
||||
return False # No change, skip save
|
||||
_backup_file(file_path)
|
||||
|
||||
file_path.write_text(content)
|
||||
return True
|
||||
|
||||
|
||||
def _get_service_compose_path(name: str) -> Path:
|
||||
"""Get compose path for service, raising HTTPException if not found."""
|
||||
config = get_config()
|
||||
@@ -183,8 +229,9 @@ async def save_compose(
|
||||
"""Save compose file content."""
|
||||
compose_path = _get_service_compose_path(name)
|
||||
_validate_yaml(content)
|
||||
compose_path.write_text(content)
|
||||
return {"success": True, "message": "Compose file saved"}
|
||||
saved = _save_with_backup(compose_path, content)
|
||||
msg = "Compose file saved" if saved else "No changes to save"
|
||||
return {"success": True, "message": msg}
|
||||
|
||||
|
||||
@router.put("/service/{name}/env")
|
||||
@@ -193,8 +240,9 @@ async def save_env(
|
||||
) -> dict[str, Any]:
|
||||
"""Save .env file content."""
|
||||
env_path = _get_service_compose_path(name).parent / ".env"
|
||||
env_path.write_text(content)
|
||||
return {"success": True, "message": ".env file saved"}
|
||||
saved = _save_with_backup(env_path, content)
|
||||
msg = ".env file saved" if saved else "No changes to save"
|
||||
return {"success": True, "message": msg}
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
@@ -207,6 +255,106 @@ async def save_config(
|
||||
raise HTTPException(status_code=404, detail="Config file not found")
|
||||
|
||||
_validate_yaml(content)
|
||||
config_path.write_text(content)
|
||||
saved = _save_with_backup(config_path, content)
|
||||
msg = "Config saved" if saved else "No changes to save"
|
||||
return {"success": True, "message": msg}
|
||||
|
||||
return {"success": True, "message": "Config saved"}
|
||||
|
||||
async def _read_file_local(path: str) -> str:
|
||||
"""Read a file from the local filesystem."""
|
||||
expanded = Path(path).expanduser()
|
||||
return await asyncio.to_thread(expanded.read_text, encoding="utf-8")
|
||||
|
||||
|
||||
async def _write_file_local(path: str, content: str) -> bool:
|
||||
"""Write content to a file on the local filesystem with backup.
|
||||
|
||||
Returns True if file was saved, False if content was unchanged.
|
||||
"""
|
||||
expanded = Path(path).expanduser()
|
||||
return await asyncio.to_thread(_save_with_backup, expanded, content)
|
||||
|
||||
|
||||
async def _read_file_remote(host: Any, path: str) -> str:
|
||||
"""Read a file from a remote host via SSH."""
|
||||
# Expand ~ on remote by using shell
|
||||
cmd = f"cat {shlex.quote(path)}"
|
||||
if path.startswith("~/"):
|
||||
cmd = f"cat ~/{shlex.quote(path[2:])}"
|
||||
|
||||
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
|
||||
result = await conn.run(cmd, check=True)
|
||||
stdout = result.stdout or ""
|
||||
return stdout.decode() if isinstance(stdout, bytes) else stdout
|
||||
|
||||
|
||||
async def _write_file_remote(host: Any, path: str, content: str) -> None:
|
||||
"""Write content to a file on a remote host via SSH."""
|
||||
# Expand ~ on remote by using shell
|
||||
target_path = f"~/{path[2:]}" if path.startswith("~/") else path
|
||||
cmd = f"cat > {shlex.quote(target_path)}"
|
||||
|
||||
async with asyncssh.connect(**ssh_connect_kwargs(host)) as conn:
|
||||
result = await conn.run(cmd, input=content, check=True)
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
|
||||
msg = f"Failed to write file: {stderr}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def _get_console_host(host: str, path: str) -> Any:
|
||||
"""Validate and return host config for console file operations."""
|
||||
config = get_config()
|
||||
host_config = config.hosts.get(host)
|
||||
|
||||
if not host_config:
|
||||
raise HTTPException(status_code=404, detail=f"Host '{host}' not found")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="Path is required")
|
||||
|
||||
return host_config
|
||||
|
||||
|
||||
@router.get("/console/file")
|
||||
async def read_console_file(
|
||||
host: Annotated[str, Query(description="Host name")],
|
||||
path: Annotated[str, Query(description="File path")],
|
||||
) -> dict[str, Any]:
|
||||
"""Read a file from a host for the console editor."""
|
||||
host_config = _get_console_host(host, path)
|
||||
|
||||
try:
|
||||
if is_local(host_config):
|
||||
content = await _read_file_local(path)
|
||||
else:
|
||||
content = await _read_file_remote(host_config, path)
|
||||
return {"success": True, "content": content}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {path}") from None
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.put("/console/file")
|
||||
async def write_console_file(
|
||||
host: Annotated[str, Query(description="Host name")],
|
||||
path: Annotated[str, Query(description="File path")],
|
||||
content: Annotated[str, Body(media_type="text/plain")],
|
||||
) -> dict[str, Any]:
|
||||
"""Write a file to a host from the console editor."""
|
||||
host_config = _get_console_host(host, path)
|
||||
|
||||
try:
|
||||
if is_local(host_config):
|
||||
saved = await _write_file_local(path, content)
|
||||
msg = f"Saved: {path}" if saved else "No changes to save"
|
||||
else:
|
||||
await _write_file_remote(host_config, path, content)
|
||||
msg = f"Saved: {path}" # Remote doesn't track changes
|
||||
return {"success": True, "message": msg}
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=403, detail=f"Permission denied: {path}") from None
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
@@ -7,19 +7,57 @@ from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from compose_farm.executor import is_local
|
||||
from compose_farm.paths import find_config_path
|
||||
from compose_farm.state import (
|
||||
get_orphaned_services,
|
||||
get_service_host,
|
||||
get_services_needing_migration,
|
||||
get_services_not_in_state,
|
||||
group_running_services_by_host,
|
||||
load_state,
|
||||
)
|
||||
from compose_farm.web.deps import get_config, get_templates
|
||||
from compose_farm.web.deps import (
|
||||
extract_config_error,
|
||||
get_config,
|
||||
get_templates,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/console", response_class=HTMLResponse)
|
||||
async def console(request: Request) -> HTMLResponse:
|
||||
"""Console page with terminal and editor."""
|
||||
config = get_config()
|
||||
templates = get_templates()
|
||||
|
||||
# Find local host and sort it first
|
||||
local_host = None
|
||||
for name, host in config.hosts.items():
|
||||
if is_local(host):
|
||||
local_host = name
|
||||
break
|
||||
|
||||
# Sort hosts with local first
|
||||
hosts = sorted(config.hosts.keys())
|
||||
if local_host:
|
||||
hosts = [local_host] + [h for h in hosts if h != local_host]
|
||||
|
||||
# Get config path for default editor file
|
||||
config_path = str(config.config_path) if config.config_path else ""
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"console.html",
|
||||
{
|
||||
"request": request,
|
||||
"hosts": hosts,
|
||||
"local_host": local_host,
|
||||
"config_path": config_path,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request) -> HTMLResponse:
|
||||
"""Dashboard page - combined view of all cluster info."""
|
||||
@@ -30,11 +68,7 @@ async def index(request: Request) -> HTMLResponse:
|
||||
try:
|
||||
config = get_config()
|
||||
except (ValidationError, FileNotFoundError) as e:
|
||||
# Extract error message
|
||||
if isinstance(e, ValidationError):
|
||||
config_error = "; ".join(err.get("msg", str(err)) for err in e.errors())
|
||||
else:
|
||||
config_error = str(e)
|
||||
config_error = extract_config_error(e)
|
||||
|
||||
# Read raw config content for the editor
|
||||
config_path = find_config_path()
|
||||
@@ -70,14 +104,8 @@ async def index(request: Request) -> HTMLResponse:
|
||||
migrations = get_services_needing_migration(config)
|
||||
not_started = get_services_not_in_state(config)
|
||||
|
||||
# Group services by host
|
||||
services_by_host: dict[str, list[str]] = {}
|
||||
for svc, host in deployed.items():
|
||||
if isinstance(host, list):
|
||||
for h in host:
|
||||
services_by_host.setdefault(h, []).append(svc)
|
||||
else:
|
||||
services_by_host.setdefault(host, []).append(svc)
|
||||
# Group services by host (filter out hosts with no running services)
|
||||
services_by_host = group_running_services_by_host(deployed, config.hosts)
|
||||
|
||||
# Config file content
|
||||
config_content = ""
|
||||
@@ -186,10 +214,7 @@ async def config_error_partial(request: Request) -> HTMLResponse:
|
||||
get_config()
|
||||
return HTMLResponse("") # No error
|
||||
except (ValidationError, FileNotFoundError) as e:
|
||||
if isinstance(e, ValidationError):
|
||||
error = "; ".join(err.get("msg", str(err)) for err in e.errors())
|
||||
else:
|
||||
error = str(e)
|
||||
error = extract_config_error(e)
|
||||
return templates.TemplateResponse(
|
||||
"partials/config_error.html", {"request": request, "config_error": error}
|
||||
)
|
||||
@@ -246,15 +271,7 @@ async def services_by_host_partial(request: Request, expanded: bool = True) -> H
|
||||
templates = get_templates()
|
||||
|
||||
deployed = load_state(config)
|
||||
|
||||
# Group services by host
|
||||
services_by_host: dict[str, list[str]] = {}
|
||||
for svc, host in deployed.items():
|
||||
if isinstance(host, list):
|
||||
for h in host:
|
||||
services_by_host.setdefault(h, []).append(svc)
|
||||
else:
|
||||
services_by_host.setdefault(host, []).append(svc)
|
||||
services_by_host = group_running_services_by_host(deployed, config.hosts)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"partials/services_by_host.html",
|
||||
|
||||
@@ -17,6 +17,35 @@ const editors = {};
|
||||
let monacoLoaded = false;
|
||||
let monacoLoading = false;
|
||||
|
||||
// Language detection from file path
|
||||
const LANGUAGE_MAP = {
|
||||
'yaml': 'yaml', 'yml': 'yaml',
|
||||
'json': 'json',
|
||||
'js': 'javascript', 'mjs': 'javascript',
|
||||
'ts': 'typescript', 'tsx': 'typescript',
|
||||
'py': 'python',
|
||||
'sh': 'shell', 'bash': 'shell',
|
||||
'md': 'markdown',
|
||||
'html': 'html', 'htm': 'html',
|
||||
'css': 'css',
|
||||
'sql': 'sql',
|
||||
'toml': 'toml',
|
||||
'ini': 'ini', 'conf': 'ini',
|
||||
'dockerfile': 'dockerfile',
|
||||
'env': 'plaintext'
|
||||
};
|
||||
|
||||
/**
|
||||
* Get Monaco language from file path
|
||||
* @param {string} path - File path
|
||||
* @returns {string} Monaco language identifier
|
||||
*/
|
||||
function getLanguageFromPath(path) {
|
||||
const ext = path.split('.').pop().toLowerCase();
|
||||
return LANGUAGE_MAP[ext] || 'plaintext';
|
||||
}
|
||||
window.getLanguageFromPath = getLanguageFromPath;
|
||||
|
||||
// Terminal color theme (dark mode matching PicoCSS)
|
||||
const TERMINAL_THEME = {
|
||||
background: '#1a1a2e',
|
||||
@@ -87,6 +116,7 @@ function createWebSocket(path) {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
return new WebSocket(`${protocol}//${window.location.host}${path}`);
|
||||
}
|
||||
window.createWebSocket = createWebSocket;
|
||||
|
||||
/**
|
||||
* Initialize a terminal and connect to WebSocket for streaming
|
||||
@@ -223,10 +253,16 @@ function loadMonaco(callback) {
|
||||
* @param {HTMLElement} container - Container element
|
||||
* @param {string} content - Initial content
|
||||
* @param {string} language - Editor language (yaml, plaintext, etc.)
|
||||
* @param {boolean} readonly - Whether editor is read-only
|
||||
* @param {object} opts - Options: { readonly, onSave }
|
||||
* @returns {object} Monaco editor instance
|
||||
*/
|
||||
function createEditor(container, content, language, readonly = false) {
|
||||
function createEditor(container, content, language, opts = {}) {
|
||||
// Support legacy boolean readonly parameter
|
||||
if (typeof opts === 'boolean') {
|
||||
opts = { readonly: opts };
|
||||
}
|
||||
const { readonly = false, onSave = null } = opts;
|
||||
|
||||
const options = {
|
||||
value: content,
|
||||
language: language,
|
||||
@@ -249,12 +285,17 @@ function createEditor(container, content, language, readonly = false) {
|
||||
// Add Command+S / Ctrl+S handler for editable editors
|
||||
if (!readonly) {
|
||||
editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, function() {
|
||||
saveAllEditors();
|
||||
if (onSave) {
|
||||
onSave(editor);
|
||||
} else {
|
||||
saveAllEditors();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return editor;
|
||||
}
|
||||
window.createEditor = createEditor;
|
||||
|
||||
/**
|
||||
* Initialize all Monaco editors on the page
|
||||
@@ -370,6 +411,16 @@ function initPage() {
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initPage();
|
||||
initKeyboardShortcuts();
|
||||
|
||||
// Handle ?action= parameter (from command palette navigation)
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const action = params.get('action');
|
||||
if (action && window.location.pathname === '/') {
|
||||
// Clear the URL parameter
|
||||
history.replaceState({}, '', '/');
|
||||
// Trigger the action
|
||||
htmx.ajax('POST', `/api/${action}`, {swap: 'none'});
|
||||
}
|
||||
});
|
||||
|
||||
// Re-initialize after HTMX swaps main content
|
||||
@@ -441,41 +492,59 @@ document.body.addEventListener('htmx:afterRequest', function(evt) {
|
||||
const fab = document.getElementById('cmd-fab');
|
||||
if (!dialog || !input || !list) return;
|
||||
|
||||
const colors = { service: '#22c55e', action: '#eab308', nav: '#3b82f6' };
|
||||
// Load icons from template (rendered server-side from icons.html)
|
||||
const iconTemplate = document.getElementById('cmd-icons');
|
||||
const icons = {};
|
||||
if (iconTemplate) {
|
||||
iconTemplate.content.querySelectorAll('[data-icon]').forEach(el => {
|
||||
icons[el.dataset.icon] = el.innerHTML;
|
||||
});
|
||||
}
|
||||
|
||||
const colors = { service: '#22c55e', action: '#eab308', nav: '#3b82f6', app: '#a855f7' };
|
||||
let commands = [];
|
||||
let filtered = [];
|
||||
let selected = 0;
|
||||
|
||||
const post = (url) => () => htmx.ajax('POST', url, {swap: 'none'});
|
||||
const nav = (url) => () => window.location.href = url;
|
||||
const cmd = (type, name, desc, action) => ({ type, name, desc, action });
|
||||
// Navigate to dashboard and trigger action (or just POST if already on dashboard)
|
||||
const dashboardAction = (endpoint) => () => {
|
||||
if (window.location.pathname === '/') {
|
||||
htmx.ajax('POST', `/api/${endpoint}`, {swap: 'none'});
|
||||
} else {
|
||||
window.location.href = `/?action=${endpoint}`;
|
||||
}
|
||||
};
|
||||
const cmd = (type, name, desc, action, icon = null) => ({ type, name, desc, action, icon });
|
||||
|
||||
function buildCommands() {
|
||||
const actions = [
|
||||
cmd('action', 'Apply', 'Make reality match config', post('/api/apply')),
|
||||
cmd('action', 'Refresh', 'Update state from reality', post('/api/refresh')),
|
||||
cmd('nav', 'Dashboard', 'Go to dashboard', nav('/')),
|
||||
cmd('action', 'Apply', 'Make reality match config', dashboardAction('apply'), icons.check),
|
||||
cmd('action', 'Refresh', 'Update state from reality', dashboardAction('refresh'), icons.refresh_cw),
|
||||
cmd('app', 'Dashboard', 'Go to dashboard', nav('/'), icons.home),
|
||||
cmd('app', 'Console', 'Go to console', nav('/console'), icons.terminal),
|
||||
];
|
||||
|
||||
// Add service-specific actions if on a service page
|
||||
const match = window.location.pathname.match(/^\/service\/(.+)$/);
|
||||
if (match) {
|
||||
const svc = decodeURIComponent(match[1]);
|
||||
const svcCmd = (name, desc, endpoint) => cmd('service', name, `${desc} ${svc}`, post(`/api/service/${svc}/${endpoint}`));
|
||||
const svcCmd = (name, desc, endpoint, icon) => cmd('service', name, `${desc} ${svc}`, post(`/api/service/${svc}/${endpoint}`), icon);
|
||||
actions.unshift(
|
||||
svcCmd('Up', 'Start', 'up'),
|
||||
svcCmd('Down', 'Stop', 'down'),
|
||||
svcCmd('Restart', 'Restart', 'restart'),
|
||||
svcCmd('Pull', 'Pull', 'pull'),
|
||||
svcCmd('Update', 'Pull + restart', 'update'),
|
||||
svcCmd('Logs', 'View logs for', 'logs'),
|
||||
svcCmd('Up', 'Start', 'up', icons.play),
|
||||
svcCmd('Down', 'Stop', 'down', icons.square),
|
||||
svcCmd('Restart', 'Restart', 'restart', icons.rotate_cw),
|
||||
svcCmd('Pull', 'Pull', 'pull', icons.cloud_download),
|
||||
svcCmd('Update', 'Pull + restart', 'update', icons.refresh_cw),
|
||||
svcCmd('Logs', 'View logs for', 'logs', icons.file_text),
|
||||
);
|
||||
}
|
||||
|
||||
// Add nav commands for all services from sidebar
|
||||
const services = [...document.querySelectorAll('#sidebar-services li[data-svc] a[href]')].map(a => {
|
||||
const name = a.getAttribute('href').replace('/service/', '');
|
||||
return cmd('nav', name, 'Go to service', nav(`/service/${name}`));
|
||||
return cmd('nav', name, 'Go to service', nav(`/service/${name}`), icons.box);
|
||||
});
|
||||
|
||||
commands = [...actions, ...services];
|
||||
@@ -490,10 +559,13 @@ document.body.addEventListener('htmx:afterRequest', function(evt) {
|
||||
function render() {
|
||||
list.innerHTML = filtered.map((c, i) => `
|
||||
<a class="flex justify-between items-center px-3 py-2 rounded-r cursor-pointer hover:bg-base-200 border-l-4 ${i === selected ? 'bg-base-300' : ''}" style="border-left-color: ${colors[c.type] || '#666'}" data-idx="${i}">
|
||||
<span><span class="opacity-50 text-xs mr-2">${c.type}</span>${c.name}</span>
|
||||
<span class="flex items-center gap-2">${c.icon || ''}<span>${c.name}</span></span>
|
||||
<span class="opacity-40 text-xs">${c.desc}</span>
|
||||
</a>
|
||||
`).join('') || '<div class="opacity-50 p-2">No matches</div>';
|
||||
// Scroll selected item into view
|
||||
const sel = list.querySelector(`[data-idx="${selected}"]`);
|
||||
if (sel) sel.scrollIntoView({ block: 'nearest' });
|
||||
}
|
||||
|
||||
function open() {
|
||||
|
||||
@@ -4,12 +4,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from compose_farm.executor import build_ssh_command
|
||||
from compose_farm.ssh_keys import get_ssh_auth_sock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from compose_farm.config import Config
|
||||
|
||||
# Environment variable to identify the web service (for self-update detection)
|
||||
CF_WEB_SERVICE = os.environ.get("CF_WEB_SERVICE", "")
|
||||
|
||||
# ANSI escape codes for terminal output
|
||||
RED = "\x1b[31m"
|
||||
GREEN = "\x1b[32m"
|
||||
@@ -17,25 +22,6 @@ DIM = "\x1b[2m"
|
||||
RESET = "\x1b[0m"
|
||||
CRLF = "\r\n"
|
||||
|
||||
|
||||
def _get_ssh_auth_sock() -> str | None:
|
||||
"""Get SSH_AUTH_SOCK, auto-detecting forwarded agent if needed."""
|
||||
sock = os.environ.get("SSH_AUTH_SOCK")
|
||||
if sock and Path(sock).is_socket():
|
||||
return sock
|
||||
|
||||
# Try to find a forwarded SSH agent socket
|
||||
agent_dir = Path.home() / ".ssh" / "agent"
|
||||
if agent_dir.is_dir():
|
||||
sockets = sorted(
|
||||
agent_dir.glob("s.*.sshd.*"), key=lambda p: p.stat().st_mtime, reverse=True
|
||||
)
|
||||
for s in sockets:
|
||||
if s.is_socket():
|
||||
return str(s)
|
||||
return None
|
||||
|
||||
|
||||
# In-memory task registry
|
||||
tasks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@@ -69,7 +55,7 @@ async def run_cli_streaming(
|
||||
env = {"FORCE_COLOR": "1", "TERM": "xterm-256color", "COLUMNS": "120"}
|
||||
|
||||
# Ensure SSH agent is available (auto-detect if needed)
|
||||
ssh_sock = _get_ssh_auth_sock()
|
||||
ssh_sock = get_ssh_auth_sock()
|
||||
if ssh_sock:
|
||||
env["SSH_AUTH_SOCK"] = ssh_sock
|
||||
|
||||
@@ -97,6 +83,76 @@ async def run_cli_streaming(
|
||||
tasks[task_id]["status"] = "failed"
|
||||
|
||||
|
||||
def _is_self_update(service: str, command: str) -> bool:
|
||||
"""Check if this is a self-update (updating the web service itself).
|
||||
|
||||
Self-updates need special handling because running 'down' on the container
|
||||
we're running in would kill the process before 'up' can execute.
|
||||
"""
|
||||
if not CF_WEB_SERVICE or service != CF_WEB_SERVICE:
|
||||
return False
|
||||
# Commands that involve 'down' need SSH: update, restart, down
|
||||
return command in ("update", "restart", "down")
|
||||
|
||||
|
||||
async def _run_cli_via_ssh(
|
||||
config: Config,
|
||||
args: list[str],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Run a cf CLI command via SSH to the host.
|
||||
|
||||
Used for self-updates to ensure the command survives container restart.
|
||||
"""
|
||||
try:
|
||||
# Get the host for the web service
|
||||
host = config.get_host(CF_WEB_SERVICE)
|
||||
|
||||
# Build the remote command
|
||||
remote_cmd = f"cf {' '.join(args)} --config={config.config_path}"
|
||||
|
||||
# Show what we're doing
|
||||
await stream_to_task(
|
||||
task_id,
|
||||
f"{DIM}$ ssh {host.user}@{host.address} {remote_cmd}{RESET}{CRLF}",
|
||||
)
|
||||
await stream_to_task(
|
||||
task_id,
|
||||
f"{GREEN}Running via SSH (self-update protection){RESET}{CRLF}",
|
||||
)
|
||||
|
||||
# Build SSH command using shared helper
|
||||
ssh_args = build_ssh_command(host, remote_cmd)
|
||||
|
||||
# Set up environment with SSH agent
|
||||
env = {**os.environ, "FORCE_COLOR": "1", "TERM": "xterm-256color"}
|
||||
ssh_sock = get_ssh_auth_sock()
|
||||
if ssh_sock:
|
||||
env["SSH_AUTH_SOCK"] = ssh_sock
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*ssh_args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Stream output
|
||||
if process.stdout:
|
||||
async for line in process.stdout:
|
||||
text = line.decode("utf-8", errors="replace")
|
||||
if text.endswith("\n") and not text.endswith("\r\n"):
|
||||
text = text[:-1] + "\r\n"
|
||||
await stream_to_task(task_id, text)
|
||||
|
||||
exit_code = await process.wait()
|
||||
tasks[task_id]["status"] = "completed" if exit_code == 0 else "failed"
|
||||
|
||||
except Exception as e:
|
||||
await stream_to_task(task_id, f"{RED}Error: {e}{RESET}{CRLF}")
|
||||
tasks[task_id]["status"] = "failed"
|
||||
|
||||
|
||||
async def run_compose_streaming(
|
||||
config: Config,
|
||||
service: str,
|
||||
@@ -111,4 +167,9 @@ async def run_compose_streaming(
|
||||
|
||||
# Build CLI args
|
||||
cli_args = [cli_cmd, service, *extra_args]
|
||||
await run_cli_streaming(config, cli_args, task_id)
|
||||
|
||||
# Use SSH for self-updates to survive container restart
|
||||
if _is_self_update(service, cli_cmd):
|
||||
await _run_cli_via_ssh(config, cli_args, task_id)
|
||||
else:
|
||||
await run_cli_streaming(config, cli_args, task_id)
|
||||
|
||||
241
src/compose_farm/web/templates/console.html
Normal file
241
src/compose_farm/web/templates/console.html
Normal file
@@ -0,0 +1,241 @@
|
||||
{% extends "base.html" %}
|
||||
{% from "partials/components.html" import page_header %}
|
||||
{% from "partials/icons.html" import terminal, save %}
|
||||
{% block title %}Console - Compose Farm{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="max-w-6xl">
|
||||
{{ page_header("Console", "Terminal and editor access") }}
|
||||
|
||||
<!-- Host Selector -->
|
||||
<div class="flex items-center gap-4 mb-4">
|
||||
<label class="font-semibold">Host:</label>
|
||||
<select id="console-host-select" class="select select-sm select-bordered">
|
||||
{% for name in hosts %}
|
||||
<option value="{{ name }}">{{ name }}{% if name == local_host %} (local){% endif %}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
<button id="console-connect-btn" class="btn btn-sm btn-primary" onclick="connectConsole()">Connect</button>
|
||||
<span id="console-status" class="text-sm opacity-60"></span>
|
||||
</div>
|
||||
|
||||
<!-- Terminal -->
|
||||
<div class="mb-6">
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<h3 class="font-semibold flex items-center gap-2">{{ terminal() }} Terminal</h3>
|
||||
<span class="text-xs opacity-50">Full shell access to selected host</span>
|
||||
</div>
|
||||
<div id="console-terminal" class="w-full bg-base-300 rounded-lg overflow-hidden resize-y" style="height: 384px; min-height: 200px;"></div>
|
||||
</div>
|
||||
|
||||
<!-- Editor -->
|
||||
<div class="mb-6">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<div class="flex items-center gap-4">
|
||||
<h3 class="font-semibold">Editor</h3>
|
||||
<input type="text" id="console-file-path" class="input input-sm input-bordered w-96" placeholder="Enter file path (e.g., ~/docker-compose.yaml)" value="{{ config_path }}">
|
||||
<button class="btn btn-sm btn-outline" onclick="loadFile()">Open</button>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span id="editor-status" class="text-sm opacity-60"></span>
|
||||
<button id="console-save-btn" class="btn btn-sm btn-primary" onclick="saveFile()">{{ save() }} Save</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="console-editor" class="resize-y overflow-hidden rounded-lg" style="height: 512px; min-height: 200px;"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Use var to allow re-declaration on HTMX navigation
|
||||
var consoleTerminal = null;
|
||||
var consoleWs = null;
|
||||
var consoleEditor = null;
|
||||
var currentFilePath = null;
|
||||
var currentHost = null;
|
||||
|
||||
function connectConsole() {
|
||||
const hostSelect = document.getElementById('console-host-select');
|
||||
const host = hostSelect.value;
|
||||
const statusEl = document.getElementById('console-status');
|
||||
const terminalEl = document.getElementById('console-terminal');
|
||||
|
||||
if (!host) {
|
||||
statusEl.textContent = 'Please select a host';
|
||||
return;
|
||||
}
|
||||
|
||||
currentHost = host;
|
||||
|
||||
// Clean up existing connection
|
||||
if (consoleWs) {
|
||||
consoleWs.close();
|
||||
consoleWs = null;
|
||||
}
|
||||
if (consoleTerminal) {
|
||||
consoleTerminal.dispose();
|
||||
consoleTerminal = null;
|
||||
}
|
||||
|
||||
statusEl.textContent = 'Connecting...';
|
||||
|
||||
// Create WebSocket
|
||||
consoleWs = createWebSocket(`/ws/shell/${host}`);
|
||||
|
||||
// Resize callback - createTerminal's ResizeObserver calls this on container resize
|
||||
const sendSize = (cols, rows) => {
|
||||
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
|
||||
consoleWs.send(JSON.stringify({ type: 'resize', cols, rows }));
|
||||
}
|
||||
};
|
||||
|
||||
// Create terminal with resize callback
|
||||
const { term } = createTerminal(terminalEl, { cursorBlink: true }, sendSize);
|
||||
consoleTerminal = term;
|
||||
|
||||
consoleWs.onopen = () => {
|
||||
statusEl.textContent = `Connected to ${host}`;
|
||||
sendSize(term.cols, term.rows);
|
||||
term.focus();
|
||||
// Auto-load the default file once editor is ready
|
||||
const pathInput = document.getElementById('console-file-path');
|
||||
if (pathInput && pathInput.value) {
|
||||
const tryLoad = () => {
|
||||
if (consoleEditor) {
|
||||
loadFile();
|
||||
} else {
|
||||
setTimeout(tryLoad, 100);
|
||||
}
|
||||
};
|
||||
tryLoad();
|
||||
}
|
||||
};
|
||||
|
||||
consoleWs.onmessage = (event) => term.write(event.data);
|
||||
|
||||
consoleWs.onclose = () => {
|
||||
statusEl.textContent = 'Disconnected';
|
||||
term.write(`${ANSI.CRLF}${ANSI.DIM}[Connection closed]${ANSI.RESET}${ANSI.CRLF}`);
|
||||
};
|
||||
|
||||
consoleWs.onerror = (error) => {
|
||||
statusEl.textContent = 'Connection error';
|
||||
term.write(`${ANSI.RED}[WebSocket Error]${ANSI.RESET}${ANSI.CRLF}`);
|
||||
console.error('Console WebSocket error:', error);
|
||||
};
|
||||
|
||||
// Send input to WebSocket
|
||||
term.onData((data) => {
|
||||
if (consoleWs && consoleWs.readyState === WebSocket.OPEN) {
|
||||
consoleWs.send(data);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function initConsoleEditor() {
|
||||
const editorEl = document.getElementById('console-editor');
|
||||
if (!editorEl || consoleEditor) return;
|
||||
|
||||
loadMonaco(() => {
|
||||
consoleEditor = createEditor(editorEl, '', 'plaintext', { onSave: saveFile });
|
||||
});
|
||||
}
|
||||
|
||||
async function loadFile() {
|
||||
const pathInput = document.getElementById('console-file-path');
|
||||
const path = pathInput.value.trim();
|
||||
const statusEl = document.getElementById('editor-status');
|
||||
|
||||
if (!path) {
|
||||
statusEl.textContent = 'Enter a file path';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!currentHost) {
|
||||
statusEl.textContent = 'Connect to a host first';
|
||||
return;
|
||||
}
|
||||
|
||||
statusEl.textContent = `Loading ${path}...`;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(path)}`);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok || !data.success) {
|
||||
statusEl.textContent = data.detail || 'Failed to load file';
|
||||
return;
|
||||
}
|
||||
|
||||
const language = getLanguageFromPath(path);
|
||||
|
||||
if (consoleEditor) {
|
||||
consoleEditor.setValue(data.content);
|
||||
monaco.editor.setModelLanguage(consoleEditor.getModel(), language);
|
||||
currentFilePath = path; // Only set after content is loaded
|
||||
statusEl.textContent = `Loaded: ${path}`;
|
||||
} else {
|
||||
statusEl.textContent = 'Editor not ready';
|
||||
}
|
||||
} catch (e) {
|
||||
statusEl.textContent = `Error: ${e.message}`;
|
||||
}
|
||||
}
|
||||
|
||||
async function saveFile() {
|
||||
const statusEl = document.getElementById('editor-status');
|
||||
|
||||
if (!currentFilePath) {
|
||||
statusEl.textContent = 'No file loaded';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!currentHost) {
|
||||
statusEl.textContent = 'Not connected to a host';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!consoleEditor) {
|
||||
statusEl.textContent = 'Editor not ready';
|
||||
return;
|
||||
}
|
||||
|
||||
statusEl.textContent = `Saving ${currentFilePath}...`;
|
||||
|
||||
try {
|
||||
const content = consoleEditor.getValue();
|
||||
const response = await fetch(`/api/console/file?host=${encodeURIComponent(currentHost)}&path=${encodeURIComponent(currentFilePath)}`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'text/plain' },
|
||||
body: content
|
||||
});
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok || !data.success) {
|
||||
statusEl.textContent = data.detail || 'Failed to save file';
|
||||
return;
|
||||
}
|
||||
|
||||
statusEl.textContent = `Saved: ${currentFilePath}`;
|
||||
} catch (e) {
|
||||
statusEl.textContent = `Error: ${e.message}`;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize editor and auto-connect to first host
|
||||
function init() {
|
||||
initConsoleEditor();
|
||||
const hostSelect = document.getElementById('console-host-select');
|
||||
if (hostSelect && hostSelect.options.length > 0) {
|
||||
connectConsole();
|
||||
}
|
||||
}
|
||||
|
||||
// On HTMX navigation, dependencies (app.js) are already loaded.
|
||||
// On hard refresh, this script runs before app.js, so wait for DOMContentLoaded.
|
||||
if (typeof createTerminal === 'function') {
|
||||
init();
|
||||
} else {
|
||||
document.addEventListener('DOMContentLoaded', init);
|
||||
}
|
||||
</script>
|
||||
{% endblock content %}
|
||||
@@ -1,4 +1,18 @@
|
||||
{% from "partials/icons.html" import search, command %}
|
||||
{% from "partials/icons.html" import search, play, square, rotate_cw, cloud_download, refresh_cw, file_text, check, home, terminal, box %}
|
||||
|
||||
<!-- Icons for command palette (referenced by JS) -->
|
||||
<template id="cmd-icons">
|
||||
<span data-icon="play">{{ play() }}</span>
|
||||
<span data-icon="square">{{ square() }}</span>
|
||||
<span data-icon="rotate_cw">{{ rotate_cw() }}</span>
|
||||
<span data-icon="cloud_download">{{ cloud_download() }}</span>
|
||||
<span data-icon="refresh_cw">{{ refresh_cw() }}</span>
|
||||
<span data-icon="file_text">{{ file_text() }}</span>
|
||||
<span data-icon="check">{{ check() }}</span>
|
||||
<span data-icon="home">{{ home() }}</span>
|
||||
<span data-icon="terminal">{{ terminal() }}</span>
|
||||
<span data-icon="box">{{ box() }}</span>
|
||||
</template>
|
||||
<dialog id="cmd-palette" class="modal">
|
||||
<div class="modal-box max-w-lg p-0">
|
||||
<label class="input input-lg bg-base-100 border-0 border-b border-base-300 w-full rounded-none rounded-t-box sticky top-0 z-10 focus-within:outline-none">
|
||||
@@ -15,5 +29,7 @@
|
||||
|
||||
<!-- Floating button to open command palette -->
|
||||
<button id="cmd-fab" class="btn btn-circle glass shadow-lg fixed bottom-6 right-6 z-50 hover:ring hover:ring-base-content/50" title="Command Palette (⌘K)">
|
||||
{{ command(24) }}
|
||||
<span class="flex items-center gap-0.5 text-sm font-semibold">
|
||||
<span class="opacity-70">⌘</span><span>K</span>
|
||||
</span>
|
||||
</button>
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
{% from "partials/icons.html" import home, search %}
|
||||
<!-- Dashboard Link -->
|
||||
{% from "partials/icons.html" import home, search, terminal %}
|
||||
<!-- Navigation Links -->
|
||||
<div class="mb-4">
|
||||
<ul class="menu" hx-boost="true" hx-target="#main-content" hx-select="#main-content" hx-swap="outerHTML">
|
||||
<li><a href="/" class="font-semibold">{{ home() }} Dashboard</a></li>
|
||||
<li><a href="/console" class="font-semibold">{{ terminal() }} Console</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import asyncssh
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from compose_farm.executor import is_local
|
||||
from compose_farm.executor import is_local, ssh_connect_kwargs
|
||||
from compose_farm.web.deps import get_config
|
||||
from compose_farm.web.streaming import CRLF, DIM, GREEN, RED, RESET, tasks
|
||||
|
||||
@@ -121,6 +121,14 @@ async def _bridge_websocket_to_ssh(
|
||||
proc.terminate()
|
||||
|
||||
|
||||
def _make_controlling_tty(slave_fd: int) -> None:
|
||||
"""Set up the slave PTY as the controlling terminal for the child process."""
|
||||
# Create a new session
|
||||
os.setsid()
|
||||
# Make the slave fd the controlling terminal
|
||||
fcntl.ioctl(slave_fd, termios.TIOCSCTTY, 0)
|
||||
|
||||
|
||||
async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
|
||||
"""Run docker exec locally with PTY."""
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
@@ -131,6 +139,8 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
close_fds=True,
|
||||
preexec_fn=lambda: _make_controlling_tty(slave_fd),
|
||||
start_new_session=False, # We handle setsid in preexec_fn
|
||||
)
|
||||
os.close(slave_fd)
|
||||
|
||||
@@ -141,13 +151,14 @@ async def _run_local_exec(websocket: WebSocket, exec_cmd: str) -> None:
|
||||
await _bridge_websocket_to_fd(websocket, master_fd, proc)
|
||||
|
||||
|
||||
async def _run_remote_exec(websocket: WebSocket, host: Host, exec_cmd: str) -> None:
|
||||
async def _run_remote_exec(
|
||||
websocket: WebSocket, host: Host, exec_cmd: str, *, agent_forwarding: bool = False
|
||||
) -> None:
|
||||
"""Run docker exec on remote host via SSH with PTY."""
|
||||
# ssh_connect_kwargs includes agent_path and client_keys fallback
|
||||
async with asyncssh.connect(
|
||||
host.address,
|
||||
port=host.port,
|
||||
username=host.user,
|
||||
known_hosts=None,
|
||||
**ssh_connect_kwargs(host),
|
||||
agent_forwarding=agent_forwarding,
|
||||
) as conn:
|
||||
proc: asyncssh.SSHClientProcess[Any] = await conn.create_process(
|
||||
exec_cmd,
|
||||
@@ -202,6 +213,48 @@ async def exec_websocket(
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def _run_shell_session(
|
||||
websocket: WebSocket,
|
||||
host_name: str,
|
||||
) -> None:
|
||||
"""Run an interactive shell session on a host over WebSocket."""
|
||||
config = get_config()
|
||||
host = config.hosts.get(host_name)
|
||||
if not host:
|
||||
await websocket.send_text(f"{RED}Host '{host_name}' not found{RESET}{CRLF}")
|
||||
return
|
||||
|
||||
# Start interactive shell in home directory (avoid login shell to prevent job control warnings)
|
||||
shell_cmd = "cd ~ && exec bash -i 2>/dev/null || exec sh -i"
|
||||
|
||||
if is_local(host):
|
||||
await _run_local_exec(websocket, shell_cmd)
|
||||
else:
|
||||
await _run_remote_exec(websocket, host, shell_cmd, agent_forwarding=True)
|
||||
|
||||
|
||||
@router.websocket("/ws/shell/{host}")
|
||||
async def shell_websocket(
|
||||
websocket: WebSocket,
|
||||
host: str,
|
||||
) -> None:
|
||||
"""WebSocket endpoint for interactive host shell access."""
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
await websocket.send_text(f"{DIM}[Connecting to {host}...]{RESET}{CRLF}")
|
||||
await _run_shell_session(websocket, host)
|
||||
await websocket.send_text(f"{CRLF}{DIM}[Disconnected]{RESET}{CRLF}")
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
await websocket.send_text(f"{RED}Error: {e}{RESET}{CRLF}")
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@router.websocket("/ws/terminal/{task_id}")
|
||||
async def terminal_websocket(websocket: WebSocket, task_id: str) -> None:
|
||||
"""WebSocket endpoint for terminal streaming."""
|
||||
|
||||
@@ -150,7 +150,7 @@ class TestLogsHostFilter:
|
||||
mock_run_async, _ = _mock_run_async_factory(["svc1", "svc2"])
|
||||
|
||||
with (
|
||||
patch("compose_farm.cli.monitoring.load_config_or_exit", return_value=cfg),
|
||||
patch("compose_farm.cli.common.load_config_or_exit", return_value=cfg),
|
||||
patch("compose_farm.cli.monitoring.run_async", side_effect=mock_run_async),
|
||||
patch("compose_farm.cli.monitoring.run_on_services") as mock_run,
|
||||
):
|
||||
@@ -174,7 +174,7 @@ class TestLogsHostFilter:
|
||||
mock_run_async, _ = _mock_run_async_factory(["svc1", "svc2"])
|
||||
|
||||
with (
|
||||
patch("compose_farm.cli.monitoring.load_config_or_exit", return_value=cfg),
|
||||
patch("compose_farm.cli.common.load_config_or_exit", return_value=cfg),
|
||||
patch("compose_farm.cli.monitoring.run_async", side_effect=mock_run_async),
|
||||
patch("compose_farm.cli.monitoring.run_on_services") as mock_run,
|
||||
):
|
||||
|
||||
114
tests/test_cli_ssh.py
Normal file
114
tests/test_cli_ssh.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Tests for CLI ssh commands."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from compose_farm.cli.app import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class TestSshKeygen:
|
||||
"""Tests for cf ssh keygen command."""
|
||||
|
||||
def test_keygen_generates_key(self, tmp_path: Path) -> None:
|
||||
"""Generate SSH key when none exists."""
|
||||
key_path = tmp_path / "compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
|
||||
with (
|
||||
patch("compose_farm.cli.ssh.SSH_KEY_PATH", key_path),
|
||||
patch("compose_farm.cli.ssh.SSH_PUBKEY_PATH", pubkey_path),
|
||||
patch("compose_farm.cli.ssh.key_exists", return_value=False),
|
||||
):
|
||||
result = runner.invoke(app, ["ssh", "keygen"])
|
||||
|
||||
# Command runs (may fail if ssh-keygen not available in test env)
|
||||
assert result.exit_code in (0, 1)
|
||||
|
||||
def test_keygen_skips_if_exists(self, tmp_path: Path) -> None:
|
||||
"""Skip key generation if key already exists."""
|
||||
key_path = tmp_path / "compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
|
||||
with (
|
||||
patch("compose_farm.cli.ssh.SSH_KEY_PATH", key_path),
|
||||
patch("compose_farm.cli.ssh.SSH_PUBKEY_PATH", pubkey_path),
|
||||
patch("compose_farm.cli.ssh.key_exists", return_value=True),
|
||||
):
|
||||
result = runner.invoke(app, ["ssh", "keygen"])
|
||||
|
||||
assert "already exists" in result.output
|
||||
|
||||
|
||||
class TestSshStatus:
|
||||
"""Tests for cf ssh status command."""
|
||||
|
||||
def test_status_shows_no_key(self, tmp_path: Path) -> None:
|
||||
"""Show message when no key exists."""
|
||||
config_file = tmp_path / "compose-farm.yaml"
|
||||
config_file.write_text("""
|
||||
hosts:
|
||||
local:
|
||||
address: localhost
|
||||
services:
|
||||
test: local
|
||||
""")
|
||||
|
||||
with patch("compose_farm.cli.ssh.key_exists", return_value=False):
|
||||
result = runner.invoke(app, ["ssh", "status", f"--config={config_file}"])
|
||||
|
||||
assert "No key found" in result.output
|
||||
|
||||
def test_status_shows_key_exists(self, tmp_path: Path) -> None:
|
||||
"""Show key info when key exists."""
|
||||
config_file = tmp_path / "compose-farm.yaml"
|
||||
config_file.write_text("""
|
||||
hosts:
|
||||
local:
|
||||
address: localhost
|
||||
services:
|
||||
test: local
|
||||
""")
|
||||
|
||||
with (
|
||||
patch("compose_farm.cli.ssh.key_exists", return_value=True),
|
||||
patch("compose_farm.cli.ssh.get_pubkey_content", return_value="ssh-ed25519 AAAA..."),
|
||||
):
|
||||
result = runner.invoke(app, ["ssh", "status", f"--config={config_file}"])
|
||||
|
||||
assert "Key exists" in result.output
|
||||
|
||||
|
||||
class TestSshSetup:
|
||||
"""Tests for cf ssh setup command."""
|
||||
|
||||
def test_setup_no_remote_hosts(self, tmp_path: Path) -> None:
|
||||
"""Show message when no remote hosts configured."""
|
||||
config_file = tmp_path / "compose-farm.yaml"
|
||||
config_file.write_text("""
|
||||
hosts:
|
||||
local:
|
||||
address: localhost
|
||||
services:
|
||||
test: local
|
||||
""")
|
||||
|
||||
result = runner.invoke(app, ["ssh", "setup", f"--config={config_file}"])
|
||||
|
||||
assert "No remote hosts" in result.output
|
||||
|
||||
|
||||
class TestSshHelp:
|
||||
"""Tests for cf ssh help."""
|
||||
|
||||
def test_ssh_help(self) -> None:
|
||||
"""Show help for ssh command."""
|
||||
result = runner.invoke(app, ["ssh", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "setup" in result.output
|
||||
assert "status" in result.output
|
||||
assert "keygen" in result.output
|
||||
245
tests/test_ssh_keys.py
Normal file
245
tests/test_ssh_keys.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for ssh_keys module."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from compose_farm.config import Host
|
||||
from compose_farm.executor import ssh_connect_kwargs
|
||||
from compose_farm.ssh_keys import (
|
||||
SSH_KEY_PATH,
|
||||
get_key_path,
|
||||
get_pubkey_content,
|
||||
get_ssh_auth_sock,
|
||||
get_ssh_env,
|
||||
key_exists,
|
||||
)
|
||||
|
||||
|
||||
class TestGetSshAuthSock:
|
||||
"""Tests for get_ssh_auth_sock function."""
|
||||
|
||||
def test_returns_env_var_when_socket_exists(self) -> None:
|
||||
"""Return SSH_AUTH_SOCK env var if the socket exists."""
|
||||
mock_path = MagicMock()
|
||||
mock_path.is_socket.return_value = True
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"SSH_AUTH_SOCK": "/tmp/agent.sock"}),
|
||||
patch("compose_farm.ssh_keys.Path", return_value=mock_path),
|
||||
):
|
||||
result = get_ssh_auth_sock()
|
||||
assert result == "/tmp/agent.sock"
|
||||
|
||||
def test_returns_none_when_env_var_not_socket(self, tmp_path: Path) -> None:
|
||||
"""Return None if SSH_AUTH_SOCK points to non-socket."""
|
||||
regular_file = tmp_path / "not_a_socket"
|
||||
regular_file.touch()
|
||||
with (
|
||||
patch.dict(os.environ, {"SSH_AUTH_SOCK": str(regular_file)}),
|
||||
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
|
||||
):
|
||||
# Should fall through to agent dir check, which won't exist
|
||||
result = get_ssh_auth_sock()
|
||||
assert result is None
|
||||
|
||||
def test_finds_agent_in_ssh_agent_dir(self, tmp_path: Path) -> None:
|
||||
"""Find agent socket in ~/.ssh/agent/ directory."""
|
||||
# Create agent directory structure with a regular file
|
||||
agent_dir = tmp_path / ".ssh" / "agent"
|
||||
agent_dir.mkdir(parents=True)
|
||||
sock_path = agent_dir / "s.12345.sshd.67890"
|
||||
sock_path.touch() # Create as regular file
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
|
||||
patch.object(Path, "is_socket", return_value=True),
|
||||
):
|
||||
os.environ.pop("SSH_AUTH_SOCK", None)
|
||||
result = get_ssh_auth_sock()
|
||||
assert result == str(sock_path)
|
||||
|
||||
def test_returns_none_when_no_agent_found(self, tmp_path: Path) -> None:
|
||||
"""Return None when no SSH agent socket is found."""
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
|
||||
):
|
||||
os.environ.pop("SSH_AUTH_SOCK", None)
|
||||
result = get_ssh_auth_sock()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetSshEnv:
|
||||
"""Tests for get_ssh_env function."""
|
||||
|
||||
def test_returns_env_with_ssh_auth_sock(self) -> None:
|
||||
"""Return env dict with SSH_AUTH_SOCK set."""
|
||||
with patch("compose_farm.ssh_keys.get_ssh_auth_sock", return_value="/tmp/agent.sock"):
|
||||
result = get_ssh_env()
|
||||
assert result["SSH_AUTH_SOCK"] == "/tmp/agent.sock"
|
||||
# Should include other env vars too
|
||||
assert "PATH" in result or len(result) > 1
|
||||
|
||||
def test_returns_env_without_ssh_auth_sock_when_none(self, tmp_path: Path) -> None:
|
||||
"""Return env without SSH_AUTH_SOCK when no agent found."""
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
patch("compose_farm.ssh_keys.Path.home", return_value=tmp_path),
|
||||
):
|
||||
os.environ.pop("SSH_AUTH_SOCK", None)
|
||||
result = get_ssh_env()
|
||||
# SSH_AUTH_SOCK should not be set if no agent found
|
||||
assert result.get("SSH_AUTH_SOCK") is None
|
||||
|
||||
|
||||
class TestKeyExists:
|
||||
"""Tests for key_exists function."""
|
||||
|
||||
def test_returns_true_when_both_keys_exist(self, tmp_path: Path) -> None:
|
||||
"""Return True when both private and public keys exist."""
|
||||
key_path = tmp_path / "compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
key_path.touch()
|
||||
pubkey_path.touch()
|
||||
|
||||
with (
|
||||
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
|
||||
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
|
||||
):
|
||||
assert key_exists() is True
|
||||
|
||||
def test_returns_false_when_private_key_missing(self, tmp_path: Path) -> None:
|
||||
"""Return False when private key doesn't exist."""
|
||||
key_path = tmp_path / "compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
pubkey_path.touch() # Only public key exists
|
||||
|
||||
with (
|
||||
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
|
||||
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
|
||||
):
|
||||
assert key_exists() is False
|
||||
|
||||
def test_returns_false_when_public_key_missing(self, tmp_path: Path) -> None:
|
||||
"""Return False when public key doesn't exist."""
|
||||
key_path = tmp_path / "compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
key_path.touch() # Only private key exists
|
||||
|
||||
with (
|
||||
patch("compose_farm.ssh_keys.SSH_KEY_PATH", key_path),
|
||||
patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path),
|
||||
):
|
||||
assert key_exists() is False
|
||||
|
||||
|
||||
class TestGetKeyPath:
|
||||
"""Tests for get_key_path function."""
|
||||
|
||||
def test_returns_path_when_key_exists(self) -> None:
|
||||
"""Return key path when key exists."""
|
||||
with patch("compose_farm.ssh_keys.key_exists", return_value=True):
|
||||
result = get_key_path()
|
||||
assert result == SSH_KEY_PATH
|
||||
|
||||
def test_returns_none_when_key_missing(self) -> None:
|
||||
"""Return None when key doesn't exist."""
|
||||
with patch("compose_farm.ssh_keys.key_exists", return_value=False):
|
||||
result = get_key_path()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetPubkeyContent:
|
||||
"""Tests for get_pubkey_content function."""
|
||||
|
||||
def test_returns_content_when_exists(self, tmp_path: Path) -> None:
|
||||
"""Return public key content when file exists."""
|
||||
pubkey_content = "ssh-ed25519 AAAA... compose-farm"
|
||||
pubkey_path = tmp_path / "compose-farm.pub"
|
||||
pubkey_path.write_text(pubkey_content + "\n")
|
||||
|
||||
with patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path):
|
||||
result = get_pubkey_content()
|
||||
assert result == pubkey_content
|
||||
|
||||
def test_returns_none_when_missing(self, tmp_path: Path) -> None:
|
||||
"""Return None when public key doesn't exist."""
|
||||
pubkey_path = tmp_path / "compose-farm.pub" # Doesn't exist
|
||||
|
||||
with patch("compose_farm.ssh_keys.SSH_PUBKEY_PATH", pubkey_path):
|
||||
result = get_pubkey_content()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSshConnectKwargs:
|
||||
"""Tests for ssh_connect_kwargs function."""
|
||||
|
||||
def test_basic_kwargs(self) -> None:
|
||||
"""Return basic connection kwargs."""
|
||||
host = Host(address="example.com", port=22, user="testuser")
|
||||
|
||||
with (
|
||||
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
|
||||
patch("compose_farm.executor.get_key_path", return_value=None),
|
||||
):
|
||||
result = ssh_connect_kwargs(host)
|
||||
|
||||
assert result["host"] == "example.com"
|
||||
assert result["port"] == 22
|
||||
assert result["username"] == "testuser"
|
||||
assert result["known_hosts"] is None
|
||||
assert "agent_path" not in result
|
||||
assert "client_keys" not in result
|
||||
|
||||
def test_includes_agent_path_when_available(self) -> None:
|
||||
"""Include agent_path when SSH agent is available."""
|
||||
host = Host(address="example.com")
|
||||
|
||||
with (
|
||||
patch("compose_farm.executor.get_ssh_auth_sock", return_value="/tmp/agent.sock"),
|
||||
patch("compose_farm.executor.get_key_path", return_value=None),
|
||||
):
|
||||
result = ssh_connect_kwargs(host)
|
||||
|
||||
assert result["agent_path"] == "/tmp/agent.sock"
|
||||
|
||||
def test_includes_client_keys_when_key_exists(self, tmp_path: Path) -> None:
|
||||
"""Include client_keys when compose-farm key exists."""
|
||||
host = Host(address="example.com")
|
||||
key_path = tmp_path / "compose-farm"
|
||||
|
||||
with (
|
||||
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
|
||||
patch("compose_farm.executor.get_key_path", return_value=key_path),
|
||||
):
|
||||
result = ssh_connect_kwargs(host)
|
||||
|
||||
assert result["client_keys"] == [str(key_path)]
|
||||
|
||||
def test_includes_both_agent_and_key(self, tmp_path: Path) -> None:
|
||||
"""Include both agent_path and client_keys when both available."""
|
||||
host = Host(address="example.com")
|
||||
key_path = tmp_path / "compose-farm"
|
||||
|
||||
with (
|
||||
patch("compose_farm.executor.get_ssh_auth_sock", return_value="/tmp/agent.sock"),
|
||||
patch("compose_farm.executor.get_key_path", return_value=key_path),
|
||||
):
|
||||
result = ssh_connect_kwargs(host)
|
||||
|
||||
assert result["agent_path"] == "/tmp/agent.sock"
|
||||
assert result["client_keys"] == [str(key_path)]
|
||||
|
||||
def test_custom_port(self) -> None:
|
||||
"""Handle custom SSH port."""
|
||||
host = Host(address="example.com", port=2222)
|
||||
|
||||
with (
|
||||
patch("compose_farm.executor.get_ssh_auth_sock", return_value=None),
|
||||
patch("compose_farm.executor.get_key_path", return_value=None),
|
||||
):
|
||||
result = ssh_connect_kwargs(host)
|
||||
|
||||
assert result["port"] == 2222
|
||||
54
tests/web/test_backup.py
Normal file
54
tests/web/test_backup.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Tests for file backup functionality."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from compose_farm.web.routes.api import _backup_file, _save_with_backup
|
||||
|
||||
|
||||
def test_backup_creates_timestamped_file(tmp_path: Path) -> None:
|
||||
"""Test that backup creates file in .backups with correct content."""
|
||||
test_file = tmp_path / "test.yaml"
|
||||
test_file.write_text("original content")
|
||||
|
||||
backup_path = _backup_file(test_file)
|
||||
|
||||
assert backup_path is not None
|
||||
assert backup_path.parent.name == ".backups"
|
||||
assert backup_path.name.startswith("test.yaml.")
|
||||
assert backup_path.read_text() == "original content"
|
||||
|
||||
|
||||
def test_backup_returns_none_for_nonexistent_file(tmp_path: Path) -> None:
|
||||
"""Test that backup returns None if file doesn't exist."""
|
||||
assert _backup_file(tmp_path / "nonexistent.yaml") is None
|
||||
|
||||
|
||||
def test_save_creates_new_file(tmp_path: Path) -> None:
|
||||
"""Test that save creates new file without backup."""
|
||||
test_file = tmp_path / "new.yaml"
|
||||
|
||||
assert _save_with_backup(test_file, "content") is True
|
||||
assert test_file.read_text() == "content"
|
||||
assert not (tmp_path / ".backups").exists()
|
||||
|
||||
|
||||
def test_save_skips_unchanged_content(tmp_path: Path) -> None:
|
||||
"""Test that save returns False and creates no backup if unchanged."""
|
||||
test_file = tmp_path / "test.yaml"
|
||||
test_file.write_text("same")
|
||||
|
||||
assert _save_with_backup(test_file, "same") is False
|
||||
assert not (tmp_path / ".backups").exists()
|
||||
|
||||
|
||||
def test_save_creates_backup_before_overwrite(tmp_path: Path) -> None:
|
||||
"""Test that save backs up original before overwriting."""
|
||||
test_file = tmp_path / "test.yaml"
|
||||
test_file.write_text("original")
|
||||
|
||||
assert _save_with_backup(test_file, "new") is True
|
||||
assert test_file.read_text() == "new"
|
||||
|
||||
backups = list((tmp_path / ".backups").glob("test.yaml.*"))
|
||||
assert len(backups) == 1
|
||||
assert backups[0].read_text() == "original"
|
||||
111
tests/web/test_template_context.py
Normal file
111
tests/web/test_template_context.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Tests to verify template context variables match what templates expect.
|
||||
|
||||
Uses runtime validation by actually rendering templates and catching errors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from compose_farm.config import Config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Config:
|
||||
"""Create a minimal mock config for template testing."""
|
||||
compose_dir = tmp_path / "compose"
|
||||
compose_dir.mkdir()
|
||||
|
||||
# Create minimal service directory
|
||||
svc_dir = compose_dir / "test-service"
|
||||
svc_dir.mkdir()
|
||||
(svc_dir / "compose.yaml").write_text("services:\n app:\n image: nginx\n")
|
||||
|
||||
config_path = tmp_path / "compose-farm.yaml"
|
||||
config_path.write_text(f"""
|
||||
compose_dir: {compose_dir}
|
||||
hosts:
|
||||
local-host:
|
||||
address: localhost
|
||||
services:
|
||||
test-service: local-host
|
||||
""")
|
||||
|
||||
state_path = tmp_path / "compose-farm-state.yaml"
|
||||
state_path.write_text("deployed:\n test-service: local-host\n")
|
||||
|
||||
from compose_farm.config import load_config
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
# Patch get_config in all relevant modules
|
||||
from compose_farm.web import deps
|
||||
from compose_farm.web.routes import api, pages
|
||||
|
||||
monkeypatch.setattr(deps, "get_config", lambda: config)
|
||||
monkeypatch.setattr(api, "get_config", lambda: config)
|
||||
monkeypatch.setattr(pages, "get_config", lambda: config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_config: Config) -> TestClient:
|
||||
"""Create a test client with mocked config."""
|
||||
from compose_farm.web.app import create_app
|
||||
|
||||
return TestClient(create_app())
|
||||
|
||||
|
||||
class TestPageTemplatesRender:
|
||||
"""Test that page templates render without missing variables."""
|
||||
|
||||
def test_index_renders(self, client: TestClient) -> None:
|
||||
"""Test index page renders without errors."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "Compose Farm" in response.text
|
||||
|
||||
def test_console_renders(self, client: TestClient) -> None:
|
||||
"""Test console page renders without errors."""
|
||||
response = client.get("/console")
|
||||
assert response.status_code == 200
|
||||
assert "Console" in response.text
|
||||
assert "Terminal" in response.text
|
||||
|
||||
def test_service_detail_renders(self, client: TestClient) -> None:
|
||||
"""Test service detail page renders without errors."""
|
||||
response = client.get("/service/test-service")
|
||||
assert response.status_code == 200
|
||||
assert "test-service" in response.text
|
||||
|
||||
|
||||
class TestPartialTemplatesRender:
|
||||
"""Test that partial templates render without missing variables."""
|
||||
|
||||
def test_sidebar_renders(self, client: TestClient) -> None:
|
||||
"""Test sidebar partial renders without errors."""
|
||||
response = client.get("/partials/sidebar")
|
||||
assert response.status_code == 200
|
||||
assert "Dashboard" in response.text
|
||||
assert "Console" in response.text
|
||||
|
||||
def test_stats_renders(self, client: TestClient) -> None:
|
||||
"""Test stats partial renders without errors."""
|
||||
response = client.get("/partials/stats")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_pending_renders(self, client: TestClient) -> None:
|
||||
"""Test pending partial renders without errors."""
|
||||
response = client.get("/partials/pending")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_services_by_host_renders(self, client: TestClient) -> None:
|
||||
"""Test services_by_host partial renders without errors."""
|
||||
response = client.get("/partials/services-by-host")
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user