Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions Dockerfile.strixhalo
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
FROM rocm/dev-ubuntu-22.04:latest

# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive
# Set the Hugging Face home directory for better model caching
ENV HF_HOME=/app/hf_cache

# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libsndfile1 \
ffmpeg \
python3 \
python3-pip \
python3-dev \
python3-venv \
git \
rocm-ml-libraries \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# Create a symlink for python3 to be python for convenience
RUN ln -s /usr/bin/python3 /usr/bin/python

# Set up working directory
WORKDIR /app

# Copy requirements first to leverage Docker cache
COPY requirements-strixhalo-init.txt ./requirements-strixhalo-init.txt
COPY requirements-strixhalo.txt ./requirements-strixhalo.txt

# Upgrade pip and install Python dependencies
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
pip3 install --no-cache-dir -r requirements-strixhalo.txt && \
python3 -m pip install --no-cache-dir --force-reinstall -r requirements-strixhalo-init.txt

# Patch s3tokenizer dtype issues for torch 2.9+ compatibility:
# 1. Cast numpy wav to float32 in _prepare_audio (numpy defaults to float64)
# This fixes the cascade: float64 wav → float64 STFT → float64 mel → assertion failure
# 2. Cast _mel_filters to match magnitudes dtype (guards against float64 checkpoint weights)
RUN sed -i \
's/wav = torch\.from_numpy(wav)$/wav = torch.from_numpy(wav).float()/' \
/usr/local/lib/python3.10/dist-packages/chatterbox/models/s3tokenizer/s3tokenizer.py && \
sed -i \
's/mel_spec = self\._mel_filters\.to(self\.device) @ magnitudes/mel_spec = self._mel_filters.to(self.device).to(magnitudes.dtype) @ magnitudes/' \
/usr/local/lib/python3.10/dist-packages/chatterbox/models/s3tokenizer/s3tokenizer.py

# Patch voice_encoder dtype issue: melspectrogram() returns numpy float64 → LSTM requires float32
RUN sed -i \
's/utt_embeds = self\.inference(mels\.to(self\.device),/utt_embeds = self.inference(mels.to(self.device).float(),/' \
/usr/local/lib/python3.10/dist-packages/chatterbox/models/voice_encoder/voice_encoder.py

# Copy the rest of the application code
COPY . .

# Create required directories for the application (fixed syntax error)
RUN mkdir -p model_cache reference_audio outputs voices logs hf_cache

# Expose the port the application will run on
EXPOSE 8004

# Command to run the application
CMD ["python3", "server.py"]
61 changes: 61 additions & 0 deletions docker-compose-strixhalo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

services:
chatterbox-tts-server:
build:
context: .
dockerfile: Dockerfile.strixhalo
ports:
- "${PORT:-8004}:8004"
volumes:
# Mount local config file for persistence
- ./config.yaml:/app/config.yaml
# Mount local directories for persistent app data
- ./voices:/app/voices
- ./reference_audio:/app/reference_audio
- ./outputs:/app/outputs
- ./logs:/app/logs
- hf_cache:/app/hf_cache

# --- ROCm GPU Access ---
# Standard ROCm device access - required for AMD GPU acceleration
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
- render
ipc: host
shm_size: 8g
security_opt:
- seccomp=unconfined

# --- Optional: Enhanced ROCm Access ---
# Uncomment the lines below if you experience GPU access issues
# privileged: true
# cap_add:
# - SYS_PTRACE
# devices:
# - /dev/mem

restart: unless-stopped
environment:
# Enable faster Hugging Face downloads
- HF_HUB_ENABLE_HF_TRANSFER=1
- HF_TOKEN=YOUR_TOKEN_HERE
- TTS_ENGINE_DEVICE=cuda
# Enable bfloat16 for T3/S3Gen — halves memory bandwidth on token generation
# Set TTS_BF16=off if you experience precision issues
- TTS_BF16=on
- TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
- MIOPEN_FIND_MODE=FAST
- HSA_OVERRIDE_GFX_VERSION=11.0.0
# Critical for Strix Halo unified memory (gfx1151): enables page-fault retry
- HSA_XNACK=1
# Reduces memory fragmentation on large unified memory pool
- PYTORCH_ALLOC_CONF=expandable_segments:True
# NOTE: HSA_OVERRIDE_GFX_VERSION should only be set by users with unsupported GPUs
# Example usage: HSA_OVERRIDE_GFX_VERSION=10.3.0 docker compose up
# Common values: 10.3.0 (RX 5000/6000), 11.0.0 (RX 7000), 9.0.6 (Vega)

volumes:
hf_cache:
114 changes: 92 additions & 22 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gc
import logging
import os
import random
import numpy as np
import torch
Expand Down Expand Up @@ -38,7 +39,8 @@

logger = logging.getLogger(__name__)

# Log Turbo availability status at module load time
# Log BF16 setting at module load so it's visible in startup logs
# (BF16_ENABLED is resolved after logger is set up — logged in initialize_tts_model)
if TURBO_AVAILABLE:
logger.info("ChatterboxTurboTTS is available in the installed chatterbox package.")
else:
Expand Down Expand Up @@ -79,6 +81,25 @@
"shush",
]

# --- BF16 optimization flag ---
# TTS_BF16: controls whether T3/S3Gen are converted to bfloat16 and whether
# autocast is used during inference.
# auto (default) — enable only if torch.cuda.is_bf16_supported()
# on / 1 / true — force-enable (assumes hardware supports it)
# off / 0 / false — disable even if hardware supports it
def _resolve_bf16_setting() -> bool:
val = os.environ.get("TTS_BF16", "auto").strip().lower()
if val in ("off", "0", "false"):
return False
if val in ("on", "1", "true"):
return True
# auto: detect at runtime
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
return False

BF16_ENABLED: bool = _resolve_bf16_setting()

# --- Global Module Variables ---
chatterbox_model: Optional[ChatterboxTTS] = None
MODEL_LOADED: bool = False
Expand All @@ -90,6 +111,18 @@
loaded_model_type: Optional[str] = None # "original" or "turbo"
loaded_model_class_name: Optional[str] = None # "ChatterboxTTS" or "ChatterboxTurboTTS"

# Voice conditioning cache: avoids re-encoding the same voice file on every request.
# Key: (resolved_path, file_mtime, exaggeration) — mtime invalidates if file changes.
_conds_cache: dict = {}


def _conds_cache_key(path: str, exaggeration: float) -> tuple:
try:
mtime = os.path.getmtime(path)
except OSError:
mtime = 0.0
return (path, mtime, exaggeration)


def set_seed(seed_value: int):
"""
Expand Down Expand Up @@ -307,6 +340,10 @@ def load_model() -> bool:

model_device = resolved_device_str
logger.info(f"Final device selection: {model_device}")
logger.info(
f"BF16 optimization: {'enabled' if BF16_ENABLED else 'disabled'} "
f"(TTS_BF16={os.environ.get('TTS_BF16', 'auto')})"
)

# Get the model selector from config
model_selector = config_manager.get_string("model.repo_id", "chatterbox-turbo")
Expand All @@ -329,6 +366,17 @@ def load_model() -> bool:
# Load the model using from_pretrained - handles HuggingFace downloads automatically
chatterbox_model = model_class.from_pretrained(device=model_device)

# Convert T3 to bfloat16 if enabled.
# Token generation is memory-bandwidth bound; bf16 halves bytes read per
# forward pass. S3Gen is intentionally kept in float32 — it runs only
# 2 CFM timesteps and bf16 causes token/mask size mismatches.
if BF16_ENABLED:
if hasattr(chatterbox_model, "t3"):
chatterbox_model.t3 = chatterbox_model.t3.bfloat16()
logger.info("T3 model converted to bfloat16 for faster token generation.")
else:
logger.info("BF16 optimization disabled (TTS_BF16=off or hardware unsupported).")

# Store model metadata
loaded_model_type = model_type
loaded_model_class_name = model_class.__name__
Expand Down Expand Up @@ -423,25 +471,45 @@ def synthesize(
f"language={language}"
)

# Call the core model's generate method
# Multilingual model requires language_id parameter
if loaded_model_type == "multilingual":
wav_tensor = chatterbox_model.generate(
text=text,
language_id=language,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)
else:
wav_tensor = chatterbox_model.generate(
text=text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)
# Voice conditioning cache: skip re-encoding the same voice file.
# Turbo ignores exaggeration in conds; others include it in the key.
effective_prompt = audio_prompt_path
conds_key = None
if audio_prompt_path and hasattr(chatterbox_model, "conds"):
ex_for_key = 0.0 if loaded_model_type == "turbo" else exaggeration
conds_key = _conds_cache_key(audio_prompt_path, ex_for_key)
if conds_key in _conds_cache:
chatterbox_model.conds = _conds_cache[conds_key]
effective_prompt = None # conds already set, skip prepare_conditionals
logger.debug(f"Voice cache hit: {audio_prompt_path}")

# Call the core model's generate method.
# autocast promotes float32 inputs to bfloat16 to match T3/S3Gen weights,
# keeping numerically sensitive ops (softmax, norms) in float32 automatically.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=BF16_ENABLED):
if loaded_model_type == "multilingual":
wav_tensor = chatterbox_model.generate(
text=text,
language_id=language,
audio_prompt_path=effective_prompt,
temperature=temperature,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)
else:
wav_tensor = chatterbox_model.generate(
text=text,
audio_prompt_path=effective_prompt,
temperature=temperature,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)

# Store conds in cache after first compute for this voice.
if conds_key is not None and effective_prompt is not None:
if chatterbox_model.conds is not None:
_conds_cache[conds_key] = chatterbox_model.conds
logger.debug(f"Cached voice conditionals for: {audio_prompt_path}")

# The ChatterboxTTS.generate method already returns a CPU tensor.
return wav_tensor, chatterbox_model.sr
Expand All @@ -460,7 +528,7 @@ def reload_model() -> bool:
Returns:
bool: True if the new model loaded successfully, False otherwise.
"""
global chatterbox_model, MODEL_LOADED, model_device, loaded_model_type, loaded_model_class_name
global chatterbox_model, MODEL_LOADED, model_device, loaded_model_type, loaded_model_class_name, _conds_cache

logger.info("Initiating model hot-swap/reload sequence...")

Expand All @@ -470,10 +538,12 @@ def reload_model() -> bool:
del chatterbox_model
chatterbox_model = None

# 2. Reset state flags
# 2. Reset state flags and clear voice cache (conds are model-specific)
MODEL_LOADED = False
loaded_model_type = None
loaded_model_class_name = None
_conds_cache.clear()
logger.info("Voice conditioning cache cleared.")

# 3. Force Python Garbage Collection
gc.collect()
Expand Down
4 changes: 4 additions & 0 deletions requirements-strixhalo-init.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchaudio-2.9.0%2Brocm7.2.0.gite3c6ee2b-cp310-cp310-linux_x86_64.whl
30 changes: 30 additions & 0 deletions requirements-strixhalo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# requirements.txt

# Chatterbox TTS engine - Install from chatterbox-v2 fork
chatterbox-tts @ git+https://github.com/devnen/chatterbox-v2.git@master

# Core Web Framework
fastapi
uvicorn[standard]

# Machine Learning & Audio
numpy>=1.24.0,<1.26.0 # Match chatterbox-v2 requirements
soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu)
huggingface_hub
descript-audio-codec
safetensors

# Configuration & Utilities
pydantic
python-dotenv # Used ONLY for initial config seeding if config.yaml missing
Jinja2
python-multipart # For file uploads
requests # For health checks or other potential uses
PyYAML # For parsing presets.yaml AND primary config.yaml
tqdm

# Audio Post-processing
pydub
praat-parselmouth # For unvoiced segment removal
librosa # for changes to sampling
hf-transfer