From 2610936598de5db62a670bbd62a36c698ef837ac Mon Sep 17 00:00:00 2001 From: 0xrushi <0xrushi@gmail.com> Date: Tue, 24 Feb 2026 08:51:26 -0500 Subject: [PATCH 1/2] strix halo optimizations --- Dockerfile.strixhalo | 65 ++++++++++++++++++++ docker-compose-strixhalo.yml | 61 +++++++++++++++++++ engine.py | 114 ++++++++++++++++++++++++++++------- requirements-rocm-init.txt | 4 ++ requirements-rocm.txt | 3 - 5 files changed, 222 insertions(+), 25 deletions(-) create mode 100644 Dockerfile.strixhalo create mode 100644 docker-compose-strixhalo.yml create mode 100644 requirements-rocm-init.txt diff --git a/Dockerfile.strixhalo b/Dockerfile.strixhalo new file mode 100644 index 0000000..c29f60c --- /dev/null +++ b/Dockerfile.strixhalo @@ -0,0 +1,65 @@ +FROM rocm/dev-ubuntu-22.04:latest + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive +# Set the Hugging Face home directory for better model caching +ENV HF_HOME=/app/hf_cache + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libsndfile1 \ + ffmpeg \ + python3 \ + python3-pip \ + python3-dev \ + python3-venv \ + git \ + rocm-ml-libraries \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Create a symlink for python3 to be python for convenience +RUN ln -s /usr/bin/python3 /usr/bin/python + +# Set up working directory +WORKDIR /app + +# Copy requirements first to leverage Docker cache +COPY requirements-rocm-init.txt ./requirements-rocm-init.txt +COPY requirements-rocm.txt ./requirements-rocm.txt + +# Upgrade pip and install Python dependencies +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + pip3 install --no-cache-dir -r requirements-rocm.txt && \ + python3 -m pip install --no-cache-dir --force-reinstall -r requirements-rocm-init.txt + +# Patch s3tokenizer dtype issues for torch 2.9+ compatibility: +# 1. Cast numpy wav to float32 in _prepare_audio (numpy defaults to float64) +# This fixes the cascade: float64 wav → float64 STFT → float64 mel → assertion failure +# 2. Cast _mel_filters to match magnitudes dtype (guards against float64 checkpoint weights) +RUN sed -i \ + 's/wav = torch\.from_numpy(wav)$/wav = torch.from_numpy(wav).float()/' \ + /usr/local/lib/python3.10/dist-packages/chatterbox/models/s3tokenizer/s3tokenizer.py && \ + sed -i \ + 's/mel_spec = self\._mel_filters\.to(self\.device) @ magnitudes/mel_spec = self._mel_filters.to(self.device).to(magnitudes.dtype) @ magnitudes/' \ + /usr/local/lib/python3.10/dist-packages/chatterbox/models/s3tokenizer/s3tokenizer.py + +# Patch voice_encoder dtype issue: melspectrogram() returns numpy float64 → LSTM requires float32 +RUN sed -i \ + 's/utt_embeds = self\.inference(mels\.to(self\.device),/utt_embeds = self.inference(mels.to(self.device).float(),/' \ + /usr/local/lib/python3.10/dist-packages/chatterbox/models/voice_encoder/voice_encoder.py + +# Copy the rest of the application code +COPY . . + +# Create required directories for the application (fixed syntax error) +RUN mkdir -p model_cache reference_audio outputs voices logs hf_cache + +# Expose the port the application will run on +EXPOSE 8004 + +# Command to run the application +CMD ["python3", "server.py"] diff --git a/docker-compose-strixhalo.yml b/docker-compose-strixhalo.yml new file mode 100644 index 0000000..13518b1 --- /dev/null +++ b/docker-compose-strixhalo.yml @@ -0,0 +1,61 @@ + +services: + chatterbox-tts-server: + build: + context: . + dockerfile: Dockerfile.strixhalo + ports: + - "${PORT:-8004}:8004" + volumes: + # Mount local config file for persistence + - ./config.yaml:/app/config.yaml + # Mount local directories for persistent app data + - ./voices:/app/voices + - ./reference_audio:/app/reference_audio + - ./outputs:/app/outputs + - ./logs:/app/logs + - hf_cache:/app/hf_cache + + # --- ROCm GPU Access --- + # Standard ROCm device access - required for AMD GPU acceleration + devices: + - /dev/kfd + - /dev/dri + group_add: + - video + - render + ipc: host + shm_size: 8g + security_opt: + - seccomp=unconfined + + # --- Optional: Enhanced ROCm Access --- + # Uncomment the lines below if you experience GPU access issues + # privileged: true + # cap_add: + # - SYS_PTRACE + # devices: + # - /dev/mem + + restart: unless-stopped + environment: + # Enable faster Hugging Face downloads + - HF_HUB_ENABLE_HF_TRANSFER=1 + - HF_TOKEN=YOUR_TOKEN_HERE + - TTS_ENGINE_DEVICE=cuda + # Enable bfloat16 for T3/S3Gen — halves memory bandwidth on token generation + # Set TTS_BF16=off if you experience precision issues + - TTS_BF16=on + - TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 + - MIOPEN_FIND_MODE=FAST + - HSA_OVERRIDE_GFX_VERSION=11.0.0 + # Critical for Strix Halo unified memory (gfx1151): enables page-fault retry + - HSA_XNACK=1 + # Reduces memory fragmentation on large unified memory pool + - PYTORCH_ALLOC_CONF=expandable_segments:True + # NOTE: HSA_OVERRIDE_GFX_VERSION should only be set by users with unsupported GPUs + # Example usage: HSA_OVERRIDE_GFX_VERSION=10.3.0 docker compose up + # Common values: 10.3.0 (RX 5000/6000), 11.0.0 (RX 7000), 9.0.6 (Vega) + +volumes: + hf_cache: diff --git a/engine.py b/engine.py index f7fdcd4..06098ee 100644 --- a/engine.py +++ b/engine.py @@ -3,6 +3,7 @@ import gc import logging +import os import random import numpy as np import torch @@ -38,7 +39,8 @@ logger = logging.getLogger(__name__) -# Log Turbo availability status at module load time +# Log BF16 setting at module load so it's visible in startup logs +# (BF16_ENABLED is resolved after logger is set up — logged in initialize_tts_model) if TURBO_AVAILABLE: logger.info("ChatterboxTurboTTS is available in the installed chatterbox package.") else: @@ -79,6 +81,25 @@ "shush", ] +# --- BF16 optimization flag --- +# TTS_BF16: controls whether T3/S3Gen are converted to bfloat16 and whether +# autocast is used during inference. +# auto (default) — enable only if torch.cuda.is_bf16_supported() +# on / 1 / true — force-enable (assumes hardware supports it) +# off / 0 / false — disable even if hardware supports it +def _resolve_bf16_setting() -> bool: + val = os.environ.get("TTS_BF16", "auto").strip().lower() + if val in ("off", "0", "false"): + return False + if val in ("on", "1", "true"): + return True + # auto: detect at runtime + if torch.cuda.is_available(): + return torch.cuda.is_bf16_supported() + return False + +BF16_ENABLED: bool = _resolve_bf16_setting() + # --- Global Module Variables --- chatterbox_model: Optional[ChatterboxTTS] = None MODEL_LOADED: bool = False @@ -90,6 +111,18 @@ loaded_model_type: Optional[str] = None # "original" or "turbo" loaded_model_class_name: Optional[str] = None # "ChatterboxTTS" or "ChatterboxTurboTTS" +# Voice conditioning cache: avoids re-encoding the same voice file on every request. +# Key: (resolved_path, file_mtime, exaggeration) — mtime invalidates if file changes. +_conds_cache: dict = {} + + +def _conds_cache_key(path: str, exaggeration: float) -> tuple: + try: + mtime = os.path.getmtime(path) + except OSError: + mtime = 0.0 + return (path, mtime, exaggeration) + def set_seed(seed_value: int): """ @@ -307,6 +340,10 @@ def load_model() -> bool: model_device = resolved_device_str logger.info(f"Final device selection: {model_device}") + logger.info( + f"BF16 optimization: {'enabled' if BF16_ENABLED else 'disabled'} " + f"(TTS_BF16={os.environ.get('TTS_BF16', 'auto')})" + ) # Get the model selector from config model_selector = config_manager.get_string("model.repo_id", "chatterbox-turbo") @@ -329,6 +366,17 @@ def load_model() -> bool: # Load the model using from_pretrained - handles HuggingFace downloads automatically chatterbox_model = model_class.from_pretrained(device=model_device) + # Convert T3 to bfloat16 if enabled. + # Token generation is memory-bandwidth bound; bf16 halves bytes read per + # forward pass. S3Gen is intentionally kept in float32 — it runs only + # 2 CFM timesteps and bf16 causes token/mask size mismatches. + if BF16_ENABLED: + if hasattr(chatterbox_model, "t3"): + chatterbox_model.t3 = chatterbox_model.t3.bfloat16() + logger.info("T3 model converted to bfloat16 for faster token generation.") + else: + logger.info("BF16 optimization disabled (TTS_BF16=off or hardware unsupported).") + # Store model metadata loaded_model_type = model_type loaded_model_class_name = model_class.__name__ @@ -423,25 +471,45 @@ def synthesize( f"language={language}" ) - # Call the core model's generate method - # Multilingual model requires language_id parameter - if loaded_model_type == "multilingual": - wav_tensor = chatterbox_model.generate( - text=text, - language_id=language, - audio_prompt_path=audio_prompt_path, - temperature=temperature, - exaggeration=exaggeration, - cfg_weight=cfg_weight, - ) - else: - wav_tensor = chatterbox_model.generate( - text=text, - audio_prompt_path=audio_prompt_path, - temperature=temperature, - exaggeration=exaggeration, - cfg_weight=cfg_weight, - ) + # Voice conditioning cache: skip re-encoding the same voice file. + # Turbo ignores exaggeration in conds; others include it in the key. + effective_prompt = audio_prompt_path + conds_key = None + if audio_prompt_path and hasattr(chatterbox_model, "conds"): + ex_for_key = 0.0 if loaded_model_type == "turbo" else exaggeration + conds_key = _conds_cache_key(audio_prompt_path, ex_for_key) + if conds_key in _conds_cache: + chatterbox_model.conds = _conds_cache[conds_key] + effective_prompt = None # conds already set, skip prepare_conditionals + logger.debug(f"Voice cache hit: {audio_prompt_path}") + + # Call the core model's generate method. + # autocast promotes float32 inputs to bfloat16 to match T3/S3Gen weights, + # keeping numerically sensitive ops (softmax, norms) in float32 automatically. + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=BF16_ENABLED): + if loaded_model_type == "multilingual": + wav_tensor = chatterbox_model.generate( + text=text, + language_id=language, + audio_prompt_path=effective_prompt, + temperature=temperature, + exaggeration=exaggeration, + cfg_weight=cfg_weight, + ) + else: + wav_tensor = chatterbox_model.generate( + text=text, + audio_prompt_path=effective_prompt, + temperature=temperature, + exaggeration=exaggeration, + cfg_weight=cfg_weight, + ) + + # Store conds in cache after first compute for this voice. + if conds_key is not None and effective_prompt is not None: + if chatterbox_model.conds is not None: + _conds_cache[conds_key] = chatterbox_model.conds + logger.debug(f"Cached voice conditionals for: {audio_prompt_path}") # The ChatterboxTTS.generate method already returns a CPU tensor. return wav_tensor, chatterbox_model.sr @@ -460,7 +528,7 @@ def reload_model() -> bool: Returns: bool: True if the new model loaded successfully, False otherwise. """ - global chatterbox_model, MODEL_LOADED, model_device, loaded_model_type, loaded_model_class_name + global chatterbox_model, MODEL_LOADED, model_device, loaded_model_type, loaded_model_class_name, _conds_cache logger.info("Initiating model hot-swap/reload sequence...") @@ -470,10 +538,12 @@ def reload_model() -> bool: del chatterbox_model chatterbox_model = None - # 2. Reset state flags + # 2. Reset state flags and clear voice cache (conds are model-specific) MODEL_LOADED = False loaded_model_type = None loaded_model_class_name = None + _conds_cache.clear() + logger.info("Voice conditioning cache cleared.") # 3. Force Python Garbage Collection gc.collect() diff --git a/requirements-rocm-init.txt b/requirements-rocm-init.txt new file mode 100644 index 0000000..ed8ae9b --- /dev/null +++ b/requirements-rocm-init.txt @@ -0,0 +1,4 @@ +https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl +https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl +https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl +https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchaudio-2.9.0%2Brocm7.2.0.gite3c6ee2b-cp310-cp310-linux_x86_64.whl diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 6ecabc7..0aaa939 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,9 +8,6 @@ fastapi uvicorn[standard] # Machine Learning & Audio -torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torch-2.6.0%2Brocm6.4.1.git1ded221d-cp310-cp310-linux_x86_64.whl -torchaudio @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torchaudio-2.6.0%2Brocm6.4.1.gitd8831425-cp310-cp310-linux_x86_64.whl -pytorch-triton-rocm @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/pytorch_triton_rocm-3.2.0%2Brocm6.4.1.git6da9e660-cp310-cp310-linux_x86_64.whl numpy>=1.24.0,<1.26.0 # Match chatterbox-v2 requirements soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu) huggingface_hub From 62ff3f16641d9855f2dd58ebaf9dfb65a4c423f7 Mon Sep 17 00:00:00 2001 From: 0xrushi <0xrushi@gmail.com> Date: Tue, 24 Feb 2026 09:08:51 -0500 Subject: [PATCH 2/2] new reqs --- Dockerfile.strixhalo | 8 ++--- requirements-rocm.txt | 3 ++ ...nit.txt => requirements-strixhalo-init.txt | 0 requirements-strixhalo.txt | 30 +++++++++++++++++++ 4 files changed, 37 insertions(+), 4 deletions(-) rename requirements-rocm-init.txt => requirements-strixhalo-init.txt (100%) create mode 100644 requirements-strixhalo.txt diff --git a/Dockerfile.strixhalo b/Dockerfile.strixhalo index c29f60c..d5f9607 100644 --- a/Dockerfile.strixhalo +++ b/Dockerfile.strixhalo @@ -28,13 +28,13 @@ RUN ln -s /usr/bin/python3 /usr/bin/python WORKDIR /app # Copy requirements first to leverage Docker cache -COPY requirements-rocm-init.txt ./requirements-rocm-init.txt -COPY requirements-rocm.txt ./requirements-rocm.txt +COPY requirements-strixhalo-init.txt ./requirements-strixhalo-init.txt +COPY requirements-strixhalo.txt ./requirements-strixhalo.txt # Upgrade pip and install Python dependencies RUN python3 -m pip install --no-cache-dir --upgrade pip && \ - pip3 install --no-cache-dir -r requirements-rocm.txt && \ - python3 -m pip install --no-cache-dir --force-reinstall -r requirements-rocm-init.txt + pip3 install --no-cache-dir -r requirements-strixhalo.txt && \ + python3 -m pip install --no-cache-dir --force-reinstall -r requirements-strixhalo-init.txt # Patch s3tokenizer dtype issues for torch 2.9+ compatibility: # 1. Cast numpy wav to float32 in _prepare_audio (numpy defaults to float64) diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 0aaa939..6ecabc7 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,6 +8,9 @@ fastapi uvicorn[standard] # Machine Learning & Audio +torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torch-2.6.0%2Brocm6.4.1.git1ded221d-cp310-cp310-linux_x86_64.whl +torchaudio @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torchaudio-2.6.0%2Brocm6.4.1.gitd8831425-cp310-cp310-linux_x86_64.whl +pytorch-triton-rocm @ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/pytorch_triton_rocm-3.2.0%2Brocm6.4.1.git6da9e660-cp310-cp310-linux_x86_64.whl numpy>=1.24.0,<1.26.0 # Match chatterbox-v2 requirements soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu) huggingface_hub diff --git a/requirements-rocm-init.txt b/requirements-strixhalo-init.txt similarity index 100% rename from requirements-rocm-init.txt rename to requirements-strixhalo-init.txt diff --git a/requirements-strixhalo.txt b/requirements-strixhalo.txt new file mode 100644 index 0000000..0aaa939 --- /dev/null +++ b/requirements-strixhalo.txt @@ -0,0 +1,30 @@ +# requirements.txt + +# Chatterbox TTS engine - Install from chatterbox-v2 fork +chatterbox-tts @ git+https://github.com/devnen/chatterbox-v2.git@master + +# Core Web Framework +fastapi +uvicorn[standard] + +# Machine Learning & Audio +numpy>=1.24.0,<1.26.0 # Match chatterbox-v2 requirements +soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu) +huggingface_hub +descript-audio-codec +safetensors + +# Configuration & Utilities +pydantic +python-dotenv # Used ONLY for initial config seeding if config.yaml missing +Jinja2 +python-multipart # For file uploads +requests # For health checks or other potential uses +PyYAML # For parsing presets.yaml AND primary config.yaml +tqdm + +# Audio Post-processing +pydub +praat-parselmouth # For unvoiced segment removal +librosa # for changes to sampling +hf-transfer