Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions backend/src/lib/audio/tts_fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class FishTTSInputs(NamedTuple):

class FishTTSOutputs(NamedTuple):
audio: Sender[AudioFrame]
done_speaking: Sender[TextFrame] | None = None


class FishTTS(ThreadedComponent[FishTTSInputs, FishTTSOutputs]):
Expand Down Expand Up @@ -58,29 +59,68 @@ def handle_interrupts() -> None:

threading.Thread(target=handle_interrupts, daemon=True).start()

def text_stream():
for frame in inputs.text:
if frame is None or interrupted.is_set() or self.stop_event.is_set():
break
yield frame.text

remainder = b""
for chunk in self._client.tts.stream_websocket(
text_stream(),
reference_id=self.config.reference_id,
format="pcm",
latency="balanced",
model=self.config.model,
):
if interrupted.is_set() or self.stop_event.is_set():
# Outer loop: one websocket call per utterance
for frame in inputs.text:
if frame is None or self.stop_event.is_set():
break
if chunk:
data = remainder + chunk
usable = len(data) - (len(data) % 2)
if usable > 0:
outputs.audio.send(
AudioFrame.new(
data=data[:usable], sample_rate=44100, channels=1

# Collect tokens until EOS
tokens: list[str] = [frame.text]
inputs.text.blocking = False
for extra in inputs.text:
if extra is None:
break
if isinstance(extra, EOS):
break
tokens.append(extra.text)
inputs.text.blocking = True

if interrupted.is_set():
interrupted.clear()
continue

def token_iter():
for t in tokens:
if interrupted.is_set() or self.stop_event.is_set():
break
yield t

import time as _time

remainder = b""
first_chunk_time: float | None = None
total_samples = 0
sample_rate = 44100

for chunk in self._client.tts.stream_websocket(
token_iter(),
reference_id=self.config.reference_id,
format="pcm",
latency="balanced",
model=self.config.model,
):
if interrupted.is_set() or self.stop_event.is_set():
break
if chunk:
data = remainder + chunk
usable = len(data) - (len(data) % 2)
if usable > 0:
if first_chunk_time is None:
first_chunk_time = _time.monotonic()
total_samples += usable // 2 # 16-bit PCM = 2 bytes per sample
outputs.audio.send(
AudioFrame.new(
data=data[:usable], sample_rate=sample_rate, channels=1
)
)
)
remainder = data[usable:]
remainder = data[usable:]

# Wait until audio finishes playing, then signal done
if outputs.done_speaking is not None and first_chunk_time is not None:
audio_duration = total_samples / sample_rate
elapsed = _time.monotonic() - first_chunk_time
remaining = audio_duration - elapsed
if remaining > 0:
self.stop_event.wait(remaining)
if not self.stop_event.is_set():
outputs.done_speaking.send(TextFrame.new(text="done"))
23 changes: 23 additions & 0 deletions backend/src/lib/llm/talk_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class TalkToolConfig(BaseModel):

class TalkToolInputs(NamedTuple):
tool_call: Receiver[ToolCall]
done_speaking: Receiver[TextFrame] | None = None


class TalkToolOutputs(NamedTuple):
Expand Down Expand Up @@ -55,16 +56,38 @@ def setup(self, outputs: TalkToolOutputs) -> None:
)

def run(self, inputs: TalkToolInputs, outputs: TalkToolOutputs) -> None:
speaking = False

if inputs.done_speaking is not None:
inputs.done_speaking.blocking = False

for call in inputs.tool_call:
if call is None:
break
if call.name != "talk":
continue

# Drain done_speaking signal
if inputs.done_speaking is not None:
done = next(inputs.done_speaking, None)
if done is not None:
speaking = False

if speaking:
outputs.tool_result.send(
ToolResult.new(
call_id=call.call_id,
content="Error: still speaking. Wait until done.",
)
)
continue

try:
text = json.loads(call.arguments).get("text", "")
except (json.JSONDecodeError, AttributeError):
text = call.arguments
if text:
print(f"[Talk] {text}")
outputs.text.send(TextFrame.new(text=text))
speaking = True
outputs.tool_result.send(ToolResult.new(call_id=call.call_id, content=""))
Loading