From 506b64ec95ef37bde6a7ef71ca8ca11fcf0e58b8 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 06:53:43 -0400 Subject: [PATCH] feat: TalkTool speaking guard + FishTTS done_speaking signal - FishTTS: restructured to one websocket call per utterance, sends done_speaking signal after real audio duration has elapsed (not generation time). Uses PCM sample count to calculate exact playback duration. - TalkTool: tracks speaking state, rejects new talk calls with error while still speaking. Clears on done_speaking signal from TTS. - Both done_speaking sender/receiver are optional for backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/lib/audio/tts_fish.py | 88 ++++++++++++++++++++++--------- backend/src/lib/llm/talk_tool.py | 23 ++++++++ 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/backend/src/lib/audio/tts_fish.py b/backend/src/lib/audio/tts_fish.py index 03cc90b..bf5aaea 100644 --- a/backend/src/lib/audio/tts_fish.py +++ b/backend/src/lib/audio/tts_fish.py @@ -27,6 +27,7 @@ class FishTTSInputs(NamedTuple): class FishTTSOutputs(NamedTuple): audio: Sender[AudioFrame] + done_speaking: Sender[TextFrame] | None = None class FishTTS(ThreadedComponent[FishTTSInputs, FishTTSOutputs]): @@ -58,29 +59,68 @@ def handle_interrupts() -> None: threading.Thread(target=handle_interrupts, daemon=True).start() - def text_stream(): - for frame in inputs.text: - if frame is None or interrupted.is_set() or self.stop_event.is_set(): - break - yield frame.text - - remainder = b"" - for chunk in self._client.tts.stream_websocket( - text_stream(), - reference_id=self.config.reference_id, - format="pcm", - latency="balanced", - model=self.config.model, - ): - if interrupted.is_set() or self.stop_event.is_set(): + # Outer loop: one websocket call per utterance + for frame in inputs.text: + if frame is None or self.stop_event.is_set(): break - if chunk: - data = remainder + chunk - usable = len(data) - (len(data) % 2) - if usable > 0: - outputs.audio.send( - AudioFrame.new( - data=data[:usable], sample_rate=44100, channels=1 + + # Collect tokens until EOS + tokens: list[str] = [frame.text] + inputs.text.blocking = False + for extra in inputs.text: + if extra is None: + break + if isinstance(extra, EOS): + break + tokens.append(extra.text) + inputs.text.blocking = True + + if interrupted.is_set(): + interrupted.clear() + continue + + def token_iter(): + for t in tokens: + if interrupted.is_set() or self.stop_event.is_set(): + break + yield t + + import time as _time + + remainder = b"" + first_chunk_time: float | None = None + total_samples = 0 + sample_rate = 44100 + + for chunk in self._client.tts.stream_websocket( + token_iter(), + reference_id=self.config.reference_id, + format="pcm", + latency="balanced", + model=self.config.model, + ): + if interrupted.is_set() or self.stop_event.is_set(): + break + if chunk: + data = remainder + chunk + usable = len(data) - (len(data) % 2) + if usable > 0: + if first_chunk_time is None: + first_chunk_time = _time.monotonic() + total_samples += usable // 2 # 16-bit PCM = 2 bytes per sample + outputs.audio.send( + AudioFrame.new( + data=data[:usable], sample_rate=sample_rate, channels=1 + ) ) - ) - remainder = data[usable:] + remainder = data[usable:] + + # Wait until audio finishes playing, then signal done + if outputs.done_speaking is not None and first_chunk_time is not None: + audio_duration = total_samples / sample_rate + elapsed = _time.monotonic() - first_chunk_time + remaining = audio_duration - elapsed + if remaining > 0: + self.stop_event.wait(remaining) + if not self.stop_event.is_set(): + outputs.done_speaking.send(TextFrame.new(text="done")) diff --git a/backend/src/lib/llm/talk_tool.py b/backend/src/lib/llm/talk_tool.py index dadade4..3782c2b 100644 --- a/backend/src/lib/llm/talk_tool.py +++ b/backend/src/lib/llm/talk_tool.py @@ -18,6 +18,7 @@ class TalkToolConfig(BaseModel): class TalkToolInputs(NamedTuple): tool_call: Receiver[ToolCall] + done_speaking: Receiver[TextFrame] | None = None class TalkToolOutputs(NamedTuple): @@ -55,11 +56,32 @@ def setup(self, outputs: TalkToolOutputs) -> None: ) def run(self, inputs: TalkToolInputs, outputs: TalkToolOutputs) -> None: + speaking = False + + if inputs.done_speaking is not None: + inputs.done_speaking.blocking = False + for call in inputs.tool_call: if call is None: break if call.name != "talk": continue + + # Drain done_speaking signal + if inputs.done_speaking is not None: + done = next(inputs.done_speaking, None) + if done is not None: + speaking = False + + if speaking: + outputs.tool_result.send( + ToolResult.new( + call_id=call.call_id, + content="Error: still speaking. Wait until done.", + ) + ) + continue + try: text = json.loads(call.arguments).get("text", "") except (json.JSONDecodeError, AttributeError): @@ -67,4 +89,5 @@ def run(self, inputs: TalkToolInputs, outputs: TalkToolOutputs) -> None: if text: print(f"[Talk] {text}") outputs.text.send(TextFrame.new(text=text)) + speaking = True outputs.tool_result.send(ToolResult.new(call_id=call.call_id, content=""))