diff --git a/README.md b/README.md index 6ed5754..cf9f68f 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ after the hook scope and behavior are settled. ## Docs - [Codex hook design](docs/codex-hook-design.md) +- [Escape cancel playback design](docs/escape-cancel-playback-design.md) - [Codex plugin scaffold](codex/README.md) ## Proposed config diff --git a/codex/README.md b/codex/README.md index 38670e8..4d7d3a4 100644 --- a/codex/README.md +++ b/codex/README.md @@ -20,6 +20,7 @@ src/ playback.py session_start.py stop.py + tts_playback_supervisor.py pyproject.toml ``` @@ -121,3 +122,20 @@ The current packaged runtime uses code defaults when launched through does not execute from the Codex plugin cache. A future iteration should add a stable user config path or environment override if per-user settings are needed with the Git-backed launcher. + +## Playback Cancellation + +The Stop hook writes the generated WAV and starts the packaged playback +supervisor with `python -m tts_hook.tts_playback_supervisor`. The supervisor +owns playback until the audio player exits, then removes the temporary WAV file. + +Press Escape in the focused Codex terminal or tmux pane to cancel the current +playback. Cancellation is best-effort and terminal-scoped: it only affects the +audio player launched for the current hook result, and it depends on the +supervisor being able to read `/dev/tty`. + +If `/dev/tty` is unavailable or cannot be configured for cbreak input, playback +continues normally without Escape cancellation. There is no fallback cancel +command or global hotkey. + +The cancel key is fixed to Escape. It is not configurable. diff --git a/codex/tests/test_playback_supervisor_behavior.py b/codex/tests/test_playback_supervisor_behavior.py new file mode 100644 index 0000000..62dc6d9 --- /dev/null +++ b/codex/tests/test_playback_supervisor_behavior.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from io import StringIO +from pathlib import Path +from typing import Any +import subprocess +import sys + +import pytest + +ROOT = Path(__file__).resolve().parents[1] +REPO_ROOT = ROOT.parent +sys.path.insert(0, str(REPO_ROOT / "src")) + +from tts_hook import tts_playback_supervisor as supervisor_module # noqa: E402 +from tts_hook.playback import PlaybackProcessResult # noqa: E402 + + +class FakeStream: + def __init__(self, data: bytes = b"") -> None: + self.data = data + self.closed = False + + def fileno(self) -> int: + return 42 + + def read(self, size: int) -> bytes: + return self.data[:size] + + def close(self) -> None: + self.closed = True + + +class FakeTermios: + TCSADRAIN = 1 + + def __init__(self) -> None: + self.restored: list[tuple[int, int, object]] = [] + + def tcgetattr(self, fd: int) -> list[int]: + assert fd == 42 + return [1, 2, 3] + + def tcsetattr(self, fd: int, when: int, attrs: object) -> None: + self.restored.append((fd, when, attrs)) + + +class FakeTty: + def __init__(self) -> None: + self.cbreak_fds: list[int] = [] + + def setcbreak(self, fd: int) -> None: + self.cbreak_fds.append(fd) + + +class FakeProcess: + def __init__(self, poll_results: list[int | None] | None = None) -> None: + self.poll_results = poll_results or [0] + self.wait_calls = 0 + self.pid = 321 + + def poll(self) -> int | None: + if len(self.poll_results) > 1: + return self.poll_results.pop(0) + return self.poll_results[0] + + def wait(self, timeout: float | None = None) -> int: + self.wait_calls += 1 + self.poll_results = [0] + return 0 + + +class FakeReader: + def __init__(self, escape_results: list[bool]) -> None: + self.escape_results = escape_results + self.closed = False + + def escape_pressed(self, timeout_seconds: float) -> bool: + if len(self.escape_results) > 1: + return self.escape_results.pop(0) + return self.escape_results[0] + + def close(self) -> None: + self.closed = True + + +def test_tty_escape_reader_detects_escape_and_restores_terminal() -> None: + stream = FakeStream(supervisor_module.ESCAPE) + termios = FakeTermios() + tty = FakeTty() + reader = supervisor_module.TtyEscapeReader( + opener=lambda *args, **kwargs: stream, + select_fn=lambda readable, _writable, _errors, _timeout: (readable, [], []), + termios_module=termios, + tty_module=tty, + ) + + with reader as active_reader: + assert active_reader.escape_pressed(0.01) is True + + assert tty.cbreak_fds == [42] + assert termios.restored == [(42, termios.TCSADRAIN, [1, 2, 3])] + assert stream.closed is True + + +def test_tty_escape_reader_ignores_non_escape_input_and_timeout() -> None: + stream = FakeStream(b"x") + reader = supervisor_module.TtyEscapeReader( + opener=lambda *args, **kwargs: stream, + select_fn=lambda readable, _writable, _errors, _timeout: (readable, [], []), + termios_module=FakeTermios(), + tty_module=FakeTty(), + ) + + with reader: + assert reader.escape_pressed(0.01) is False + + timeout_reader = supervisor_module.TtyEscapeReader( + opener=lambda *args, **kwargs: FakeStream(supervisor_module.ESCAPE), + select_fn=lambda _readable, _writable, _errors, _timeout: ([], [], []), + termios_module=FakeTermios(), + tty_module=FakeTty(), + ) + with timeout_reader: + assert timeout_reader.escape_pressed(0.01) is False + + +def test_open_tty_failure_disables_cancel_support() -> None: + stderr = StringIO() + + class FailingReader: + def __enter__(self) -> object: + raise OSError("no tty") + + def close(self) -> None: + return + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr(supervisor_module, "TtyEscapeReader", FailingReader) + + reader = supervisor_module.open_tty_escape_reader(stderr=stderr) + + assert reader is None + assert "Escape cancel unavailable" in stderr.getvalue() + + +def test_tty_configuration_failure_restores_terminal_and_disables_cancel() -> None: + stream = FakeStream() + termios = FakeTermios() + stderr = StringIO() + + class FailingTty: + def setcbreak(self, fd: int) -> None: + raise supervisor_module.termios.error("cannot configure tty") + + original_reader = supervisor_module.TtyEscapeReader + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + supervisor_module, + "TtyEscapeReader", + lambda: original_reader( + opener=lambda *args, **kwargs: stream, + termios_module=termios, + tty_module=FailingTty(), + ), + ) + + reader = supervisor_module.open_tty_escape_reader(stderr=stderr) + + assert reader is None + assert termios.restored == [(42, termios.TCSADRAIN, [1, 2, 3])] + assert stream.closed is True + assert "Escape cancel unavailable" in stderr.getvalue() + + +def test_supervisor_no_tty_waits_for_playback_and_cleans_wav( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + process = FakeProcess([None, 0]) + monkeypatch.setattr( + supervisor_module, + "launch_audio_player_process", + lambda wav_path, *, player: PlaybackProcessResult(ok=True, command=("player", str(wav_path)), process=process), + ) + monkeypatch.setattr(supervisor_module, "open_tty_escape_reader", lambda *, stderr=None: None) + monkeypatch.setattr(supervisor_module, "install_signal_handlers", lambda: None) + + exit_code = supervisor_module.supervise_playback(wav, player="auto", stderr=StringIO()) + + assert exit_code == 0 + assert process.wait_calls == 1 + assert not wav.exists() + + +def test_supervisor_poll_failure_continues_playback_without_cancel( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + process = FakeProcess([None, None, 0]) + stderr = StringIO() + + class FailingReader: + closed = False + + def escape_pressed(self, timeout_seconds: float) -> bool: + raise OSError("poll failed") + + def close(self) -> None: + self.closed = True + + reader = FailingReader() + monkeypatch.setattr( + supervisor_module, + "launch_audio_player_process", + lambda wav_path, *, player: PlaybackProcessResult(ok=True, command=("player", str(wav_path)), process=process), + ) + monkeypatch.setattr(supervisor_module, "open_tty_escape_reader", lambda *, stderr=None: reader) + monkeypatch.setattr(supervisor_module, "install_signal_handlers", lambda: None) + + exit_code = supervisor_module.supervise_playback(wav, player="auto", stderr=stderr) + + assert exit_code == 0 + assert process.wait_calls == 1 + assert reader.closed is True + assert "Disabling Escape cancel support" in stderr.getvalue() + assert not wav.exists() + + +def test_supervisor_escape_cancels_current_playback_and_cleans_wav( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + process = FakeProcess([None, None]) + reader = FakeReader([True]) + terminated: list[FakeProcess] = [] + monkeypatch.setattr( + supervisor_module, + "launch_audio_player_process", + lambda wav_path, *, player: PlaybackProcessResult(ok=True, command=("player", str(wav_path)), process=process), + ) + monkeypatch.setattr(supervisor_module, "open_tty_escape_reader", lambda *, stderr=None: reader) + monkeypatch.setattr(supervisor_module, "terminate_process_group", lambda active_process: terminated.append(active_process)) + monkeypatch.setattr(supervisor_module, "install_signal_handlers", lambda: None) + + exit_code = supervisor_module.supervise_playback(wav, player="auto", stderr=StringIO()) + + assert exit_code == 0 + assert terminated == [process] + assert reader.closed is True + assert not wav.exists() + + +def test_supervisor_natural_completion_does_not_cancel( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + process = FakeProcess([None, 0]) + reader = FakeReader([False]) + terminated: list[Any] = [] + monkeypatch.setattr( + supervisor_module, + "launch_audio_player_process", + lambda wav_path, *, player: PlaybackProcessResult(ok=True, command=("player", str(wav_path)), process=process), + ) + monkeypatch.setattr(supervisor_module, "open_tty_escape_reader", lambda *, stderr=None: reader) + monkeypatch.setattr(supervisor_module, "terminate_process_group", lambda active_process: terminated.append(active_process)) + monkeypatch.setattr(supervisor_module, "install_signal_handlers", lambda: None) + + exit_code = supervisor_module.supervise_playback(wav, player="auto", stderr=StringIO()) + + assert exit_code == 0 + assert terminated == [] + assert reader.closed is True + assert not wav.exists() + + +def test_cleanup_failure_is_nonfatal(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + process = FakeProcess([0]) + stderr = StringIO() + monkeypatch.setattr( + supervisor_module, + "launch_audio_player_process", + lambda wav_path, *, player: PlaybackProcessResult(ok=True, command=("player", str(wav_path)), process=process), + ) + monkeypatch.setattr(supervisor_module, "open_tty_escape_reader", lambda *, stderr=None: None) + monkeypatch.setattr(supervisor_module, "install_signal_handlers", lambda: None) + + def fail_unlink(self: Path, *args: object, **kwargs: object) -> None: + raise OSError("permission denied") + + monkeypatch.setattr(Path, "unlink", fail_unlink) + + exit_code = supervisor_module.supervise_playback(wav, player="auto", stderr=stderr) + + assert exit_code == 0 + assert "Could not delete temporary WAV file" in stderr.getvalue() + + +def test_supervisor_diagnostics_do_not_write_hook_json_to_stdout(tmp_path: Path) -> None: + missing = tmp_path / "missing.wav" + + result = subprocess.run( + [sys.executable, "-m", "tts_hook.tts_playback_supervisor", str(missing)], + cwd=REPO_ROOT, + env={"PYTHONPATH": str(REPO_ROOT / "src")}, + text=True, + capture_output=True, + check=False, + ) + + assert result.returncode == 2 + assert result.stdout == "" + assert "WAV file does not exist" in result.stderr + assert '{"continue"' not in result.stdout diff --git a/codex/tests/test_playback_supervisor_foundation.py b/codex/tests/test_playback_supervisor_foundation.py new file mode 100644 index 0000000..19e112b --- /dev/null +++ b/codex/tests/test_playback_supervisor_foundation.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Sequence +import signal +import subprocess +import sys + +import pytest + +ROOT = Path(__file__).resolve().parents[1] +REPO_ROOT = ROOT.parent +sys.path.insert(0, str(REPO_ROOT / "src")) + +from tts_hook import playback as playback_module # noqa: E402 +from tts_hook.playback import ( # noqa: E402 + build_playback_command, + launch_audio_player_process, + launch_process_group, + play_audio_file, + terminate_process_group, +) + + +class CapturingPopen: + calls: list[dict[str, Any]] = [] + + def __init__(self, command: Sequence[str], **kwargs: Any) -> None: + self.command = tuple(command) + self.kwargs = kwargs + self.pid = 12345 + self.returncode = None + CapturingPopen.calls.append({"command": self.command, "kwargs": kwargs}) + + def wait(self, timeout: float | None = None) -> int: + self.returncode = 0 + return 0 + + +class FakeProcess: + def __init__(self, *, pid: int = 222, already_exited: bool = False, timeout: bool = False) -> None: + self.pid = pid + self._already_exited = already_exited + self._timeout = timeout + self.wait_calls = 0 + + def poll(self) -> int | None: + return 0 if self._already_exited else None + + def wait(self, timeout: float | None = None) -> int: + self.wait_calls += 1 + if self._timeout and timeout is not None: + raise subprocess.TimeoutExpired(["player"], timeout) + return 0 + + +def test_build_playback_command_reuses_auto_player_selection(tmp_path: Path) -> None: + bin_dir = tmp_path / "bin" + bin_dir.mkdir() + ffplay = bin_dir / "ffplay" + ffplay.write_text("#!/bin/sh\n", encoding="utf-8") + ffplay.chmod(0o755) + wav = tmp_path / "audio.wav" + + command = build_playback_command(wav, player="auto", path_env=str(bin_dir)) + + assert command == (str(ffplay), "-nodisp", "-autoexit", str(wav)) + + +def test_play_audio_file_and_supervisor_launch_use_same_command_path( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + wav = tmp_path / "audio.wav" + wav.write_bytes(b"RIFF") + seen: list[tuple[Path, str, str | None]] = [] + + def fake_build(wav_path: Path, *, player: str = "auto", path_env: str | None = None) -> tuple[str, ...]: + seen.append((wav_path, player, path_env)) + return ("/bin/echo", str(wav_path)) + + monkeypatch.setattr(playback_module, "build_playback_command", fake_build) + monkeypatch.setattr(playback_module, "launch_process_group", lambda command: CapturingPopen(command)) + CapturingPopen.calls.clear() + + direct = play_audio_file(wav, player="auto", blocking=False, path_env="/tmp/bin") + supervisor = launch_audio_player_process(wav, player="auto", path_env="/tmp/bin") + + assert direct.ok is True + assert supervisor.ok is True + assert seen == [(wav, "auto", "/tmp/bin"), (wav, "auto", "/tmp/bin")] + assert CapturingPopen.calls[0]["command"] == CapturingPopen.calls[1]["command"] + + +def test_launch_process_group_suppresses_output_and_starts_new_session(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(subprocess, "Popen", CapturingPopen) + CapturingPopen.calls.clear() + + process = launch_process_group(("/bin/echo", "audio.wav")) + + assert process.pid == 12345 + assert CapturingPopen.calls == [ + { + "command": ("/bin/echo", "audio.wav"), + "kwargs": { + "stdin": subprocess.DEVNULL, + "stdout": subprocess.DEVNULL, + "stderr": subprocess.DEVNULL, + "start_new_session": True, + }, + } + ] + + +def test_terminate_process_group_returns_when_process_already_exited(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[int, signal.Signals]] = [] + monkeypatch.setattr(playback_module.os, "killpg", lambda pgid, sig: calls.append((pgid, sig))) + + graceful = terminate_process_group(FakeProcess(already_exited=True)) + + assert graceful is True + assert calls == [] + + +def test_terminate_process_group_sends_sigterm_and_waits(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[int, signal.Signals]] = [] + monkeypatch.setattr(playback_module.os, "getpgid", lambda pid: 777) + monkeypatch.setattr(playback_module.os, "killpg", lambda pgid, sig: calls.append((pgid, sig))) + process = FakeProcess() + + graceful = terminate_process_group(process, timeout_seconds=0.1) + + assert graceful is True + assert process.wait_calls == 1 + assert calls == [(777, signal.SIGTERM)] + + +def test_terminate_process_group_escalates_to_sigkill_after_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[int, signal.Signals]] = [] + monkeypatch.setattr(playback_module.os, "getpgid", lambda pid: 888) + monkeypatch.setattr(playback_module.os, "killpg", lambda pgid, sig: calls.append((pgid, sig))) + process = FakeProcess(timeout=True) + + graceful = terminate_process_group(process, timeout_seconds=0.1) + + assert graceful is False + assert process.wait_calls == 2 + assert calls == [(888, signal.SIGTERM), (888, signal.SIGKILL)] + + +def test_supervisor_entrypoint_help_imports_tts_hook_modules() -> None: + result = subprocess.run( + [sys.executable, "-m", "tts_hook.tts_playback_supervisor", "--help"], + cwd=REPO_ROOT, + env={"PYTHONPATH": str(REPO_ROOT / "src")}, + text=True, + capture_output=True, + check=False, + ) + + assert result.returncode == 0 + assert "usage:" in result.stdout + assert "Codex" not in result.stdout diff --git a/codex/tests/test_stop_hook.py b/codex/tests/test_stop_hook.py index afa7623..a6011d2 100644 --- a/codex/tests/test_stop_hook.py +++ b/codex/tests/test_stop_hook.py @@ -6,7 +6,6 @@ from threading import Thread from typing import Any import json -import os import subprocess import sys import time @@ -17,10 +16,11 @@ REPO_ROOT = ROOT.parent sys.path.insert(0, str(REPO_ROOT / "src")) +from tts_hook import stop as stop_hook_module # noqa: E402 from tts_hook.config import load_config # noqa: E402 from tts_hook.logging import HookLogger # noqa: E402 from tts_hook.playback import choose_player_command, play_audio_file # noqa: E402 -from tts_hook.stop import extract_assistant_message, main, speak_last_assistant_message # noqa: E402 +from tts_hook.stop import extract_assistant_message, main, spawn_playback_supervisor, speak_last_assistant_message # noqa: E402 FIXTURES = ROOT / "tests" / "fixtures" / "stop" @@ -72,6 +72,26 @@ def make_fake_player(bin_dir: Path, name: str = "pw-play", *, sleep_seconds: flo return marker +def write_fake_supervisor(plugin_root: Path, *, sleep_seconds: float = 0.0) -> Path: + marker = plugin_root / "supervisor_args.json" + supervisor = plugin_root / "src" / "tts_hook" / "tts_playback_supervisor.py" + supervisor.parent.mkdir(parents=True, exist_ok=True) + supervisor.write_text( + "#!/usr/bin/env python3\n" + "import json\n" + "import pathlib\n" + "import sys\n" + "import time\n" + "print('unexpected supervisor stdout')\n" + "sys.stderr.write('unexpected supervisor stderr\\n')\n" + f"pathlib.Path({str(marker)!r}).write_text(json.dumps(sys.argv[1:]), encoding='utf-8')\n" + f"time.sleep({sleep_seconds})\n", + encoding="utf-8", + ) + supervisor.chmod(0o755) + return marker + + @pytest.fixture def kokoro_speech_server() -> tuple[ThreadingHTTPServer, dict[str, Any]]: state: dict[str, Any] = { @@ -149,21 +169,52 @@ def test_malformed_stop_input_returns_valid_json_and_stderr_only(tmp_path: Path) assert "not valid JSON" in stderr.getvalue() +def test_config_error_returns_valid_json_and_does_not_spawn_supervisor( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + (tmp_path / "tts-hook.toml").write_text("[kokoro]\nport = \"bad\"\n", encoding="utf-8") + stdout = StringIO() + stderr = StringIO() + spawned: list[Path] = [] + + def fake_spawn(wav_path: Path, config: Any) -> Any: + spawned.append(wav_path) + return stop_hook_module.PlaybackResult(ok=True, command=("supervisor", str(wav_path)), pid=123) + + monkeypatch.setattr(stop_hook_module, "spawn_playback_supervisor", fake_spawn) + + exit_code = main( + stdin=StringIO(read_fixture("normal.json")), + stdout=stdout, + stderr=stderr, + plugin_root=tmp_path, + ) + + assert exit_code == 0 + assert json.loads(stdout.getvalue()) == {"continue": True} + assert "Could not load plugin-local TTS config" in stderr.getvalue() + assert spawned == [] + + def test_successful_kokoro_response_writes_unique_wavs_and_preserves_full_payload( tmp_path: Path, kokoro_speech_server: tuple[ThreadingHTTPServer, dict[str, Any]], monkeypatch: pytest.MonkeyPatch, ) -> None: server, state = kokoro_speech_server - bin_dir = tmp_path / "bin" - bin_dir.mkdir() - marker = make_fake_player(bin_dir) - monkeypatch.setenv("PATH", str(bin_dir)) write_config(tmp_path, port=server.server_port, log_path=tmp_path / "hook.log", voice="af_sarah") config = load_config(tmp_path) logger = HookLogger.from_config(config, stderr=StringIO()) message = read_fixture("long_multiparagraph.json") payload = json.loads(message) + spawned: list[tuple[Path, str, bool]] = [] + + def fake_spawn(wav_path: Path, config: Any) -> Any: + spawned.append((wav_path, config.playback.player, config.playback.blocking)) + return stop_hook_module.PlaybackResult(ok=True, command=("supervisor", str(wav_path)), pid=123) + + monkeypatch.setattr(stop_hook_module, "spawn_playback_supervisor", fake_spawn) first = speak_last_assistant_message(payload, config, logger) second = speak_last_assistant_message(payload, config, logger) @@ -175,10 +226,7 @@ def test_successful_kokoro_response_writes_unique_wavs_and_preserves_full_payloa assert second.suffix == ".wav" assert first.read_bytes() == b"RIFFfake-wave" assert second.read_bytes() == b"RIFFfake-wave" - deadline = time.monotonic() + 2 - while time.monotonic() < deadline and not marker.exists(): - time.sleep(0.05) - assert marker.exists() + assert spawned == [(first, "auto", False), (second, "auto", False)] assert state["requests"][0]["path"] == "/v1/audio/speech" assert state["requests"][0]["content_type"] == "application/json" assert state["requests"][0]["payload"] == { @@ -208,14 +256,18 @@ def test_kokoro_request_failure_logs_without_breaking_hook(tmp_path: Path) -> No assert "Kokoro speech request failed" in stderr.getvalue() -def test_no_playback_command_logs_warning_and_returns_valid_hook_json( +def test_supervisor_spawn_failure_logs_warning_and_returns_valid_hook_json( tmp_path: Path, kokoro_speech_server: tuple[ThreadingHTTPServer, dict[str, Any]], monkeypatch: pytest.MonkeyPatch, ) -> None: server, _state = kokoro_speech_server - monkeypatch.setenv("PATH", "") write_config(tmp_path, port=server.server_port, log_path=tmp_path / "hook.log") + monkeypatch.setattr( + stop_hook_module.subprocess, + "Popen", + lambda *args, **kwargs: (_ for _ in ()).throw(OSError("spawn failed")), + ) stdout = StringIO() stderr = StringIO() @@ -228,7 +280,49 @@ def test_no_playback_command_logs_warning_and_returns_valid_hook_json( assert exit_code == 0 assert json.loads(stdout.getvalue()) == {"continue": True} - assert "Playback did not start" in stderr.getvalue() + assert "Playback supervisor did not start" in stderr.getvalue() + + +def test_spawn_playback_supervisor_passes_wav_player_and_blocking_option( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + write_config(tmp_path, port=9, log_path=tmp_path / "hook.log", player="ffplay -volume 20", blocking=True) + config = load_config(tmp_path) + wav = tmp_path / "audio.wav" + captured: dict[str, Any] = {} + + class FakeProcess: + pid = 456 + + def fake_popen(command: tuple[str, ...], **kwargs: Any) -> FakeProcess: + captured["command"] = command + captured["kwargs"] = kwargs + return FakeProcess() + + monkeypatch.setattr(stop_hook_module.subprocess, "Popen", fake_popen) + + result = spawn_playback_supervisor(wav, config) + + assert result.ok is True + assert result.pid == 456 + assert result.command == captured["command"] + assert captured["command"] == ( + sys.executable, + "-m", + "tts_hook.tts_playback_supervisor", + str(wav), + "--player", + "ffplay -volume 20", + "--blocking", + ) + assert captured["kwargs"] == { + "stdin": subprocess.DEVNULL, + "stdout": subprocess.DEVNULL, + "stderr": subprocess.DEVNULL, + "start_new_session": True, + "cwd": tmp_path, + } def test_auto_player_selection_order(tmp_path: Path) -> None: @@ -295,14 +389,16 @@ def test_stop_fixture_subprocess_posts_audio_and_spawns_playback( server, state = kokoro_speech_server bin_dir = tmp_path / "bin" bin_dir.mkdir() - marker = make_fake_player(bin_dir, sleep_seconds=0.5) - write_config(tmp_path, port=server.server_port, log_path=tmp_path / "hook.log") + player_marker = make_fake_player(bin_dir, sleep_seconds=0.5) + supervisor_marker = write_fake_supervisor(tmp_path, sleep_seconds=1.0) + write_config(tmp_path, port=server.server_port, log_path=tmp_path / "hook.log", player="pw-play", blocking=True) stdout = StringIO() stderr = StringIO() start = time.monotonic() with pytest.MonkeyPatch.context() as monkeypatch: - monkeypatch.setenv("PATH", f"{bin_dir}:{os.environ.get('PATH', '')}") + monkeypatch.setenv("PATH", str(bin_dir)) + monkeypatch.setenv("PYTHONPATH", str(tmp_path / "src")) exit_code = main( stdin=StringIO(read_fixture("normal.json")), stdout=stdout, @@ -318,9 +414,12 @@ def test_stop_fixture_subprocess_posts_audio_and_spawns_playback( assert elapsed < 0.5 assert state["requests"][0]["payload"]["input"] == "Codex finished the requested change." deadline = time.monotonic() + 2 - while time.monotonic() < deadline and not marker.exists(): + while time.monotonic() < deadline and not supervisor_marker.exists(): time.sleep(0.05) - assert marker.exists() + supervisor_args = json.loads(supervisor_marker.read_text(encoding="utf-8")) + assert supervisor_args[1:] == ["--player", "pw-play", "--blocking"] + assert Path(supervisor_args[0]).suffix == ".wav" + assert not player_marker.exists() def test_no_max_chars_policy_exists() -> None: diff --git a/src/tts_hook/playback.py b/src/tts_hook/playback.py index 368a746..7e3067e 100644 --- a/src/tts_hook/playback.py +++ b/src/tts_hook/playback.py @@ -8,6 +8,7 @@ import os import shutil +import signal import subprocess AUTO_PLAYERS: tuple[tuple[str, ...], ...] = ( @@ -28,6 +29,22 @@ class PlaybackResult: error: str | None = None +@dataclass(frozen=True) +class PlaybackProcessResult: + """Success or failure details for a process-group playback launch.""" + + ok: bool + command: tuple[str, ...] = () + process: subprocess.Popen[bytes] | None = None + error: str | None = None + + @property + def pid(self) -> int | None: + """Return the child PID when playback was launched.""" + + return None if self.process is None else self.process.pid + + def choose_player_command(player: str, *, path_env: str | None = None) -> tuple[str, ...] | None: """Return the configured playback command, or the first available auto player.""" @@ -50,6 +67,84 @@ def choose_player_command(player: str, *, path_env: str | None = None) -> tuple[ return None +def build_playback_command( + wav_path: Path, + *, + player: str = "auto", + path_env: str | None = None, +) -> tuple[str, ...] | None: + """Return the full audio playback command for ``wav_path``.""" + + command_prefix = choose_player_command(player, path_env=path_env) + if command_prefix is None: + return None + return (*command_prefix, str(wav_path)) + + +def launch_process_group(command: Sequence[str]) -> subprocess.Popen[bytes]: + """Launch ``command`` in a new process group with child output suppressed.""" + + return subprocess.Popen( # noqa: S603 + tuple(command), + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) + + +def launch_audio_player_process( + wav_path: Path, + *, + player: str = "auto", + path_env: str | None = None, +) -> PlaybackProcessResult: + """Start host playback for a WAV file in its own process group.""" + + command = build_playback_command(wav_path, player=player, path_env=path_env) + if command is None: + return PlaybackProcessResult(ok=False, error=f"No playback command found for player={player!r}") + + try: + process = launch_process_group(command) + except OSError as exc: + return PlaybackProcessResult(ok=False, command=command, error=str(exc)) + + return PlaybackProcessResult(ok=True, command=command, process=process) + + +def terminate_process_group(process: subprocess.Popen[bytes], *, timeout_seconds: float = 0.5) -> bool: + """Terminate a playback process group, escalating to SIGKILL after timeout. + + Returns ``True`` when the process had already exited or stopped after + SIGTERM, and ``False`` when SIGKILL escalation was needed. + """ + + if process.poll() is not None: + return True + + try: + process_group_id = os.getpgid(process.pid) + except ProcessLookupError: + return True + + try: + os.killpg(process_group_id, signal.SIGTERM) + except ProcessLookupError: + return True + + try: + process.wait(timeout=timeout_seconds) + return True + except subprocess.TimeoutExpired: + try: + os.killpg(process_group_id, signal.SIGKILL) + except ProcessLookupError: + return False + process.wait() + return False + + def play_audio_file( wav_path: Path, *, @@ -63,18 +158,21 @@ def play_audio_file( reserved for Codex hook JSON. """ - command_prefix = choose_player_command(player, path_env=path_env) - if command_prefix is None: + command = build_playback_command(wav_path, player=player, path_env=path_env) + if command is None: return PlaybackResult(ok=False, error=f"No playback command found for player={player!r}") - command = (*command_prefix, str(wav_path)) try: - process = subprocess.Popen( # noqa: S603 - command, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=not blocking, + process = ( + subprocess.Popen( # noqa: S603 + command, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=False, + ) + if blocking + else launch_process_group(command) ) except OSError as exc: return PlaybackResult(ok=False, command=command, error=str(exc)) @@ -91,4 +189,3 @@ def command_display(command: Sequence[str]) -> str: """Return a concise command string for logs.""" return " ".join(command) - diff --git a/src/tts_hook/stop.py b/src/tts_hook/stop.py index f957ddc..38510da 100644 --- a/src/tts_hook/stop.py +++ b/src/tts_hook/stop.py @@ -5,13 +5,14 @@ from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any, TextIO +import subprocess import sys from .config import ConfigError, TtsHookConfig, load_config from .hook_io import continue_result, read_hook_json, write_hook_json from .kokoro import synthesize_speech from .logging import HookLogger -from .playback import command_display, play_audio_file +from .playback import PlaybackResult, command_display def main( @@ -65,19 +66,44 @@ def speak_last_assistant_message( logger.warning(f"Could not write Kokoro WAV response; skipping playback. {exc}", stderr=True) return None - playback = play_audio_file( - wav_path, - player=config.playback.player, - blocking=config.playback.blocking, - ) + playback = spawn_playback_supervisor(wav_path, config) if not playback.ok: - logger.warning(f"Playback did not start; continuing. {playback.error or 'Unknown playback error.'}", stderr=True) + logger.warning(f"Playback supervisor did not start; continuing. {playback.error or 'Unknown playback error.'}", stderr=True) return wav_path - logger.info(f"Started playback with {command_display(playback.command)}") + logger.info(f"Started playback supervisor with {command_display(playback.command)}") return wav_path +def spawn_playback_supervisor(wav_path: Path, config: TtsHookConfig) -> PlaybackResult: + """Start the playback supervisor without waiting for playback completion.""" + + command = [ + sys.executable, + "-m", + "tts_hook.tts_playback_supervisor", + str(wav_path), + "--player", + config.playback.player, + ] + if config.playback.blocking: + command.append("--blocking") + + try: + process = subprocess.Popen( # noqa: S603 + tuple(command), + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + cwd=config.plugin_root, + ) + except OSError as exc: + return PlaybackResult(ok=False, command=tuple(command), error=str(exc)) + + return PlaybackResult(ok=True, command=tuple(command), pid=process.pid) + + def extract_assistant_message(payload: dict[str, Any]) -> str: """Return the full final assistant message with only outer whitespace removed.""" diff --git a/src/tts_hook/tts_playback_supervisor.py b/src/tts_hook/tts_playback_supervisor.py new file mode 100644 index 0000000..f7bf6b7 --- /dev/null +++ b/src/tts_hook/tts_playback_supervisor.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""Standalone playback supervisor entrypoint for generated TTS WAV files.""" + +from __future__ import annotations + +from argparse import ArgumentParser, Namespace +from pathlib import Path +from types import FrameType +from typing import BinaryIO, TextIO +import select +import signal +import sys +import termios +import tty + +from .playback import launch_audio_player_process, terminate_process_group + +ESCAPE = b"\x1b" +POLL_INTERVAL_SECONDS = 0.05 + +_active_supervisor: "PlaybackSupervisor | None" = None + + +def build_parser() -> ArgumentParser: + """Create the command-line parser for the playback supervisor.""" + + parser = ArgumentParser(description="Play a generated TTS WAV file under a playback supervisor.") + parser.add_argument("wav_path", type=Path, help="Path to the generated WAV file to play.") + parser.add_argument( + "--player", + default="auto", + help="Playback command to use, or 'auto' to select the first available supported player.", + ) + parser.add_argument( + "--blocking", + action="store_true", + help="Accepted for playback config compatibility; the supervisor still owns playback waiting.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + """Validate CLI arguments and run supervised playback.""" + + args = build_parser().parse_args(argv) + return run(args) + + +def cli() -> int: + """Console-script entrypoint.""" + + return main() + + +def run(args: Namespace) -> int: + """Run supervised playback for parsed supervisor arguments.""" + + return supervise_playback(args.wav_path, player=args.player) + + +class TtyEscapeReader: + """Read Escape key presses from a controlling terminal in cbreak mode.""" + + def __init__( + self, + *, + tty_path: str = "/dev/tty", + opener=open, + select_fn=select.select, + termios_module=termios, + tty_module=tty, + ) -> None: + self._tty_path = tty_path + self._opener = opener + self._select = select_fn + self._termios = termios_module + self._tty = tty_module + self._stream: BinaryIO | None = None + self._original_attrs: object | None = None + + def __enter__(self) -> "TtyEscapeReader": + self._stream = self._opener(self._tty_path, "rb", buffering=0) + fd = self._stream.fileno() + self._original_attrs = self._termios.tcgetattr(fd) + self._tty.setcbreak(fd) + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + self.close() + + def close(self) -> None: + """Restore terminal settings and close the terminal stream.""" + + stream = self._stream + if stream is None: + return + + try: + if self._original_attrs is not None: + self._termios.tcsetattr(stream.fileno(), self._termios.TCSADRAIN, self._original_attrs) + finally: + self._stream = None + stream.close() + + def escape_pressed(self, timeout_seconds: float) -> bool: + """Return ``True`` if Escape is available on the controlling terminal.""" + + if self._stream is None: + return False + + ready, _writable, _errors = self._select([self._stream], [], [], timeout_seconds) + if not ready: + return False + return self._stream.read(1) == ESCAPE + + +class PlaybackSupervisor: + """Own one playback process, optional Escape reader, and temp WAV cleanup.""" + + def __init__(self, wav_path: Path, *, player: str = "auto", stderr: TextIO | None = None) -> None: + self.wav_path = wav_path + self.player = player + self.stderr = stderr or sys.stderr + self.process = None + self._cancel_requested = False + + def run(self) -> int: + """Run supervised playback until completion or Escape cancellation.""" + + if not self.wav_path.is_file(): + _write_stderr(f"WAV file does not exist: {self.wav_path}", stderr=self.stderr) + return 2 + + global _active_supervisor + _active_supervisor = self + try: + return self._run_started_playback() + finally: + _active_supervisor = None + self._cleanup_wav() + + def request_cancel(self) -> None: + """Request cancellation of the active playback process.""" + + self._cancel_requested = True + if self.process is not None: + terminate_process_group(self.process) + + def _run_started_playback(self) -> int: + playback = launch_audio_player_process(self.wav_path, player=self.player) + if not playback.ok: + _write_stderr(playback.error or "Playback did not start.", stderr=self.stderr) + return 1 + if playback.process is None: + _write_stderr("Playback did not return a process handle.", stderr=self.stderr) + return 1 + + self.process = playback.process + reader = open_tty_escape_reader(stderr=self.stderr) + try: + self._wait_for_playback(reader) + finally: + if reader is not None: + reader.close() + return 0 + + def _wait_for_playback(self, reader: TtyEscapeReader | None) -> None: + if self.process is None: + return + + while self.process.poll() is None: + if self._cancel_requested: + return + if reader is None: + self.process.wait() + return + try: + if reader.escape_pressed(POLL_INTERVAL_SECONDS): + self.request_cancel() + return + except (OSError, select.error) as exc: + _write_stderr(f"Disabling Escape cancel support: {exc}", stderr=self.stderr) + reader.close() + reader = None + + def _cleanup_wav(self) -> None: + try: + self.wav_path.unlink(missing_ok=True) + except OSError as exc: + _write_stderr(f"Could not delete temporary WAV file {self.wav_path}: {exc}", stderr=self.stderr) + + +def supervise_playback(wav_path: Path, *, player: str = "auto", stderr: TextIO | None = None) -> int: + """Run a playback supervisor for one generated WAV file.""" + + install_signal_handlers() + return PlaybackSupervisor(wav_path, player=player, stderr=stderr).run() + + +def open_tty_escape_reader(*, stderr: TextIO | None = None) -> TtyEscapeReader | None: + """Best-effort open of `/dev/tty` for Escape cancellation.""" + + reader = TtyEscapeReader() + try: + return reader.__enter__() + except (OSError, termios.error) as exc: + _write_stderr(f"Escape cancel unavailable: {exc}", stderr=stderr) + reader.close() + return None + + +def install_signal_handlers() -> None: + """Install best-effort cleanup handlers for supervisor termination.""" + + for signum in (signal.SIGTERM, signal.SIGINT): + signal.signal(signum, _handle_termination_signal) + + +def _handle_termination_signal(signum: int, _frame: FrameType | None) -> None: + if _active_supervisor is not None: + _active_supervisor.request_cancel() + _active_supervisor._cleanup_wav() + raise SystemExit(128 + signum) + + +def _write_stderr(message: str, *, stderr: TextIO | None = None) -> None: + stream = stderr or sys.stderr + try: + stream.write(message + "\n") + stream.flush() + except OSError: + return + + +if __name__ == "__main__": + raise SystemExit(main())