From 4f23eba74e6c6be8ac12e7c2d4c71610b458abce Mon Sep 17 00:00:00 2001 From: Benjamin Kobjolke Date: Wed, 18 Mar 2026 07:14:05 +0100 Subject: [PATCH 1/4] FEATURE: ONNX2 Support --- config.yaml | 119 +++++----------------------------------------------- engine.py | 55 +++++++++++++++++++++--- server.py | 1 + 3 files changed, 60 insertions(+), 115 deletions(-) diff --git a/config.yaml b/config.yaml index e90edb7..9f595a5 100644 --- a/config.yaml +++ b/config.yaml @@ -1,137 +1,38 @@ -# ----------------------------------------------------------------------------- -# Kitten TTS Server Configuration File (config.yaml) -# -# This file controls all the settings for the server. -# Changes to sections like 'server', 'tts_engine', or 'paths' -# typically require a server restart to take effect. -# ----------------------------------------------------------------------------- - -# --- Server Settings --- -# Controls the web server's network behavior, security, and logging. server: - # The IP address for the server to listen on. - # - "0.0.0.0": Makes the server accessible from other devices on your network. (Recommended for Docker) - # - "127.0.0.1" or "localhost": The server will only be accessible from your own computer. host: 0.0.0.0 - - # The network port the server will run on. - # If you get a "port already in use" error, change this to another number (e.g., 8006). port: 8005 - - # --- Logging Configuration --- - # Path to the server's log file, relative to the project root directory. + use_ngrok: false + use_auth: false + auth_username: user + auth_password: password log_file_path: logs\tts_server.log - - # The maximum size of a single log file in megabytes (MB) before it is rotated. - # This prevents log files from growing indefinitely. log_file_max_size_mb: 10 - - # The number of old log files to keep as backups. - # For example, if this is 5, you will have 'tts_server.log' and up to 5 older backups. log_file_backup_count: 5 - -# --- Model Settings --- -# Specifies the core AI model to be used by the TTS engine. model: - # The repository ID of the model on the Hugging Face Hub. - # You can change this to use a different compatible ONNX model in the future. - repo_id: KittenML/kitten-tts-nano-0.1 - -# --- TTS Engine Settings --- -# Configures the hardware and core settings for the speech synthesis engine. + repo_id: KittenML/kitten-tts-mini-0.8 tts_engine: - # Determines which hardware to use for inference. This is a critical performance setting. - # Valid options: "auto", "cuda", "gpu", "cpu" - # - "auto": (Recommended) Automatically uses an NVIDIA GPU if one is detected, otherwise falls back to the CPU. - # - "cuda" or "gpu": Explicitly forces the use of an NVIDIA GPU. The server will fail to start if one is not available. - # - "cpu": Forces the use of the CPU, even if a powerful GPU is present. - device: auto - -# --- File Path Settings --- -# Defines where the application should store various files. + device: cuda paths: - # The directory where downloaded model files will be cached. model_cache: model_cache - - # The default directory where generated audio files will be saved. output: outputs - -# --- Default Generation Parameters --- -# Default values for the speech generation process. These can be overridden in the UI or via API calls. generation_defaults: - # The default speed of the generated speech. - # 1.0 is normal speed. > 1.0 is faster, < 1.0 is slower. - speed: 1.1 - - # Default language for the phonemizer. The current model is trained for English. + speed: 1 language: en - - # This is a legacy/duplicate setting. Please use the 'speed' setting above. speed_factor: 1.1 - -# --- Audio Output Settings --- -# Controls the format and quality of the final audio file. audio_output: - # The default audio format for the output file. - # - "wav": Highest quality, uncompressed, large file size. Best for processing. - # - "mp3": Good quality, compressed, small file size. Best for sharing and listening. - # - "opus": Excellent quality, highly compressed, smallest file size. Best for streaming/web. - format: wav - - # The sample rate of the output audio in Hz. - # 24000 is the native sample rate of the KittenTTS model. - # Higher values (e.g., 48000) will resample the audio but won't add more detail. + format: ogg sample_rate: 24000 - -# --- UI State Persistence --- -# Saves the state of the web interface between sessions so you don't lose your work. ui_state: - # The last text that was entered into the main text box. - last_text: 'The solar system consists of the Sun and the astronomical objects gravitationally - bound in orbit around it. - - Mars, often called the Red Planet, is the fourth planet from the Sun. It is a - terrestrial planet with a thin atmosphere, having surface features reminiscent - both of the impact craters of the Moon and the volcanoes, valleys, deserts, and - polar ice caps of Earth. - - ' - # The ID of the last voice you selected from the dropdown. + last_text: YOu dont have any new unread emails. last_voice: expr-voice-2-m - # The last value you set for the "Chunk Size" slider, used for large text processing. last_chunk_size: 200 - - # Remembers whether the "Split text into chunks" checkbox was enabled. last_split_text_enabled: true - - # Set to 'true' to permanently hide the one-time warning about voice consistency when using chunking. hide_chunk_warning: false - - # Set to 'true' to permanently hide the one-time general notice about generation quality. - hide_generation_warning: false - - # The theme for the web interface. Options: "dark", "light". + hide_generation_warning: true theme: dark - -# --- General UI Settings --- -# Controls the appearance and static elements of the web interface. ui: - # The title that appears in the browser tab. title: Kitten TTS Server - - # Controls whether the language selection dropdown is visible in the UI. - # Set to 'false' to hide it if you only ever use one language. show_language_select: true - - # (For future use) If you had hundreds of voices, this would limit the number shown in the dropdown to prevent UI lag. max_predefined_voices_in_dropdown: 50 - -# --- Debugging Settings --- -# Tools for troubleshooting and development. debug: - # If 'true' and text chunking is enabled, the server will save each individual audio chunk - # as a separate file in the 'outputs' directory before they are stitched together. - # This is extremely useful for diagnosing issues with a specific part of a long text. save_intermediate_audio: false - - diff --git a/engine.py b/engine.py index 96f7d3b..e4ef6e2 100644 --- a/engine.py +++ b/engine.py @@ -26,8 +26,10 @@ phonemizer_backend: Optional[phonemizer.backend.EspeakBackend] = None text_cleaner: Optional["TextCleaner"] = None MODEL_LOADED: bool = False +voice_aliases: dict = {} +speed_priors: dict = {} -# KittenTTS available voices +# KittenTTS available voices (populated dynamically after model load) KITTEN_TTS_VOICES = [ "expr-voice-2-m", "expr-voice-2-f", @@ -81,7 +83,7 @@ def load_model() -> bool: Returns: bool: True if the model was loaded successfully, False otherwise. """ - global onnx_session, voices_data, phonemizer_backend, text_cleaner, MODEL_LOADED + global onnx_session, voices_data, phonemizer_backend, text_cleaner, MODEL_LOADED, voice_aliases, speed_priors, KITTEN_TTS_VOICES if MODEL_LOADED: logger.info("KittenTTS model is already loaded.") @@ -115,8 +117,10 @@ def load_model() -> bool: with open(config_path, "r") as f: model_config = json.load(f) - if model_config.get("type") != "ONNX1": - raise ValueError("Unsupported model type. Expected ONNX1.") + supported_types = {"ONNX1", "ONNX2"} + model_type = model_config.get("type") + if model_type not in supported_types: + raise ValueError(f"Unsupported model type: '{model_type}'. Expected one of: {supported_types}") # Download model and voices files model_path = hf_hub_download( @@ -135,6 +139,18 @@ def load_model() -> bool: voices_data = np.load(voices_path) logger.info(f"Loaded voices data with keys: {list(voices_data.keys())}") + # Parse ONNX2 config fields + voice_aliases = model_config.get("voice_aliases", {}) + speed_priors = model_config.get("speed_priors", {}) + if voice_aliases: + logger.info(f"Loaded voice aliases: {voice_aliases}") + if speed_priors: + logger.info(f"Loaded speed priors: {speed_priors}") + + # Build available voices list from loaded voice data + aliases + KITTEN_TTS_VOICES = list(voices_data.keys()) + list(voice_aliases.keys()) + logger.info(f"Available voices: {KITTEN_TTS_VOICES}") + # Determine device and providers and configure for optimal performance device_setting = config_manager.get_string("tts_engine.device", "auto").lower() available_providers = ort.get_available_providers() @@ -279,10 +295,24 @@ def load_model() -> bool: voices_data = None phonemizer_backend = None text_cleaner = None + voice_aliases = {} + speed_priors = {} MODEL_LOADED = False return False +def _get_voice_embedding(voice: str, text_length: int) -> np.ndarray: + """Get voice embedding, handling both ONNX1 (1D) and ONNX2 (2D) formats.""" + embedding = voices_data[voice] + if embedding.ndim == 1: + # ONNX1: single embedding, just add batch dim + return np.expand_dims(embedding, axis=0).astype(np.float32) + else: + # ONNX2: multiple reference embeddings, select one based on text length + ref_id = min(text_length, embedding.shape[0] - 1) + return embedding[ref_id:ref_id+1].astype(np.float32) + + def synthesize( text: str, voice: str, speed: float = 1.0 ) -> Tuple[Optional[np.ndarray], Optional[int]]: @@ -310,6 +340,19 @@ def synthesize( ) return None, None + # Resolve voice alias to internal voice ID + resolved_voice = voice_aliases.get(voice, voice) + if resolved_voice != voice: + logger.debug(f"Resolved voice alias '{voice}' -> '{resolved_voice}'") + + # Apply speed prior for this voice if available + prior = speed_priors.get(resolved_voice, 1.0) + if prior != 1.0: + speed = speed * prior + logger.debug(f"Applied speed prior {prior} for voice '{resolved_voice}', effective speed: {speed}") + + voice = resolved_voice + try: logger.debug(f"Synthesizing with voice='{voice}', speed={speed}") logger.debug(f"Input text (first 100 chars): '{text[:100]}...'") @@ -343,7 +386,7 @@ def synthesize( # --- I/O Binding Path for GPU using NumPy --- # Create standard NumPy arrays on the CPU first. input_ids_np = np.array([tokens], dtype=np.int64) - ref_s_np = voices_data[voice].astype(np.float32) # Ensure correct type + ref_s_np = _get_voice_embedding(voice, len(tokens)) speed_array_np = np.array([speed], dtype=np.float32) # Create OrtValues from the NumPy arrays. I/O binding will handle the copy to GPU. @@ -379,7 +422,7 @@ def synthesize( else: # --- Standard Path for CPU --- input_ids = np.array([tokens], dtype=np.int64) - ref_s = voices_data[voice] + ref_s = _get_voice_embedding(voice, len(tokens)) speed_array = np.array([speed], dtype=np.float32) onnx_inputs = { diff --git a/server.py b/server.py index bf6aa00..3c06b2f 100644 --- a/server.py +++ b/server.py @@ -292,6 +292,7 @@ async def get_ui_initial_data(): "config": full_config, "presets": loaded_presets, "initial_gen_result": initial_gen_result_placeholder, + "available_voices": engine.KITTEN_TTS_VOICES, } except Exception as e: logger.error(f"Error preparing initial UI data for API: {e}", exc_info=True) From f0d9e051ec39658eac4d51b8f6ec4d0e2020340a Mon Sep 17 00:00:00 2001 From: Benjamin Kobjolke Date: Wed, 18 Mar 2026 07:14:18 +0100 Subject: [PATCH 2/4] GIT: ignore --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b3582f6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.claude +venv +__pycache__ +model_cache +logs From c6bf07a97dff78560e0d221405537cb2ea0538ca Mon Sep 17 00:00:00 2001 From: Benjamin Kobjolke Date: Wed, 18 Mar 2026 07:56:29 +0100 Subject: [PATCH 3/4] GIT (config): add .gitattributes for line ending normalization - normalize all text files to LF in the repository - mark common binary file types --- .gitattributes | 13 + README.md | 984 ++++++++--------- config.py | 1578 +++++++++++++-------------- config.yaml | 76 +- docker-compose-cpu.yml | 56 +- docker-compose.yml | 96 +- engine.py | 900 +++++++-------- models.py | 144 +-- requirements-nvidia.txt | 61 +- requirements.txt | 60 +- server.py | 1358 +++++++++++------------ ui/index.html | 680 ++++++------ ui/presets.yaml | 86 +- ui/script.js | 1424 ++++++++++++------------ utils.py | 2288 +++++++++++++++++++-------------------- 15 files changed, 4908 insertions(+), 4896 deletions(-) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..fdc2541 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,13 @@ +# Normalize all text files to LF in the repository +* text=auto eol=lf + +# Explicitly mark binary files +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.ico binary +*.pdf binary +*.zip binary +*.gz binary +*.tar binary diff --git a/README.md b/README.md index 731eff3..23693e7 100644 --- a/README.md +++ b/README.md @@ -1,492 +1,492 @@ -# Kitten TTS Server: High-Performance, Lightweight TTS with API and GPU Acceleration - -**Self-host the ultra-lightweight [KittenTTS model](https://github.com/KittenML/KittenTTS) with this enhanced API server. Features an intuitive Web UI, a flexible API, large text processing for audiobooks, and uniquely, high-performance GPU acceleration.** - -This server provides a robust, user-friendly, and powerful interface for the kitten-tts engine, an open-source, realistic text-to-speech model with just 15 million parameters. This project significantly enhances the original model by adding a full-featured server, an easy-to-use UI, and an optimized inference pipeline for hardware ranging from NVIDIA GPUs to CPUs and even the Raspberry Pi 5 (RP5) and Raspberry Pi 4 (RP4). - -[![Project Link](https://img.shields.io/badge/GitHub-devnen/Kitten--TTS--Server-blue?style=for-the-badge&logo=github)](https://github.com/devnen/Kitten-TTS-Server) -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge)](LICENSE) -[![Python Version](https://img.shields.io/badge/Python-3.10+-blue.svg?style=for-the-badge)](https://www.python.org/downloads/) -[![Framework](https://img.shields.io/badge/Framework-FastAPI-green.svg?style=for-the-badge)](https://fastapi.tiangolo.com/) -[![Model Source](https://img.shields.io/badge/Model-KittenML/KittenTTS-orange.svg?style=for-the-badge)](https://github.com/KittenML/KittenTTS) -[![Docker](https://img.shields.io/badge/Docker-Supported-blue.svg?style=for-the-badge)](https://www.docker.com/) -[![Web UI](https://img.shields.io/badge/Web_UI-Included-4285F4?style=for-the-badge&logo=googlechrome&logoColor=white)](#) -[![CUDA Compatible](https://img.shields.io/badge/NVIDIA_CUDA-Compatible-76B900?style=for-the-badge&logo=nvidia&logoColor=white)](https://developer.nvidia.com/cuda-zone) -[![API](https://img.shields.io/badge/OpenAI_Compatible_API-Ready-000000?style=for-the-badge&logo=openai&logoColor=white)](https://platform.openai.com/docs/api-reference) - -
- Kitten TTS Server Web UI - Dark Mode - Kitten TTS Server Web UI - Light Mode -
- ---- - -## 🗣️ Overview: Enhanced KittenTTS Generation - -The [KittenTTS model by KittenML](https://github.com/KittenML/KittenTTS) provides a foundation for generating high-quality speech from a model smaller than 25MB. This project elevates that foundation into a production-ready service by providing a robust [FastAPI](https://fastapi.tiangolo.com/) server that makes KittenTTS significantly easier to use, more powerful, and drastically faster. - -We solve the complexity of setting up and running the model by offering: - -* A **modern Web UI** for easy experimentation, preset loading, and speed adjustment. -* **True GPU Acceleration** for NVIDIA GPUs, a feature not present in the original implementation. -* **Large Text Handling & Audiobook Generation:** Intelligently splits long texts into manageable chunks, processes them sequentially, and seamlessly concatenates the audio. Perfect for creating complete audiobooks. -* **A flexible, dual-API system** including a simple endpoint and an OpenAI-compatible endpoint for easy integration. -* **Built-in Voices:** A fixed list of 8 ready-to-use voices for consistent and reliable output. -* **Cross-platform support** for Windows and Linux, with clear setup instructions. -* **Docker support** for easy, reproducible containerized deployment. - -## 🍓 Raspberry Pi & Edge Device Support - -The ultra-lightweight nature of the KittenTTS model and the efficiency of this server make it a perfect candidate for running on single-board computers (SBCs) and other edge devices. - -* ✅ **Raspberry Pi 5 (RP5):** Confirmed to run with **excellent performance**. The server is fast and responsive, easily handling requests from other devices on the same local network (LAN). This makes it ideal for local network services, home automation, and other DIY projects. - -* ⏳ **Raspberry Pi 4 (RP4):** Testing is currently in progress. Not working on the 32-bit Raspberry Pi OS. - -To install, simply follow the standard **Linux installation guide** provided in this README. - -## 🔥 GPU Acceleration included - -A standout feature of this server is the implementation of **high-performance GPU acceleration**, a capability not available in the original KittenTTS project. While the base model is CPU-only, this server unlocks the full potential of your hardware. - -* **Optimized ONNX Runtime Pipeline:** We leverage `onnxruntime-gpu` to move the entire inference process to your NVIDIA graphics card. -* **Eliminated I/O Bottlenecks:** The server uses advanced **I/O Binding**. This technique pre-allocates memory directly on the GPU for both model inputs and outputs, drastically reducing the latency caused by copying data between system RAM and the GPU's VRAM. -* **True Performance Gains:** This isn't just running the model on the GPU; it's an optimized pipeline designed to minimize latency and maximize throughput, making real-time generation significantly faster than on CPU. - -This enhancement transforms KittenTTS from a lightweight-but-modest engine into a high-speed synthesis powerhouse. - -## 🔄 Alternative to Piper TTS - -The [KittenTTS model](https://github.com/KittenML/KittenTTS) serves as an excellent alternative to [Piper TTS](https://github.com/rhasspy/piper) for fast generation on limited compute and edge devices like Raspberry Pi 5. - -**KittenTTS Model Advantages:** -- **Extreme Efficiency**: Just 15 million parameters and under 25MB, significantly smaller than most Piper models -- **Universal Compatibility**: CPU-optimized to run without GPU on any device and "works literally everywhere" -- **Real-time Performance**: Optimized for real-time speech synthesis even on resource-constrained hardware - -**This Server Project's Enhancement:** -While KittenTTS provides the ultra-lightweight foundation, this server transforms it into a production-ready Piper replacement by adding GPU acceleration (unavailable in the base model), modern REST/OpenAI APIs, audiobook processing capabilities, and an intuitive web interface—all while maintaining the model's edge device compatibility. - -Perfect for users seeking Piper's offline capabilities with better performance on limited hardware and modern server infrastructure. - -## ✨ Key Features of This Server - -* **🚀 Ultra-Lightweight Model:** Powered by the `KittenTTS` ONNX model, which is under 25MB. -* ⚡ **True GPU Acceleration:** Full support for **NVIDIA (CUDA)** via an optimized `onnxruntime-gpu` pipeline with I/O Binding for maximum performance. -* **📚 Large Text & Audiobook Generation:** - * Automatically handles long texts by intelligently splitting them based on sentence boundaries. - * Processes each chunk individually and seamlessly concatenates the resulting audio. - * **Ideal for audiobooks** - paste entire books and get professional-quality audio. -* **🖥️ Modern Web Interface:** - * Intuitive UI for text input, voice selection, and parameter adjustment. - * Real-time waveform visualization of generated audio. -* **🎤 8 Built-in Voices:** - * Utilizes the 8 built-in voices from the KittenTTS model (4 male, 4 female). - * Easily selectable via a UI dropdown menu. -* **⚙️ Dual API Endpoints:** - * A primary `/tts` endpoint offering full control over all generation parameters. - * An OpenAI-compatible `/v1/audio/speech` endpoint for seamless integration into existing workflows. -* **🔧 Easy Configuration:** - * All settings are managed through a single `config.yaml` file. - * The server automatically creates a default config on the first run. -* **💾 UI State Persistence:** The web interface remembers your last-used text, voice, and settings to streamline your workflow. -* **🐳 Docker Support:** Easy, reproducible deployment for both CPU and GPU via Docker Compose. - ---- - -## 🔩 System Prerequisites - -* **Operating System:** Windows 10/11 (64-bit) or Linux (Debian/Ubuntu recommended). -* **Python:** Version 3.10 or later. -* **Git:** For cloning the repository. -* **eSpeak NG:** This is a **required** dependency for text phonemization. - * **Windows:** See installation guide below. - * **Linux:** `sudo apt install espeak-ng` -* **Raspberry Pi:** - * Raspberry Pi 5 - * Raspberry Pi 4 -* **(For GPU Acceleration):** - * An **NVIDIA GPU** with CUDA support. -* **(For Linux Only):** - * `libsndfile1`: Audio library needed by `soundfile`. Install via `sudo apt install libsndfile1`. - * `ffmpeg`: For robust audio operations. Install via `sudo apt install ffmpeg`. - -## 💻 Installation and Setup - -This project uses specific dependency files and a clear process to ensure a smooth, one-command installation for your hardware. - -**1. Clone the Repository** -```bash -git clone https://github.com/devnen/Kitten-TTS-Server.git -cd Kitten-TTS-Server -``` - -**2. Create and Activate a Python Virtual Environment** -This is crucial to avoid conflicts with other Python projects. - -* **Windows (PowerShell):** - ```powershell - python -m venv venv - .\venv\Scripts\activate - ``` - If you see an error about execution policies, run: - `Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser` and try activating again. - -* **Linux (Bash):** - ```bash - python3 -m venv venv - source venv/bin/activate - ``` - Your command prompt should now start with `(venv)`. - -**3. Install eSpeak NG (Required)** - -* **Windows:** - 1. Download the installer from the [eSpeak NG Releases page](https://github.com/espeak-ng/espeak-ng/releases/latest). Look for the file named `espeak-ng-X.XX-x64.msi`. - 2. Run the installer with default settings. - 3. **Important:** Restart your terminal (PowerShell/CMD) after installation for the changes to take effect. - -* **Linux (Ubuntu/Debian):** - ```bash - sudo apt update && sudo apt install -y espeak-ng - ``` - -**4. Install Python Dependencies** - -Choose one of the following paths based on your hardware. - ---- - -### **Option 1: CPU-Only Installation** -This is the simplest path and works on any machine. - -```bash -# Make sure your (venv) is active -pip install --upgrade pip -pip install -r requirements.txt -``` - ---- - -### **Option 2: NVIDIA GPU Installation (Recommended for Performance)** -This method ensures all necessary CUDA libraries are correctly installed within your virtual environment for a hassle-free setup. - -```bash -# Make sure your (venv) is active -pip install --upgrade pip - -# Step 1: Install the GPU-enabled ONNX Runtime -pip install onnxruntime-gpu - -# Step 2: Install PyTorch with CUDA support. This command also brings the -# necessary CUDA and cuDNN .dll files that onnxruntime-gpu needs. -pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 - -# Step 3: Install the remaining dependencies from the requirements file -pip install -r requirements-nvidia.txt -``` - -**After installation, verify that PyTorch can see your GPU:** -```bash -python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}'); print(f'Device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')" -``` -If `CUDA available:` shows `True`, your setup is correct! - ---- - -### **Option 3: Upgrading from CPU to GPU** - -If you initially installed the server for CPU-only usage and now want to enable GPU acceleration, follow these steps to upgrade your environment safely. - -```bash -# Make sure your (venv) is active -pip install --upgrade pip - -# Step 1: Uninstall the CPU-only versions of onnxruntime and torch. -# This is critical to prevent conflicts with the GPU packages. -pip uninstall onnxruntime torch torchaudio -y - -# Step 2: Install the GPU-enabled ONNX Runtime. -pip install onnxruntime-gpu - -# Step 3: Install PyTorch with CUDA support. This command also brings the -# necessary CUDA and cuDNN .dll files that onnxruntime-gpu needs. -pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 - -# Step 4: Re-install from the nvidia requirements file to ensure all other -# dependencies are correct and up to date. -pip install -r requirements-nvidia.txt -``` - -**After upgrading, do the following:** - -1. **Verify the installation** by running the same check from Option 2: - ```bash - python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" - ``` - The output must be `CUDA available: True`. - -2. **Update your configuration** by editing the `config.yaml` file: - ```yaml - tts_engine: - device: auto # Or "cuda", or "gpu" - ``` - -3. **Restart the server** for the changes to take effect. It will now use your NVIDIA GPU. - ---- - -## ▶️ Running the Server - -**Important: First-Run Model Download** -The first time you start the server, it will automatically download the KittenTTS model (~25MB) from Hugging Face. This is a one-time process. Subsequent launches will be instant. - -1. **Activate the virtual environment** (if not already active). - * Windows: `.\venv\Scripts\activate` - * Linux: `source venv/bin/activate` - -2. **Run the server:** - ```bash - python server.py - ``` - -3. The server will start and automatically open the Web UI in your default browser. - * **Web UI:** `http://localhost:8005` - * **API Docs:** `http://localhost:8005/docs` - -4. **To stop the server:** Press `CTRL+C` in the terminal. - -### **Raspberry Pi 4 & 5 Installation (CPU-Only)** - -KittenTTS runs excellently on Raspberry Pi devices, making it ideal for local network services and DIY projects. However, installation requirements vary significantly between Pi models due to CPU architecture differences. - -#### **Raspberry Pi 5 - Full Support ✅** - -**Raspberry Pi 5 works out-of-the-box** with the standard Linux installation guide above. No special steps required! - -**Tested Configuration:** -- **Hardware:** Raspberry Pi 5 Model B Rev 1.0 -- **OS:** Debian GNU/Linux 12 (bookworm) 64-bit -- **Architecture:** aarch64 (ARM64) -- **Python:** 3.11 -- **Memory:** 4GB RAM -- **Installation:** Follow the standard [Linux Installation](#linux-installation) guide exactly - -**Installation Steps:** -```bash -# Step 1: Install system dependencies -sudo apt update && sudo apt upgrade -y -sudo apt install -y espeak-ng libsndfile1 ffmpeg python3-pip python3-venv git - -# Step 2: Set up Python environment -python -m venv venv -source venv/bin/activate - -# Step 3: Install Python dependencies -pip install -r requirements.txt - -# Step 4: Start the server -python server.py -``` - -> **⏱️ Important:** During the `pip install -r requirements.txt` step, some Python packages (especially audio processing libraries like `librosa`, `praat-parselmouth`, and others) may need to be compiled from source on ARM architecture. This process can take **15-30 minutes** depending on your SD card speed and system load. This is normal - let it complete without interruption. - -#### **Raspberry Pi 4 - Limited Support ⚠️** - -**Raspberry Pi 4 support is currently in development** due to complex dependency compilation issues on 32-bit ARM architecture. - -**Known Technical Challenges:** -- **ONNX Runtime:** No official ARM wheels available on PyPI -- **PyTorch Ecosystem:** Limited pre-built wheel availability for armv7l -- **NLP Dependencies:** SpaCy and related libraries fail to compile due to architecture detection issues -- **Audio Processing:** Some native audio libraries require manual compilation - -**Current Status:** -- ✅ **64-bit Raspberry Pi OS:** May work with standard installation (limited testing) -- ⚠️ **32-bit Raspberry Pi OS:** Requires complex manual dependency resolution -- 🔧 **Alternative Solutions:** Being developed for core functionality - -**For Raspberry Pi 4 Users:** -We recommend upgrading to **64-bit Raspberry Pi OS** if possible, as this significantly improves compatibility with modern Python packages. For users requiring 32-bit support, please check our [GitHub Issues](link-to-issues) for the latest progress updates and community-contributed solutions. - -**Alternative Recommendation:** -For the best Raspberry Pi TTS experience, we strongly recommend using a **Raspberry Pi 5** with the standard 64-bit OS, which provides excellent performance and full compatibility. - -## 🐳 Docker Installation - -Run Kitten-TTS-Server easily using Docker. The recommended method uses Docker Compose, which is pre-configured for both CPU and NVIDIA GPU deployment. - -### Prerequisites - -* [Docker](https://docs.docker.com/get-docker/) installed. -* [Docker Compose](https://docs.docker.com/compose/install/) installed (usually included with Docker Desktop). -* **(For GPU acceleration)** - * An NVIDIA GPU. - * Up-to-date NVIDIA drivers for your host operating system. - * The [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed. - -### Using Docker Compose (Recommended) - -This method uses the provided `docker-compose.yml` files to automatically build the correct image and manage the container, volumes, and configuration. - -**1. Clone the Repository** -```bash -git clone https://github.com/devnen/Kitten-TTS-Server.git -cd Kitten-TTS-Server -``` - -**2. Start the Container Based on Your Hardware** - -Choose one of the following commands: - -#### **For NVIDIA GPU (Recommended for Performance):** -The default `docker-compose.yml` is configured for NVIDIA GPUs. It will build the image with full CUDA support. - -```bash -docker compose up -d --build -``` - -#### **For CPU-only:** -This uses a dedicated compose file that builds the image without GPU dependencies. - -```bash -docker compose -f docker-compose-cpu.yml up -d --build -``` - -⭐ **Note:** The first time you run this, Docker will build the image and the server will download the KittenTTS model, which can take a few minutes. Subsequent starts will be much faster. - -### 3. Access and Manage the Application - -* **Access the Web UI:** Open your browser to `http://localhost:8005` -* **Access the API Docs:** `http://localhost:8005/docs` - -* **View Logs:** - ```bash - # For GPU or CPU version - docker compose logs -f - ``` - -* **Stop the Container:** - ```bash - # This stops and removes the container but keeps your data volumes - docker compose down - ``` - -### How It Works - -* **Build-time Argument:** The `Dockerfile` uses a `RUNTIME` argument (`nvidia` or `cpu`) to conditionally install the correct Python packages, creating an optimized image for your hardware. -* **Persistent Data:** The `docker-compose` files use Docker volumes to persist your important data on your host machine, even if the container is removed: - * `./config.yaml`: Your main server configuration file. - * `./outputs`: All generated audio files are saved here. - * `./logs`: Server log files for troubleshooting. - * `hf_cache` (Named Volume): Persists the downloaded Hugging Face models, saving significant time on rebuilds. - -### Verify GPU Access (for NVIDIA users) - -After starting the GPU container, you can verify that Docker and the application can see your graphics card. - -```bash -# Check if the container can see the NVIDIA GPU -docker compose exec kitten-tts-server nvidia-smi - -# Check if PyTorch inside the container can access CUDA -docker compose exec kitten-tts-server python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" -``` -If `CUDA available:` prints `True`, your GPU setup is working correctly - -## 💡 Usage Guide - -### Generate Your First Audio - -1. Start the server and open the Web UI (`http://localhost:8005`). -2. Type or paste your text into the input box. -3. Select a voice from the dropdown menu. -4. Adjust the speech speed if desired. -5. Click **"Generate Speech"**. -6. The audio will play automatically and be available for download. - -### Generate an Audiobook - -1. Copy the entire plain text of your book or chapter. -2. Paste it into the text area. -3. Ensure **"Split text into chunks"** is enabled. -4. Set a **Chunk Size** between 300 and 500 characters for natural pauses. -5. Click **"Generate Speech"**. The server will process the entire text and stitch the audio together seamlessly. -6. Download your complete audiobook file. - -## 📖 API Documentation - -The server exposes two main endpoints for TTS. See `http://localhost:8005/docs` for an interactive playground. - -### Primary Endpoint: `/tts` - -This endpoint offers the most control. - -* **Method:** `POST` -* **Body:** - ```json - { - "text": "Hello from the KittenTTS API!", - "voice": "expr-voice-5-m", - "speed": 1.0, - "output_format": "mp3", - "split_text": true, - "chunk_size": 300 - } - ``` -* **Response:** Streaming audio file (`audio/wav`, `audio/mp3`, etc.). - -### OpenAI-Compatible Endpoint: `/v1/audio/speech` - -Use this for drop-in compatibility with scripts expecting OpenAI's TTS API structure. - -* **Method:** `POST` -* **Body:** - ```json - { - "model": "kitten-tts", - "input": "This is an OpenAI-compatible request.", - "voice": "expr-voice-4-f", - "response_format": "wav", - "speed": 0.9 - } - ``` - -## ⚙️ Configuration - -All server settings are managed in the `config.yaml` file. It's created automatically on first launch if it doesn't exist. - -**Key Settings:** -* `server.host`, `server.port`: Network settings. -* `tts_engine.device`: Set to `auto`, `cuda`, or `cpu`. The server will use your GPU if set to `auto` or `cuda` and a compatible environment is found. -* `generation_defaults.speed`: Default speech speed (1.0 is normal). -* `audio_output.format`: Default audio format (`wav`, `mp3`, `opus`). - -## 🛠️ Troubleshooting - -* **Phonemizer / eSpeak Errors:** - * This is the most common issue. Ensure you have installed **eSpeak NG** correctly for your OS and **restarted your terminal** afterward. The server includes auto-detection logic for common install paths. -* **GPU Not Used / Falls Back to CPU:** - * Follow the **NVIDIA GPU Installation** steps exactly. The most common cause is `torch` being installed without CUDA support. - * Run the verification command from the installation guide to confirm `torch.cuda.is_available()` is `True`. -* **"No module named 'soundfile'" or Audio Errors on Linux:** - * The underlying system library is likely missing. Run `sudo apt install libsndfile1`. -* **"Port already in use" Error:** - * Another application is using port 8005. Stop that application or change the port in `config.yaml` (e.g., `port: 8006`) and restart the server. - -## 🙏 Acknowledgements & Credits - -* **Core Model:** This project is powered by the **[KittenTTS model](https://github.com/KittenML/KittenTTS)** created by **[KittenML](https://github.com/KittenML)**. Our work adds a high-performance server and UI layer on top of their excellent lightweight model. -* **Core Libraries:** FastAPI, Uvicorn, ONNX Runtime, PyTorch, Hugging Face Hub, Phonemizer. -* **UI Inspiration:** The UI/server architecture is inspired by our previous work on the [Chatterbox-TTS-Server](https://github.com/devnen/Chatterbox-TTS-Server). - -## 📄 License - -This project is licensed under the **MIT License**. See the [LICENSE](LICENSE) file for details. - -## 🤝 Contributing - -Contributions, issues, and feature requests are welcome! Please feel free to open an issue or submit a pull request. - - - +# Kitten TTS Server: High-Performance, Lightweight TTS with API and GPU Acceleration + +**Self-host the ultra-lightweight [KittenTTS model](https://github.com/KittenML/KittenTTS) with this enhanced API server. Features an intuitive Web UI, a flexible API, large text processing for audiobooks, and uniquely, high-performance GPU acceleration.** + +This server provides a robust, user-friendly, and powerful interface for the kitten-tts engine, an open-source, realistic text-to-speech model with just 15 million parameters. This project significantly enhances the original model by adding a full-featured server, an easy-to-use UI, and an optimized inference pipeline for hardware ranging from NVIDIA GPUs to CPUs and even the Raspberry Pi 5 (RP5) and Raspberry Pi 4 (RP4). + +[![Project Link](https://img.shields.io/badge/GitHub-devnen/Kitten--TTS--Server-blue?style=for-the-badge&logo=github)](https://github.com/devnen/Kitten-TTS-Server) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge)](LICENSE) +[![Python Version](https://img.shields.io/badge/Python-3.10+-blue.svg?style=for-the-badge)](https://www.python.org/downloads/) +[![Framework](https://img.shields.io/badge/Framework-FastAPI-green.svg?style=for-the-badge)](https://fastapi.tiangolo.com/) +[![Model Source](https://img.shields.io/badge/Model-KittenML/KittenTTS-orange.svg?style=for-the-badge)](https://github.com/KittenML/KittenTTS) +[![Docker](https://img.shields.io/badge/Docker-Supported-blue.svg?style=for-the-badge)](https://www.docker.com/) +[![Web UI](https://img.shields.io/badge/Web_UI-Included-4285F4?style=for-the-badge&logo=googlechrome&logoColor=white)](#) +[![CUDA Compatible](https://img.shields.io/badge/NVIDIA_CUDA-Compatible-76B900?style=for-the-badge&logo=nvidia&logoColor=white)](https://developer.nvidia.com/cuda-zone) +[![API](https://img.shields.io/badge/OpenAI_Compatible_API-Ready-000000?style=for-the-badge&logo=openai&logoColor=white)](https://platform.openai.com/docs/api-reference) + +
+ Kitten TTS Server Web UI - Dark Mode + Kitten TTS Server Web UI - Light Mode +
+ +--- + +## 🗣️ Overview: Enhanced KittenTTS Generation + +The [KittenTTS model by KittenML](https://github.com/KittenML/KittenTTS) provides a foundation for generating high-quality speech from a model smaller than 25MB. This project elevates that foundation into a production-ready service by providing a robust [FastAPI](https://fastapi.tiangolo.com/) server that makes KittenTTS significantly easier to use, more powerful, and drastically faster. + +We solve the complexity of setting up and running the model by offering: + +* A **modern Web UI** for easy experimentation, preset loading, and speed adjustment. +* **True GPU Acceleration** for NVIDIA GPUs, a feature not present in the original implementation. +* **Large Text Handling & Audiobook Generation:** Intelligently splits long texts into manageable chunks, processes them sequentially, and seamlessly concatenates the audio. Perfect for creating complete audiobooks. +* **A flexible, dual-API system** including a simple endpoint and an OpenAI-compatible endpoint for easy integration. +* **Built-in Voices:** A fixed list of 8 ready-to-use voices for consistent and reliable output. +* **Cross-platform support** for Windows and Linux, with clear setup instructions. +* **Docker support** for easy, reproducible containerized deployment. + +## 🍓 Raspberry Pi & Edge Device Support + +The ultra-lightweight nature of the KittenTTS model and the efficiency of this server make it a perfect candidate for running on single-board computers (SBCs) and other edge devices. + +* ✅ **Raspberry Pi 5 (RP5):** Confirmed to run with **excellent performance**. The server is fast and responsive, easily handling requests from other devices on the same local network (LAN). This makes it ideal for local network services, home automation, and other DIY projects. + +* ⏳ **Raspberry Pi 4 (RP4):** Testing is currently in progress. Not working on the 32-bit Raspberry Pi OS. + +To install, simply follow the standard **Linux installation guide** provided in this README. + +## 🔥 GPU Acceleration included + +A standout feature of this server is the implementation of **high-performance GPU acceleration**, a capability not available in the original KittenTTS project. While the base model is CPU-only, this server unlocks the full potential of your hardware. + +* **Optimized ONNX Runtime Pipeline:** We leverage `onnxruntime-gpu` to move the entire inference process to your NVIDIA graphics card. +* **Eliminated I/O Bottlenecks:** The server uses advanced **I/O Binding**. This technique pre-allocates memory directly on the GPU for both model inputs and outputs, drastically reducing the latency caused by copying data between system RAM and the GPU's VRAM. +* **True Performance Gains:** This isn't just running the model on the GPU; it's an optimized pipeline designed to minimize latency and maximize throughput, making real-time generation significantly faster than on CPU. + +This enhancement transforms KittenTTS from a lightweight-but-modest engine into a high-speed synthesis powerhouse. + +## 🔄 Alternative to Piper TTS + +The [KittenTTS model](https://github.com/KittenML/KittenTTS) serves as an excellent alternative to [Piper TTS](https://github.com/rhasspy/piper) for fast generation on limited compute and edge devices like Raspberry Pi 5. + +**KittenTTS Model Advantages:** +- **Extreme Efficiency**: Just 15 million parameters and under 25MB, significantly smaller than most Piper models +- **Universal Compatibility**: CPU-optimized to run without GPU on any device and "works literally everywhere" +- **Real-time Performance**: Optimized for real-time speech synthesis even on resource-constrained hardware + +**This Server Project's Enhancement:** +While KittenTTS provides the ultra-lightweight foundation, this server transforms it into a production-ready Piper replacement by adding GPU acceleration (unavailable in the base model), modern REST/OpenAI APIs, audiobook processing capabilities, and an intuitive web interface—all while maintaining the model's edge device compatibility. + +Perfect for users seeking Piper's offline capabilities with better performance on limited hardware and modern server infrastructure. + +## ✨ Key Features of This Server + +* **🚀 Ultra-Lightweight Model:** Powered by the `KittenTTS` ONNX model, which is under 25MB. +* ⚡ **True GPU Acceleration:** Full support for **NVIDIA (CUDA)** via an optimized `onnxruntime-gpu` pipeline with I/O Binding for maximum performance. +* **📚 Large Text & Audiobook Generation:** + * Automatically handles long texts by intelligently splitting them based on sentence boundaries. + * Processes each chunk individually and seamlessly concatenates the resulting audio. + * **Ideal for audiobooks** - paste entire books and get professional-quality audio. +* **🖥️ Modern Web Interface:** + * Intuitive UI for text input, voice selection, and parameter adjustment. + * Real-time waveform visualization of generated audio. +* **🎤 8 Built-in Voices:** + * Utilizes the 8 built-in voices from the KittenTTS model (4 male, 4 female). + * Easily selectable via a UI dropdown menu. +* **⚙️ Dual API Endpoints:** + * A primary `/tts` endpoint offering full control over all generation parameters. + * An OpenAI-compatible `/v1/audio/speech` endpoint for seamless integration into existing workflows. +* **🔧 Easy Configuration:** + * All settings are managed through a single `config.yaml` file. + * The server automatically creates a default config on the first run. +* **💾 UI State Persistence:** The web interface remembers your last-used text, voice, and settings to streamline your workflow. +* **🐳 Docker Support:** Easy, reproducible deployment for both CPU and GPU via Docker Compose. + +--- + +## 🔩 System Prerequisites + +* **Operating System:** Windows 10/11 (64-bit) or Linux (Debian/Ubuntu recommended). +* **Python:** Version 3.10 or later. +* **Git:** For cloning the repository. +* **eSpeak NG:** This is a **required** dependency for text phonemization. + * **Windows:** See installation guide below. + * **Linux:** `sudo apt install espeak-ng` +* **Raspberry Pi:** + * Raspberry Pi 5 + * Raspberry Pi 4 +* **(For GPU Acceleration):** + * An **NVIDIA GPU** with CUDA support. +* **(For Linux Only):** + * `libsndfile1`: Audio library needed by `soundfile`. Install via `sudo apt install libsndfile1`. + * `ffmpeg`: For robust audio operations. Install via `sudo apt install ffmpeg`. + +## 💻 Installation and Setup + +This project uses specific dependency files and a clear process to ensure a smooth, one-command installation for your hardware. + +**1. Clone the Repository** +```bash +git clone https://github.com/devnen/Kitten-TTS-Server.git +cd Kitten-TTS-Server +``` + +**2. Create and Activate a Python Virtual Environment** +This is crucial to avoid conflicts with other Python projects. + +* **Windows (PowerShell):** + ```powershell + python -m venv venv + .\venv\Scripts\activate + ``` + If you see an error about execution policies, run: + `Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser` and try activating again. + +* **Linux (Bash):** + ```bash + python3 -m venv venv + source venv/bin/activate + ``` + Your command prompt should now start with `(venv)`. + +**3. Install eSpeak NG (Required)** + +* **Windows:** + 1. Download the installer from the [eSpeak NG Releases page](https://github.com/espeak-ng/espeak-ng/releases/latest). Look for the file named `espeak-ng-X.XX-x64.msi`. + 2. Run the installer with default settings. + 3. **Important:** Restart your terminal (PowerShell/CMD) after installation for the changes to take effect. + +* **Linux (Ubuntu/Debian):** + ```bash + sudo apt update && sudo apt install -y espeak-ng + ``` + +**4. Install Python Dependencies** + +Choose one of the following paths based on your hardware. + +--- + +### **Option 1: CPU-Only Installation** +This is the simplest path and works on any machine. + +```bash +# Make sure your (venv) is active +pip install --upgrade pip +pip install -r requirements.txt +``` + +--- + +### **Option 2: NVIDIA GPU Installation (Recommended for Performance)** +This method ensures all necessary CUDA libraries are correctly installed within your virtual environment for a hassle-free setup. + +```bash +# Make sure your (venv) is active +pip install --upgrade pip + +# Step 1: Install the GPU-enabled ONNX Runtime +pip install onnxruntime-gpu + +# Step 2: Install PyTorch with CUDA support. This command also brings the +# necessary CUDA and cuDNN .dll files that onnxruntime-gpu needs. +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 + +# Step 3: Install the remaining dependencies from the requirements file +pip install -r requirements-nvidia.txt +``` + +**After installation, verify that PyTorch can see your GPU:** +```bash +python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}'); print(f'Device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')" +``` +If `CUDA available:` shows `True`, your setup is correct! + +--- + +### **Option 3: Upgrading from CPU to GPU** + +If you initially installed the server for CPU-only usage and now want to enable GPU acceleration, follow these steps to upgrade your environment safely. + +```bash +# Make sure your (venv) is active +pip install --upgrade pip + +# Step 1: Uninstall the CPU-only versions of onnxruntime and torch. +# This is critical to prevent conflicts with the GPU packages. +pip uninstall onnxruntime torch torchaudio -y + +# Step 2: Install the GPU-enabled ONNX Runtime. +pip install onnxruntime-gpu + +# Step 3: Install PyTorch with CUDA support. This command also brings the +# necessary CUDA and cuDNN .dll files that onnxruntime-gpu needs. +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 + +# Step 4: Re-install from the nvidia requirements file to ensure all other +# dependencies are correct and up to date. +pip install -r requirements-nvidia.txt +``` + +**After upgrading, do the following:** + +1. **Verify the installation** by running the same check from Option 2: + ```bash + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + ``` + The output must be `CUDA available: True`. + +2. **Update your configuration** by editing the `config.yaml` file: + ```yaml + tts_engine: + device: auto # Or "cuda", or "gpu" + ``` + +3. **Restart the server** for the changes to take effect. It will now use your NVIDIA GPU. + +--- + +## ▶️ Running the Server + +**Important: First-Run Model Download** +The first time you start the server, it will automatically download the KittenTTS model (~25MB) from Hugging Face. This is a one-time process. Subsequent launches will be instant. + +1. **Activate the virtual environment** (if not already active). + * Windows: `.\venv\Scripts\activate` + * Linux: `source venv/bin/activate` + +2. **Run the server:** + ```bash + python server.py + ``` + +3. The server will start and automatically open the Web UI in your default browser. + * **Web UI:** `http://localhost:8005` + * **API Docs:** `http://localhost:8005/docs` + +4. **To stop the server:** Press `CTRL+C` in the terminal. + +### **Raspberry Pi 4 & 5 Installation (CPU-Only)** + +KittenTTS runs excellently on Raspberry Pi devices, making it ideal for local network services and DIY projects. However, installation requirements vary significantly between Pi models due to CPU architecture differences. + +#### **Raspberry Pi 5 - Full Support ✅** + +**Raspberry Pi 5 works out-of-the-box** with the standard Linux installation guide above. No special steps required! + +**Tested Configuration:** +- **Hardware:** Raspberry Pi 5 Model B Rev 1.0 +- **OS:** Debian GNU/Linux 12 (bookworm) 64-bit +- **Architecture:** aarch64 (ARM64) +- **Python:** 3.11 +- **Memory:** 4GB RAM +- **Installation:** Follow the standard [Linux Installation](#linux-installation) guide exactly + +**Installation Steps:** +```bash +# Step 1: Install system dependencies +sudo apt update && sudo apt upgrade -y +sudo apt install -y espeak-ng libsndfile1 ffmpeg python3-pip python3-venv git + +# Step 2: Set up Python environment +python -m venv venv +source venv/bin/activate + +# Step 3: Install Python dependencies +pip install -r requirements.txt + +# Step 4: Start the server +python server.py +``` + +> **⏱️ Important:** During the `pip install -r requirements.txt` step, some Python packages (especially audio processing libraries like `librosa`, `praat-parselmouth`, and others) may need to be compiled from source on ARM architecture. This process can take **15-30 minutes** depending on your SD card speed and system load. This is normal - let it complete without interruption. + +#### **Raspberry Pi 4 - Limited Support ⚠️** + +**Raspberry Pi 4 support is currently in development** due to complex dependency compilation issues on 32-bit ARM architecture. + +**Known Technical Challenges:** +- **ONNX Runtime:** No official ARM wheels available on PyPI +- **PyTorch Ecosystem:** Limited pre-built wheel availability for armv7l +- **NLP Dependencies:** SpaCy and related libraries fail to compile due to architecture detection issues +- **Audio Processing:** Some native audio libraries require manual compilation + +**Current Status:** +- ✅ **64-bit Raspberry Pi OS:** May work with standard installation (limited testing) +- ⚠️ **32-bit Raspberry Pi OS:** Requires complex manual dependency resolution +- 🔧 **Alternative Solutions:** Being developed for core functionality + +**For Raspberry Pi 4 Users:** +We recommend upgrading to **64-bit Raspberry Pi OS** if possible, as this significantly improves compatibility with modern Python packages. For users requiring 32-bit support, please check our [GitHub Issues](link-to-issues) for the latest progress updates and community-contributed solutions. + +**Alternative Recommendation:** +For the best Raspberry Pi TTS experience, we strongly recommend using a **Raspberry Pi 5** with the standard 64-bit OS, which provides excellent performance and full compatibility. + +## 🐳 Docker Installation + +Run Kitten-TTS-Server easily using Docker. The recommended method uses Docker Compose, which is pre-configured for both CPU and NVIDIA GPU deployment. + +### Prerequisites + +* [Docker](https://docs.docker.com/get-docker/) installed. +* [Docker Compose](https://docs.docker.com/compose/install/) installed (usually included with Docker Desktop). +* **(For GPU acceleration)** + * An NVIDIA GPU. + * Up-to-date NVIDIA drivers for your host operating system. + * The [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed. + +### Using Docker Compose (Recommended) + +This method uses the provided `docker-compose.yml` files to automatically build the correct image and manage the container, volumes, and configuration. + +**1. Clone the Repository** +```bash +git clone https://github.com/devnen/Kitten-TTS-Server.git +cd Kitten-TTS-Server +``` + +**2. Start the Container Based on Your Hardware** + +Choose one of the following commands: + +#### **For NVIDIA GPU (Recommended for Performance):** +The default `docker-compose.yml` is configured for NVIDIA GPUs. It will build the image with full CUDA support. + +```bash +docker compose up -d --build +``` + +#### **For CPU-only:** +This uses a dedicated compose file that builds the image without GPU dependencies. + +```bash +docker compose -f docker-compose-cpu.yml up -d --build +``` + +⭐ **Note:** The first time you run this, Docker will build the image and the server will download the KittenTTS model, which can take a few minutes. Subsequent starts will be much faster. + +### 3. Access and Manage the Application + +* **Access the Web UI:** Open your browser to `http://localhost:8005` +* **Access the API Docs:** `http://localhost:8005/docs` + +* **View Logs:** + ```bash + # For GPU or CPU version + docker compose logs -f + ``` + +* **Stop the Container:** + ```bash + # This stops and removes the container but keeps your data volumes + docker compose down + ``` + +### How It Works + +* **Build-time Argument:** The `Dockerfile` uses a `RUNTIME` argument (`nvidia` or `cpu`) to conditionally install the correct Python packages, creating an optimized image for your hardware. +* **Persistent Data:** The `docker-compose` files use Docker volumes to persist your important data on your host machine, even if the container is removed: + * `./config.yaml`: Your main server configuration file. + * `./outputs`: All generated audio files are saved here. + * `./logs`: Server log files for troubleshooting. + * `hf_cache` (Named Volume): Persists the downloaded Hugging Face models, saving significant time on rebuilds. + +### Verify GPU Access (for NVIDIA users) + +After starting the GPU container, you can verify that Docker and the application can see your graphics card. + +```bash +# Check if the container can see the NVIDIA GPU +docker compose exec kitten-tts-server nvidia-smi + +# Check if PyTorch inside the container can access CUDA +docker compose exec kitten-tts-server python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" +``` +If `CUDA available:` prints `True`, your GPU setup is working correctly + +## 💡 Usage Guide + +### Generate Your First Audio + +1. Start the server and open the Web UI (`http://localhost:8005`). +2. Type or paste your text into the input box. +3. Select a voice from the dropdown menu. +4. Adjust the speech speed if desired. +5. Click **"Generate Speech"**. +6. The audio will play automatically and be available for download. + +### Generate an Audiobook + +1. Copy the entire plain text of your book or chapter. +2. Paste it into the text area. +3. Ensure **"Split text into chunks"** is enabled. +4. Set a **Chunk Size** between 300 and 500 characters for natural pauses. +5. Click **"Generate Speech"**. The server will process the entire text and stitch the audio together seamlessly. +6. Download your complete audiobook file. + +## 📖 API Documentation + +The server exposes two main endpoints for TTS. See `http://localhost:8005/docs` for an interactive playground. + +### Primary Endpoint: `/tts` + +This endpoint offers the most control. + +* **Method:** `POST` +* **Body:** + ```json + { + "text": "Hello from the KittenTTS API!", + "voice": "expr-voice-5-m", + "speed": 1.0, + "output_format": "mp3", + "split_text": true, + "chunk_size": 300 + } + ``` +* **Response:** Streaming audio file (`audio/wav`, `audio/mp3`, etc.). + +### OpenAI-Compatible Endpoint: `/v1/audio/speech` + +Use this for drop-in compatibility with scripts expecting OpenAI's TTS API structure. + +* **Method:** `POST` +* **Body:** + ```json + { + "model": "kitten-tts", + "input": "This is an OpenAI-compatible request.", + "voice": "expr-voice-4-f", + "response_format": "wav", + "speed": 0.9 + } + ``` + +## ⚙️ Configuration + +All server settings are managed in the `config.yaml` file. It's created automatically on first launch if it doesn't exist. + +**Key Settings:** +* `server.host`, `server.port`: Network settings. +* `tts_engine.device`: Set to `auto`, `cuda`, or `cpu`. The server will use your GPU if set to `auto` or `cuda` and a compatible environment is found. +* `generation_defaults.speed`: Default speech speed (1.0 is normal). +* `audio_output.format`: Default audio format (`wav`, `mp3`, `opus`). + +## 🛠️ Troubleshooting + +* **Phonemizer / eSpeak Errors:** + * This is the most common issue. Ensure you have installed **eSpeak NG** correctly for your OS and **restarted your terminal** afterward. The server includes auto-detection logic for common install paths. +* **GPU Not Used / Falls Back to CPU:** + * Follow the **NVIDIA GPU Installation** steps exactly. The most common cause is `torch` being installed without CUDA support. + * Run the verification command from the installation guide to confirm `torch.cuda.is_available()` is `True`. +* **"No module named 'soundfile'" or Audio Errors on Linux:** + * The underlying system library is likely missing. Run `sudo apt install libsndfile1`. +* **"Port already in use" Error:** + * Another application is using port 8005. Stop that application or change the port in `config.yaml` (e.g., `port: 8006`) and restart the server. + +## 🙏 Acknowledgements & Credits + +* **Core Model:** This project is powered by the **[KittenTTS model](https://github.com/KittenML/KittenTTS)** created by **[KittenML](https://github.com/KittenML)**. Our work adds a high-performance server and UI layer on top of their excellent lightweight model. +* **Core Libraries:** FastAPI, Uvicorn, ONNX Runtime, PyTorch, Hugging Face Hub, Phonemizer. +* **UI Inspiration:** The UI/server architecture is inspired by our previous work on the [Chatterbox-TTS-Server](https://github.com/devnen/Chatterbox-TTS-Server). + +## 📄 License + +This project is licensed under the **MIT License**. See the [LICENSE](LICENSE) file for details. + +## 🤝 Contributing + +Contributions, issues, and feature requests are welcome! Please feel free to open an issue or submit a pull request. + + + diff --git a/config.py b/config.py index 7039e8e..91a1132 100644 --- a/config.py +++ b/config.py @@ -1,789 +1,789 @@ -# File: config.py -# Manages application configuration using a YAML file (config.yaml). -# Handles loading, saving, and providing access to configuration settings. - -import os -import logging -import yaml -import shutil -from copy import deepcopy -from threading import Lock -from typing import Dict, Any, Optional, List, Tuple -import torch # For automatic CUDA/CPU device detection -from pathlib import Path - -# Standard logger setup -logger = logging.getLogger(__name__) - -# --- File Path Constants --- -# Defines the primary configuration file name. -CONFIG_FILE_PATH = Path("config.yaml") - -# --- Default Directory Paths --- -# These paths are used if not specified in config.yaml and are created if they don't exist -# when a default configuration file is generated. -DEFAULT_LOGS_PATH = Path("logs") -DEFAULT_MODEL_FILES_PATH = Path("./model_cache") # For downloaded model files -DEFAULT_OUTPUT_PATH = Path("./outputs") # For server-saved audio outputs (if any) - -# --- Default Configuration Structure --- -# This dictionary defines the complete expected structure of 'config.yaml', -# including default values for all settings. It serves as the template for -# creating a new config.yaml if one does not exist. -DEFAULT_CONFIG: Dict[str, Any] = { - "server": { - "host": "0.0.0.0", # Host address for the server to listen on. - "port": 8005, # Port number for the server. - "use_ngrok": False, # Placeholder for ngrok integration (if used). - "use_auth": False, # Placeholder for basic authentication (if used). - "auth_username": "user", # Default username if authentication is enabled. - "auth_password": "password", # Default password if authentication is enabled. - "log_file_path": str( - DEFAULT_LOGS_PATH / "tts_server.log" - ), # Path to the server log file. - "log_file_max_size_mb": 10, # Maximum size of a single log file before rotation. - "log_file_backup_count": 5, # Number of backup log files to keep. - }, - "model": { # Updated section for model source configuration - "repo_id": "KittenML/kitten-tts-nano-0.1", # KittenTTS Hugging Face repository ID - }, - "tts_engine": { - "device": "auto", # TTS processing device: 'auto', 'cuda', or 'cpu'. - # 'auto' will attempt to use 'cuda' if available, otherwise 'cpu'. - }, - "paths": { # General configurable paths for the application. - "model_cache": str( - DEFAULT_MODEL_FILES_PATH - ), # Directory for caching or storing downloaded models. - "output": str( - DEFAULT_OUTPUT_PATH - ), # Default directory for any output files generated by the server. - }, - "generation_defaults": { # Default parameters for TTS audio generation. - "speed": 1.0, # Controls the speed of the generated speech. - "language": "en", # Default language for TTS. - }, - "audio_output": { # Settings related to the format of generated audio. - "format": "wav", # Output audio format (e.g., 'wav', 'mp3'). - "sample_rate": 24000, # Sample rate of the output audio in Hz. - }, - "ui_state": { # Stores user interface preferences and last-used values. - "last_text": "", # Last text entered by the user. - "last_voice": "expr-voice-5-m", # Last selected voice. - "last_chunk_size": 120, # Last used chunk size for text splitting in UI. - "last_split_text_enabled": True, # Whether text splitting was last enabled in UI. - "hide_chunk_warning": False, # Flag to hide the chunking warning modal. - "hide_generation_warning": False, # Flag to hide the general generation quality notice modal. - "theme": "dark", # Default UI theme ('dark' or 'light'). - }, - "ui": { # General UI display settings. - "title": "Kitten TTS Server", # Updated title - "show_language_select": True, # Whether to show language selection in the UI. - }, - "debug": { # Settings for debugging purposes - "save_intermediate_audio": False # If true, save intermediate audio files for debugging - }, -} - - -def _ensure_default_paths_exist(): - """ - Creates the default directories specified in DEFAULT_CONFIG if they do not already exist. - This is typically called when generating a default config.yaml file. - """ - paths_to_check = [ - Path( - DEFAULT_CONFIG["server"]["log_file_path"] - ).parent, # Log file's parent directory - Path(DEFAULT_CONFIG["paths"]["model_cache"]), - Path(DEFAULT_CONFIG["paths"]["output"]), - ] - for path in paths_to_check: - try: - path.mkdir(parents=True, exist_ok=True) - except Exception as e: - logger.error(f"Error creating default directory {path}: {e}", exc_info=True) - - -def _deep_merge_dicts(source: Dict, destination: Dict) -> Dict: - """ - Recursively merges the 'source' dictionary into the 'destination' dictionary. - Keys from 'source' will overwrite existing keys in 'destination'. - If a key in 'source' corresponds to a dictionary, a recursive merge is performed. - The 'destination' dictionary is modified in place. - """ - for key, value in source.items(): - if isinstance(value, dict): - node = destination.setdefault(key, {}) - if isinstance( - node, dict - ): # Ensure the destination node is a dict for merging - _deep_merge_dicts(value, node) - else: # If destination's node is not a dict, overwrite it entirely - destination[key] = deepcopy(value) - else: - destination[key] = value - return destination - - -def _set_nested_value(d: Dict, keys: List[str], value: Any): - """Helper function to set a value in a nested dictionary using a list of keys.""" - for key in keys[:-1]: - d = d.setdefault(key, {}) - d[keys[-1]] = value - - -def _get_nested_value(d: Dict, keys: List[str], default: Any = None) -> Any: - """Helper function to get a value from a nested dictionary using a list of keys.""" - for key in keys: - if isinstance(d, dict) and key in d: - d = d[key] - else: - return default - return d - - -class YamlConfigManager: - """ - Manages application configuration stored in a YAML file. - This class handles loading, saving, updating, and resetting the configuration. - Operations are thread-safe for file writing. - """ - - def __init__(self): - """Initializes the configuration manager by loading the configuration from YAML.""" - self.config: Dict[str, Any] = {} - self._lock = Lock() # Ensures thread-safety for file write operations. - self.load_config() - - def _load_defaults(self) -> Dict[str, Any]: - """ - Returns a deep copy of the hardcoded default configuration structure. - Also ensures that default directory paths defined in the structure exist. - """ - _ensure_default_paths_exist() # Create necessary default directories. - return deepcopy(DEFAULT_CONFIG) - - def _resolve_paths_and_device(self, config_data: Dict[str, Any]) -> Dict[str, Any]: - """ - Resolves device settings (e.g., 'auto' to 'cuda' or 'cpu') and converts - string paths in the configuration data to Path objects for internal use. - The 'config_data' dictionary is modified in place. - """ - # Resolve TTS device setting with robust CUDA detection. - current_device_setting = _get_nested_value( - config_data, ["tts_engine", "device"], "auto" - ) - if current_device_setting == "auto": - resolved_device = self._detect_best_device() - _set_nested_value(config_data, ["tts_engine", "device"], resolved_device) - elif current_device_setting not in ["cuda", "cpu", "gpu"]: - logger.warning( - f"Invalid TTS device '{current_device_setting}' in configuration. " - f"Recognized values are 'auto', 'cpu', 'cuda', 'gpu'. Defaulting to auto-detection." - ) - resolved_device = self._detect_best_device() - _set_nested_value(config_data, ["tts_engine", "device"], resolved_device) - - final_device = _get_nested_value(config_data, ["tts_engine", "device"]) - logger.info(f"TTS processing device resolved to: {final_device}") - - # Convert relevant string paths to Path objects. - path_key_map_for_conversion = { - "server": ["log_file_path"], - "paths": ["model_cache", "output"], - } - for section, keys_list in path_key_map_for_conversion.items(): - if section in config_data: - for key in keys_list: - current_path_val = _get_nested_value(config_data, [section, key]) - if isinstance(current_path_val, str): - _set_nested_value( - config_data, [section, key], Path(current_path_val) - ) - return config_data - - def _detect_best_device(self) -> str: - """ - Robustly detects the best available device for TTS processing. - Tests actual CUDA functionality rather than just checking availability. - - Returns: - str: 'cuda' if CUDA is truly functional, 'cpu' otherwise. - """ - # Test CUDA first as it's generally preferred for ML workloads - if torch.cuda.is_available(): - try: - # Actually test CUDA functionality by creating a tensor and moving it to CUDA - test_tensor = torch.tensor([1.0]) - test_tensor = test_tensor.cuda() - test_tensor = test_tensor.cpu() # Clean up - logger.info("CUDA test successful. Using CUDA device.") - return "cuda" - except Exception as e: - logger.warning( - f"CUDA is reported as available but failed functionality test: {e}. " - f"This usually means PyTorch was not compiled with CUDA support." - ) - - logger.info("CUDA not available or functional. Using CPU.") - return "cpu" - - def _prepare_config_for_saving(self, config_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Prepares a configuration dictionary for YAML serialization by converting - internal Path objects back to their string representations. - Returns a deep copy of the configuration dictionary with paths as strings. - """ - config_copy_for_saving = deepcopy(config_dict) - path_key_map_for_conversion = { - "server": ["log_file_path"], - "paths": ["model_cache", "output"], - } - for section, keys_list in path_key_map_for_conversion.items(): - if section in config_copy_for_saving: - for key in keys_list: - current_path_val = _get_nested_value( - config_copy_for_saving, [section, key] - ) - if isinstance(current_path_val, Path): - _set_nested_value( - config_copy_for_saving, - [section, key], - str(current_path_val), - ) - return config_copy_for_saving - - def load_config(self): - """ - Loads the application configuration from 'config.yaml'. - If the file doesn't exist, it's created using defaults. - The loaded configuration is merged with defaults to ensure all expected keys are present. - Device settings and path types are resolved after loading. - """ - with self._lock: # Ensure thread-safe loading. - base_defaults = self._load_defaults() # Ensures default paths exist. - - if CONFIG_FILE_PATH.exists(): - logger.info(f"Loading configuration from: {CONFIG_FILE_PATH}") - try: - with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f: - yaml_data = yaml.safe_load(f) - if isinstance(yaml_data, dict): - # Merge loaded YAML data into a copy of defaults. - # YAML data takes precedence, defaults fill in missing parts. - effective_config = deepcopy(base_defaults) - _deep_merge_dicts(yaml_data, effective_config) - self.config = effective_config - logger.info( - f"Successfully loaded and merged configuration from {CONFIG_FILE_PATH}." - ) - else: - logger.error( - f"Invalid format in {CONFIG_FILE_PATH}. Expected a dictionary. " - f"Using defaults and attempting to overwrite the invalid file." - ) - self.config = base_defaults # Fallback to defaults. - if not self._save_config_yaml_internal(self.config): - logger.error( - f"Failed to overwrite invalid {CONFIG_FILE_PATH} with defaults." - ) - except yaml.YAMLError as e: - logger.error( - f"Error parsing YAML from {CONFIG_FILE_PATH}: {e}. " - f"Using defaults and attempting to overwrite the corrupted file." - ) - self.config = base_defaults - if not self._save_config_yaml_internal(self.config): - logger.error( - f"Failed to overwrite corrupted {CONFIG_FILE_PATH} with defaults." - ) - except Exception as e: - logger.error( - f"Unexpected error loading {CONFIG_FILE_PATH}: {e}. Using in-memory defaults.", - exc_info=True, - ) - self.config = base_defaults # Use defaults, avoid saving on unexpected errors. - else: - logger.info( - f"{CONFIG_FILE_PATH} not found. Creating initial configuration using defaults..." - ) - # Start with defaults. - self.config = base_defaults - if self._save_config_yaml_internal(self.config): - logger.info( - f"Successfully created and saved initial default configuration to {CONFIG_FILE_PATH}." - ) - else: - logger.error( - f"Failed to save initial configuration to {CONFIG_FILE_PATH}. " - f"Using in-memory defaults." - ) - - # Resolve device and convert path strings to Path objects for the loaded/created config. - self.config = self._resolve_paths_and_device(self.config) - logger.debug(f"Current configuration loaded and resolved: {self.config}") - return self.config - - def _save_config_yaml_internal(self, config_dict_to_save: Dict[str, Any]) -> bool: - """ - Internal method to save the provided configuration dictionary to 'config.yaml'. - It includes a backup and restore mechanism for safety during writes. - Assumes the caller holds the necessary lock. - Converts Path objects to strings before YAML serialization. - """ - # Prepare the configuration for saving (e.g., convert Path objects to strings). - prepared_config_for_yaml = self._prepare_config_for_saving(config_dict_to_save) - - temp_file = CONFIG_FILE_PATH.with_suffix(CONFIG_FILE_PATH.suffix + ".tmp") - backup_file = CONFIG_FILE_PATH.with_suffix(CONFIG_FILE_PATH.suffix + ".bak") - - try: - # Atomically write to a temporary file first. - with open(temp_file, "w", encoding="utf-8") as f: - yaml.dump( - prepared_config_for_yaml, - f, - default_flow_style=False, - sort_keys=False, - indent=2, - ) - - # If an existing config file exists, back it up. - if CONFIG_FILE_PATH.exists(): - try: - shutil.move(str(CONFIG_FILE_PATH), str(backup_file)) - logger.debug(f"Backed up existing configuration to {backup_file}") - except Exception as backup_error: - logger.warning( - f"Could not create backup of {CONFIG_FILE_PATH}: {backup_error}" - ) - # Proceed with saving, but warn about missing backup. - - # Rename the temporary file to the actual configuration file. - shutil.move(str(temp_file), str(CONFIG_FILE_PATH)) - logger.info(f"Configuration successfully saved to {CONFIG_FILE_PATH}") - return True - - except yaml.YAMLError as e_yaml: - logger.error( - f"Error formatting data for {CONFIG_FILE_PATH} (YAML error): {e_yaml}", - exc_info=True, - ) - return False - except Exception as e_general: - logger.error( - f"Failed to save configuration to {CONFIG_FILE_PATH}: {e_general}", - exc_info=True, - ) - # Attempt to restore from backup if the save operation failed. - if backup_file.exists() and not CONFIG_FILE_PATH.exists(): - try: - shutil.move(str(backup_file), str(CONFIG_FILE_PATH)) - logger.info( - f"Restored configuration from backup {backup_file} due to save failure." - ) - except Exception as restore_error: - logger.error( - f"Failed to restore configuration from backup: {restore_error}" - ) - # Clean up the temporary file if it still exists after a failure. - if temp_file.exists(): - try: - os.remove(str(temp_file)) - except Exception as remove_error: - logger.warning( - f"Could not remove temporary config file {temp_file}: {remove_error}" - ) - return False - finally: - # Clean up the backup file if the main configuration file exists and the save was successful. - if CONFIG_FILE_PATH.exists() and backup_file.exists(): - try: - if ( - CONFIG_FILE_PATH.stat().st_size > 0 - ): # Basic check that the new file is not empty. - os.remove(str(backup_file)) - logger.debug( - f"Removed backup file {backup_file} after successful save." - ) - except Exception as remove_bak_error: - logger.warning( - f"Could not remove backup config file {backup_file}: {remove_bak_error}" - ) - - def save_config_yaml(self) -> bool: - """ - Public method to save the current in-memory configuration to 'config.yaml'. - Ensures thread-safety using a lock. - """ - with self._lock: - return self._save_config_yaml_internal(self.config) - - def get(self, key_path: str, default: Any = None) -> Any: - """ - Retrieves a configuration value using a dot-separated key path (e.g., 'server.port'). - If the key path is not found, 'default' is returned. - For mutable types (dicts, lists), a deep copy is returned to prevent - unintended modification of the in-memory configuration. - """ - keys = key_path.split(".") - with self._lock: # Ensure thread-safe access to self.config. - value = _get_nested_value(self.config, keys, default) - return deepcopy(value) if isinstance(value, (dict, list)) else value - - def get_string(self, key_path: str, default: Optional[str] = None) -> str: - """Retrieves a configuration value, ensuring it's a string.""" - # Added this method for explicit string retrieval, common for paths/IDs. - raw_value = self.get(key_path) - if raw_value is None: - if default is not None: - logger.debug( - f"Config string '{key_path}' is None, using provided method default: '{default}'" - ) - return default - logger.error( - f"Mandatory string config '{key_path}' is None, and no method default. Returning empty string." - ) - return "" - if isinstance( - raw_value, (Path, str) - ): # Handle Path objects by converting to string - return str(raw_value) - try: # Attempt conversion for other types if necessary - return str(raw_value) - except Exception: - logger.warning( - f"Could not convert value '{raw_value}' for '{key_path}' to string. Using method default or empty string." - ) - if default is not None: - return default - return "" - - def get_all(self) -> Dict[str, Any]: - """ - Returns a deep copy of the entire current configuration. - Ensures thread-safety during the copy operation. - """ - with self._lock: - return deepcopy(self.config) - - def update_and_save(self, partial_update_dict: Dict[str, Any]) -> bool: - """ - Deeply merges a 'partial_update_dict' into the current configuration - and saves the entire updated configuration back to the YAML file. - This allows updating specific nested values without overwriting entire sections. - """ - if not isinstance(partial_update_dict, dict): - logger.error("Invalid partial update data: input must be a dictionary.") - return False - - with self._lock: - try: - # Work on a deep copy of the current config to avoid altering it before a successful save. - config_copy_for_update = deepcopy(self.config) - # Merge the partial update into this copy. - _deep_merge_dicts(partial_update_dict, config_copy_for_update) - - # Before saving, the merged config might need path/device re-resolution - # if those specific keys were part of partial_update_dict. - # For robustness, always re-resolve. - resolved_updated_config = self._resolve_paths_and_device( - config_copy_for_update - ) - - if self._save_config_yaml_internal(resolved_updated_config): - # If save was successful, update the active in-memory config. - self.config = resolved_updated_config - logger.info( - "Configuration updated, saved, and re-resolved successfully." - ) - return True - else: - logger.error("Failed to save updated configuration after merging.") - return False - except Exception as e: - logger.error( - f"Error during configuration update and save process: {e}", - exc_info=True, - ) - return False - - def reset_and_save(self) -> bool: - """ - Resets the application configuration to its hardcoded defaults. - The reset configuration (after resolving paths/device) is then saved to 'config.yaml'. - """ - with self._lock: - logger.warning("Initiating configuration reset to hardcoded defaults...") - # Start with hardcoded defaults (this also ensures default directories are created). - reset_config_base = self._load_defaults() - # Resolve device settings and ensure paths are Path objects for the new in-memory config. - final_reset_config = self._resolve_paths_and_device(reset_config_base) - - if self._save_config_yaml_internal( - final_reset_config - ): # Save the fully resolved reset config. - self.config = final_reset_config # Update the active in-memory config. - logger.info( - "Configuration successfully reset to defaults, saved, and resolved." - ) - return True - else: - logger.error( - "Failed to save the reset configuration. Current configuration remains unchanged." - ) - # If save failed, the old self.config is retained. - return False - - # --- Type-specific Getters --- - # These provide convenient, type-checked access to configuration values. - def get_int(self, key_path: str, default: Optional[int] = None) -> int: - """Retrieves a configuration value, converting it to an integer.""" - raw_value = self.get(key_path) - if raw_value is None: - if default is not None: - logger.debug( - f"Config '{key_path}' is None, using provided method default: {default}" - ) - return default - logger.error( - f"Mandatory integer config '{key_path}' is None, and no method default. Returning 0." - ) - return 0 - try: - return int(raw_value) - except (ValueError, TypeError): - logger.warning( - f"Invalid integer value '{raw_value}' for '{key_path}'. Using method default or 0." - ) - if isinstance(default, int): - return default - logger.error( - f"Cannot parse '{raw_value}' as int for '{key_path}' and no valid method default. Returning 0." - ) - return 0 - - def get_float(self, key_path: str, default: Optional[float] = None) -> float: - """Retrieves a configuration value, converting it to a float.""" - raw_value = self.get(key_path) - if raw_value is None: - if default is not None: - logger.debug( - f"Config '{key_path}' is None, using provided method default: {default}" - ) - return default - logger.error( - f"Mandatory float config '{key_path}' is None, and no method default. Returning 0.0." - ) - return 0.0 - try: - return float(raw_value) - except (ValueError, TypeError): - logger.warning( - f"Invalid float value '{raw_value}' for '{key_path}'. Using method default or 0.0." - ) - if isinstance(default, float): - return default - logger.error( - f"Cannot parse '{raw_value}' as float for '{key_path}' and no valid method default. Returning 0.0." - ) - return 0.0 - - def get_bool(self, key_path: str, default: Optional[bool] = None) -> bool: - """Retrieves a configuration value, converting it to a boolean.""" - raw_value = self.get(key_path) - if raw_value is None: - if default is not None: - logger.debug( - f"Config '{key_path}' is None, using provided method default: {default}" - ) - return default - logger.error( - f"Mandatory boolean config '{key_path}' is None, and no method default. Returning False." - ) - return False - if isinstance(raw_value, bool): - return raw_value - if isinstance( - raw_value, str - ): # Handle common string representations of booleans. - return raw_value.lower() in ("true", "1", "t", "yes", "y") - try: # Handle numeric representations (e.g., 1 for True, 0 for False). - return bool(int(raw_value)) - except (ValueError, TypeError): - logger.warning( - f"Invalid boolean value '{raw_value}' for '{key_path}'. Using method default or False." - ) - if isinstance(default, bool): - return default - logger.error( - f"Cannot parse '{raw_value}' as bool for '{key_path}' and no valid method default. Returning False." - ) - return False - - def get_path( - self, - key_path: str, - default_str_path: Optional[str] = None, - ensure_absolute: bool = False, - ) -> Path: - """ - Retrieves a configuration value expected to be a path, returning it as a Path object. - If 'ensure_absolute' is True, the path is resolved to an absolute path. - """ - value_from_config = self.get(key_path) - - path_obj_to_return: Path - if isinstance(value_from_config, Path): - path_obj_to_return = value_from_config - elif isinstance(value_from_config, str): # Convert string from config to Path. - path_obj_to_return = Path(value_from_config) - elif default_str_path is not None: # Fallback to provided string default. - logger.debug( - f"Config Path '{key_path}' not found or invalid type, using provided default string path: '{default_str_path}'" - ) - path_obj_to_return = Path(default_str_path) - else: # Ultimate fallback if no value and no default. - logger.error( - f"Config Path '{key_path}' not found or invalid type, and no default provided. Returning Path('.')" - ) - path_obj_to_return = Path(".") # Current directory. - - return path_obj_to_return.resolve() if ensure_absolute else path_obj_to_return - - -# --- Singleton Instance --- -# This provides a single, globally accessible instance of the configuration manager. -config_manager = YamlConfigManager() - -# --- Convenience Accessor Functions --- -# These functions provide easy, module-level access to common configuration settings -# using the singleton 'config_manager' instance. - - -def _get_default_from_structure(key_path: str) -> Any: - """Internal helper to retrieve a default value directly from the DEFAULT_CONFIG structure.""" - keys = key_path.split(".") - return _get_nested_value(DEFAULT_CONFIG, keys) - - -# Server Settings Accessors -def get_host() -> str: - """Returns the server host address.""" - return config_manager.get_string( - "server.host", _get_default_from_structure("server.host") - ) - - -def get_port() -> int: - """Returns the server port number.""" - return config_manager.get_int( - "server.port", _get_default_from_structure("server.port") - ) - - -# Audio Output Settings Accessors -def get_audio_output_format() -> str: - """Returns the default audio output format (e.g., 'wav').""" - return config_manager.get_string( - "audio_output.format", _get_default_from_structure("audio_output.format") - ) - - -def get_log_file_path() -> Path: - """Returns the absolute path to the server log file.""" - default_path_str = str(_get_default_from_structure("server.log_file_path")) - return config_manager.get_path( - "server.log_file_path", default_path_str, ensure_absolute=True - ) - - -# Model Settings Accessors -def get_model_repo_id() -> str: - """Returns the Hugging Face repository ID for the model.""" - return config_manager.get_string( - "model.repo_id", _get_default_from_structure("model.repo_id") - ) - - -# TTS Engine Settings Accessors -def get_tts_device() -> str: - """Returns the resolved TTS processing device ('cuda' or 'cpu').""" - # Device is resolved during load_config, so direct get is appropriate. - return config_manager.get_string( - "tts_engine.device", _get_default_from_structure("tts_engine.device") - ) - - -# General Path Settings Accessors -def get_model_cache_path(ensure_absolute: bool = True) -> Path: - """Returns the path to the model cache directory.""" - default_path_str = str(_get_default_from_structure("paths.model_cache")) - return config_manager.get_path( - "paths.model_cache", default_path_str, ensure_absolute=ensure_absolute - ) - - -def get_output_path(ensure_absolute: bool = True) -> Path: - """Returns the path to the default output directory.""" - default_path_str = str(_get_default_from_structure("paths.output")) - return config_manager.get_path( - "paths.output", default_path_str, ensure_absolute=ensure_absolute - ) - - -# Default Generation Parameter Accessors -def get_gen_default_speed() -> float: - """Returns the default speed for TTS generation.""" - return config_manager.get_float( - "generation_defaults.speed", - _get_default_from_structure("generation_defaults.speed"), - ) - - -def get_gen_default_language() -> str: - """Returns the default language for TTS generation.""" - return config_manager.get_string( - "generation_defaults.language", - _get_default_from_structure("generation_defaults.language"), - ) - - -# Audio Output Settings Accessors -def get_audio_sample_rate() -> int: - """Returns the default audio sample rate.""" - return config_manager.get_int( - "audio_output.sample_rate", - _get_default_from_structure("audio_output.sample_rate"), - ) - - -# UI State Accessors -def get_ui_state() -> Dict[str, Any]: - """Returns the entire UI state dictionary (for UI persistence).""" - return config_manager.get( - "ui_state", deepcopy(_get_default_from_structure("ui_state")) - ) - - -# General UI Settings Accessors -def get_ui_title() -> str: - """Returns the title for the web UI.""" - return config_manager.get_string( - "ui.title", _get_default_from_structure("ui.title") - ) - - -def get_full_config_for_template() -> Dict[str, Any]: - """ - Returns a deep copy of the current configuration, with Path objects - converted to strings. This is suitable for serialization (e.g., JSON) - or for passing to web templates or API responses. - """ - config_snapshot = config_manager.get_all() # Gets a deep copy. - # Convert Path objects in this snapshot to strings for serialization. - return config_manager._prepare_config_for_saving(config_snapshot) - - -# --- End File: config.py --- +# File: config.py +# Manages application configuration using a YAML file (config.yaml). +# Handles loading, saving, and providing access to configuration settings. + +import os +import logging +import yaml +import shutil +from copy import deepcopy +from threading import Lock +from typing import Dict, Any, Optional, List, Tuple +import torch # For automatic CUDA/CPU device detection +from pathlib import Path + +# Standard logger setup +logger = logging.getLogger(__name__) + +# --- File Path Constants --- +# Defines the primary configuration file name. +CONFIG_FILE_PATH = Path("config.yaml") + +# --- Default Directory Paths --- +# These paths are used if not specified in config.yaml and are created if they don't exist +# when a default configuration file is generated. +DEFAULT_LOGS_PATH = Path("logs") +DEFAULT_MODEL_FILES_PATH = Path("./model_cache") # For downloaded model files +DEFAULT_OUTPUT_PATH = Path("./outputs") # For server-saved audio outputs (if any) + +# --- Default Configuration Structure --- +# This dictionary defines the complete expected structure of 'config.yaml', +# including default values for all settings. It serves as the template for +# creating a new config.yaml if one does not exist. +DEFAULT_CONFIG: Dict[str, Any] = { + "server": { + "host": "0.0.0.0", # Host address for the server to listen on. + "port": 8005, # Port number for the server. + "use_ngrok": False, # Placeholder for ngrok integration (if used). + "use_auth": False, # Placeholder for basic authentication (if used). + "auth_username": "user", # Default username if authentication is enabled. + "auth_password": "password", # Default password if authentication is enabled. + "log_file_path": str( + DEFAULT_LOGS_PATH / "tts_server.log" + ), # Path to the server log file. + "log_file_max_size_mb": 10, # Maximum size of a single log file before rotation. + "log_file_backup_count": 5, # Number of backup log files to keep. + }, + "model": { # Updated section for model source configuration + "repo_id": "KittenML/kitten-tts-nano-0.1", # KittenTTS Hugging Face repository ID + }, + "tts_engine": { + "device": "auto", # TTS processing device: 'auto', 'cuda', or 'cpu'. + # 'auto' will attempt to use 'cuda' if available, otherwise 'cpu'. + }, + "paths": { # General configurable paths for the application. + "model_cache": str( + DEFAULT_MODEL_FILES_PATH + ), # Directory for caching or storing downloaded models. + "output": str( + DEFAULT_OUTPUT_PATH + ), # Default directory for any output files generated by the server. + }, + "generation_defaults": { # Default parameters for TTS audio generation. + "speed": 1.0, # Controls the speed of the generated speech. + "language": "en", # Default language for TTS. + }, + "audio_output": { # Settings related to the format of generated audio. + "format": "wav", # Output audio format (e.g., 'wav', 'mp3'). + "sample_rate": 24000, # Sample rate of the output audio in Hz. + }, + "ui_state": { # Stores user interface preferences and last-used values. + "last_text": "", # Last text entered by the user. + "last_voice": "expr-voice-5-m", # Last selected voice. + "last_chunk_size": 120, # Last used chunk size for text splitting in UI. + "last_split_text_enabled": True, # Whether text splitting was last enabled in UI. + "hide_chunk_warning": False, # Flag to hide the chunking warning modal. + "hide_generation_warning": False, # Flag to hide the general generation quality notice modal. + "theme": "dark", # Default UI theme ('dark' or 'light'). + }, + "ui": { # General UI display settings. + "title": "Kitten TTS Server", # Updated title + "show_language_select": True, # Whether to show language selection in the UI. + }, + "debug": { # Settings for debugging purposes + "save_intermediate_audio": False # If true, save intermediate audio files for debugging + }, +} + + +def _ensure_default_paths_exist(): + """ + Creates the default directories specified in DEFAULT_CONFIG if they do not already exist. + This is typically called when generating a default config.yaml file. + """ + paths_to_check = [ + Path( + DEFAULT_CONFIG["server"]["log_file_path"] + ).parent, # Log file's parent directory + Path(DEFAULT_CONFIG["paths"]["model_cache"]), + Path(DEFAULT_CONFIG["paths"]["output"]), + ] + for path in paths_to_check: + try: + path.mkdir(parents=True, exist_ok=True) + except Exception as e: + logger.error(f"Error creating default directory {path}: {e}", exc_info=True) + + +def _deep_merge_dicts(source: Dict, destination: Dict) -> Dict: + """ + Recursively merges the 'source' dictionary into the 'destination' dictionary. + Keys from 'source' will overwrite existing keys in 'destination'. + If a key in 'source' corresponds to a dictionary, a recursive merge is performed. + The 'destination' dictionary is modified in place. + """ + for key, value in source.items(): + if isinstance(value, dict): + node = destination.setdefault(key, {}) + if isinstance( + node, dict + ): # Ensure the destination node is a dict for merging + _deep_merge_dicts(value, node) + else: # If destination's node is not a dict, overwrite it entirely + destination[key] = deepcopy(value) + else: + destination[key] = value + return destination + + +def _set_nested_value(d: Dict, keys: List[str], value: Any): + """Helper function to set a value in a nested dictionary using a list of keys.""" + for key in keys[:-1]: + d = d.setdefault(key, {}) + d[keys[-1]] = value + + +def _get_nested_value(d: Dict, keys: List[str], default: Any = None) -> Any: + """Helper function to get a value from a nested dictionary using a list of keys.""" + for key in keys: + if isinstance(d, dict) and key in d: + d = d[key] + else: + return default + return d + + +class YamlConfigManager: + """ + Manages application configuration stored in a YAML file. + This class handles loading, saving, updating, and resetting the configuration. + Operations are thread-safe for file writing. + """ + + def __init__(self): + """Initializes the configuration manager by loading the configuration from YAML.""" + self.config: Dict[str, Any] = {} + self._lock = Lock() # Ensures thread-safety for file write operations. + self.load_config() + + def _load_defaults(self) -> Dict[str, Any]: + """ + Returns a deep copy of the hardcoded default configuration structure. + Also ensures that default directory paths defined in the structure exist. + """ + _ensure_default_paths_exist() # Create necessary default directories. + return deepcopy(DEFAULT_CONFIG) + + def _resolve_paths_and_device(self, config_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Resolves device settings (e.g., 'auto' to 'cuda' or 'cpu') and converts + string paths in the configuration data to Path objects for internal use. + The 'config_data' dictionary is modified in place. + """ + # Resolve TTS device setting with robust CUDA detection. + current_device_setting = _get_nested_value( + config_data, ["tts_engine", "device"], "auto" + ) + if current_device_setting == "auto": + resolved_device = self._detect_best_device() + _set_nested_value(config_data, ["tts_engine", "device"], resolved_device) + elif current_device_setting not in ["cuda", "cpu", "gpu"]: + logger.warning( + f"Invalid TTS device '{current_device_setting}' in configuration. " + f"Recognized values are 'auto', 'cpu', 'cuda', 'gpu'. Defaulting to auto-detection." + ) + resolved_device = self._detect_best_device() + _set_nested_value(config_data, ["tts_engine", "device"], resolved_device) + + final_device = _get_nested_value(config_data, ["tts_engine", "device"]) + logger.info(f"TTS processing device resolved to: {final_device}") + + # Convert relevant string paths to Path objects. + path_key_map_for_conversion = { + "server": ["log_file_path"], + "paths": ["model_cache", "output"], + } + for section, keys_list in path_key_map_for_conversion.items(): + if section in config_data: + for key in keys_list: + current_path_val = _get_nested_value(config_data, [section, key]) + if isinstance(current_path_val, str): + _set_nested_value( + config_data, [section, key], Path(current_path_val) + ) + return config_data + + def _detect_best_device(self) -> str: + """ + Robustly detects the best available device for TTS processing. + Tests actual CUDA functionality rather than just checking availability. + + Returns: + str: 'cuda' if CUDA is truly functional, 'cpu' otherwise. + """ + # Test CUDA first as it's generally preferred for ML workloads + if torch.cuda.is_available(): + try: + # Actually test CUDA functionality by creating a tensor and moving it to CUDA + test_tensor = torch.tensor([1.0]) + test_tensor = test_tensor.cuda() + test_tensor = test_tensor.cpu() # Clean up + logger.info("CUDA test successful. Using CUDA device.") + return "cuda" + except Exception as e: + logger.warning( + f"CUDA is reported as available but failed functionality test: {e}. " + f"This usually means PyTorch was not compiled with CUDA support." + ) + + logger.info("CUDA not available or functional. Using CPU.") + return "cpu" + + def _prepare_config_for_saving(self, config_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepares a configuration dictionary for YAML serialization by converting + internal Path objects back to their string representations. + Returns a deep copy of the configuration dictionary with paths as strings. + """ + config_copy_for_saving = deepcopy(config_dict) + path_key_map_for_conversion = { + "server": ["log_file_path"], + "paths": ["model_cache", "output"], + } + for section, keys_list in path_key_map_for_conversion.items(): + if section in config_copy_for_saving: + for key in keys_list: + current_path_val = _get_nested_value( + config_copy_for_saving, [section, key] + ) + if isinstance(current_path_val, Path): + _set_nested_value( + config_copy_for_saving, + [section, key], + str(current_path_val), + ) + return config_copy_for_saving + + def load_config(self): + """ + Loads the application configuration from 'config.yaml'. + If the file doesn't exist, it's created using defaults. + The loaded configuration is merged with defaults to ensure all expected keys are present. + Device settings and path types are resolved after loading. + """ + with self._lock: # Ensure thread-safe loading. + base_defaults = self._load_defaults() # Ensures default paths exist. + + if CONFIG_FILE_PATH.exists(): + logger.info(f"Loading configuration from: {CONFIG_FILE_PATH}") + try: + with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f: + yaml_data = yaml.safe_load(f) + if isinstance(yaml_data, dict): + # Merge loaded YAML data into a copy of defaults. + # YAML data takes precedence, defaults fill in missing parts. + effective_config = deepcopy(base_defaults) + _deep_merge_dicts(yaml_data, effective_config) + self.config = effective_config + logger.info( + f"Successfully loaded and merged configuration from {CONFIG_FILE_PATH}." + ) + else: + logger.error( + f"Invalid format in {CONFIG_FILE_PATH}. Expected a dictionary. " + f"Using defaults and attempting to overwrite the invalid file." + ) + self.config = base_defaults # Fallback to defaults. + if not self._save_config_yaml_internal(self.config): + logger.error( + f"Failed to overwrite invalid {CONFIG_FILE_PATH} with defaults." + ) + except yaml.YAMLError as e: + logger.error( + f"Error parsing YAML from {CONFIG_FILE_PATH}: {e}. " + f"Using defaults and attempting to overwrite the corrupted file." + ) + self.config = base_defaults + if not self._save_config_yaml_internal(self.config): + logger.error( + f"Failed to overwrite corrupted {CONFIG_FILE_PATH} with defaults." + ) + except Exception as e: + logger.error( + f"Unexpected error loading {CONFIG_FILE_PATH}: {e}. Using in-memory defaults.", + exc_info=True, + ) + self.config = base_defaults # Use defaults, avoid saving on unexpected errors. + else: + logger.info( + f"{CONFIG_FILE_PATH} not found. Creating initial configuration using defaults..." + ) + # Start with defaults. + self.config = base_defaults + if self._save_config_yaml_internal(self.config): + logger.info( + f"Successfully created and saved initial default configuration to {CONFIG_FILE_PATH}." + ) + else: + logger.error( + f"Failed to save initial configuration to {CONFIG_FILE_PATH}. " + f"Using in-memory defaults." + ) + + # Resolve device and convert path strings to Path objects for the loaded/created config. + self.config = self._resolve_paths_and_device(self.config) + logger.debug(f"Current configuration loaded and resolved: {self.config}") + return self.config + + def _save_config_yaml_internal(self, config_dict_to_save: Dict[str, Any]) -> bool: + """ + Internal method to save the provided configuration dictionary to 'config.yaml'. + It includes a backup and restore mechanism for safety during writes. + Assumes the caller holds the necessary lock. + Converts Path objects to strings before YAML serialization. + """ + # Prepare the configuration for saving (e.g., convert Path objects to strings). + prepared_config_for_yaml = self._prepare_config_for_saving(config_dict_to_save) + + temp_file = CONFIG_FILE_PATH.with_suffix(CONFIG_FILE_PATH.suffix + ".tmp") + backup_file = CONFIG_FILE_PATH.with_suffix(CONFIG_FILE_PATH.suffix + ".bak") + + try: + # Atomically write to a temporary file first. + with open(temp_file, "w", encoding="utf-8") as f: + yaml.dump( + prepared_config_for_yaml, + f, + default_flow_style=False, + sort_keys=False, + indent=2, + ) + + # If an existing config file exists, back it up. + if CONFIG_FILE_PATH.exists(): + try: + shutil.move(str(CONFIG_FILE_PATH), str(backup_file)) + logger.debug(f"Backed up existing configuration to {backup_file}") + except Exception as backup_error: + logger.warning( + f"Could not create backup of {CONFIG_FILE_PATH}: {backup_error}" + ) + # Proceed with saving, but warn about missing backup. + + # Rename the temporary file to the actual configuration file. + shutil.move(str(temp_file), str(CONFIG_FILE_PATH)) + logger.info(f"Configuration successfully saved to {CONFIG_FILE_PATH}") + return True + + except yaml.YAMLError as e_yaml: + logger.error( + f"Error formatting data for {CONFIG_FILE_PATH} (YAML error): {e_yaml}", + exc_info=True, + ) + return False + except Exception as e_general: + logger.error( + f"Failed to save configuration to {CONFIG_FILE_PATH}: {e_general}", + exc_info=True, + ) + # Attempt to restore from backup if the save operation failed. + if backup_file.exists() and not CONFIG_FILE_PATH.exists(): + try: + shutil.move(str(backup_file), str(CONFIG_FILE_PATH)) + logger.info( + f"Restored configuration from backup {backup_file} due to save failure." + ) + except Exception as restore_error: + logger.error( + f"Failed to restore configuration from backup: {restore_error}" + ) + # Clean up the temporary file if it still exists after a failure. + if temp_file.exists(): + try: + os.remove(str(temp_file)) + except Exception as remove_error: + logger.warning( + f"Could not remove temporary config file {temp_file}: {remove_error}" + ) + return False + finally: + # Clean up the backup file if the main configuration file exists and the save was successful. + if CONFIG_FILE_PATH.exists() and backup_file.exists(): + try: + if ( + CONFIG_FILE_PATH.stat().st_size > 0 + ): # Basic check that the new file is not empty. + os.remove(str(backup_file)) + logger.debug( + f"Removed backup file {backup_file} after successful save." + ) + except Exception as remove_bak_error: + logger.warning( + f"Could not remove backup config file {backup_file}: {remove_bak_error}" + ) + + def save_config_yaml(self) -> bool: + """ + Public method to save the current in-memory configuration to 'config.yaml'. + Ensures thread-safety using a lock. + """ + with self._lock: + return self._save_config_yaml_internal(self.config) + + def get(self, key_path: str, default: Any = None) -> Any: + """ + Retrieves a configuration value using a dot-separated key path (e.g., 'server.port'). + If the key path is not found, 'default' is returned. + For mutable types (dicts, lists), a deep copy is returned to prevent + unintended modification of the in-memory configuration. + """ + keys = key_path.split(".") + with self._lock: # Ensure thread-safe access to self.config. + value = _get_nested_value(self.config, keys, default) + return deepcopy(value) if isinstance(value, (dict, list)) else value + + def get_string(self, key_path: str, default: Optional[str] = None) -> str: + """Retrieves a configuration value, ensuring it's a string.""" + # Added this method for explicit string retrieval, common for paths/IDs. + raw_value = self.get(key_path) + if raw_value is None: + if default is not None: + logger.debug( + f"Config string '{key_path}' is None, using provided method default: '{default}'" + ) + return default + logger.error( + f"Mandatory string config '{key_path}' is None, and no method default. Returning empty string." + ) + return "" + if isinstance( + raw_value, (Path, str) + ): # Handle Path objects by converting to string + return str(raw_value) + try: # Attempt conversion for other types if necessary + return str(raw_value) + except Exception: + logger.warning( + f"Could not convert value '{raw_value}' for '{key_path}' to string. Using method default or empty string." + ) + if default is not None: + return default + return "" + + def get_all(self) -> Dict[str, Any]: + """ + Returns a deep copy of the entire current configuration. + Ensures thread-safety during the copy operation. + """ + with self._lock: + return deepcopy(self.config) + + def update_and_save(self, partial_update_dict: Dict[str, Any]) -> bool: + """ + Deeply merges a 'partial_update_dict' into the current configuration + and saves the entire updated configuration back to the YAML file. + This allows updating specific nested values without overwriting entire sections. + """ + if not isinstance(partial_update_dict, dict): + logger.error("Invalid partial update data: input must be a dictionary.") + return False + + with self._lock: + try: + # Work on a deep copy of the current config to avoid altering it before a successful save. + config_copy_for_update = deepcopy(self.config) + # Merge the partial update into this copy. + _deep_merge_dicts(partial_update_dict, config_copy_for_update) + + # Before saving, the merged config might need path/device re-resolution + # if those specific keys were part of partial_update_dict. + # For robustness, always re-resolve. + resolved_updated_config = self._resolve_paths_and_device( + config_copy_for_update + ) + + if self._save_config_yaml_internal(resolved_updated_config): + # If save was successful, update the active in-memory config. + self.config = resolved_updated_config + logger.info( + "Configuration updated, saved, and re-resolved successfully." + ) + return True + else: + logger.error("Failed to save updated configuration after merging.") + return False + except Exception as e: + logger.error( + f"Error during configuration update and save process: {e}", + exc_info=True, + ) + return False + + def reset_and_save(self) -> bool: + """ + Resets the application configuration to its hardcoded defaults. + The reset configuration (after resolving paths/device) is then saved to 'config.yaml'. + """ + with self._lock: + logger.warning("Initiating configuration reset to hardcoded defaults...") + # Start with hardcoded defaults (this also ensures default directories are created). + reset_config_base = self._load_defaults() + # Resolve device settings and ensure paths are Path objects for the new in-memory config. + final_reset_config = self._resolve_paths_and_device(reset_config_base) + + if self._save_config_yaml_internal( + final_reset_config + ): # Save the fully resolved reset config. + self.config = final_reset_config # Update the active in-memory config. + logger.info( + "Configuration successfully reset to defaults, saved, and resolved." + ) + return True + else: + logger.error( + "Failed to save the reset configuration. Current configuration remains unchanged." + ) + # If save failed, the old self.config is retained. + return False + + # --- Type-specific Getters --- + # These provide convenient, type-checked access to configuration values. + def get_int(self, key_path: str, default: Optional[int] = None) -> int: + """Retrieves a configuration value, converting it to an integer.""" + raw_value = self.get(key_path) + if raw_value is None: + if default is not None: + logger.debug( + f"Config '{key_path}' is None, using provided method default: {default}" + ) + return default + logger.error( + f"Mandatory integer config '{key_path}' is None, and no method default. Returning 0." + ) + return 0 + try: + return int(raw_value) + except (ValueError, TypeError): + logger.warning( + f"Invalid integer value '{raw_value}' for '{key_path}'. Using method default or 0." + ) + if isinstance(default, int): + return default + logger.error( + f"Cannot parse '{raw_value}' as int for '{key_path}' and no valid method default. Returning 0." + ) + return 0 + + def get_float(self, key_path: str, default: Optional[float] = None) -> float: + """Retrieves a configuration value, converting it to a float.""" + raw_value = self.get(key_path) + if raw_value is None: + if default is not None: + logger.debug( + f"Config '{key_path}' is None, using provided method default: {default}" + ) + return default + logger.error( + f"Mandatory float config '{key_path}' is None, and no method default. Returning 0.0." + ) + return 0.0 + try: + return float(raw_value) + except (ValueError, TypeError): + logger.warning( + f"Invalid float value '{raw_value}' for '{key_path}'. Using method default or 0.0." + ) + if isinstance(default, float): + return default + logger.error( + f"Cannot parse '{raw_value}' as float for '{key_path}' and no valid method default. Returning 0.0." + ) + return 0.0 + + def get_bool(self, key_path: str, default: Optional[bool] = None) -> bool: + """Retrieves a configuration value, converting it to a boolean.""" + raw_value = self.get(key_path) + if raw_value is None: + if default is not None: + logger.debug( + f"Config '{key_path}' is None, using provided method default: {default}" + ) + return default + logger.error( + f"Mandatory boolean config '{key_path}' is None, and no method default. Returning False." + ) + return False + if isinstance(raw_value, bool): + return raw_value + if isinstance( + raw_value, str + ): # Handle common string representations of booleans. + return raw_value.lower() in ("true", "1", "t", "yes", "y") + try: # Handle numeric representations (e.g., 1 for True, 0 for False). + return bool(int(raw_value)) + except (ValueError, TypeError): + logger.warning( + f"Invalid boolean value '{raw_value}' for '{key_path}'. Using method default or False." + ) + if isinstance(default, bool): + return default + logger.error( + f"Cannot parse '{raw_value}' as bool for '{key_path}' and no valid method default. Returning False." + ) + return False + + def get_path( + self, + key_path: str, + default_str_path: Optional[str] = None, + ensure_absolute: bool = False, + ) -> Path: + """ + Retrieves a configuration value expected to be a path, returning it as a Path object. + If 'ensure_absolute' is True, the path is resolved to an absolute path. + """ + value_from_config = self.get(key_path) + + path_obj_to_return: Path + if isinstance(value_from_config, Path): + path_obj_to_return = value_from_config + elif isinstance(value_from_config, str): # Convert string from config to Path. + path_obj_to_return = Path(value_from_config) + elif default_str_path is not None: # Fallback to provided string default. + logger.debug( + f"Config Path '{key_path}' not found or invalid type, using provided default string path: '{default_str_path}'" + ) + path_obj_to_return = Path(default_str_path) + else: # Ultimate fallback if no value and no default. + logger.error( + f"Config Path '{key_path}' not found or invalid type, and no default provided. Returning Path('.')" + ) + path_obj_to_return = Path(".") # Current directory. + + return path_obj_to_return.resolve() if ensure_absolute else path_obj_to_return + + +# --- Singleton Instance --- +# This provides a single, globally accessible instance of the configuration manager. +config_manager = YamlConfigManager() + +# --- Convenience Accessor Functions --- +# These functions provide easy, module-level access to common configuration settings +# using the singleton 'config_manager' instance. + + +def _get_default_from_structure(key_path: str) -> Any: + """Internal helper to retrieve a default value directly from the DEFAULT_CONFIG structure.""" + keys = key_path.split(".") + return _get_nested_value(DEFAULT_CONFIG, keys) + + +# Server Settings Accessors +def get_host() -> str: + """Returns the server host address.""" + return config_manager.get_string( + "server.host", _get_default_from_structure("server.host") + ) + + +def get_port() -> int: + """Returns the server port number.""" + return config_manager.get_int( + "server.port", _get_default_from_structure("server.port") + ) + + +# Audio Output Settings Accessors +def get_audio_output_format() -> str: + """Returns the default audio output format (e.g., 'wav').""" + return config_manager.get_string( + "audio_output.format", _get_default_from_structure("audio_output.format") + ) + + +def get_log_file_path() -> Path: + """Returns the absolute path to the server log file.""" + default_path_str = str(_get_default_from_structure("server.log_file_path")) + return config_manager.get_path( + "server.log_file_path", default_path_str, ensure_absolute=True + ) + + +# Model Settings Accessors +def get_model_repo_id() -> str: + """Returns the Hugging Face repository ID for the model.""" + return config_manager.get_string( + "model.repo_id", _get_default_from_structure("model.repo_id") + ) + + +# TTS Engine Settings Accessors +def get_tts_device() -> str: + """Returns the resolved TTS processing device ('cuda' or 'cpu').""" + # Device is resolved during load_config, so direct get is appropriate. + return config_manager.get_string( + "tts_engine.device", _get_default_from_structure("tts_engine.device") + ) + + +# General Path Settings Accessors +def get_model_cache_path(ensure_absolute: bool = True) -> Path: + """Returns the path to the model cache directory.""" + default_path_str = str(_get_default_from_structure("paths.model_cache")) + return config_manager.get_path( + "paths.model_cache", default_path_str, ensure_absolute=ensure_absolute + ) + + +def get_output_path(ensure_absolute: bool = True) -> Path: + """Returns the path to the default output directory.""" + default_path_str = str(_get_default_from_structure("paths.output")) + return config_manager.get_path( + "paths.output", default_path_str, ensure_absolute=ensure_absolute + ) + + +# Default Generation Parameter Accessors +def get_gen_default_speed() -> float: + """Returns the default speed for TTS generation.""" + return config_manager.get_float( + "generation_defaults.speed", + _get_default_from_structure("generation_defaults.speed"), + ) + + +def get_gen_default_language() -> str: + """Returns the default language for TTS generation.""" + return config_manager.get_string( + "generation_defaults.language", + _get_default_from_structure("generation_defaults.language"), + ) + + +# Audio Output Settings Accessors +def get_audio_sample_rate() -> int: + """Returns the default audio sample rate.""" + return config_manager.get_int( + "audio_output.sample_rate", + _get_default_from_structure("audio_output.sample_rate"), + ) + + +# UI State Accessors +def get_ui_state() -> Dict[str, Any]: + """Returns the entire UI state dictionary (for UI persistence).""" + return config_manager.get( + "ui_state", deepcopy(_get_default_from_structure("ui_state")) + ) + + +# General UI Settings Accessors +def get_ui_title() -> str: + """Returns the title for the web UI.""" + return config_manager.get_string( + "ui.title", _get_default_from_structure("ui.title") + ) + + +def get_full_config_for_template() -> Dict[str, Any]: + """ + Returns a deep copy of the current configuration, with Path objects + converted to strings. This is suitable for serialization (e.g., JSON) + or for passing to web templates or API responses. + """ + config_snapshot = config_manager.get_all() # Gets a deep copy. + # Convert Path objects in this snapshot to strings for serialization. + return config_manager._prepare_config_for_saving(config_snapshot) + + +# --- End File: config.py --- diff --git a/config.yaml b/config.yaml index 9f595a5..7b64840 100644 --- a/config.yaml +++ b/config.yaml @@ -1,38 +1,38 @@ -server: - host: 0.0.0.0 - port: 8005 - use_ngrok: false - use_auth: false - auth_username: user - auth_password: password - log_file_path: logs\tts_server.log - log_file_max_size_mb: 10 - log_file_backup_count: 5 -model: - repo_id: KittenML/kitten-tts-mini-0.8 -tts_engine: - device: cuda -paths: - model_cache: model_cache - output: outputs -generation_defaults: - speed: 1 - language: en - speed_factor: 1.1 -audio_output: - format: ogg - sample_rate: 24000 -ui_state: - last_text: YOu dont have any new unread emails. - last_voice: expr-voice-2-m - last_chunk_size: 200 - last_split_text_enabled: true - hide_chunk_warning: false - hide_generation_warning: true - theme: dark -ui: - title: Kitten TTS Server - show_language_select: true - max_predefined_voices_in_dropdown: 50 -debug: - save_intermediate_audio: false +server: + host: 0.0.0.0 + port: 8005 + use_ngrok: false + use_auth: false + auth_username: user + auth_password: password + log_file_path: logs\tts_server.log + log_file_max_size_mb: 10 + log_file_backup_count: 5 +model: + repo_id: KittenML/kitten-tts-mini-0.8 +tts_engine: + device: cuda +paths: + model_cache: model_cache + output: outputs +generation_defaults: + speed: 1 + language: en + speed_factor: 1.1 +audio_output: + format: ogg + sample_rate: 24000 +ui_state: + last_text: YOu dont have any new unread emails. + last_voice: expr-voice-2-m + last_chunk_size: 200 + last_split_text_enabled: true + hide_chunk_warning: false + hide_generation_warning: true + theme: dark +ui: + title: Kitten TTS Server + show_language_select: true + max_predefined_voices_in_dropdown: 50 +debug: + save_intermediate_audio: false diff --git a/docker-compose-cpu.yml b/docker-compose-cpu.yml index 3e63221..2543434 100644 --- a/docker-compose-cpu.yml +++ b/docker-compose-cpu.yml @@ -1,29 +1,29 @@ -version: '3.8' - -services: - kitten-tts-server: - build: - context: . - dockerfile: Dockerfile - args: - # This build argument ensures only CPU dependencies are installed - - RUNTIME=cpu - ports: - - "${PORT:-8005}:8005" - volumes: - # Mount local config file for persistence - - ./config.yaml:/app/config.yaml - # Mount local directories for persistent app data - - ./outputs:/app/outputs - - ./logs:/app/logs - # Named volume for Hugging Face model cache to persist across container rebuilds - - hf_cache:/app/hf_cache - - restart: unless-stopped - environment: - # Enable faster Hugging Face downloads inside the container - - HF_HUB_ENABLE_HF_TRANSFER=1 - -# Define the named volume for the Hugging Face cache -volumes: +version: '3.8' + +services: + kitten-tts-server: + build: + context: . + dockerfile: Dockerfile + args: + # This build argument ensures only CPU dependencies are installed + - RUNTIME=cpu + ports: + - "${PORT:-8005}:8005" + volumes: + # Mount local config file for persistence + - ./config.yaml:/app/config.yaml + # Mount local directories for persistent app data + - ./outputs:/app/outputs + - ./logs:/app/logs + # Named volume for Hugging Face model cache to persist across container rebuilds + - hf_cache:/app/hf_cache + + restart: unless-stopped + environment: + # Enable faster Hugging Face downloads inside the container + - HF_HUB_ENABLE_HF_TRANSFER=1 + +# Define the named volume for the Hugging Face cache +volumes: hf_cache: \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 9cd9735..c088a5e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,49 +1,49 @@ -version: '3.8' - -services: - kitten-tts-server: - build: - args: - # Can be nvidia or cpu; Default is Nvidia - - RUNTIME=nvidia - context: . - dockerfile: Dockerfile - ports: - - "${PORT:-8005}:8005" - volumes: - # Mount local config file for persistence - - ./config.yaml:/app/config.yaml - # Mount local directories for persistent app data - - ./outputs:/app/outputs - - ./logs:/app/logs - # Named volume for Hugging Face model cache to persist across container rebuilds - - hf_cache:/app/hf_cache - - # --- GPU Support (NVIDIA) --- - # The 'deploy' key is the modern way to request GPU resources. - # If you get a 'CDI device injection failed' error, comment out the 'deploy' section - # and uncomment the 'runtime: nvidia' line below. - - # Method 1: Modern Docker Compose (Recommended) - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] - - # Method 2: Legacy Docker Compose (for older setups) - # runtime: nvidia - - restart: unless-stopped - environment: - # Enable faster Hugging Face downloads inside the container - - HF_HUB_ENABLE_HF_TRANSFER=1 - # Make NVIDIA GPUs visible and specify capabilities for PyTorch - - NVIDIA_VISIBLE_DEVICES=all - - NVIDIA_DRIVER_CAPABILITIES=compute,utility - -# Define the named volume for the Hugging Face cache -volumes: +version: '3.8' + +services: + kitten-tts-server: + build: + args: + # Can be nvidia or cpu; Default is Nvidia + - RUNTIME=nvidia + context: . + dockerfile: Dockerfile + ports: + - "${PORT:-8005}:8005" + volumes: + # Mount local config file for persistence + - ./config.yaml:/app/config.yaml + # Mount local directories for persistent app data + - ./outputs:/app/outputs + - ./logs:/app/logs + # Named volume for Hugging Face model cache to persist across container rebuilds + - hf_cache:/app/hf_cache + + # --- GPU Support (NVIDIA) --- + # The 'deploy' key is the modern way to request GPU resources. + # If you get a 'CDI device injection failed' error, comment out the 'deploy' section + # and uncomment the 'runtime: nvidia' line below. + + # Method 1: Modern Docker Compose (Recommended) + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # Method 2: Legacy Docker Compose (for older setups) + # runtime: nvidia + + restart: unless-stopped + environment: + # Enable faster Hugging Face downloads inside the container + - HF_HUB_ENABLE_HF_TRANSFER=1 + # Make NVIDIA GPUs visible and specify capabilities for PyTorch + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=compute,utility + +# Define the named volume for the Hugging Face cache +volumes: hf_cache: \ No newline at end of file diff --git a/engine.py b/engine.py index e4ef6e2..ea4fb7e 100644 --- a/engine.py +++ b/engine.py @@ -1,450 +1,450 @@ -# File: engine.py -# Core TTS model loading and speech generation logic for KittenTTS ONNX. - -import torch -import os -import logging -import numpy as np -import onnxruntime as ort -from typing import Optional, Tuple -from pathlib import Path -from huggingface_hub import hf_hub_download -import phonemizer - -# This loader can be problematic on Linux, we will bypass it with system-installed eSpeak. -# We still import it as it's a dependency, but we will avoid calling it directly where possible. -import espeakng_loader - -# Import the singleton config_manager -from config import config_manager - -logger = logging.getLogger(__name__) - -# --- Global Module Variables --- -onnx_session: Optional[ort.InferenceSession] = None -voices_data: Optional[dict] = None -phonemizer_backend: Optional[phonemizer.backend.EspeakBackend] = None -text_cleaner: Optional["TextCleaner"] = None -MODEL_LOADED: bool = False -voice_aliases: dict = {} -speed_priors: dict = {} - -# KittenTTS available voices (populated dynamically after model load) -KITTEN_TTS_VOICES = [ - "expr-voice-2-m", - "expr-voice-2-f", - "expr-voice-3-m", - "expr-voice-3-f", - "expr-voice-4-m", - "expr-voice-4-f", - "expr-voice-5-m", - "expr-voice-5-f", -] - - -class TextCleaner: - """Text cleaner for KittenTTS - converts text to token indices.""" - - def __init__(self): - _pad = "$" - _punctuation = ';:,.!?¡¿—…"«»"" ' - _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" - - symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) - - self.word_index_dictionary = {} - for i in range(len(symbols)): - self.word_index_dictionary[symbols[i]] = i - - def __call__(self, text: str): - indexes = [] - for char in text: - try: - indexes.append(self.word_index_dictionary[char]) - except KeyError: - pass - return indexes - - -def basic_english_tokenize(text: str): - """Basic English tokenizer that splits on whitespace and punctuation.""" - import re - - tokens = re.findall(r"\w+|[^\w\s]", text) - return tokens - - -def load_model() -> bool: - """ - Loads the KittenTTS model from Hugging Face Hub and initializes ONNX session. - Updates global variables for model components. - - Returns: - bool: True if the model was loaded successfully, False otherwise. - """ - global onnx_session, voices_data, phonemizer_backend, text_cleaner, MODEL_LOADED, voice_aliases, speed_priors, KITTEN_TTS_VOICES - - if MODEL_LOADED: - logger.info("KittenTTS model is already loaded.") - return True - - try: - # Get model repository and cache path from config - model_repo_id = config_manager.get_string( - "model.repo_id", "KittenML/kitten-tts-nano-0.1" - ) - model_cache_path = config_manager.get_path( - "paths.model_cache", "./model_cache", ensure_absolute=True - ) - - logger.info(f"Loading KittenTTS model from: {model_repo_id}") - logger.info(f"Using cache directory: {model_cache_path}") - - # Ensure cache directory exists - model_cache_path.mkdir(parents=True, exist_ok=True) - - # Download config.json first - config_path = hf_hub_download( - repo_id=model_repo_id, - filename="config.json", - cache_dir=str(model_cache_path), - ) - - # Load config to get model filenames - import json - - with open(config_path, "r") as f: - model_config = json.load(f) - - supported_types = {"ONNX1", "ONNX2"} - model_type = model_config.get("type") - if model_type not in supported_types: - raise ValueError(f"Unsupported model type: '{model_type}'. Expected one of: {supported_types}") - - # Download model and voices files - model_path = hf_hub_download( - repo_id=model_repo_id, - filename=model_config["model_file"], - cache_dir=str(model_cache_path), - ) - - voices_path = hf_hub_download( - repo_id=model_repo_id, - filename=model_config["voices"], - cache_dir=str(model_cache_path), - ) - - # Load voices data - voices_data = np.load(voices_path) - logger.info(f"Loaded voices data with keys: {list(voices_data.keys())}") - - # Parse ONNX2 config fields - voice_aliases = model_config.get("voice_aliases", {}) - speed_priors = model_config.get("speed_priors", {}) - if voice_aliases: - logger.info(f"Loaded voice aliases: {voice_aliases}") - if speed_priors: - logger.info(f"Loaded speed priors: {speed_priors}") - - # Build available voices list from loaded voice data + aliases - KITTEN_TTS_VOICES = list(voices_data.keys()) + list(voice_aliases.keys()) - logger.info(f"Available voices: {KITTEN_TTS_VOICES}") - - # Determine device and providers and configure for optimal performance - device_setting = config_manager.get_string("tts_engine.device", "auto").lower() - available_providers = ort.get_available_providers() - logger.info(f"Available ONNX Runtime providers: {available_providers}") - - sess_options = ort.SessionOptions() - providers = [] - provider_options = [] - - # A boolean flag to check if we should attempt to use the GPU - attempt_gpu = device_setting in ["auto", "cuda", "gpu"] - is_gpu_available = "CUDAExecutionProvider" in available_providers - - # The primary condition: attempt to use GPU and check if it's available - if attempt_gpu and is_gpu_available: - logger.info( - f"'{device_setting}' mode selected and CUDAExecutionProvider is available." - ) - logger.info("Configuring CUDAExecutionProvider for optimal performance.") - providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] - provider_options = [ - { - "device_id": "0", - }, - {}, - ] - else: - # Fallback to CPU for all other cases - if device_setting in ["cuda", "gpu"] and not is_gpu_available: - logger.warning( - f"Configuration explicitly requests GPU ('{device_setting}'), but CUDAExecutionProvider is NOT available." - ) - logger.warning( - "Please ensure NVIDIA drivers and the correct dependencies are installed." - ) - - logger.info("Defaulting to CPUExecutionProvider.") - providers = ["CPUExecutionProvider"] - - # Initialize the ONNX Inference Session with the chosen providers and options - logger.info( - f"Initializing ONNX InferenceSession from {model_path} with providers: {providers}" - ) - - # Only pass provider_options if the GPU provider is being used - if "CUDAExecutionProvider" in providers: - onnx_session = ort.InferenceSession( - str(model_path), - sess_options, - providers=providers, - provider_options=provider_options, - ) - else: - # For CPU-only, do not pass the provider_options argument - onnx_session = ort.InferenceSession( - str(model_path), - sess_options, - providers=providers, - ) - - # --- Cross-Platform eSpeak Configuration --- - # This block ensures that on both Windows and Linux, the correct eSpeak library - # and data files are found and configured, bypassing potential issues with loaders. - - # Auto-configure eSpeak for Windows - if os.name == "nt": # Windows - logger.info("Checking for eSpeak NG on Windows...") - possible_paths = [ - Path(r"C:\Program Files\eSpeak NG"), - Path(r"C:\Program Files (x86)\eSpeak NG"), - Path(r"C:\eSpeak NG"), - Path(os.environ.get("ProgramFiles", "")) / "eSpeak NG", - Path(os.environ.get("ProgramFiles(x86)", "")) / "eSpeak NG", - ] - espeak_found = False - for espeak_path in possible_paths: - if espeak_path.exists(): - dll_path = espeak_path / "libespeak-ng.dll" - if dll_path.exists(): - os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(dll_path) - from phonemizer.backend.espeak.wrapper import ( - EspeakWrapper as PhonemizeEspeakWrapper, - ) - - PhonemizeEspeakWrapper.set_library(str(dll_path)) - logger.info(f"Auto-configured eSpeak from: {espeak_path}") - espeak_found = True - break - if not espeak_found: - logger.warning("eSpeak NG not found in common Windows locations.") - - # Auto-configure eSpeak for Linux by finding the system-installed library - elif os.name == "posix": # Linux/macOS - logger.info("Checking for system-installed eSpeak NG on Linux...") - # By setting the library path, we let phonemizer handle finding the data path, which is more robust. - espeak_lib_path = "/usr/lib/x86_64-linux-gnu/libespeak-ng.so" - if Path(espeak_lib_path).exists(): - os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = espeak_lib_path - logger.info( - f"Found and configured system eSpeak NG library: {espeak_lib_path}" - ) - else: - logger.warning( - f"Could not find system eSpeak NG library at {espeak_lib_path}. " - "Please ensure 'espeak-ng' is installed via your package manager." - ) - - # Initialize phonemizer with better error handling - try: - # Suppress phonemizer warnings during initialization - import logging as log_module - - phonemizer_logger = log_module.getLogger("phonemizer") - original_level = phonemizer_logger.level - phonemizer_logger.setLevel(log_module.ERROR) - - phonemizer_backend = phonemizer.backend.EspeakBackend( - language="en-us", preserve_punctuation=True, with_stress=True - ) - - phonemizer_logger.setLevel(original_level) - logger.info("Phonemizer backend initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize phonemizer: {e}") - logger.error( - "Please ensure eSpeak NG is installed:\n" - " Windows: Download from https://github.com/espeak-ng/espeak-ng/releases\n" - " Linux: Run 'sudo apt install espeak-ng'" - ) - raise - - # Initialize text cleaner - text_cleaner = TextCleaner() - - MODEL_LOADED = True - logger.info("KittenTTS model loaded successfully.") - return True - - except Exception as e: - logger.error(f"Error loading KittenTTS model: {e}", exc_info=True) - onnx_session = None - voices_data = None - phonemizer_backend = None - text_cleaner = None - voice_aliases = {} - speed_priors = {} - MODEL_LOADED = False - return False - - -def _get_voice_embedding(voice: str, text_length: int) -> np.ndarray: - """Get voice embedding, handling both ONNX1 (1D) and ONNX2 (2D) formats.""" - embedding = voices_data[voice] - if embedding.ndim == 1: - # ONNX1: single embedding, just add batch dim - return np.expand_dims(embedding, axis=0).astype(np.float32) - else: - # ONNX2: multiple reference embeddings, select one based on text length - ref_id = min(text_length, embedding.shape[0] - 1) - return embedding[ref_id:ref_id+1].astype(np.float32) - - -def synthesize( - text: str, voice: str, speed: float = 1.0 -) -> Tuple[Optional[np.ndarray], Optional[int]]: - """ - Synthesizes audio from text using the loaded KittenTTS model. - - Args: - text: The text to synthesize. - voice: Voice identifier (e.g., 'expr-voice-5-m'). - speed: Speech speed factor (1.0 is normal speed). - - Returns: - A tuple containing the audio waveform (numpy array) and the sample rate (int), - or (None, None) if synthesis fails. - """ - global onnx_session, voices_data, phonemizer_backend, text_cleaner - - if not MODEL_LOADED or onnx_session is None: - logger.error("KittenTTS model is not loaded. Cannot synthesize audio.") - return None, None - - if voice not in KITTEN_TTS_VOICES: - logger.error( - f"Voice '{voice}' not available. Available voices: {KITTEN_TTS_VOICES}" - ) - return None, None - - # Resolve voice alias to internal voice ID - resolved_voice = voice_aliases.get(voice, voice) - if resolved_voice != voice: - logger.debug(f"Resolved voice alias '{voice}' -> '{resolved_voice}'") - - # Apply speed prior for this voice if available - prior = speed_priors.get(resolved_voice, 1.0) - if prior != 1.0: - speed = speed * prior - logger.debug(f"Applied speed prior {prior} for voice '{resolved_voice}', effective speed: {speed}") - - voice = resolved_voice - - try: - logger.debug(f"Synthesizing with voice='{voice}', speed={speed}") - logger.debug(f"Input text (first 100 chars): '{text[:100]}...'") - - # Phonemize the input text - # Suppress the word count mismatch warning by temporarily adjusting log level - import logging as log_module - - phonemizer_logger = log_module.getLogger("phonemizer") - original_level = phonemizer_logger.level - phonemizer_logger.setLevel(log_module.ERROR) - - phonemes_list = phonemizer_backend.phonemize([text]) - - # Restore original log level - phonemizer_logger.setLevel(original_level) - - # Process phonemes to get token IDs - phonemes = basic_english_tokenize(phonemes_list[0]) - phonemes = " ".join(phonemes) - tokens = text_cleaner(phonemes) - - # Add start and end tokens - tokens.insert(0, 0) - tokens.append(0) - - # Determine the execution device from the session to decide where to place tensors - provider = onnx_session.get_providers()[0] - - if provider == "CUDAExecutionProvider": - # --- I/O Binding Path for GPU using NumPy --- - # Create standard NumPy arrays on the CPU first. - input_ids_np = np.array([tokens], dtype=np.int64) - ref_s_np = _get_voice_embedding(voice, len(tokens)) - speed_array_np = np.array([speed], dtype=np.float32) - - # Create OrtValues from the NumPy arrays. I/O binding will handle the copy to GPU. - input_ids_ort = ort.OrtValue.ortvalue_from_numpy(input_ids_np, "cuda", 0) - ref_s_ort = ort.OrtValue.ortvalue_from_numpy(ref_s_np, "cuda", 0) - speed_array_ort = ort.OrtValue.ortvalue_from_numpy( - speed_array_np, "cuda", 0 - ) - - # Set up I/O binding - io_binding = onnx_session.io_binding() - - # Bind the OrtValue inputs - io_binding.bind_ortvalue_input("input_ids", input_ids_ort) - io_binding.bind_ortvalue_input("style", ref_s_ort) - io_binding.bind_ortvalue_input("speed", speed_array_ort) - - # Get the actual name of the first output from the loaded model - output_name = onnx_session.get_outputs()[0].name - - # Bind the output to the GPU using the correct name - io_binding.bind_output(output_name, "cuda") - - # Run inference with binding - onnx_session.run_with_iobinding(io_binding) - - # Get the output from the binding - output_ortvalue = io_binding.get_outputs()[0] - - # The output is on the GPU. Copy it back to the CPU to be used by the rest of the app. - audio = output_ortvalue.numpy() - - else: - # --- Standard Path for CPU --- - input_ids = np.array([tokens], dtype=np.int64) - ref_s = _get_voice_embedding(voice, len(tokens)) - speed_array = np.array([speed], dtype=np.float32) - - onnx_inputs = { - "input_ids": input_ids, - "style": ref_s, - "speed": speed_array, - } - # Run standard inference - outputs = onnx_session.run(None, onnx_inputs) - audio = outputs[0] - - # KittenTTS uses 24kHz sample rate - sample_rate = 24000 - - logger.info( - f"Successfully generated {len(audio)} audio samples at {sample_rate}Hz" - ) - return audio, sample_rate - - except Exception as e: - logger.error(f"Error during KittenTTS synthesis: {e}", exc_info=True) - return None, None - - -# --- End File: engine.py --- +# File: engine.py +# Core TTS model loading and speech generation logic for KittenTTS ONNX. + +import torch +import os +import logging +import numpy as np +import onnxruntime as ort +from typing import Optional, Tuple +from pathlib import Path +from huggingface_hub import hf_hub_download +import phonemizer + +# This loader can be problematic on Linux, we will bypass it with system-installed eSpeak. +# We still import it as it's a dependency, but we will avoid calling it directly where possible. +import espeakng_loader + +# Import the singleton config_manager +from config import config_manager + +logger = logging.getLogger(__name__) + +# --- Global Module Variables --- +onnx_session: Optional[ort.InferenceSession] = None +voices_data: Optional[dict] = None +phonemizer_backend: Optional[phonemizer.backend.EspeakBackend] = None +text_cleaner: Optional["TextCleaner"] = None +MODEL_LOADED: bool = False +voice_aliases: dict = {} +speed_priors: dict = {} + +# KittenTTS available voices (populated dynamically after model load) +KITTEN_TTS_VOICES = [ + "expr-voice-2-m", + "expr-voice-2-f", + "expr-voice-3-m", + "expr-voice-3-f", + "expr-voice-4-m", + "expr-voice-4-f", + "expr-voice-5-m", + "expr-voice-5-f", +] + + +class TextCleaner: + """Text cleaner for KittenTTS - converts text to token indices.""" + + def __init__(self): + _pad = "$" + _punctuation = ';:,.!?¡¿—…"«»"" ' + _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + + self.word_index_dictionary = {} + for i in range(len(symbols)): + self.word_index_dictionary[symbols[i]] = i + + def __call__(self, text: str): + indexes = [] + for char in text: + try: + indexes.append(self.word_index_dictionary[char]) + except KeyError: + pass + return indexes + + +def basic_english_tokenize(text: str): + """Basic English tokenizer that splits on whitespace and punctuation.""" + import re + + tokens = re.findall(r"\w+|[^\w\s]", text) + return tokens + + +def load_model() -> bool: + """ + Loads the KittenTTS model from Hugging Face Hub and initializes ONNX session. + Updates global variables for model components. + + Returns: + bool: True if the model was loaded successfully, False otherwise. + """ + global onnx_session, voices_data, phonemizer_backend, text_cleaner, MODEL_LOADED, voice_aliases, speed_priors, KITTEN_TTS_VOICES + + if MODEL_LOADED: + logger.info("KittenTTS model is already loaded.") + return True + + try: + # Get model repository and cache path from config + model_repo_id = config_manager.get_string( + "model.repo_id", "KittenML/kitten-tts-nano-0.1" + ) + model_cache_path = config_manager.get_path( + "paths.model_cache", "./model_cache", ensure_absolute=True + ) + + logger.info(f"Loading KittenTTS model from: {model_repo_id}") + logger.info(f"Using cache directory: {model_cache_path}") + + # Ensure cache directory exists + model_cache_path.mkdir(parents=True, exist_ok=True) + + # Download config.json first + config_path = hf_hub_download( + repo_id=model_repo_id, + filename="config.json", + cache_dir=str(model_cache_path), + ) + + # Load config to get model filenames + import json + + with open(config_path, "r") as f: + model_config = json.load(f) + + supported_types = {"ONNX1", "ONNX2"} + model_type = model_config.get("type") + if model_type not in supported_types: + raise ValueError(f"Unsupported model type: '{model_type}'. Expected one of: {supported_types}") + + # Download model and voices files + model_path = hf_hub_download( + repo_id=model_repo_id, + filename=model_config["model_file"], + cache_dir=str(model_cache_path), + ) + + voices_path = hf_hub_download( + repo_id=model_repo_id, + filename=model_config["voices"], + cache_dir=str(model_cache_path), + ) + + # Load voices data + voices_data = np.load(voices_path) + logger.info(f"Loaded voices data with keys: {list(voices_data.keys())}") + + # Parse ONNX2 config fields + voice_aliases = model_config.get("voice_aliases", {}) + speed_priors = model_config.get("speed_priors", {}) + if voice_aliases: + logger.info(f"Loaded voice aliases: {voice_aliases}") + if speed_priors: + logger.info(f"Loaded speed priors: {speed_priors}") + + # Build available voices list from loaded voice data + aliases + KITTEN_TTS_VOICES = list(voices_data.keys()) + list(voice_aliases.keys()) + logger.info(f"Available voices: {KITTEN_TTS_VOICES}") + + # Determine device and providers and configure for optimal performance + device_setting = config_manager.get_string("tts_engine.device", "auto").lower() + available_providers = ort.get_available_providers() + logger.info(f"Available ONNX Runtime providers: {available_providers}") + + sess_options = ort.SessionOptions() + providers = [] + provider_options = [] + + # A boolean flag to check if we should attempt to use the GPU + attempt_gpu = device_setting in ["auto", "cuda", "gpu"] + is_gpu_available = "CUDAExecutionProvider" in available_providers + + # The primary condition: attempt to use GPU and check if it's available + if attempt_gpu and is_gpu_available: + logger.info( + f"'{device_setting}' mode selected and CUDAExecutionProvider is available." + ) + logger.info("Configuring CUDAExecutionProvider for optimal performance.") + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + provider_options = [ + { + "device_id": "0", + }, + {}, + ] + else: + # Fallback to CPU for all other cases + if device_setting in ["cuda", "gpu"] and not is_gpu_available: + logger.warning( + f"Configuration explicitly requests GPU ('{device_setting}'), but CUDAExecutionProvider is NOT available." + ) + logger.warning( + "Please ensure NVIDIA drivers and the correct dependencies are installed." + ) + + logger.info("Defaulting to CPUExecutionProvider.") + providers = ["CPUExecutionProvider"] + + # Initialize the ONNX Inference Session with the chosen providers and options + logger.info( + f"Initializing ONNX InferenceSession from {model_path} with providers: {providers}" + ) + + # Only pass provider_options if the GPU provider is being used + if "CUDAExecutionProvider" in providers: + onnx_session = ort.InferenceSession( + str(model_path), + sess_options, + providers=providers, + provider_options=provider_options, + ) + else: + # For CPU-only, do not pass the provider_options argument + onnx_session = ort.InferenceSession( + str(model_path), + sess_options, + providers=providers, + ) + + # --- Cross-Platform eSpeak Configuration --- + # This block ensures that on both Windows and Linux, the correct eSpeak library + # and data files are found and configured, bypassing potential issues with loaders. + + # Auto-configure eSpeak for Windows + if os.name == "nt": # Windows + logger.info("Checking for eSpeak NG on Windows...") + possible_paths = [ + Path(r"C:\Program Files\eSpeak NG"), + Path(r"C:\Program Files (x86)\eSpeak NG"), + Path(r"C:\eSpeak NG"), + Path(os.environ.get("ProgramFiles", "")) / "eSpeak NG", + Path(os.environ.get("ProgramFiles(x86)", "")) / "eSpeak NG", + ] + espeak_found = False + for espeak_path in possible_paths: + if espeak_path.exists(): + dll_path = espeak_path / "libespeak-ng.dll" + if dll_path.exists(): + os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(dll_path) + from phonemizer.backend.espeak.wrapper import ( + EspeakWrapper as PhonemizeEspeakWrapper, + ) + + PhonemizeEspeakWrapper.set_library(str(dll_path)) + logger.info(f"Auto-configured eSpeak from: {espeak_path}") + espeak_found = True + break + if not espeak_found: + logger.warning("eSpeak NG not found in common Windows locations.") + + # Auto-configure eSpeak for Linux by finding the system-installed library + elif os.name == "posix": # Linux/macOS + logger.info("Checking for system-installed eSpeak NG on Linux...") + # By setting the library path, we let phonemizer handle finding the data path, which is more robust. + espeak_lib_path = "/usr/lib/x86_64-linux-gnu/libespeak-ng.so" + if Path(espeak_lib_path).exists(): + os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = espeak_lib_path + logger.info( + f"Found and configured system eSpeak NG library: {espeak_lib_path}" + ) + else: + logger.warning( + f"Could not find system eSpeak NG library at {espeak_lib_path}. " + "Please ensure 'espeak-ng' is installed via your package manager." + ) + + # Initialize phonemizer with better error handling + try: + # Suppress phonemizer warnings during initialization + import logging as log_module + + phonemizer_logger = log_module.getLogger("phonemizer") + original_level = phonemizer_logger.level + phonemizer_logger.setLevel(log_module.ERROR) + + phonemizer_backend = phonemizer.backend.EspeakBackend( + language="en-us", preserve_punctuation=True, with_stress=True + ) + + phonemizer_logger.setLevel(original_level) + logger.info("Phonemizer backend initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize phonemizer: {e}") + logger.error( + "Please ensure eSpeak NG is installed:\n" + " Windows: Download from https://github.com/espeak-ng/espeak-ng/releases\n" + " Linux: Run 'sudo apt install espeak-ng'" + ) + raise + + # Initialize text cleaner + text_cleaner = TextCleaner() + + MODEL_LOADED = True + logger.info("KittenTTS model loaded successfully.") + return True + + except Exception as e: + logger.error(f"Error loading KittenTTS model: {e}", exc_info=True) + onnx_session = None + voices_data = None + phonemizer_backend = None + text_cleaner = None + voice_aliases = {} + speed_priors = {} + MODEL_LOADED = False + return False + + +def _get_voice_embedding(voice: str, text_length: int) -> np.ndarray: + """Get voice embedding, handling both ONNX1 (1D) and ONNX2 (2D) formats.""" + embedding = voices_data[voice] + if embedding.ndim == 1: + # ONNX1: single embedding, just add batch dim + return np.expand_dims(embedding, axis=0).astype(np.float32) + else: + # ONNX2: multiple reference embeddings, select one based on text length + ref_id = min(text_length, embedding.shape[0] - 1) + return embedding[ref_id:ref_id+1].astype(np.float32) + + +def synthesize( + text: str, voice: str, speed: float = 1.0 +) -> Tuple[Optional[np.ndarray], Optional[int]]: + """ + Synthesizes audio from text using the loaded KittenTTS model. + + Args: + text: The text to synthesize. + voice: Voice identifier (e.g., 'expr-voice-5-m'). + speed: Speech speed factor (1.0 is normal speed). + + Returns: + A tuple containing the audio waveform (numpy array) and the sample rate (int), + or (None, None) if synthesis fails. + """ + global onnx_session, voices_data, phonemizer_backend, text_cleaner + + if not MODEL_LOADED or onnx_session is None: + logger.error("KittenTTS model is not loaded. Cannot synthesize audio.") + return None, None + + if voice not in KITTEN_TTS_VOICES: + logger.error( + f"Voice '{voice}' not available. Available voices: {KITTEN_TTS_VOICES}" + ) + return None, None + + # Resolve voice alias to internal voice ID + resolved_voice = voice_aliases.get(voice, voice) + if resolved_voice != voice: + logger.debug(f"Resolved voice alias '{voice}' -> '{resolved_voice}'") + + # Apply speed prior for this voice if available + prior = speed_priors.get(resolved_voice, 1.0) + if prior != 1.0: + speed = speed * prior + logger.debug(f"Applied speed prior {prior} for voice '{resolved_voice}', effective speed: {speed}") + + voice = resolved_voice + + try: + logger.debug(f"Synthesizing with voice='{voice}', speed={speed}") + logger.debug(f"Input text (first 100 chars): '{text[:100]}...'") + + # Phonemize the input text + # Suppress the word count mismatch warning by temporarily adjusting log level + import logging as log_module + + phonemizer_logger = log_module.getLogger("phonemizer") + original_level = phonemizer_logger.level + phonemizer_logger.setLevel(log_module.ERROR) + + phonemes_list = phonemizer_backend.phonemize([text]) + + # Restore original log level + phonemizer_logger.setLevel(original_level) + + # Process phonemes to get token IDs + phonemes = basic_english_tokenize(phonemes_list[0]) + phonemes = " ".join(phonemes) + tokens = text_cleaner(phonemes) + + # Add start and end tokens + tokens.insert(0, 0) + tokens.append(0) + + # Determine the execution device from the session to decide where to place tensors + provider = onnx_session.get_providers()[0] + + if provider == "CUDAExecutionProvider": + # --- I/O Binding Path for GPU using NumPy --- + # Create standard NumPy arrays on the CPU first. + input_ids_np = np.array([tokens], dtype=np.int64) + ref_s_np = _get_voice_embedding(voice, len(tokens)) + speed_array_np = np.array([speed], dtype=np.float32) + + # Create OrtValues from the NumPy arrays. I/O binding will handle the copy to GPU. + input_ids_ort = ort.OrtValue.ortvalue_from_numpy(input_ids_np, "cuda", 0) + ref_s_ort = ort.OrtValue.ortvalue_from_numpy(ref_s_np, "cuda", 0) + speed_array_ort = ort.OrtValue.ortvalue_from_numpy( + speed_array_np, "cuda", 0 + ) + + # Set up I/O binding + io_binding = onnx_session.io_binding() + + # Bind the OrtValue inputs + io_binding.bind_ortvalue_input("input_ids", input_ids_ort) + io_binding.bind_ortvalue_input("style", ref_s_ort) + io_binding.bind_ortvalue_input("speed", speed_array_ort) + + # Get the actual name of the first output from the loaded model + output_name = onnx_session.get_outputs()[0].name + + # Bind the output to the GPU using the correct name + io_binding.bind_output(output_name, "cuda") + + # Run inference with binding + onnx_session.run_with_iobinding(io_binding) + + # Get the output from the binding + output_ortvalue = io_binding.get_outputs()[0] + + # The output is on the GPU. Copy it back to the CPU to be used by the rest of the app. + audio = output_ortvalue.numpy() + + else: + # --- Standard Path for CPU --- + input_ids = np.array([tokens], dtype=np.int64) + ref_s = _get_voice_embedding(voice, len(tokens)) + speed_array = np.array([speed], dtype=np.float32) + + onnx_inputs = { + "input_ids": input_ids, + "style": ref_s, + "speed": speed_array, + } + # Run standard inference + outputs = onnx_session.run(None, onnx_inputs) + audio = outputs[0] + + # KittenTTS uses 24kHz sample rate + sample_rate = 24000 + + logger.info( + f"Successfully generated {len(audio)} audio samples at {sample_rate}Hz" + ) + return audio, sample_rate + + except Exception as e: + logger.error(f"Error during KittenTTS synthesis: {e}", exc_info=True) + return None, None + + +# --- End File: engine.py --- diff --git a/models.py b/models.py index a860666..1906266 100644 --- a/models.py +++ b/models.py @@ -1,72 +1,72 @@ -# File: models.py -# Pydantic models for API request and response validation. - -from typing import Optional, Literal -from pydantic import BaseModel, Field - - -class GenerationParams(BaseModel): - """Common parameters for TTS generation.""" - - speed: Optional[float] = Field( - None, - ge=0.25, - le=4.0, - description="Speed factor for the generated audio. 1.0 is normal speed.", - ) - language: Optional[str] = Field( - None, - description="Language of the text. (Primarily for UI, actual engine may infer)", - ) - - -class CustomTTSRequest(BaseModel): - """Request model for the custom /tts endpoint.""" - - text: str = Field(..., min_length=1, description="Text to be synthesized.") - - voice: str = Field( - ..., - description="Voice identifier (e.g., 'expr-voice-5-m'). Available voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f", - ) - - output_format: Optional[Literal["wav", "opus", "mp3"]] = Field( - "wav", description="Desired audio output format." - ) - - split_text: Optional[bool] = Field( - True, - description="Whether to automatically split long text into chunks for processing.", - ) - chunk_size: Optional[int] = Field( - 120, - ge=50, - le=500, - description="Approximate target character length for text chunks when splitting is enabled (50-500).", - ) - - # Embed generation parameters directly - speed: Optional[float] = Field( - None, description="Overrides default speed if provided." - ) - language: Optional[str] = Field( - None, description="Overrides default language if provided." - ) - - -class ErrorResponse(BaseModel): - """Standard error response model for API errors.""" - - detail: str = Field(..., description="A human-readable explanation of the error.") - - -class UpdateStatusResponse(BaseModel): - """Response model for status updates, e.g., after saving settings.""" - - message: str = Field( - ..., description="A message describing the result of the operation." - ) - restart_needed: Optional[bool] = Field( - False, - description="Indicates if a server restart is recommended or required for changes to take full effect.", - ) +# File: models.py +# Pydantic models for API request and response validation. + +from typing import Optional, Literal +from pydantic import BaseModel, Field + + +class GenerationParams(BaseModel): + """Common parameters for TTS generation.""" + + speed: Optional[float] = Field( + None, + ge=0.25, + le=4.0, + description="Speed factor for the generated audio. 1.0 is normal speed.", + ) + language: Optional[str] = Field( + None, + description="Language of the text. (Primarily for UI, actual engine may infer)", + ) + + +class CustomTTSRequest(BaseModel): + """Request model for the custom /tts endpoint.""" + + text: str = Field(..., min_length=1, description="Text to be synthesized.") + + voice: str = Field( + ..., + description="Voice identifier (e.g., 'expr-voice-5-m'). Available voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f", + ) + + output_format: Optional[Literal["wav", "opus", "mp3"]] = Field( + "wav", description="Desired audio output format." + ) + + split_text: Optional[bool] = Field( + True, + description="Whether to automatically split long text into chunks for processing.", + ) + chunk_size: Optional[int] = Field( + 120, + ge=50, + le=500, + description="Approximate target character length for text chunks when splitting is enabled (50-500).", + ) + + # Embed generation parameters directly + speed: Optional[float] = Field( + None, description="Overrides default speed if provided." + ) + language: Optional[str] = Field( + None, description="Overrides default language if provided." + ) + + +class ErrorResponse(BaseModel): + """Standard error response model for API errors.""" + + detail: str = Field(..., description="A human-readable explanation of the error.") + + +class UpdateStatusResponse(BaseModel): + """Response model for status updates, e.g., after saving settings.""" + + message: str = Field( + ..., description="A message describing the result of the operation." + ) + restart_needed: Optional[bool] = Field( + False, + description="Indicates if a server restart is recommended or required for changes to take full effect.", + ) diff --git a/requirements-nvidia.txt b/requirements-nvidia.txt index f5e50a1..9c056dd 100644 --- a/requirements-nvidia.txt +++ b/requirements-nvidia.txt @@ -1,31 +1,30 @@ -# requirements-nvidia.txt (NVIDIA GPU Installation) - -# KittenTTS Core Dependencies (GPU Version) -onnxruntime-gpu -soundfile -numpy -huggingface_hub -phonemizer -misaki[en]>=0.9.4 -espeakng_loader==0.2.0 - -# Web Framework -fastapi -uvicorn[standard] - -# Configuration & Utilities -PyYAML -python-multipart -requests -Jinja2 -watchdog -aiofiles -unidecode -inflect -tqdm -hf-transfer -pydub -librosa -praat-parselmouth -torch -torchaudio \ No newline at end of file +# requirements-nvidia.txt (NVIDIA GPU Installation) + +# KittenTTS Core Dependencies (GPU Version) +# Note: onnxruntime-gpu, torch, and torchaudio are installed separately +# by install_gpu.bat to ensure correct CUDA versions. +soundfile +numpy +huggingface_hub +phonemizer +misaki[en]>=0.9.4 +espeakng_loader==0.2.0 + +# Web Framework +fastapi +uvicorn[standard] + +# Configuration & Utilities +PyYAML +python-multipart +requests +Jinja2 +watchdog +aiofiles +unidecode +inflect +tqdm +hf-transfer +pydub +librosa +praat-parselmouth \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 166d0a3..304e717 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,31 @@ -# requirements.txt (CPU Installation) - -# Web Framework -fastapi -uvicorn[standard] - -# KittenTTS Core Dependencies -onnxruntime -soundfile -numpy -huggingface_hub -phonemizer -misaki[en]>=0.9.4 -espeakng_loader==0.2.0 - -# Configuration & Utilities -PyYAML -python-multipart -requests -Jinja2 -watchdog -aiofiles -unidecode -inflect -tqdm -hf-transfer -pydub -librosa -praat-parselmouth -torch +# requirements.txt (CPU Installation) + +# Web Framework +fastapi +uvicorn[standard] + +# KittenTTS Core Dependencies +onnxruntime +soundfile +numpy +huggingface_hub +phonemizer +misaki[en]>=0.9.4 +espeakng_loader==0.2.0 + +# Configuration & Utilities +PyYAML +python-multipart +requests +Jinja2 +watchdog +aiofiles +unidecode +inflect +tqdm +hf-transfer +pydub +librosa +praat-parselmouth +torch torchaudio \ No newline at end of file diff --git a/server.py b/server.py index 3c06b2f..150fbbd 100644 --- a/server.py +++ b/server.py @@ -1,679 +1,679 @@ -# File: server.py -# Main FastAPI application for the TTS Server. -# Handles API requests for text-to-speech generation, UI serving, -# configuration management, and file uploads. - -import os -import io -import logging -import logging.handlers # For RotatingFileHandler -import shutil -import time -import uuid -import yaml # For loading presets -import numpy as np -from pathlib import Path -from contextlib import asynccontextmanager -from typing import Optional, List, Dict, Any, Literal -import webbrowser # For automatic browser opening -import threading # For automatic browser opening - -from fastapi import ( - FastAPI, - HTTPException, - Request, - File, - UploadFile, - Form, - BackgroundTasks, -) -from fastapi.responses import ( - HTMLResponse, - JSONResponse, - StreamingResponse, - FileResponse, -) -from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates -from fastapi.middleware.cors import CORSMiddleware - -# --- Internal Project Imports --- -from config import ( - config_manager, - get_host, - get_port, - get_log_file_path, - get_output_path, - get_ui_title, - get_gen_default_speed, - get_gen_default_language, - get_audio_sample_rate, - get_full_config_for_template, - get_audio_output_format, -) - -import engine # TTS Engine interface -from models import ( # Pydantic models - CustomTTSRequest, - ErrorResponse, - UpdateStatusResponse, -) -import utils # Utility functions - -from pydantic import BaseModel, Field - - -class OpenAISpeechRequest(BaseModel): - model: str - input_: str = Field(..., alias="input") - voice: str - response_format: Literal["wav", "opus", "mp3"] = "wav" # Add "mp3" - speed: float = 1.0 - seed: Optional[int] = None - - -# --- Logging Configuration --- -log_file_path_obj = get_log_file_path() -log_file_max_size_mb = config_manager.get_int("server.log_file_max_size_mb", 10) -log_backup_count = config_manager.get_int("server.log_file_backup_count", 5) - -log_file_path_obj.parent.mkdir(parents=True, exist_ok=True) - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[ - logging.handlers.RotatingFileHandler( - str(log_file_path_obj), - maxBytes=log_file_max_size_mb * 1024 * 1024, - backupCount=log_backup_count, - encoding="utf-8", - ), - logging.StreamHandler(), - ], -) -logging.getLogger("uvicorn.access").setLevel(logging.WARNING) -logging.getLogger("watchfiles").setLevel(logging.WARNING) -logger = logging.getLogger(__name__) - -# --- Global Variables & Application Setup --- -startup_complete_event = threading.Event() # For coordinating browser opening - - -def _delayed_browser_open(host: str, port: int): - """ - Waits for the startup_complete_event, then opens the web browser - to the server's main page after a short delay. - """ - try: - startup_complete_event.wait(timeout=30) - if not startup_complete_event.is_set(): - logger.warning( - "Server startup did not signal completion within timeout. Browser will not be opened automatically." - ) - return - - time.sleep(1.5) - display_host = "localhost" if host == "0.0.0.0" else host - browser_url = f"http://{display_host}:{port}/" - logger.info(f"Attempting to open web browser to: {browser_url}") - webbrowser.open(browser_url) - except Exception as e: - logger.error(f"Failed to open browser automatically: {e}", exc_info=True) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manages application startup and shutdown events.""" - logger.info("TTS Server: Initializing application...") - try: - logger.info(f"Configuration loaded. Log file at: {get_log_file_path()}") - - paths_to_ensure = [ - get_output_path(), - Path("ui"), - config_manager.get_path( - "paths.model_cache", "./model_cache", ensure_absolute=True - ), - ] - for p in paths_to_ensure: - p.mkdir(parents=True, exist_ok=True) - - if not engine.load_model(): - logger.critical( - "CRITICAL: TTS Model failed to load on startup. Server might not function correctly." - ) - else: - logger.info("TTS Model loaded successfully via engine.") - host_address = get_host() - server_port = get_port() - browser_thread = threading.Thread( - target=lambda: _delayed_browser_open(host_address, server_port), - daemon=True, - ) - browser_thread.start() - - logger.info("Application startup sequence complete.") - startup_complete_event.set() - yield - except Exception as e_startup: - logger.error( - f"FATAL ERROR during application startup: {e_startup}", exc_info=True - ) - startup_complete_event.set() - yield - finally: - logger.info("TTS Server: Application shutdown sequence initiated...") - logger.info("TTS Server: Application shutdown complete.") - - -# --- FastAPI Application Instance --- -app = FastAPI( - title=get_ui_title(), - description="Text-to-Speech server with advanced UI and API capabilities.", - version="2.0.2", # Version Bump - lifespan=lifespan, -) - -# --- CORS Middleware --- -app.add_middleware( - CORSMiddleware, - allow_origins=["*", "null"], - allow_credentials=True, - allow_methods=["GET", "POST", "OPTIONS"], - allow_headers=["*"], -) - -# --- Static Files and HTML Templates --- -ui_static_path = Path(__file__).parent / "ui" -if ui_static_path.is_dir(): - app.mount("/ui", StaticFiles(directory=ui_static_path), name="ui_static_assets") -else: - logger.warning( - f"UI static assets directory not found at '{ui_static_path}'. UI may not load correctly." - ) - -# This will serve files from 'ui_static_path/vendor' when requests come to '/vendor/*' -if (ui_static_path / "vendor").is_dir(): - app.mount( - "/vendor", StaticFiles(directory=ui_static_path / "vendor"), name="vendor_files" - ) -else: - logger.warning( - f"Vendor directory not found at '{ui_static_path}' /vendor. Wavesurfer might not load." - ) - - -@app.get("/styles.css", include_in_schema=False) -async def get_main_styles(): - styles_file = ui_static_path / "styles.css" - if styles_file.is_file(): - return FileResponse(styles_file) - raise HTTPException(status_code=404, detail="styles.css not found") - - -@app.get("/script.js", include_in_schema=False) -async def get_main_script(): - script_file = ui_static_path / "script.js" - if script_file.is_file(): - return FileResponse(script_file) - raise HTTPException(status_code=404, detail="script.js not found") - - -outputs_static_path = get_output_path(ensure_absolute=True) -try: - app.mount( - "/outputs", - StaticFiles(directory=str(outputs_static_path)), - name="generated_outputs", - ) -except RuntimeError as e_mount_outputs: - logger.error( - f"Failed to mount /outputs directory '{outputs_static_path}': {e_mount_outputs}. " - "Output files may not be accessible via URL." - ) - -templates = Jinja2Templates(directory=str(ui_static_path)) - -# --- API Endpoints --- - - -# --- Main UI Route --- -@app.get("/", response_class=HTMLResponse, include_in_schema=False) -async def get_web_ui(request: Request): - """Serves the main web interface (index.html).""" - logger.info("Request received for main UI page ('/').") - try: - return templates.TemplateResponse("index.html", {"request": request}) - except Exception as e_render: - logger.error(f"Error rendering main UI page: {e_render}", exc_info=True) - return HTMLResponse( - "

Internal Server Error

Could not load the TTS interface. " - "Please check server logs for more details.

", - status_code=500, - ) - - -# --- API Endpoint for Initial UI Data --- -@app.get("/api/ui/initial-data", tags=["UI Helpers"]) -async def get_ui_initial_data(): - """ - Provides all necessary initial data for the UI to render, - including configuration, file lists, and presets. - """ - logger.info("Request received for /api/ui/initial-data.") - try: - full_config = get_full_config_for_template() - loaded_presets = [] - presets_file = ui_static_path / "presets.yaml" - if presets_file.exists(): - with open(presets_file, "r", encoding="utf-8") as f: - yaml_content = yaml.safe_load(f) - if isinstance(yaml_content, list): - loaded_presets = yaml_content - else: - logger.warning( - f"Invalid format in {presets_file}. Expected a list, got {type(yaml_content)}." - ) - else: - logger.info( - f"Presets file not found: {presets_file}. No presets will be loaded for initial data." - ) - - initial_gen_result_placeholder = { - "outputUrl": None, - "filename": None, - "genTime": None, - "submittedVoice": None, - } - - return { - "config": full_config, - "presets": loaded_presets, - "initial_gen_result": initial_gen_result_placeholder, - "available_voices": engine.KITTEN_TTS_VOICES, - } - except Exception as e: - logger.error(f"Error preparing initial UI data for API: {e}", exc_info=True) - raise HTTPException( - status_code=500, detail="Failed to load initial data for UI." - ) - - -# --- Configuration Management API Endpoints --- -@app.post("/save_settings", response_model=UpdateStatusResponse, tags=["Configuration"]) -async def save_settings_endpoint(request: Request): - """ - Saves partial configuration updates to the config.yaml file. - Merges the update with the current configuration. - """ - logger.info("Request received for /save_settings.") - try: - partial_update = await request.json() - if not isinstance(partial_update, dict): - raise ValueError("Request body must be a JSON object for /save_settings.") - logger.debug(f"Received partial config data to save: {partial_update}") - - if config_manager.update_and_save(partial_update): - restart_needed = any( - key in partial_update - for key in ["server", "tts_engine", "paths", "model"] - ) - message = "Settings saved successfully." - if restart_needed: - message += " A server restart may be required for some changes to take full effect." - return UpdateStatusResponse(message=message, restart_needed=restart_needed) - else: - logger.error( - "Failed to save configuration via config_manager.update_and_save." - ) - raise HTTPException( - status_code=500, - detail="Failed to save configuration file due to an internal error.", - ) - except ValueError as ve: - logger.error(f"Invalid data format for /save_settings: {ve}") - raise HTTPException(status_code=400, detail=f"Invalid request data: {str(ve)}") - except Exception as e: - logger.error(f"Error processing /save_settings request: {e}", exc_info=True) - raise HTTPException( - status_code=500, - detail=f"Internal server error during settings save: {str(e)}", - ) - - -@app.post( - "/reset_settings", response_model=UpdateStatusResponse, tags=["Configuration"] -) -async def reset_settings_endpoint(): - """Resets the configuration in config.yaml back to hardcoded defaults.""" - logger.warning("Request received to reset all configurations to default values.") - try: - if config_manager.reset_and_save(): - logger.info("Configuration successfully reset to defaults and saved.") - return UpdateStatusResponse( - message="Configuration reset to defaults. Please reload the page. A server restart may be beneficial.", - restart_needed=True, - ) - else: - logger.error("Failed to reset and save configuration via config_manager.") - raise HTTPException( - status_code=500, detail="Failed to reset and save configuration file." - ) - except Exception as e: - logger.error(f"Error processing /reset_settings request: {e}", exc_info=True) - raise HTTPException( - status_code=500, - detail=f"Internal server error during settings reset: {str(e)}", - ) - - -@app.post( - "/restart_server", response_model=UpdateStatusResponse, tags=["Configuration"] -) -async def restart_server_endpoint(): - """Attempts to trigger a server restart.""" - logger.info("Request received for /restart_server.") - message = ( - "Server restart initiated. If running locally without a process manager, " - "you may need to restart manually. For managed environments (Docker, systemd), " - "the manager should handle the restart." - ) - logger.warning(message) - return UpdateStatusResponse(message=message, restart_needed=True) - - -# --- TTS Generation Endpoint --- - - -@app.post( - "/tts", - tags=["TTS Generation"], - summary="Generate speech with custom parameters", - responses={ - 200: { - "content": {"audio/wav": {}, "audio/opus": {}}, - "description": "Successful audio generation.", - }, - 400: { - "model": ErrorResponse, - "description": "Invalid request parameters or input.", - }, - 500: { - "model": ErrorResponse, - "description": "Internal server error during generation.", - }, - 503: { - "model": ErrorResponse, - "description": "TTS engine not available or model not loaded.", - }, - }, -) -async def custom_tts_endpoint( - request: CustomTTSRequest, background_tasks: BackgroundTasks -): - """ - Generates speech audio from text using specified parameters. - Returns audio as a stream (WAV or Opus). - """ - perf_monitor = utils.PerformanceMonitor( - enabled=config_manager.get_bool("server.enable_performance_monitor", False) - ) - perf_monitor.record("TTS request received") - - if not engine.MODEL_LOADED: - logger.error("TTS request failed: Model not loaded.") - raise HTTPException( - status_code=503, - detail="TTS engine model is not currently loaded or available.", - ) - - logger.info( - f"Received /tts request: voice='{request.voice}', format='{request.output_format}'" - ) - logger.debug( - f"TTS params: speed={request.speed}, split={request.split_text}, chunk_size={request.chunk_size}" - ) - logger.debug(f"Input text (first 100 chars): '{request.text[:100]}...'") - - perf_monitor.record("Parameters resolved") - - all_audio_segments_np: List[np.ndarray] = [] - final_output_sample_rate = get_audio_sample_rate() - engine_output_sample_rate: Optional[int] = None - - if request.split_text and len(request.text) > ( - request.chunk_size * 1.5 if request.chunk_size else 120 * 1.5 - ): - chunk_size_to_use = ( - request.chunk_size if request.chunk_size is not None else 120 - ) - logger.info(f"Splitting text into chunks of size ~{chunk_size_to_use}.") - text_chunks = utils.chunk_text_by_sentences(request.text, chunk_size_to_use) - perf_monitor.record(f"Text split into {len(text_chunks)} chunks") - else: - text_chunks = [request.text] - logger.info( - "Processing text as a single chunk (splitting not enabled or text too short)." - ) - - if not text_chunks: - raise HTTPException( - status_code=400, detail="Text processing resulted in no usable chunks." - ) - - for i, chunk in enumerate(text_chunks): - logger.info(f"Synthesizing chunk {i+1}/{len(text_chunks)}...") - try: - chunk_audio_np, chunk_sr_from_engine = engine.synthesize( - text=chunk, - voice=request.voice, - speed=( - request.speed - if request.speed is not None - else get_gen_default_speed() - ), - ) - perf_monitor.record(f"Engine synthesized chunk {i+1}") - - if chunk_audio_np is None or chunk_sr_from_engine is None: - error_detail = f"TTS engine failed to synthesize audio for chunk {i+1}." - logger.error(error_detail) - raise HTTPException(status_code=500, detail=error_detail) - - if engine_output_sample_rate is None: - engine_output_sample_rate = chunk_sr_from_engine - elif engine_output_sample_rate != chunk_sr_from_engine: - logger.warning( - f"Inconsistent sample rate from engine: chunk {i+1} ({chunk_sr_from_engine}Hz) " - f"differs from previous ({engine_output_sample_rate}Hz). Using first chunk's SR." - ) - - # The speed factor is now handled by the engine directly, so no post-processing for speed is needed here. - - all_audio_segments_np.append(chunk_audio_np) - - except HTTPException as http_exc: - raise http_exc - except Exception as e_chunk: - error_detail = f"Error processing audio chunk {i+1}: {str(e_chunk)}" - logger.error(error_detail, exc_info=True) - raise HTTPException(status_code=500, detail=error_detail) - - if not all_audio_segments_np: - logger.error("No audio segments were successfully generated.") - raise HTTPException( - status_code=500, detail="Audio generation resulted in no output." - ) - - if engine_output_sample_rate is None: - logger.error("Engine output sample rate could not be determined.") - raise HTTPException( - status_code=500, detail="Failed to determine engine sample rate." - ) - - try: - if len(all_audio_segments_np) > 1: - # Add silence between chunks for natural pauses - silence_duration_ms = 200 # silence between chunks - silence_samples = int( - silence_duration_ms / 1000 * engine_output_sample_rate - ) - silence_array = np.zeros(silence_samples, dtype=np.float32) - - # Apply crossfade and add silence between chunks - crossfade_samples = int(0.01 * engine_output_sample_rate) # 10ms crossfade - - merged_audio = [] - for i, chunk in enumerate(all_audio_segments_np): - if i == 0: - merged_audio.append(chunk) - else: - # Add silence gap between chunks - merged_audio.append(silence_array) - - # Then add the next chunk with optional crossfade - if ( - len(merged_audio[-2]) >= crossfade_samples - and len(chunk) >= crossfade_samples - ): - # Apply fade out to end of previous audio (before silence) - fade_out = np.linspace(1, 0, crossfade_samples) - merged_audio[-2][-crossfade_samples:] *= fade_out - - # Apply fade in to start of current chunk - fade_in = np.linspace(0, 1, crossfade_samples) - chunk_copy = chunk.copy() - chunk_copy[:crossfade_samples] *= fade_in - merged_audio.append(chunk_copy) - else: - merged_audio.append(chunk) - - final_audio_np = np.concatenate(merged_audio) - logger.debug( - f"Added {silence_duration_ms}ms silence between {len(all_audio_segments_np)} chunks" - ) - else: - final_audio_np = all_audio_segments_np[0] - - perf_monitor.record("All audio chunks processed and concatenated") - - except ValueError as e_concat: - logger.error(f"Audio concatenation failed: {e_concat}", exc_info=True) - for idx, seg in enumerate(all_audio_segments_np): - logger.error(f"Segment {idx} shape: {seg.shape}, dtype: {seg.dtype}") - raise HTTPException( - status_code=500, detail=f"Audio concatenation error: {e_concat}" - ) - - output_format_str = ( - request.output_format if request.output_format else get_audio_output_format() - ) - - encoded_audio_bytes = utils.encode_audio( - audio_array=final_audio_np, - sample_rate=engine_output_sample_rate, - output_format=output_format_str, - target_sample_rate=final_output_sample_rate, - ) - perf_monitor.record( - f"Final audio encoded to {output_format_str} (target SR: {final_output_sample_rate}Hz from engine SR: {engine_output_sample_rate}Hz)" - ) - - if encoded_audio_bytes is None or len(encoded_audio_bytes) < 100: - logger.error( - f"Failed to encode final audio to format: {output_format_str} or output is too small ({len(encoded_audio_bytes or b'')} bytes)." - ) - raise HTTPException( - status_code=500, - detail=f"Failed to encode audio to {output_format_str} or generated invalid audio.", - ) - - media_type = f"audio/{output_format_str}" - timestamp_str = time.strftime("%Y%m%d_%H%M%S") - suggested_filename_base = f"tts_output_{timestamp_str}" - download_filename = utils.sanitize_filename( - f"{suggested_filename_base}.{output_format_str}" - ) - headers = {"Content-Disposition": f'attachment; filename="{download_filename}"'} - - logger.info( - f"Successfully generated audio: {download_filename}, {len(encoded_audio_bytes)} bytes, type {media_type}." - ) - logger.debug(perf_monitor.report()) - - return StreamingResponse( - io.BytesIO(encoded_audio_bytes), media_type=media_type, headers=headers - ) - - -@app.post("/v1/audio/speech", tags=["OpenAI Compatible"]) -async def openai_speech_endpoint(request: OpenAISpeechRequest): - # Check if the TTS model is loaded - if not engine.MODEL_LOADED: - raise HTTPException( - status_code=503, - detail="TTS engine model is not currently loaded or available.", - ) - - try: - # Synthesize the audio - audio_np, sr = engine.synthesize( - text=request.input_, - voice=request.voice, - speed=request.speed, - ) - - if audio_np is None or sr is None: - raise HTTPException( - status_code=500, detail="TTS engine failed to synthesize audio." - ) - - # Ensure it's 1D - if audio_np.ndim == 2: - audio_np = audio_np.squeeze() - - # Encode the audio to the requested format - encoded_audio = utils.encode_audio( - audio_array=audio_np, - sample_rate=sr, - output_format=request.response_format, - target_sample_rate=get_audio_sample_rate(), - ) - - if encoded_audio is None: - raise HTTPException(status_code=500, detail="Failed to encode audio.") - - # Determine the media type - media_type = f"audio/{request.response_format}" - - # Return the streaming response - return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) - - except Exception as e: - logger.error(f"Error in openai_speech_endpoint: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - -# --- Main Execution --- -if __name__ == "__main__": - server_host = get_host() - server_port = get_port() - - logger.info(f"Starting TTS Server directly on http://{server_host}:{server_port}") - logger.info( - f"API documentation will be available at http://{server_host}:{server_port}/docs" - ) - logger.info(f"Web UI will be available at http://{server_host}:{server_port}/") - - import uvicorn - - uvicorn.run( - "server:app", - host=server_host, - port=server_port, - log_level="info", - workers=1, - reload=False, - ) +# File: server.py +# Main FastAPI application for the TTS Server. +# Handles API requests for text-to-speech generation, UI serving, +# configuration management, and file uploads. + +import os +import io +import logging +import logging.handlers # For RotatingFileHandler +import shutil +import time +import uuid +import yaml # For loading presets +import numpy as np +from pathlib import Path +from contextlib import asynccontextmanager +from typing import Optional, List, Dict, Any, Literal +import webbrowser # For automatic browser opening +import threading # For automatic browser opening + +from fastapi import ( + FastAPI, + HTTPException, + Request, + File, + UploadFile, + Form, + BackgroundTasks, +) +from fastapi.responses import ( + HTMLResponse, + JSONResponse, + StreamingResponse, + FileResponse, +) +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from fastapi.middleware.cors import CORSMiddleware + +# --- Internal Project Imports --- +from config import ( + config_manager, + get_host, + get_port, + get_log_file_path, + get_output_path, + get_ui_title, + get_gen_default_speed, + get_gen_default_language, + get_audio_sample_rate, + get_full_config_for_template, + get_audio_output_format, +) + +import engine # TTS Engine interface +from models import ( # Pydantic models + CustomTTSRequest, + ErrorResponse, + UpdateStatusResponse, +) +import utils # Utility functions + +from pydantic import BaseModel, Field + + +class OpenAISpeechRequest(BaseModel): + model: str + input_: str = Field(..., alias="input") + voice: str + response_format: Literal["wav", "opus", "mp3"] = "wav" # Add "mp3" + speed: float = 1.0 + seed: Optional[int] = None + + +# --- Logging Configuration --- +log_file_path_obj = get_log_file_path() +log_file_max_size_mb = config_manager.get_int("server.log_file_max_size_mb", 10) +log_backup_count = config_manager.get_int("server.log_file_backup_count", 5) + +log_file_path_obj.parent.mkdir(parents=True, exist_ok=True) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[ + logging.handlers.RotatingFileHandler( + str(log_file_path_obj), + maxBytes=log_file_max_size_mb * 1024 * 1024, + backupCount=log_backup_count, + encoding="utf-8", + ), + logging.StreamHandler(), + ], +) +logging.getLogger("uvicorn.access").setLevel(logging.WARNING) +logging.getLogger("watchfiles").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + +# --- Global Variables & Application Setup --- +startup_complete_event = threading.Event() # For coordinating browser opening + + +def _delayed_browser_open(host: str, port: int): + """ + Waits for the startup_complete_event, then opens the web browser + to the server's main page after a short delay. + """ + try: + startup_complete_event.wait(timeout=30) + if not startup_complete_event.is_set(): + logger.warning( + "Server startup did not signal completion within timeout. Browser will not be opened automatically." + ) + return + + time.sleep(1.5) + display_host = "localhost" if host == "0.0.0.0" else host + browser_url = f"http://{display_host}:{port}/" + logger.info(f"Attempting to open web browser to: {browser_url}") + webbrowser.open(browser_url) + except Exception as e: + logger.error(f"Failed to open browser automatically: {e}", exc_info=True) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manages application startup and shutdown events.""" + logger.info("TTS Server: Initializing application...") + try: + logger.info(f"Configuration loaded. Log file at: {get_log_file_path()}") + + paths_to_ensure = [ + get_output_path(), + Path("ui"), + config_manager.get_path( + "paths.model_cache", "./model_cache", ensure_absolute=True + ), + ] + for p in paths_to_ensure: + p.mkdir(parents=True, exist_ok=True) + + if not engine.load_model(): + logger.critical( + "CRITICAL: TTS Model failed to load on startup. Server might not function correctly." + ) + else: + logger.info("TTS Model loaded successfully via engine.") + host_address = get_host() + server_port = get_port() + browser_thread = threading.Thread( + target=lambda: _delayed_browser_open(host_address, server_port), + daemon=True, + ) + browser_thread.start() + + logger.info("Application startup sequence complete.") + startup_complete_event.set() + yield + except Exception as e_startup: + logger.error( + f"FATAL ERROR during application startup: {e_startup}", exc_info=True + ) + startup_complete_event.set() + yield + finally: + logger.info("TTS Server: Application shutdown sequence initiated...") + logger.info("TTS Server: Application shutdown complete.") + + +# --- FastAPI Application Instance --- +app = FastAPI( + title=get_ui_title(), + description="Text-to-Speech server with advanced UI and API capabilities.", + version="2.0.2", # Version Bump + lifespan=lifespan, +) + +# --- CORS Middleware --- +app.add_middleware( + CORSMiddleware, + allow_origins=["*", "null"], + allow_credentials=True, + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["*"], +) + +# --- Static Files and HTML Templates --- +ui_static_path = Path(__file__).parent / "ui" +if ui_static_path.is_dir(): + app.mount("/ui", StaticFiles(directory=ui_static_path), name="ui_static_assets") +else: + logger.warning( + f"UI static assets directory not found at '{ui_static_path}'. UI may not load correctly." + ) + +# This will serve files from 'ui_static_path/vendor' when requests come to '/vendor/*' +if (ui_static_path / "vendor").is_dir(): + app.mount( + "/vendor", StaticFiles(directory=ui_static_path / "vendor"), name="vendor_files" + ) +else: + logger.warning( + f"Vendor directory not found at '{ui_static_path}' /vendor. Wavesurfer might not load." + ) + + +@app.get("/styles.css", include_in_schema=False) +async def get_main_styles(): + styles_file = ui_static_path / "styles.css" + if styles_file.is_file(): + return FileResponse(styles_file) + raise HTTPException(status_code=404, detail="styles.css not found") + + +@app.get("/script.js", include_in_schema=False) +async def get_main_script(): + script_file = ui_static_path / "script.js" + if script_file.is_file(): + return FileResponse(script_file) + raise HTTPException(status_code=404, detail="script.js not found") + + +outputs_static_path = get_output_path(ensure_absolute=True) +try: + app.mount( + "/outputs", + StaticFiles(directory=str(outputs_static_path)), + name="generated_outputs", + ) +except RuntimeError as e_mount_outputs: + logger.error( + f"Failed to mount /outputs directory '{outputs_static_path}': {e_mount_outputs}. " + "Output files may not be accessible via URL." + ) + +templates = Jinja2Templates(directory=str(ui_static_path)) + +# --- API Endpoints --- + + +# --- Main UI Route --- +@app.get("/", response_class=HTMLResponse, include_in_schema=False) +async def get_web_ui(request: Request): + """Serves the main web interface (index.html).""" + logger.info("Request received for main UI page ('/').") + try: + return templates.TemplateResponse("index.html", {"request": request}) + except Exception as e_render: + logger.error(f"Error rendering main UI page: {e_render}", exc_info=True) + return HTMLResponse( + "

Internal Server Error

Could not load the TTS interface. " + "Please check server logs for more details.

", + status_code=500, + ) + + +# --- API Endpoint for Initial UI Data --- +@app.get("/api/ui/initial-data", tags=["UI Helpers"]) +async def get_ui_initial_data(): + """ + Provides all necessary initial data for the UI to render, + including configuration, file lists, and presets. + """ + logger.info("Request received for /api/ui/initial-data.") + try: + full_config = get_full_config_for_template() + loaded_presets = [] + presets_file = ui_static_path / "presets.yaml" + if presets_file.exists(): + with open(presets_file, "r", encoding="utf-8") as f: + yaml_content = yaml.safe_load(f) + if isinstance(yaml_content, list): + loaded_presets = yaml_content + else: + logger.warning( + f"Invalid format in {presets_file}. Expected a list, got {type(yaml_content)}." + ) + else: + logger.info( + f"Presets file not found: {presets_file}. No presets will be loaded for initial data." + ) + + initial_gen_result_placeholder = { + "outputUrl": None, + "filename": None, + "genTime": None, + "submittedVoice": None, + } + + return { + "config": full_config, + "presets": loaded_presets, + "initial_gen_result": initial_gen_result_placeholder, + "available_voices": engine.KITTEN_TTS_VOICES, + } + except Exception as e: + logger.error(f"Error preparing initial UI data for API: {e}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to load initial data for UI." + ) + + +# --- Configuration Management API Endpoints --- +@app.post("/save_settings", response_model=UpdateStatusResponse, tags=["Configuration"]) +async def save_settings_endpoint(request: Request): + """ + Saves partial configuration updates to the config.yaml file. + Merges the update with the current configuration. + """ + logger.info("Request received for /save_settings.") + try: + partial_update = await request.json() + if not isinstance(partial_update, dict): + raise ValueError("Request body must be a JSON object for /save_settings.") + logger.debug(f"Received partial config data to save: {partial_update}") + + if config_manager.update_and_save(partial_update): + restart_needed = any( + key in partial_update + for key in ["server", "tts_engine", "paths", "model"] + ) + message = "Settings saved successfully." + if restart_needed: + message += " A server restart may be required for some changes to take full effect." + return UpdateStatusResponse(message=message, restart_needed=restart_needed) + else: + logger.error( + "Failed to save configuration via config_manager.update_and_save." + ) + raise HTTPException( + status_code=500, + detail="Failed to save configuration file due to an internal error.", + ) + except ValueError as ve: + logger.error(f"Invalid data format for /save_settings: {ve}") + raise HTTPException(status_code=400, detail=f"Invalid request data: {str(ve)}") + except Exception as e: + logger.error(f"Error processing /save_settings request: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Internal server error during settings save: {str(e)}", + ) + + +@app.post( + "/reset_settings", response_model=UpdateStatusResponse, tags=["Configuration"] +) +async def reset_settings_endpoint(): + """Resets the configuration in config.yaml back to hardcoded defaults.""" + logger.warning("Request received to reset all configurations to default values.") + try: + if config_manager.reset_and_save(): + logger.info("Configuration successfully reset to defaults and saved.") + return UpdateStatusResponse( + message="Configuration reset to defaults. Please reload the page. A server restart may be beneficial.", + restart_needed=True, + ) + else: + logger.error("Failed to reset and save configuration via config_manager.") + raise HTTPException( + status_code=500, detail="Failed to reset and save configuration file." + ) + except Exception as e: + logger.error(f"Error processing /reset_settings request: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Internal server error during settings reset: {str(e)}", + ) + + +@app.post( + "/restart_server", response_model=UpdateStatusResponse, tags=["Configuration"] +) +async def restart_server_endpoint(): + """Attempts to trigger a server restart.""" + logger.info("Request received for /restart_server.") + message = ( + "Server restart initiated. If running locally without a process manager, " + "you may need to restart manually. For managed environments (Docker, systemd), " + "the manager should handle the restart." + ) + logger.warning(message) + return UpdateStatusResponse(message=message, restart_needed=True) + + +# --- TTS Generation Endpoint --- + + +@app.post( + "/tts", + tags=["TTS Generation"], + summary="Generate speech with custom parameters", + responses={ + 200: { + "content": {"audio/wav": {}, "audio/opus": {}}, + "description": "Successful audio generation.", + }, + 400: { + "model": ErrorResponse, + "description": "Invalid request parameters or input.", + }, + 500: { + "model": ErrorResponse, + "description": "Internal server error during generation.", + }, + 503: { + "model": ErrorResponse, + "description": "TTS engine not available or model not loaded.", + }, + }, +) +async def custom_tts_endpoint( + request: CustomTTSRequest, background_tasks: BackgroundTasks +): + """ + Generates speech audio from text using specified parameters. + Returns audio as a stream (WAV or Opus). + """ + perf_monitor = utils.PerformanceMonitor( + enabled=config_manager.get_bool("server.enable_performance_monitor", False) + ) + perf_monitor.record("TTS request received") + + if not engine.MODEL_LOADED: + logger.error("TTS request failed: Model not loaded.") + raise HTTPException( + status_code=503, + detail="TTS engine model is not currently loaded or available.", + ) + + logger.info( + f"Received /tts request: voice='{request.voice}', format='{request.output_format}'" + ) + logger.debug( + f"TTS params: speed={request.speed}, split={request.split_text}, chunk_size={request.chunk_size}" + ) + logger.debug(f"Input text (first 100 chars): '{request.text[:100]}...'") + + perf_monitor.record("Parameters resolved") + + all_audio_segments_np: List[np.ndarray] = [] + final_output_sample_rate = get_audio_sample_rate() + engine_output_sample_rate: Optional[int] = None + + if request.split_text and len(request.text) > ( + request.chunk_size * 1.5 if request.chunk_size else 120 * 1.5 + ): + chunk_size_to_use = ( + request.chunk_size if request.chunk_size is not None else 120 + ) + logger.info(f"Splitting text into chunks of size ~{chunk_size_to_use}.") + text_chunks = utils.chunk_text_by_sentences(request.text, chunk_size_to_use) + perf_monitor.record(f"Text split into {len(text_chunks)} chunks") + else: + text_chunks = [request.text] + logger.info( + "Processing text as a single chunk (splitting not enabled or text too short)." + ) + + if not text_chunks: + raise HTTPException( + status_code=400, detail="Text processing resulted in no usable chunks." + ) + + for i, chunk in enumerate(text_chunks): + logger.info(f"Synthesizing chunk {i+1}/{len(text_chunks)}...") + try: + chunk_audio_np, chunk_sr_from_engine = engine.synthesize( + text=chunk, + voice=request.voice, + speed=( + request.speed + if request.speed is not None + else get_gen_default_speed() + ), + ) + perf_monitor.record(f"Engine synthesized chunk {i+1}") + + if chunk_audio_np is None or chunk_sr_from_engine is None: + error_detail = f"TTS engine failed to synthesize audio for chunk {i+1}." + logger.error(error_detail) + raise HTTPException(status_code=500, detail=error_detail) + + if engine_output_sample_rate is None: + engine_output_sample_rate = chunk_sr_from_engine + elif engine_output_sample_rate != chunk_sr_from_engine: + logger.warning( + f"Inconsistent sample rate from engine: chunk {i+1} ({chunk_sr_from_engine}Hz) " + f"differs from previous ({engine_output_sample_rate}Hz). Using first chunk's SR." + ) + + # The speed factor is now handled by the engine directly, so no post-processing for speed is needed here. + + all_audio_segments_np.append(chunk_audio_np) + + except HTTPException as http_exc: + raise http_exc + except Exception as e_chunk: + error_detail = f"Error processing audio chunk {i+1}: {str(e_chunk)}" + logger.error(error_detail, exc_info=True) + raise HTTPException(status_code=500, detail=error_detail) + + if not all_audio_segments_np: + logger.error("No audio segments were successfully generated.") + raise HTTPException( + status_code=500, detail="Audio generation resulted in no output." + ) + + if engine_output_sample_rate is None: + logger.error("Engine output sample rate could not be determined.") + raise HTTPException( + status_code=500, detail="Failed to determine engine sample rate." + ) + + try: + if len(all_audio_segments_np) > 1: + # Add silence between chunks for natural pauses + silence_duration_ms = 200 # silence between chunks + silence_samples = int( + silence_duration_ms / 1000 * engine_output_sample_rate + ) + silence_array = np.zeros(silence_samples, dtype=np.float32) + + # Apply crossfade and add silence between chunks + crossfade_samples = int(0.01 * engine_output_sample_rate) # 10ms crossfade + + merged_audio = [] + for i, chunk in enumerate(all_audio_segments_np): + if i == 0: + merged_audio.append(chunk) + else: + # Add silence gap between chunks + merged_audio.append(silence_array) + + # Then add the next chunk with optional crossfade + if ( + len(merged_audio[-2]) >= crossfade_samples + and len(chunk) >= crossfade_samples + ): + # Apply fade out to end of previous audio (before silence) + fade_out = np.linspace(1, 0, crossfade_samples) + merged_audio[-2][-crossfade_samples:] *= fade_out + + # Apply fade in to start of current chunk + fade_in = np.linspace(0, 1, crossfade_samples) + chunk_copy = chunk.copy() + chunk_copy[:crossfade_samples] *= fade_in + merged_audio.append(chunk_copy) + else: + merged_audio.append(chunk) + + final_audio_np = np.concatenate(merged_audio) + logger.debug( + f"Added {silence_duration_ms}ms silence between {len(all_audio_segments_np)} chunks" + ) + else: + final_audio_np = all_audio_segments_np[0] + + perf_monitor.record("All audio chunks processed and concatenated") + + except ValueError as e_concat: + logger.error(f"Audio concatenation failed: {e_concat}", exc_info=True) + for idx, seg in enumerate(all_audio_segments_np): + logger.error(f"Segment {idx} shape: {seg.shape}, dtype: {seg.dtype}") + raise HTTPException( + status_code=500, detail=f"Audio concatenation error: {e_concat}" + ) + + output_format_str = ( + request.output_format if request.output_format else get_audio_output_format() + ) + + encoded_audio_bytes = utils.encode_audio( + audio_array=final_audio_np, + sample_rate=engine_output_sample_rate, + output_format=output_format_str, + target_sample_rate=final_output_sample_rate, + ) + perf_monitor.record( + f"Final audio encoded to {output_format_str} (target SR: {final_output_sample_rate}Hz from engine SR: {engine_output_sample_rate}Hz)" + ) + + if encoded_audio_bytes is None or len(encoded_audio_bytes) < 100: + logger.error( + f"Failed to encode final audio to format: {output_format_str} or output is too small ({len(encoded_audio_bytes or b'')} bytes)." + ) + raise HTTPException( + status_code=500, + detail=f"Failed to encode audio to {output_format_str} or generated invalid audio.", + ) + + media_type = f"audio/{output_format_str}" + timestamp_str = time.strftime("%Y%m%d_%H%M%S") + suggested_filename_base = f"tts_output_{timestamp_str}" + download_filename = utils.sanitize_filename( + f"{suggested_filename_base}.{output_format_str}" + ) + headers = {"Content-Disposition": f'attachment; filename="{download_filename}"'} + + logger.info( + f"Successfully generated audio: {download_filename}, {len(encoded_audio_bytes)} bytes, type {media_type}." + ) + logger.debug(perf_monitor.report()) + + return StreamingResponse( + io.BytesIO(encoded_audio_bytes), media_type=media_type, headers=headers + ) + + +@app.post("/v1/audio/speech", tags=["OpenAI Compatible"]) +async def openai_speech_endpoint(request: OpenAISpeechRequest): + # Check if the TTS model is loaded + if not engine.MODEL_LOADED: + raise HTTPException( + status_code=503, + detail="TTS engine model is not currently loaded or available.", + ) + + try: + # Synthesize the audio + audio_np, sr = engine.synthesize( + text=request.input_, + voice=request.voice, + speed=request.speed, + ) + + if audio_np is None or sr is None: + raise HTTPException( + status_code=500, detail="TTS engine failed to synthesize audio." + ) + + # Ensure it's 1D + if audio_np.ndim == 2: + audio_np = audio_np.squeeze() + + # Encode the audio to the requested format + encoded_audio = utils.encode_audio( + audio_array=audio_np, + sample_rate=sr, + output_format=request.response_format, + target_sample_rate=get_audio_sample_rate(), + ) + + if encoded_audio is None: + raise HTTPException(status_code=500, detail="Failed to encode audio.") + + # Determine the media type + media_type = f"audio/{request.response_format}" + + # Return the streaming response + return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) + + except Exception as e: + logger.error(f"Error in openai_speech_endpoint: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# --- Main Execution --- +if __name__ == "__main__": + server_host = get_host() + server_port = get_port() + + logger.info(f"Starting TTS Server directly on http://{server_host}:{server_port}") + logger.info( + f"API documentation will be available at http://{server_host}:{server_port}/docs" + ) + logger.info(f"Web UI will be available at http://{server_host}:{server_port}/") + + import uvicorn + + uvicorn.run( + "server:app", + host=server_host, + port=server_port, + log_level="info", + workers=1, + reload=False, + ) diff --git a/ui/index.html b/ui/index.html index d94b557..cfc4763 100644 --- a/ui/index.html +++ b/ui/index.html @@ -1,341 +1,341 @@ - - - - - - - Kitten TTS Server - - - - - - - -
- - - - -
-
- -
- -
-
-
-

Generate Speech

- -
- -

- Enter the text you want to convert to speech. For audiobooks, you can paste long - chapters. -

-
- -
- 0 Characters -
-
-
- -
- - - -
- - -
- -
- -
-
- -
- -
-

- Loading presets...

-
-
- -
-
- - Generation - Parameters - - - - - - -
-
- - -
-
- - -
-
- - -

- MP3 is recommended for smaller file sizes (e.g., audiobooks). -

-
-
- - -
-
-
-
- -
-
- - Server - Configuration - - - - - - -
-

- These settings are loaded from config.yaml - via an API call. - Restart the server to apply changes to - Host, Port, Model, or Path settings if modified here or directly in the - file. -

-
-
-
-
-
-
-
-
-
- -
- - - -
-
-
-
-
-
- - -
-
- -
- -
-

Tips & Tricks

-
-
-
    -
  • For **Audiobooks**, use **MP3** format, enable **Split text**, and set a chunk size - of ~250-500.
  • -
  • **KittenTTS** provides 8 high-quality voices: 4 male and 4 female expressions.
  • -
  • The model is ultra-lightweight (<25MB) and runs efficiently on both CPU and GPU. -
  • -
  • Experiment with **Speed** to adjust playback rate (0.25x to 4.0x).
  • -
  • Check the /docs endpoint for API details.
  • -
-
-
-
-
-
- - -
- - - - - - - - + + + + + + + Kitten TTS Server + + + + + + + +
+ + + + +
+
+ +
+ +
+
+
+

Generate Speech

+ +
+ +

+ Enter the text you want to convert to speech. For audiobooks, you can paste long + chapters. +

+
+ +
+ 0 Characters +
+
+
+ +
+ + + +
+ + +
+ +
+ +
+
+ +
+ +
+

+ Loading presets...

+
+
+ +
+
+ + Generation + Parameters + + + + + + +
+
+ + +
+
+ + +
+
+ + +

+ MP3 is recommended for smaller file sizes (e.g., audiobooks). +

+
+
+ + +
+
+
+
+ +
+
+ + Server + Configuration + + + + + + +
+

+ These settings are loaded from config.yaml + via an API call. + Restart the server to apply changes to + Host, Port, Model, or Path settings if modified here or directly in the + file. +

+
+
+
+
+
+
+
+
+
+ +
+ + + +
+
+
+
+
+
+ + +
+
+ +
+ +
+

Tips & Tricks

+
+
+
    +
  • For **Audiobooks**, use **MP3** format, enable **Split text**, and set a chunk size + of ~250-500.
  • +
  • **KittenTTS** provides 8 high-quality voices: 4 male and 4 female expressions.
  • +
  • The model is ultra-lightweight (<25MB) and runs efficiently on both CPU and GPU. +
  • +
  • Experiment with **Speed** to adjust playback rate (0.25x to 4.0x).
  • +
  • Check the /docs endpoint for API details.
  • +
+
+
+
+
+
+ + +
+ + + + + + + + \ No newline at end of file diff --git a/ui/presets.yaml b/ui/presets.yaml index 2fe57f3..fbff1dc 100644 --- a/ui/presets.yaml +++ b/ui/presets.yaml @@ -1,44 +1,44 @@ -# ui/presets.yaml -# Predefined examples for the Kitten TTS User Interface. -# Each preset has a name, the text to synthesize, and 'params' for generation settings. - -- name: "Standard Narration" - text: | - The solar system consists of the Sun and the astronomical objects gravitationally bound in orbit around it. - Mars, often called the Red Planet, is the fourth planet from the Sun. It is a terrestrial planet with a thin atmosphere, having surface features reminiscent both of the impact craters of the Moon and the volcanoes, valleys, deserts, and polar ice caps of Earth. - params: - voice: "expr-voice-2-f" - speed: 1.0 - language: "en" - -- name: "Expressive Monologue" - text: | - To be, or not to be, that is the question: whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune, or to take arms against a sea of troubles and by opposing end them. To die: to sleep; no more. - params: - voice: "expr-voice-5-f" - speed: 0.9 - language: "en" - -- name: "Technical Explanation" - text: | - Quantum entanglement is a physical phenomenon that occurs when pairs or groups of particles are generated, interact, or share spatial proximity in ways such that the quantum state of each particle cannot be described independently of the state of the others. - params: - voice: "expr-voice-4-m" - speed: 1.0 - language: "en" - -- name: "Upbeat Advertisement" - text: | - Are you tired of slow, unreliable connections? Upgrade today to Quantum Fiber, the fastest internet in the galaxy! Experience seamless streaming, lag-free gaming, and instant downloads. Call now! - params: - voice: "expr-voice-5-m" - speed: 1.1 - language: "en" - -- name: "Long Story Excerpt (Chunking Test)" - text: | - It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of Light, it was the season of Darkness, it was the spring of hope, it was the winter of despair, we had everything before us, we had nothing before us, we were all going direct to Heaven, we were all going direct the other way. - params: - voice: "expr-voice-3-f" - speed: 1.0 +# ui/presets.yaml +# Predefined examples for the Kitten TTS User Interface. +# Each preset has a name, the text to synthesize, and 'params' for generation settings. + +- name: "Standard Narration" + text: | + The solar system consists of the Sun and the astronomical objects gravitationally bound in orbit around it. + Mars, often called the Red Planet, is the fourth planet from the Sun. It is a terrestrial planet with a thin atmosphere, having surface features reminiscent both of the impact craters of the Moon and the volcanoes, valleys, deserts, and polar ice caps of Earth. + params: + voice: "expr-voice-2-f" + speed: 1.0 + language: "en" + +- name: "Expressive Monologue" + text: | + To be, or not to be, that is the question: whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune, or to take arms against a sea of troubles and by opposing end them. To die: to sleep; no more. + params: + voice: "expr-voice-5-f" + speed: 0.9 + language: "en" + +- name: "Technical Explanation" + text: | + Quantum entanglement is a physical phenomenon that occurs when pairs or groups of particles are generated, interact, or share spatial proximity in ways such that the quantum state of each particle cannot be described independently of the state of the others. + params: + voice: "expr-voice-4-m" + speed: 1.0 + language: "en" + +- name: "Upbeat Advertisement" + text: | + Are you tired of slow, unreliable connections? Upgrade today to Quantum Fiber, the fastest internet in the galaxy! Experience seamless streaming, lag-free gaming, and instant downloads. Call now! + params: + voice: "expr-voice-5-m" + speed: 1.1 + language: "en" + +- name: "Long Story Excerpt (Chunking Test)" + text: | + It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of Light, it was the season of Darkness, it was the spring of hope, it was the winter of despair, we had everything before us, we had nothing before us, we were all going direct to Heaven, we were all going direct the other way. + params: + voice: "expr-voice-3-f" + speed: 1.0 language: "en" \ No newline at end of file diff --git a/ui/script.js b/ui/script.js index 04c4ece..8269dc4 100644 --- a/ui/script.js +++ b/ui/script.js @@ -1,713 +1,713 @@ -// ui/script.js -// Client-side JavaScript for the Kitten TTS Server web interface. -// Handles UI interactions, API communication, audio playback, and settings management. - -document.addEventListener('DOMContentLoaded', async function () { - // --- Global Flags & State --- - let uiReady = false; - let listenersAttached = false; - let isGenerating = false; - let wavesurfer = null; - let currentAudioBlobUrl = null; - let saveStateTimeout = null; - - let currentConfig = {}; - let currentUiState = {}; - let appPresets = []; - let availableVoices = []; - - let hideGenerationWarning = false; - let currentVoice = 'expr-voice-5-m'; - - const IS_LOCAL_FILE = window.location.protocol === 'file:'; - // If you always access the server via localhost - const API_BASE_URL = IS_LOCAL_FILE ? 'http://localhost:8005' : ''; - - const DEBOUNCE_DELAY_MS = 750; - - // KittenTTS available voices - const KITTEN_TTS_VOICES = [ - 'expr-voice-2-m', 'expr-voice-2-f', 'expr-voice-3-m', 'expr-voice-3-f', - 'expr-voice-4-m', 'expr-voice-4-f', 'expr-voice-5-m', 'expr-voice-5-f' - ]; - - // --- DOM Element Selectors --- - const appTitleLink = document.getElementById('app-title-link'); - const themeToggleButton = document.getElementById('theme-toggle-btn'); - const themeSwitchThumb = themeToggleButton ? themeToggleButton.querySelector('.theme-switch-thumb') : null; - const notificationArea = document.getElementById('notification-area'); - const ttsForm = document.getElementById('tts-form'); - const ttsFormHeader = document.getElementById('tts-form-header'); - const textArea = document.getElementById('text'); - const charCount = document.getElementById('char-count'); - const generateBtn = document.getElementById('generate-btn'); - const splitTextToggle = document.getElementById('split-text-toggle'); - const chunkSizeControls = document.getElementById('chunk-size-controls'); - const chunkSizeSlider = document.getElementById('chunk-size-slider'); - const chunkSizeValue = document.getElementById('chunk-size-value'); - const chunkExplanation = document.getElementById('chunk-explanation'); - const voiceSelect = document.getElementById('voice-select'); - const presetsContainer = document.getElementById('presets-container'); - const presetsPlaceholder = document.getElementById('presets-placeholder'); - const speedSlider = document.getElementById('speed'); - const speedValueDisplay = document.getElementById('speed-value'); - const languageSelectContainer = document.getElementById('language-select-container'); - const languageSelect = document.getElementById('language'); - const outputFormatSelect = document.getElementById('output-format'); - const saveGenDefaultsBtn = document.getElementById('save-gen-defaults-btn'); - const genDefaultsStatus = document.getElementById('gen-defaults-status'); - const serverConfigForm = document.getElementById('server-config-form'); - const saveConfigBtn = document.getElementById('save-config-btn'); - const restartServerBtn = document.getElementById('restart-server-btn'); - const configStatus = document.getElementById('config-status'); - const resetSettingsBtn = document.getElementById('reset-settings-btn'); - const audioPlayerContainer = document.getElementById('audio-player-container'); - const loadingOverlay = document.getElementById('loading-overlay'); - const loadingMessage = document.getElementById('loading-message'); - const loadingStatusText = document.getElementById('loading-status'); - const loadingCancelBtn = document.getElementById('loading-cancel-btn'); - const generationWarningModal = document.getElementById('generation-warning-modal'); - const generationWarningAcknowledgeBtn = document.getElementById('generation-warning-acknowledge'); - const hideGenerationWarningCheckbox = document.getElementById('hide-generation-warning-checkbox'); - - // --- Utility Functions --- - function showNotification(message, type = 'info', duration = 5000) { - if (!notificationArea) return null; - const icons = { - success: '', - error: '', - warning: '', - info: '' - }; - const typeClassMap = { success: 'notification-success', error: 'notification-error', warning: 'notification-warning', info: 'notification-info' }; - const notificationDiv = document.createElement('div'); - notificationDiv.className = `notification-base ${typeClassMap[type] || 'notification-info'}`; - notificationDiv.setAttribute('role', 'alert'); - // Create content wrapper - const contentWrapper = document.createElement('div'); - contentWrapper.className = 'flex items-start flex-grow'; - contentWrapper.innerHTML = `${icons[type] || icons['info']} ${message}`; - - // Create close button - const closeButton = document.createElement('button'); - closeButton.type = 'button'; - closeButton.className = 'ml-auto -mx-1.5 -my-1.5 bg-transparent rounded-lg p-1.5 inline-flex h-8 w-8 items-center justify-center text-current hover:bg-slate-200 dark:hover:bg-slate-700 focus:outline-none focus:ring-2 focus:ring-slate-400 flex-shrink-0'; - closeButton.innerHTML = 'Close'; - closeButton.onclick = () => { - notificationDiv.style.transition = 'opacity 0.3s ease, transform 0.3s ease'; - notificationDiv.style.opacity = '0'; - notificationDiv.style.transform = 'translateY(-20px)'; - setTimeout(() => notificationDiv.remove(), 300); - }; - - // Add both to notification - notificationDiv.appendChild(contentWrapper); - notificationDiv.appendChild(closeButton); - notificationArea.appendChild(notificationDiv); - if (duration > 0) setTimeout(() => closeButton.click(), duration); - return notificationDiv; - } - - function formatTime(seconds) { - const minutes = Math.floor(seconds / 60); - const secs = Math.floor(seconds % 60).toString().padStart(2, '0'); - return `${minutes}:${secs}`; - } - - // --- Theme Management --- - function applyTheme(theme) { - const isDark = theme === 'dark'; - document.documentElement.classList.toggle('dark', isDark); - if (themeSwitchThumb) { - themeSwitchThumb.classList.toggle('translate-x-6', isDark); - themeSwitchThumb.classList.toggle('bg-indigo-500', isDark); - themeSwitchThumb.classList.toggle('bg-white', !isDark); - } - if (wavesurfer) { - wavesurfer.setOptions({ - waveColor: isDark ? '#6366f1' : '#a5b4fc', - progressColor: isDark ? '#4f46e5' : '#6366f1', - cursorColor: isDark ? '#cbd5e1' : '#475569', - }); - } - localStorage.setItem('uiTheme', theme); - } - - if (themeToggleButton) { - themeToggleButton.addEventListener('click', () => { - const newTheme = document.documentElement.classList.contains('dark') ? 'light' : 'dark'; - applyTheme(newTheme); - debouncedSaveState(); - }); - } - - // --- UI State Persistence --- - async function saveCurrentUiState() { - const stateToSave = { - last_text: textArea ? textArea.value : '', - last_voice: currentVoice, - last_chunk_size: chunkSizeSlider ? parseInt(chunkSizeSlider.value, 10) : 120, - last_split_text_enabled: splitTextToggle ? splitTextToggle.checked : true, - hide_generation_warning: hideGenerationWarning, - theme: localStorage.getItem('uiTheme') || 'dark' - }; - try { - const response = await fetch(`${API_BASE_URL}/save_settings`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ ui_state: stateToSave }) - }); - if (!response.ok) { - const errorResult = await response.json(); - throw new Error(errorResult.detail || `Failed to save UI state (status ${response.status})`); - } - } catch (error) { - console.error("Error saving UI state via API:", error); - showNotification(`Error saving settings: ${error.message}. Some changes may not persist.`, 'error', 0); - } - } - - function debouncedSaveState() { - // Do not save anything until the entire UI has finished its initial setup. - if (!uiReady || !listenersAttached) { return; } - clearTimeout(saveStateTimeout); - saveStateTimeout = setTimeout(saveCurrentUiState, DEBOUNCE_DELAY_MS); - } - - // --- Initial Application Setup --- - function initializeApplication() { - const preferredTheme = localStorage.getItem('uiTheme') || currentUiState.theme || 'dark'; - applyTheme(preferredTheme); - const pageTitle = currentConfig?.ui?.title || "Kitten TTS Server"; - document.title = pageTitle; - if (appTitleLink) appTitleLink.textContent = pageTitle; - if (ttsFormHeader) ttsFormHeader.textContent = `Generate Speech`; - loadInitialUiState(); - populateVoices(); - populatePresets(); - displayServerConfiguration(); - if (languageSelectContainer && currentConfig?.ui?.show_language_select === false) { - languageSelectContainer.classList.add('hidden'); - } - const initialGenResult = currentConfig.initial_gen_result; - if (initialGenResult && initialGenResult.outputUrl) { - initializeWaveSurfer(initialGenResult.outputUrl, initialGenResult); - } - } - - async function fetchInitialData() { - try { - const response = await fetch(`${API_BASE_URL}/api/ui/initial-data`); - if (!response.ok) { - const errorText = await response.text(); - throw new Error(`Failed to fetch initial UI data: ${response.status} ${response.statusText}. Server response: ${errorText}`); - } - const data = await response.json(); - currentConfig = data.config || {}; - currentUiState = currentConfig.ui_state || {}; - appPresets = data.presets || []; - availableVoices = data.available_voices || KITTEN_TTS_VOICES; - hideGenerationWarning = currentUiState.hide_generation_warning || false; - currentVoice = currentUiState.last_voice || 'expr-voice-5-m'; - - // This now ONLY sets values. It does NOT attach state-saving listeners. - initializeApplication(); - - } catch (error) { - console.error("Error fetching initial data:", error); - showNotification(`Could not load essential application data: ${error.message}. Please try refreshing.`, 'error', 0); - if (Object.keys(currentConfig).length === 0) { - currentConfig = { ui: { title: "Kitten TTS Server (Error Mode)" }, generation_defaults: {}, ui_state: {} }; - currentUiState = currentConfig.ui_state; - availableVoices = KITTEN_TTS_VOICES; - } - initializeApplication(); // Attempt to init in a degraded state - } finally { - // --- PHASE 2: Attach listeners and enable UI readiness --- - setTimeout(() => { - attachStateSavingListeners(); - listenersAttached = true; - uiReady = true; - }, 50); - } - } - - function loadInitialUiState() { - if (textArea && currentUiState.last_text) { - textArea.value = currentUiState.last_text; - if (charCount) charCount.textContent = textArea.value.length; - } - - if (splitTextToggle) splitTextToggle.checked = currentUiState.last_split_text_enabled !== undefined ? currentUiState.last_split_text_enabled : true; - if (chunkSizeSlider && currentUiState.last_chunk_size !== undefined) chunkSizeSlider.value = currentUiState.last_chunk_size; - if (chunkSizeValue) chunkSizeValue.textContent = chunkSizeSlider ? chunkSizeSlider.value : '120'; - toggleChunkControlsVisibility(); - - const genDefaults = currentConfig.generation_defaults || {}; - if (speedSlider) speedSlider.value = genDefaults.speed !== undefined ? genDefaults.speed : 1.0; - if (speedValueDisplay) speedValueDisplay.textContent = speedSlider.value; - if (languageSelect) languageSelect.value = genDefaults.language || 'en'; - if (outputFormatSelect) outputFormatSelect.value = currentConfig?.audio_output?.format || 'mp3'; - if (hideGenerationWarningCheckbox) hideGenerationWarningCheckbox.checked = hideGenerationWarning; - - if (textArea && !textArea.value && appPresets && appPresets.length > 0) { - const defaultPreset = appPresets.find(p => p.name === "Standard Narration") || appPresets; - if (defaultPreset) applyPreset(defaultPreset, false); - } - } - - function attachStateSavingListeners() { - if (textArea) textArea.addEventListener('input', () => { if (charCount) charCount.textContent = textArea.value.length; debouncedSaveState(); }); - if (voiceSelect) voiceSelect.addEventListener('change', () => { currentVoice = voiceSelect.value; debouncedSaveState(); }); - if (splitTextToggle) splitTextToggle.addEventListener('change', () => { toggleChunkControlsVisibility(); debouncedSaveState(); }); - if (chunkSizeSlider) { - chunkSizeSlider.addEventListener('input', () => { if (chunkSizeValue) chunkSizeValue.textContent = chunkSizeSlider.value; }); - chunkSizeSlider.addEventListener('change', debouncedSaveState); - } - if (speedSlider) { - speedSlider.addEventListener('input', () => { - if (speedValueDisplay) speedValueDisplay.textContent = speedSlider.value; - }); - speedSlider.addEventListener('change', debouncedSaveState); - } - if (languageSelect) languageSelect.addEventListener('change', debouncedSaveState); - if (outputFormatSelect) outputFormatSelect.addEventListener('change', debouncedSaveState); - } - - // --- Dynamic UI Population --- - function populateVoices() { - if (!voiceSelect) return; - const currentSelectedValue = voiceSelect.value; - voiceSelect.innerHTML = ''; - - availableVoices.forEach(voice => { - const option = document.createElement('option'); - option.value = voice; - // Format display name - const displayName = voice.replace('expr-voice-', 'Voice ').replace('-m', ' (Male)').replace('-f', ' (Female)'); - option.textContent = displayName; - voiceSelect.appendChild(option); - }); - - const lastSelected = currentUiState.last_voice; - if (currentSelectedValue !== 'none' && availableVoices.includes(currentSelectedValue)) { - voiceSelect.value = currentSelectedValue; - currentVoice = currentSelectedValue; - } else if (lastSelected && availableVoices.includes(lastSelected)) { - voiceSelect.value = lastSelected; - currentVoice = lastSelected; - } else { - voiceSelect.value = availableVoices || 'expr-voice-5-m'; - currentVoice = voiceSelect.value; - } - } - - function populatePresets() { - if (!presetsContainer || !appPresets) return; - if (appPresets.length === 0) { - if (presetsPlaceholder) presetsPlaceholder.textContent = 'No presets available.'; - return; - } - if (presetsPlaceholder) presetsPlaceholder.remove(); - presetsContainer.innerHTML = ''; - appPresets.forEach((preset, index) => { - const button = document.createElement('button'); - button.type = 'button'; - button.id = `preset-btn-${index}`; - button.className = 'preset-button'; - button.title = `Load '${preset.name}' text and settings`; - button.textContent = preset.name; - button.addEventListener('click', () => applyPreset(preset)); - presetsContainer.appendChild(button); - }); - } - - function applyPreset(presetData, showNotif = true) { - if (!presetData) return; - if (textArea && presetData.text !== undefined) { - textArea.value = presetData.text; - if (charCount) charCount.textContent = textArea.value.length; - } - const genParams = presetData.params || presetData; - if (speedSlider && genParams.speed !== undefined) speedSlider.value = genParams.speed; - if (languageSelect && genParams.language !== undefined) languageSelect.value = genParams.language; - if (speedValueDisplay && speedSlider) speedValueDisplay.textContent = speedSlider.value; - - if (genParams.voice && voiceSelect) { - const voiceExists = Array.from(voiceSelect.options).some(opt => opt.value === genParams.voice); - if (voiceExists) { - voiceSelect.value = genParams.voice; - currentVoice = genParams.voice; - } - } - - if (showNotif) showNotification(`Preset "${presetData.name}" loaded.`, 'info', 3000); - debouncedSaveState(); - } - - function toggleChunkControlsVisibility() { - const isChecked = splitTextToggle ? splitTextToggle.checked : false; - if (chunkSizeControls) chunkSizeControls.classList.toggle('hidden', !isChecked); - if (chunkExplanation) chunkExplanation.classList.toggle('hidden', !isChecked); - } - if (splitTextToggle) toggleChunkControlsVisibility(); - - // --- Audio Player (WaveSurfer) --- - function initializeWaveSurfer(audioUrl, resultDetails = {}) { - if (wavesurfer) { - wavesurfer.unAll(); - wavesurfer.destroy(); - wavesurfer = null; - } - if (currentAudioBlobUrl) { - URL.revokeObjectURL(currentAudioBlobUrl); - currentAudioBlobUrl = null; - } - currentAudioBlobUrl = audioUrl; - - // Ensure the container is clean or re-created - audioPlayerContainer.innerHTML = ` -
-
-

Generated Audio

-
-
-
- - - - - - Download - -
-
- Voice: -- - Gen Time: --s - Duration: --:-- -
-
-
-
`; - - // Re-select elements after recreating them - const waveformDiv = audioPlayerContainer.querySelector('#waveform'); - const playBtn = audioPlayerContainer.querySelector('#play-btn'); - const downloadLink = audioPlayerContainer.querySelector('#download-link'); - const playerVoiceSpan = audioPlayerContainer.querySelector('#player-voice'); - const playerGenTimeSpan = audioPlayerContainer.querySelector('#player-gen-time'); - const audioDurationSpan = audioPlayerContainer.querySelector('#audio-duration'); - - const audioFilename = resultDetails.filename || (typeof audioUrl === 'string' ? audioUrl.split('/').pop() : 'kitten_tts_output.wav'); - if (downloadLink) { - downloadLink.href = audioUrl; - downloadLink.download = audioFilename; - const downloadTextSpan = downloadLink.querySelector('span'); - if (downloadTextSpan) { - downloadTextSpan.textContent = `Download ${audioFilename.split('.').pop().toUpperCase()}`; - } - } - if (playerVoiceSpan) { - const displayVoice = resultDetails.submittedVoice || currentVoice || '--'; - playerVoiceSpan.textContent = displayVoice.replace('expr-voice-', 'Voice ').replace('-m', ' (Male)').replace('-f', ' (Female)'); - } - if (playerGenTimeSpan) playerGenTimeSpan.textContent = resultDetails.genTime ? `${resultDetails.genTime}s` : '--s'; - - const playIconSVG = `Play`; - const pauseIconSVG = `Pause`; - const isDark = document.documentElement.classList.contains('dark'); - - wavesurfer = WaveSurfer.create({ - container: waveformDiv, waveColor: isDark ? '#6366f1' : '#a5b4fc', progressColor: isDark ? '#4f46e5' : '#6366f1', - cursorColor: isDark ? '#cbd5e1' : '#475569', barWidth: 3, barRadius: 3, cursorWidth: 1, height: 80, barGap: 2, - responsive: true, url: audioUrl, mediaControls: false, normalize: true, - }); - - wavesurfer.on('ready', () => { - const duration = wavesurfer.getDuration(); - if (audioDurationSpan) audioDurationSpan.textContent = formatTime(duration); - if (playBtn) { playBtn.disabled = false; playBtn.innerHTML = playIconSVG; } - if (downloadLink) { downloadLink.classList.remove('opacity-50', 'pointer-events-none'); downloadLink.setAttribute('aria-disabled', 'false'); } - }); - wavesurfer.on('play', () => { if (playBtn) playBtn.innerHTML = pauseIconSVG; }); - wavesurfer.on('pause', () => { if (playBtn) playBtn.innerHTML = playIconSVG; }); - wavesurfer.on('finish', () => { if (playBtn) playBtn.innerHTML = playIconSVG; wavesurfer.seekTo(0); }); - wavesurfer.on('error', (err) => { - console.error("WaveSurfer error:", err); - showNotification(`Error loading audio waveform: ${err.message || err}`, 'error'); - if (waveformDiv) waveformDiv.innerHTML = `

Could not load waveform.

`; - if (playBtn) playBtn.disabled = true; - }); - - if (playBtn) { - playBtn.onclick = () => { - if (wavesurfer) { - wavesurfer.playPause(); - } - }; - } - setTimeout(() => audioPlayerContainer.scrollIntoView({ behavior: 'smooth', block: 'nearest' }), 150); - } - - // --- TTS Generation Logic --- - function getTTSFormData() { - const jsonData = { - text: textArea.value, - voice: currentVoice, - speed: parseFloat(speedSlider.value), - language: languageSelect.value, - split_text: splitTextToggle.checked, - chunk_size: parseInt(chunkSizeSlider.value, 10), - output_format: outputFormatSelect.value || 'mp3' - }; - return jsonData; - } - - async function submitTTSRequest() { - isGenerating = true; - showLoadingOverlay(); - const startTime = performance.now(); - const jsonData = getTTSFormData(); - try { - const response = await fetch(`${API_BASE_URL}/tts`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(jsonData) - }); - if (!response.ok) { - const errorResult = await response.json().catch(() => ({ detail: `HTTP error ${response.status}` })); - throw new Error(errorResult.detail || 'TTS generation failed.'); - } - const audioBlob = await response.blob(); - const endTime = performance.now(); - const genTime = ((endTime - startTime) / 1000).toFixed(2); - const contentDisposition = response.headers.get('Content-Disposition'); - const filenameFromServer = contentDisposition - ? contentDisposition.split('filename=')[1]?.replace(/"/g, '') - : 'kitten_tts_output.wav'; - const resultDetails = { - outputUrl: URL.createObjectURL(audioBlob), filename: filenameFromServer, genTime: genTime, - submittedVoice: jsonData.voice - }; - initializeWaveSurfer(resultDetails.outputUrl, resultDetails); - showNotification('Audio generated successfully!', 'success'); - } catch (error) { - console.error('TTS Generation Error:', error); - showNotification(error.message || 'An unknown error occurred during TTS generation.', 'error'); - } finally { - isGenerating = false; - hideLoadingOverlay(); - } - } - - // --- Attach main generation event to the button's CLICK --- - if (generateBtn) { - generateBtn.addEventListener('click', function (event) { - event.preventDefault(); - - if (isGenerating) { - showNotification("Generation is already in progress.", "warning"); - return; - } - const textContent = textArea.value.trim(); - if (!textContent) { - showNotification("Please enter some text to generate speech.", 'error'); - return; - } - if (!currentVoice || currentVoice === 'none') { - showNotification("Please select a voice.", 'error'); - return; - } - - // Check for the generation quality warning. - if (!hideGenerationWarning) { - showGenerationWarningModal(); - return; - } - - submitTTSRequest(); - }); - } - - // --- Modal Handling --- - function showGenerationWarningModal() { - if (generationWarningModal) { - generationWarningModal.style.display = 'flex'; - generationWarningModal.classList.remove('hidden', 'opacity-0'); - generationWarningModal.dataset.state = 'open'; - } - } - function hideGenerationWarningModal() { - if (generationWarningModal) { - generationWarningModal.classList.add('opacity-0'); - setTimeout(() => { - generationWarningModal.style.display = 'none'; - generationWarningModal.dataset.state = 'closed'; - }, 300); - } - } - if (generationWarningAcknowledgeBtn) generationWarningAcknowledgeBtn.addEventListener('click', () => { - if (hideGenerationWarningCheckbox && hideGenerationWarningCheckbox.checked) hideGenerationWarning = true; - hideGenerationWarningModal(); debouncedSaveState(); submitTTSRequest(); - }); - if (loadingCancelBtn) loadingCancelBtn.addEventListener('click', () => { - if (isGenerating) { isGenerating = false; hideLoadingOverlay(); showNotification("Generation UI cancelled by user.", "info"); } - }); - function showLoadingOverlay() { - if (loadingOverlay && generateBtn && loadingCancelBtn) { - loadingMessage.textContent = 'Generating audio...'; - loadingStatusText.textContent = 'Please wait. This may take some time.'; - loadingOverlay.style.display = 'flex'; - loadingOverlay.classList.remove('hidden', 'opacity-0'); loadingOverlay.dataset.state = 'open'; - generateBtn.disabled = true; loadingCancelBtn.disabled = false; - } - } - function hideLoadingOverlay() { - if (loadingOverlay && generateBtn) { - loadingOverlay.classList.add('opacity-0'); - setTimeout(() => { - loadingOverlay.style.display = 'none'; - loadingOverlay.dataset.state = 'closed'; - }, 300); - generateBtn.disabled = false; - } - } - - // --- Configuration Management --- - function displayServerConfiguration() { - if (!serverConfigForm || !currentConfig || Object.keys(currentConfig).length === 0) return; - const fieldsToDisplay = { - "server.host": currentConfig.server?.host, "server.port": currentConfig.server?.port, - "tts_engine.device": currentConfig.tts_engine?.device, "model.repo_id": currentConfig.model?.repo_id, - "paths.model_cache": currentConfig.paths?.model_cache, "paths.output": currentConfig.paths?.output, - "audio_output.format": currentConfig.audio_output?.format, "audio_output.sample_rate": currentConfig.audio_output?.sample_rate - }; - for (const name in fieldsToDisplay) { - const input = serverConfigForm.querySelector(`input[name="${name}"]`); - if (input) { - input.value = fieldsToDisplay[name] !== undefined ? fieldsToDisplay[name] : ''; - if (name.includes('.host') || name.includes('.port') || name.includes('.device') || name.includes('paths.')) input.readOnly = true; - else input.readOnly = false; - } - } - } - async function updateConfigStatus(button, statusElem, message, type = 'info', duration = 5000, enableButtonAfter = true) { - const statusClasses = { success: 'text-green-600 dark:text-green-400', error: 'text-red-600 dark:text-red-400', warning: 'text-yellow-600 dark:text-yellow-400', info: 'text-indigo-600 dark:text-indigo-400', processing: 'text-yellow-600 dark:text-yellow-400 animate-pulse' }; - const isProcessing = message.toLowerCase().includes('saving') || message.toLowerCase().includes('restarting') || message.toLowerCase().includes('resetting'); - const messageType = isProcessing ? 'processing' : type; - if (statusElem) { - statusElem.textContent = message; - statusElem.className = `text-xs ml-2 ${statusClasses[messageType] || statusClasses['info']}`; - statusElem.classList.remove('hidden'); - } - if (button) button.disabled = isProcessing || (type === 'error' && !enableButtonAfter) || (type === 'success' && !enableButtonAfter); - if (duration > 0) setTimeout(() => { if (statusElem) statusElem.classList.add('hidden'); if (button && enableButtonAfter) button.disabled = false; }, duration); - else if (button && enableButtonAfter && !isProcessing) button.disabled = false; - } - - if (saveConfigBtn && configStatus) { - saveConfigBtn.addEventListener('click', async () => { - const configDataToSave = {}; - const inputs = serverConfigForm.querySelectorAll('input[name]:not([readonly]), select[name]:not([readonly])'); - inputs.forEach(input => { - const keys = input.name.split('.'); let currentLevel = configDataToSave; - keys.forEach((key, index) => { - if (index === keys.length - 1) { - let value = input.value; - if (input.type === 'number') value = parseFloat(value) || 0; - else if (input.type === 'checkbox') value = input.checked; - currentLevel[key] = value; - } else { currentLevel[key] = currentLevel[key] || {}; currentLevel = currentLevel[key]; } - }); - }); - if (Object.keys(configDataToSave).length === 0) { showNotification("No editable configuration values to save.", "info"); return; } - updateConfigStatus(saveConfigBtn, configStatus, 'Saving configuration...', 'info', 0, false); - try { - const response = await fetch(`${API_BASE_URL}/save_settings`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(configDataToSave) - }); - const result = await response.json(); - if (!response.ok) throw new Error(result.detail || 'Failed to save configuration'); - updateConfigStatus(saveConfigBtn, configStatus, result.message || 'Configuration saved.', 'success', 5000); - if (result.restart_needed && restartServerBtn) restartServerBtn.classList.remove('hidden'); - await fetchInitialData(); - showNotification("Configuration saved. Some changes may require a server restart if prompted.", "success"); - } catch (error) { - console.error('Error saving server config:', error); - updateConfigStatus(saveConfigBtn, configStatus, `Error: ${error.message}`, 'error', 0); - } - }); - } - - if (saveGenDefaultsBtn && genDefaultsStatus) { - saveGenDefaultsBtn.addEventListener('click', async () => { - const genParams = { - speed: parseFloat(speedSlider.value), - language: languageSelect.value - }; - updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, 'Saving generation defaults...', 'info', 0, false); - try { - const response = await fetch(`${API_BASE_URL}/save_settings`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ generation_defaults: genParams }) - }); - const result = await response.json(); - if (!response.ok) throw new Error(result.detail || 'Failed to save generation defaults'); - updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, result.message || 'Generation defaults saved.', 'success', 5000); - if (currentConfig.generation_defaults) Object.assign(currentConfig.generation_defaults, genParams); - } catch (error) { - console.error('Error saving generation defaults:', error); - updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, `Error: ${error.message}`, 'error', 0); - } - }); - } - - if (resetSettingsBtn) { - resetSettingsBtn.addEventListener('click', async () => { - if (!confirm("Are you sure you want to reset ALL settings to their initial defaults? This will affect config.yaml and UI preferences. This action cannot be undone.")) return; - updateConfigStatus(resetSettingsBtn, configStatus, 'Resetting settings...', 'info', 0, false); - try { - const response = await fetch(`${API_BASE_URL}/reset_settings`, { - method: 'POST' - }); - if (!response.ok) { - const errorResult = await response.json().catch(() => ({ detail: 'Failed to reset settings on server.' })); - throw new Error(errorResult.detail); - } - const result = await response.json(); - updateConfigStatus(resetSettingsBtn, configStatus, result.message + " Reloading page...", 'success', 0, false); - setTimeout(() => window.location.reload(true), 2000); - } catch (error) { - console.error('Error resetting settings:', error); - updateConfigStatus(resetSettingsBtn, configStatus, `Reset Error: ${error.message}`, 'error', 0); - showNotification(`Error resetting settings: ${error.message}`, 'error'); - } - }); - } - - if (restartServerBtn) { - restartServerBtn.addEventListener('click', async () => { - if (!confirm("Are you sure you want to restart the server?")) return; - updateConfigStatus(restartServerBtn, configStatus, 'Attempting server restart...', 'processing', 0, false); - try { - const response = await fetch(`${API_BASE_URL}/restart_server`, { - method: 'POST' - }); - const result = await response.json(); - if (!response.ok) throw new Error(result.detail || 'Server responded with error on restart command'); - showNotification("Server restart initiated. Please wait a moment for the server to come back online, then refresh the page.", "info", 10000); - } catch (error) { - showNotification(`Server restart command failed: ${error.message}`, "error"); - updateConfigStatus(restartServerBtn, configStatus, `Restart failed.`, 'error', 5000, true); - } - }); - } - - await fetchInitialData(); +// ui/script.js +// Client-side JavaScript for the Kitten TTS Server web interface. +// Handles UI interactions, API communication, audio playback, and settings management. + +document.addEventListener('DOMContentLoaded', async function () { + // --- Global Flags & State --- + let uiReady = false; + let listenersAttached = false; + let isGenerating = false; + let wavesurfer = null; + let currentAudioBlobUrl = null; + let saveStateTimeout = null; + + let currentConfig = {}; + let currentUiState = {}; + let appPresets = []; + let availableVoices = []; + + let hideGenerationWarning = false; + let currentVoice = 'expr-voice-5-m'; + + const IS_LOCAL_FILE = window.location.protocol === 'file:'; + // If you always access the server via localhost + const API_BASE_URL = IS_LOCAL_FILE ? 'http://localhost:8005' : ''; + + const DEBOUNCE_DELAY_MS = 750; + + // KittenTTS available voices + const KITTEN_TTS_VOICES = [ + 'expr-voice-2-m', 'expr-voice-2-f', 'expr-voice-3-m', 'expr-voice-3-f', + 'expr-voice-4-m', 'expr-voice-4-f', 'expr-voice-5-m', 'expr-voice-5-f' + ]; + + // --- DOM Element Selectors --- + const appTitleLink = document.getElementById('app-title-link'); + const themeToggleButton = document.getElementById('theme-toggle-btn'); + const themeSwitchThumb = themeToggleButton ? themeToggleButton.querySelector('.theme-switch-thumb') : null; + const notificationArea = document.getElementById('notification-area'); + const ttsForm = document.getElementById('tts-form'); + const ttsFormHeader = document.getElementById('tts-form-header'); + const textArea = document.getElementById('text'); + const charCount = document.getElementById('char-count'); + const generateBtn = document.getElementById('generate-btn'); + const splitTextToggle = document.getElementById('split-text-toggle'); + const chunkSizeControls = document.getElementById('chunk-size-controls'); + const chunkSizeSlider = document.getElementById('chunk-size-slider'); + const chunkSizeValue = document.getElementById('chunk-size-value'); + const chunkExplanation = document.getElementById('chunk-explanation'); + const voiceSelect = document.getElementById('voice-select'); + const presetsContainer = document.getElementById('presets-container'); + const presetsPlaceholder = document.getElementById('presets-placeholder'); + const speedSlider = document.getElementById('speed'); + const speedValueDisplay = document.getElementById('speed-value'); + const languageSelectContainer = document.getElementById('language-select-container'); + const languageSelect = document.getElementById('language'); + const outputFormatSelect = document.getElementById('output-format'); + const saveGenDefaultsBtn = document.getElementById('save-gen-defaults-btn'); + const genDefaultsStatus = document.getElementById('gen-defaults-status'); + const serverConfigForm = document.getElementById('server-config-form'); + const saveConfigBtn = document.getElementById('save-config-btn'); + const restartServerBtn = document.getElementById('restart-server-btn'); + const configStatus = document.getElementById('config-status'); + const resetSettingsBtn = document.getElementById('reset-settings-btn'); + const audioPlayerContainer = document.getElementById('audio-player-container'); + const loadingOverlay = document.getElementById('loading-overlay'); + const loadingMessage = document.getElementById('loading-message'); + const loadingStatusText = document.getElementById('loading-status'); + const loadingCancelBtn = document.getElementById('loading-cancel-btn'); + const generationWarningModal = document.getElementById('generation-warning-modal'); + const generationWarningAcknowledgeBtn = document.getElementById('generation-warning-acknowledge'); + const hideGenerationWarningCheckbox = document.getElementById('hide-generation-warning-checkbox'); + + // --- Utility Functions --- + function showNotification(message, type = 'info', duration = 5000) { + if (!notificationArea) return null; + const icons = { + success: '', + error: '', + warning: '', + info: '' + }; + const typeClassMap = { success: 'notification-success', error: 'notification-error', warning: 'notification-warning', info: 'notification-info' }; + const notificationDiv = document.createElement('div'); + notificationDiv.className = `notification-base ${typeClassMap[type] || 'notification-info'}`; + notificationDiv.setAttribute('role', 'alert'); + // Create content wrapper + const contentWrapper = document.createElement('div'); + contentWrapper.className = 'flex items-start flex-grow'; + contentWrapper.innerHTML = `${icons[type] || icons['info']} ${message}`; + + // Create close button + const closeButton = document.createElement('button'); + closeButton.type = 'button'; + closeButton.className = 'ml-auto -mx-1.5 -my-1.5 bg-transparent rounded-lg p-1.5 inline-flex h-8 w-8 items-center justify-center text-current hover:bg-slate-200 dark:hover:bg-slate-700 focus:outline-none focus:ring-2 focus:ring-slate-400 flex-shrink-0'; + closeButton.innerHTML = 'Close'; + closeButton.onclick = () => { + notificationDiv.style.transition = 'opacity 0.3s ease, transform 0.3s ease'; + notificationDiv.style.opacity = '0'; + notificationDiv.style.transform = 'translateY(-20px)'; + setTimeout(() => notificationDiv.remove(), 300); + }; + + // Add both to notification + notificationDiv.appendChild(contentWrapper); + notificationDiv.appendChild(closeButton); + notificationArea.appendChild(notificationDiv); + if (duration > 0) setTimeout(() => closeButton.click(), duration); + return notificationDiv; + } + + function formatTime(seconds) { + const minutes = Math.floor(seconds / 60); + const secs = Math.floor(seconds % 60).toString().padStart(2, '0'); + return `${minutes}:${secs}`; + } + + // --- Theme Management --- + function applyTheme(theme) { + const isDark = theme === 'dark'; + document.documentElement.classList.toggle('dark', isDark); + if (themeSwitchThumb) { + themeSwitchThumb.classList.toggle('translate-x-6', isDark); + themeSwitchThumb.classList.toggle('bg-indigo-500', isDark); + themeSwitchThumb.classList.toggle('bg-white', !isDark); + } + if (wavesurfer) { + wavesurfer.setOptions({ + waveColor: isDark ? '#6366f1' : '#a5b4fc', + progressColor: isDark ? '#4f46e5' : '#6366f1', + cursorColor: isDark ? '#cbd5e1' : '#475569', + }); + } + localStorage.setItem('uiTheme', theme); + } + + if (themeToggleButton) { + themeToggleButton.addEventListener('click', () => { + const newTheme = document.documentElement.classList.contains('dark') ? 'light' : 'dark'; + applyTheme(newTheme); + debouncedSaveState(); + }); + } + + // --- UI State Persistence --- + async function saveCurrentUiState() { + const stateToSave = { + last_text: textArea ? textArea.value : '', + last_voice: currentVoice, + last_chunk_size: chunkSizeSlider ? parseInt(chunkSizeSlider.value, 10) : 120, + last_split_text_enabled: splitTextToggle ? splitTextToggle.checked : true, + hide_generation_warning: hideGenerationWarning, + theme: localStorage.getItem('uiTheme') || 'dark' + }; + try { + const response = await fetch(`${API_BASE_URL}/save_settings`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ ui_state: stateToSave }) + }); + if (!response.ok) { + const errorResult = await response.json(); + throw new Error(errorResult.detail || `Failed to save UI state (status ${response.status})`); + } + } catch (error) { + console.error("Error saving UI state via API:", error); + showNotification(`Error saving settings: ${error.message}. Some changes may not persist.`, 'error', 0); + } + } + + function debouncedSaveState() { + // Do not save anything until the entire UI has finished its initial setup. + if (!uiReady || !listenersAttached) { return; } + clearTimeout(saveStateTimeout); + saveStateTimeout = setTimeout(saveCurrentUiState, DEBOUNCE_DELAY_MS); + } + + // --- Initial Application Setup --- + function initializeApplication() { + const preferredTheme = localStorage.getItem('uiTheme') || currentUiState.theme || 'dark'; + applyTheme(preferredTheme); + const pageTitle = currentConfig?.ui?.title || "Kitten TTS Server"; + document.title = pageTitle; + if (appTitleLink) appTitleLink.textContent = pageTitle; + if (ttsFormHeader) ttsFormHeader.textContent = `Generate Speech`; + loadInitialUiState(); + populateVoices(); + populatePresets(); + displayServerConfiguration(); + if (languageSelectContainer && currentConfig?.ui?.show_language_select === false) { + languageSelectContainer.classList.add('hidden'); + } + const initialGenResult = currentConfig.initial_gen_result; + if (initialGenResult && initialGenResult.outputUrl) { + initializeWaveSurfer(initialGenResult.outputUrl, initialGenResult); + } + } + + async function fetchInitialData() { + try { + const response = await fetch(`${API_BASE_URL}/api/ui/initial-data`); + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Failed to fetch initial UI data: ${response.status} ${response.statusText}. Server response: ${errorText}`); + } + const data = await response.json(); + currentConfig = data.config || {}; + currentUiState = currentConfig.ui_state || {}; + appPresets = data.presets || []; + availableVoices = data.available_voices || KITTEN_TTS_VOICES; + hideGenerationWarning = currentUiState.hide_generation_warning || false; + currentVoice = currentUiState.last_voice || 'expr-voice-5-m'; + + // This now ONLY sets values. It does NOT attach state-saving listeners. + initializeApplication(); + + } catch (error) { + console.error("Error fetching initial data:", error); + showNotification(`Could not load essential application data: ${error.message}. Please try refreshing.`, 'error', 0); + if (Object.keys(currentConfig).length === 0) { + currentConfig = { ui: { title: "Kitten TTS Server (Error Mode)" }, generation_defaults: {}, ui_state: {} }; + currentUiState = currentConfig.ui_state; + availableVoices = KITTEN_TTS_VOICES; + } + initializeApplication(); // Attempt to init in a degraded state + } finally { + // --- PHASE 2: Attach listeners and enable UI readiness --- + setTimeout(() => { + attachStateSavingListeners(); + listenersAttached = true; + uiReady = true; + }, 50); + } + } + + function loadInitialUiState() { + if (textArea && currentUiState.last_text) { + textArea.value = currentUiState.last_text; + if (charCount) charCount.textContent = textArea.value.length; + } + + if (splitTextToggle) splitTextToggle.checked = currentUiState.last_split_text_enabled !== undefined ? currentUiState.last_split_text_enabled : true; + if (chunkSizeSlider && currentUiState.last_chunk_size !== undefined) chunkSizeSlider.value = currentUiState.last_chunk_size; + if (chunkSizeValue) chunkSizeValue.textContent = chunkSizeSlider ? chunkSizeSlider.value : '120'; + toggleChunkControlsVisibility(); + + const genDefaults = currentConfig.generation_defaults || {}; + if (speedSlider) speedSlider.value = genDefaults.speed !== undefined ? genDefaults.speed : 1.0; + if (speedValueDisplay) speedValueDisplay.textContent = speedSlider.value; + if (languageSelect) languageSelect.value = genDefaults.language || 'en'; + if (outputFormatSelect) outputFormatSelect.value = currentConfig?.audio_output?.format || 'mp3'; + if (hideGenerationWarningCheckbox) hideGenerationWarningCheckbox.checked = hideGenerationWarning; + + if (textArea && !textArea.value && appPresets && appPresets.length > 0) { + const defaultPreset = appPresets.find(p => p.name === "Standard Narration") || appPresets; + if (defaultPreset) applyPreset(defaultPreset, false); + } + } + + function attachStateSavingListeners() { + if (textArea) textArea.addEventListener('input', () => { if (charCount) charCount.textContent = textArea.value.length; debouncedSaveState(); }); + if (voiceSelect) voiceSelect.addEventListener('change', () => { currentVoice = voiceSelect.value; debouncedSaveState(); }); + if (splitTextToggle) splitTextToggle.addEventListener('change', () => { toggleChunkControlsVisibility(); debouncedSaveState(); }); + if (chunkSizeSlider) { + chunkSizeSlider.addEventListener('input', () => { if (chunkSizeValue) chunkSizeValue.textContent = chunkSizeSlider.value; }); + chunkSizeSlider.addEventListener('change', debouncedSaveState); + } + if (speedSlider) { + speedSlider.addEventListener('input', () => { + if (speedValueDisplay) speedValueDisplay.textContent = speedSlider.value; + }); + speedSlider.addEventListener('change', debouncedSaveState); + } + if (languageSelect) languageSelect.addEventListener('change', debouncedSaveState); + if (outputFormatSelect) outputFormatSelect.addEventListener('change', debouncedSaveState); + } + + // --- Dynamic UI Population --- + function populateVoices() { + if (!voiceSelect) return; + const currentSelectedValue = voiceSelect.value; + voiceSelect.innerHTML = ''; + + availableVoices.forEach(voice => { + const option = document.createElement('option'); + option.value = voice; + // Format display name + const displayName = voice.replace('expr-voice-', 'Voice ').replace('-m', ' (Male)').replace('-f', ' (Female)'); + option.textContent = displayName; + voiceSelect.appendChild(option); + }); + + const lastSelected = currentUiState.last_voice; + if (currentSelectedValue !== 'none' && availableVoices.includes(currentSelectedValue)) { + voiceSelect.value = currentSelectedValue; + currentVoice = currentSelectedValue; + } else if (lastSelected && availableVoices.includes(lastSelected)) { + voiceSelect.value = lastSelected; + currentVoice = lastSelected; + } else { + voiceSelect.value = availableVoices || 'expr-voice-5-m'; + currentVoice = voiceSelect.value; + } + } + + function populatePresets() { + if (!presetsContainer || !appPresets) return; + if (appPresets.length === 0) { + if (presetsPlaceholder) presetsPlaceholder.textContent = 'No presets available.'; + return; + } + if (presetsPlaceholder) presetsPlaceholder.remove(); + presetsContainer.innerHTML = ''; + appPresets.forEach((preset, index) => { + const button = document.createElement('button'); + button.type = 'button'; + button.id = `preset-btn-${index}`; + button.className = 'preset-button'; + button.title = `Load '${preset.name}' text and settings`; + button.textContent = preset.name; + button.addEventListener('click', () => applyPreset(preset)); + presetsContainer.appendChild(button); + }); + } + + function applyPreset(presetData, showNotif = true) { + if (!presetData) return; + if (textArea && presetData.text !== undefined) { + textArea.value = presetData.text; + if (charCount) charCount.textContent = textArea.value.length; + } + const genParams = presetData.params || presetData; + if (speedSlider && genParams.speed !== undefined) speedSlider.value = genParams.speed; + if (languageSelect && genParams.language !== undefined) languageSelect.value = genParams.language; + if (speedValueDisplay && speedSlider) speedValueDisplay.textContent = speedSlider.value; + + if (genParams.voice && voiceSelect) { + const voiceExists = Array.from(voiceSelect.options).some(opt => opt.value === genParams.voice); + if (voiceExists) { + voiceSelect.value = genParams.voice; + currentVoice = genParams.voice; + } + } + + if (showNotif) showNotification(`Preset "${presetData.name}" loaded.`, 'info', 3000); + debouncedSaveState(); + } + + function toggleChunkControlsVisibility() { + const isChecked = splitTextToggle ? splitTextToggle.checked : false; + if (chunkSizeControls) chunkSizeControls.classList.toggle('hidden', !isChecked); + if (chunkExplanation) chunkExplanation.classList.toggle('hidden', !isChecked); + } + if (splitTextToggle) toggleChunkControlsVisibility(); + + // --- Audio Player (WaveSurfer) --- + function initializeWaveSurfer(audioUrl, resultDetails = {}) { + if (wavesurfer) { + wavesurfer.unAll(); + wavesurfer.destroy(); + wavesurfer = null; + } + if (currentAudioBlobUrl) { + URL.revokeObjectURL(currentAudioBlobUrl); + currentAudioBlobUrl = null; + } + currentAudioBlobUrl = audioUrl; + + // Ensure the container is clean or re-created + audioPlayerContainer.innerHTML = ` +
+
+

Generated Audio

+
+
+
+ + + + + + Download + +
+
+ Voice: -- + Gen Time: --s + Duration: --:-- +
+
+
+
`; + + // Re-select elements after recreating them + const waveformDiv = audioPlayerContainer.querySelector('#waveform'); + const playBtn = audioPlayerContainer.querySelector('#play-btn'); + const downloadLink = audioPlayerContainer.querySelector('#download-link'); + const playerVoiceSpan = audioPlayerContainer.querySelector('#player-voice'); + const playerGenTimeSpan = audioPlayerContainer.querySelector('#player-gen-time'); + const audioDurationSpan = audioPlayerContainer.querySelector('#audio-duration'); + + const audioFilename = resultDetails.filename || (typeof audioUrl === 'string' ? audioUrl.split('/').pop() : 'kitten_tts_output.wav'); + if (downloadLink) { + downloadLink.href = audioUrl; + downloadLink.download = audioFilename; + const downloadTextSpan = downloadLink.querySelector('span'); + if (downloadTextSpan) { + downloadTextSpan.textContent = `Download ${audioFilename.split('.').pop().toUpperCase()}`; + } + } + if (playerVoiceSpan) { + const displayVoice = resultDetails.submittedVoice || currentVoice || '--'; + playerVoiceSpan.textContent = displayVoice.replace('expr-voice-', 'Voice ').replace('-m', ' (Male)').replace('-f', ' (Female)'); + } + if (playerGenTimeSpan) playerGenTimeSpan.textContent = resultDetails.genTime ? `${resultDetails.genTime}s` : '--s'; + + const playIconSVG = `Play`; + const pauseIconSVG = `Pause`; + const isDark = document.documentElement.classList.contains('dark'); + + wavesurfer = WaveSurfer.create({ + container: waveformDiv, waveColor: isDark ? '#6366f1' : '#a5b4fc', progressColor: isDark ? '#4f46e5' : '#6366f1', + cursorColor: isDark ? '#cbd5e1' : '#475569', barWidth: 3, barRadius: 3, cursorWidth: 1, height: 80, barGap: 2, + responsive: true, url: audioUrl, mediaControls: false, normalize: true, + }); + + wavesurfer.on('ready', () => { + const duration = wavesurfer.getDuration(); + if (audioDurationSpan) audioDurationSpan.textContent = formatTime(duration); + if (playBtn) { playBtn.disabled = false; playBtn.innerHTML = playIconSVG; } + if (downloadLink) { downloadLink.classList.remove('opacity-50', 'pointer-events-none'); downloadLink.setAttribute('aria-disabled', 'false'); } + }); + wavesurfer.on('play', () => { if (playBtn) playBtn.innerHTML = pauseIconSVG; }); + wavesurfer.on('pause', () => { if (playBtn) playBtn.innerHTML = playIconSVG; }); + wavesurfer.on('finish', () => { if (playBtn) playBtn.innerHTML = playIconSVG; wavesurfer.seekTo(0); }); + wavesurfer.on('error', (err) => { + console.error("WaveSurfer error:", err); + showNotification(`Error loading audio waveform: ${err.message || err}`, 'error'); + if (waveformDiv) waveformDiv.innerHTML = `

Could not load waveform.

`; + if (playBtn) playBtn.disabled = true; + }); + + if (playBtn) { + playBtn.onclick = () => { + if (wavesurfer) { + wavesurfer.playPause(); + } + }; + } + setTimeout(() => audioPlayerContainer.scrollIntoView({ behavior: 'smooth', block: 'nearest' }), 150); + } + + // --- TTS Generation Logic --- + function getTTSFormData() { + const jsonData = { + text: textArea.value, + voice: currentVoice, + speed: parseFloat(speedSlider.value), + language: languageSelect.value, + split_text: splitTextToggle.checked, + chunk_size: parseInt(chunkSizeSlider.value, 10), + output_format: outputFormatSelect.value || 'mp3' + }; + return jsonData; + } + + async function submitTTSRequest() { + isGenerating = true; + showLoadingOverlay(); + const startTime = performance.now(); + const jsonData = getTTSFormData(); + try { + const response = await fetch(`${API_BASE_URL}/tts`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(jsonData) + }); + if (!response.ok) { + const errorResult = await response.json().catch(() => ({ detail: `HTTP error ${response.status}` })); + throw new Error(errorResult.detail || 'TTS generation failed.'); + } + const audioBlob = await response.blob(); + const endTime = performance.now(); + const genTime = ((endTime - startTime) / 1000).toFixed(2); + const contentDisposition = response.headers.get('Content-Disposition'); + const filenameFromServer = contentDisposition + ? contentDisposition.split('filename=')[1]?.replace(/"/g, '') + : 'kitten_tts_output.wav'; + const resultDetails = { + outputUrl: URL.createObjectURL(audioBlob), filename: filenameFromServer, genTime: genTime, + submittedVoice: jsonData.voice + }; + initializeWaveSurfer(resultDetails.outputUrl, resultDetails); + showNotification('Audio generated successfully!', 'success'); + } catch (error) { + console.error('TTS Generation Error:', error); + showNotification(error.message || 'An unknown error occurred during TTS generation.', 'error'); + } finally { + isGenerating = false; + hideLoadingOverlay(); + } + } + + // --- Attach main generation event to the button's CLICK --- + if (generateBtn) { + generateBtn.addEventListener('click', function (event) { + event.preventDefault(); + + if (isGenerating) { + showNotification("Generation is already in progress.", "warning"); + return; + } + const textContent = textArea.value.trim(); + if (!textContent) { + showNotification("Please enter some text to generate speech.", 'error'); + return; + } + if (!currentVoice || currentVoice === 'none') { + showNotification("Please select a voice.", 'error'); + return; + } + + // Check for the generation quality warning. + if (!hideGenerationWarning) { + showGenerationWarningModal(); + return; + } + + submitTTSRequest(); + }); + } + + // --- Modal Handling --- + function showGenerationWarningModal() { + if (generationWarningModal) { + generationWarningModal.style.display = 'flex'; + generationWarningModal.classList.remove('hidden', 'opacity-0'); + generationWarningModal.dataset.state = 'open'; + } + } + function hideGenerationWarningModal() { + if (generationWarningModal) { + generationWarningModal.classList.add('opacity-0'); + setTimeout(() => { + generationWarningModal.style.display = 'none'; + generationWarningModal.dataset.state = 'closed'; + }, 300); + } + } + if (generationWarningAcknowledgeBtn) generationWarningAcknowledgeBtn.addEventListener('click', () => { + if (hideGenerationWarningCheckbox && hideGenerationWarningCheckbox.checked) hideGenerationWarning = true; + hideGenerationWarningModal(); debouncedSaveState(); submitTTSRequest(); + }); + if (loadingCancelBtn) loadingCancelBtn.addEventListener('click', () => { + if (isGenerating) { isGenerating = false; hideLoadingOverlay(); showNotification("Generation UI cancelled by user.", "info"); } + }); + function showLoadingOverlay() { + if (loadingOverlay && generateBtn && loadingCancelBtn) { + loadingMessage.textContent = 'Generating audio...'; + loadingStatusText.textContent = 'Please wait. This may take some time.'; + loadingOverlay.style.display = 'flex'; + loadingOverlay.classList.remove('hidden', 'opacity-0'); loadingOverlay.dataset.state = 'open'; + generateBtn.disabled = true; loadingCancelBtn.disabled = false; + } + } + function hideLoadingOverlay() { + if (loadingOverlay && generateBtn) { + loadingOverlay.classList.add('opacity-0'); + setTimeout(() => { + loadingOverlay.style.display = 'none'; + loadingOverlay.dataset.state = 'closed'; + }, 300); + generateBtn.disabled = false; + } + } + + // --- Configuration Management --- + function displayServerConfiguration() { + if (!serverConfigForm || !currentConfig || Object.keys(currentConfig).length === 0) return; + const fieldsToDisplay = { + "server.host": currentConfig.server?.host, "server.port": currentConfig.server?.port, + "tts_engine.device": currentConfig.tts_engine?.device, "model.repo_id": currentConfig.model?.repo_id, + "paths.model_cache": currentConfig.paths?.model_cache, "paths.output": currentConfig.paths?.output, + "audio_output.format": currentConfig.audio_output?.format, "audio_output.sample_rate": currentConfig.audio_output?.sample_rate + }; + for (const name in fieldsToDisplay) { + const input = serverConfigForm.querySelector(`input[name="${name}"]`); + if (input) { + input.value = fieldsToDisplay[name] !== undefined ? fieldsToDisplay[name] : ''; + if (name.includes('.host') || name.includes('.port') || name.includes('.device') || name.includes('paths.')) input.readOnly = true; + else input.readOnly = false; + } + } + } + async function updateConfigStatus(button, statusElem, message, type = 'info', duration = 5000, enableButtonAfter = true) { + const statusClasses = { success: 'text-green-600 dark:text-green-400', error: 'text-red-600 dark:text-red-400', warning: 'text-yellow-600 dark:text-yellow-400', info: 'text-indigo-600 dark:text-indigo-400', processing: 'text-yellow-600 dark:text-yellow-400 animate-pulse' }; + const isProcessing = message.toLowerCase().includes('saving') || message.toLowerCase().includes('restarting') || message.toLowerCase().includes('resetting'); + const messageType = isProcessing ? 'processing' : type; + if (statusElem) { + statusElem.textContent = message; + statusElem.className = `text-xs ml-2 ${statusClasses[messageType] || statusClasses['info']}`; + statusElem.classList.remove('hidden'); + } + if (button) button.disabled = isProcessing || (type === 'error' && !enableButtonAfter) || (type === 'success' && !enableButtonAfter); + if (duration > 0) setTimeout(() => { if (statusElem) statusElem.classList.add('hidden'); if (button && enableButtonAfter) button.disabled = false; }, duration); + else if (button && enableButtonAfter && !isProcessing) button.disabled = false; + } + + if (saveConfigBtn && configStatus) { + saveConfigBtn.addEventListener('click', async () => { + const configDataToSave = {}; + const inputs = serverConfigForm.querySelectorAll('input[name]:not([readonly]), select[name]:not([readonly])'); + inputs.forEach(input => { + const keys = input.name.split('.'); let currentLevel = configDataToSave; + keys.forEach((key, index) => { + if (index === keys.length - 1) { + let value = input.value; + if (input.type === 'number') value = parseFloat(value) || 0; + else if (input.type === 'checkbox') value = input.checked; + currentLevel[key] = value; + } else { currentLevel[key] = currentLevel[key] || {}; currentLevel = currentLevel[key]; } + }); + }); + if (Object.keys(configDataToSave).length === 0) { showNotification("No editable configuration values to save.", "info"); return; } + updateConfigStatus(saveConfigBtn, configStatus, 'Saving configuration...', 'info', 0, false); + try { + const response = await fetch(`${API_BASE_URL}/save_settings`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(configDataToSave) + }); + const result = await response.json(); + if (!response.ok) throw new Error(result.detail || 'Failed to save configuration'); + updateConfigStatus(saveConfigBtn, configStatus, result.message || 'Configuration saved.', 'success', 5000); + if (result.restart_needed && restartServerBtn) restartServerBtn.classList.remove('hidden'); + await fetchInitialData(); + showNotification("Configuration saved. Some changes may require a server restart if prompted.", "success"); + } catch (error) { + console.error('Error saving server config:', error); + updateConfigStatus(saveConfigBtn, configStatus, `Error: ${error.message}`, 'error', 0); + } + }); + } + + if (saveGenDefaultsBtn && genDefaultsStatus) { + saveGenDefaultsBtn.addEventListener('click', async () => { + const genParams = { + speed: parseFloat(speedSlider.value), + language: languageSelect.value + }; + updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, 'Saving generation defaults...', 'info', 0, false); + try { + const response = await fetch(`${API_BASE_URL}/save_settings`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ generation_defaults: genParams }) + }); + const result = await response.json(); + if (!response.ok) throw new Error(result.detail || 'Failed to save generation defaults'); + updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, result.message || 'Generation defaults saved.', 'success', 5000); + if (currentConfig.generation_defaults) Object.assign(currentConfig.generation_defaults, genParams); + } catch (error) { + console.error('Error saving generation defaults:', error); + updateConfigStatus(saveGenDefaultsBtn, genDefaultsStatus, `Error: ${error.message}`, 'error', 0); + } + }); + } + + if (resetSettingsBtn) { + resetSettingsBtn.addEventListener('click', async () => { + if (!confirm("Are you sure you want to reset ALL settings to their initial defaults? This will affect config.yaml and UI preferences. This action cannot be undone.")) return; + updateConfigStatus(resetSettingsBtn, configStatus, 'Resetting settings...', 'info', 0, false); + try { + const response = await fetch(`${API_BASE_URL}/reset_settings`, { + method: 'POST' + }); + if (!response.ok) { + const errorResult = await response.json().catch(() => ({ detail: 'Failed to reset settings on server.' })); + throw new Error(errorResult.detail); + } + const result = await response.json(); + updateConfigStatus(resetSettingsBtn, configStatus, result.message + " Reloading page...", 'success', 0, false); + setTimeout(() => window.location.reload(true), 2000); + } catch (error) { + console.error('Error resetting settings:', error); + updateConfigStatus(resetSettingsBtn, configStatus, `Reset Error: ${error.message}`, 'error', 0); + showNotification(`Error resetting settings: ${error.message}`, 'error'); + } + }); + } + + if (restartServerBtn) { + restartServerBtn.addEventListener('click', async () => { + if (!confirm("Are you sure you want to restart the server?")) return; + updateConfigStatus(restartServerBtn, configStatus, 'Attempting server restart...', 'processing', 0, false); + try { + const response = await fetch(`${API_BASE_URL}/restart_server`, { + method: 'POST' + }); + const result = await response.json(); + if (!response.ok) throw new Error(result.detail || 'Server responded with error on restart command'); + showNotification("Server restart initiated. Please wait a moment for the server to come back online, then refresh the page.", "info", 10000); + } catch (error) { + showNotification(`Server restart command failed: ${error.message}`, "error"); + updateConfigStatus(restartServerBtn, configStatus, `Restart failed.`, 'error', 5000, true); + } + }); + } + + await fetchInitialData(); }); \ No newline at end of file diff --git a/utils.py b/utils.py index cd7dc90..d757b5d 100644 --- a/utils.py +++ b/utils.py @@ -1,1144 +1,1144 @@ -# utils.py -# Utility functions for the TTS server application. -# This module includes functions for audio processing, text manipulation, -# file system operations, and performance monitoring. - -import os -import logging -import re -import time -import io -import uuid -from pathlib import Path -from typing import Optional, Tuple, Dict, Any, Set, List -from pydub import AudioSegment - -import numpy as np -import soundfile as sf -import torchaudio # For saving PyTorch tensors and potentially speed adjustment. -import torch - -# Configuration manager to get paths dynamically. -# Assumes config.py and its config_manager are in the same directory or accessible via PYTHONPATH. -from config import config_manager - -# Optional import for librosa (for audio resampling, e.g., Opus encoding and time stretching) -try: - import librosa - - LIBROSA_AVAILABLE = True - logger = logging.getLogger( - __name__ - ) # Initialize logger here if librosa is available - logger.info( - "Librosa library found and will be used for audio resampling and time stretching." - ) -except ImportError: - LIBROSA_AVAILABLE = False - logger = logging.getLogger(__name__) - logger.warning( - "Librosa library not found. Advanced audio resampling features (e.g., for Opus encoding) " - "and pitch-preserving speed adjustment will be limited. Speed adjustment will fall back to basic method if enabled." - ) - -# Optional import for Parselmouth (for unvoiced segment detection) -try: - import parselmouth - - PARSELMOUTH_AVAILABLE = True - logger.info( - "Parselmouth library found and will be used for unvoiced segment removal if enabled." - ) -except ImportError: - PARSELMOUTH_AVAILABLE = False - logger.warning( - "Parselmouth library not found. Unvoiced segment removal feature will be disabled." - ) - - -# --- Filename Sanitization --- -def sanitize_filename(filename: str) -> str: - """ - Removes potentially unsafe characters and path components from a filename - to make it safe for use in file paths. Replaces unsafe sequences with underscores. - - Args: - filename: The original filename string. - - Returns: - A sanitized filename string, ensuring it's not empty and reasonably short. - """ - if not filename: - # Generate a unique name if the input is empty. - return f"unnamed_file_{uuid.uuid4().hex[:8]}" - - # Remove directory separators and leading/trailing whitespace. - base_filename = Path(filename).name.strip() - if not base_filename: - return f"empty_basename_{uuid.uuid4().hex[:8]}" - - # Define a set of allowed characters (alphanumeric, underscore, hyphen, dot, space). - # Spaces will be replaced by underscores later. - safe_chars = set( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._- " - ) - sanitized_list = [] - last_char_was_underscore = False - - for char in base_filename: - if char in safe_chars: - # Replace spaces with underscores. - sanitized_list.append("_" if char == " " else char) - last_char_was_underscore = char == " " - elif not last_char_was_underscore: - # Replace any disallowed character sequence with a single underscore. - sanitized_list.append("_") - last_char_was_underscore = True - - sanitized = "".join(sanitized_list).strip("_") - - # Prevent names starting with multiple dots or consisting only of dots/underscores. - if not sanitized or sanitized.lstrip("._") == "": - return f"sanitized_file_{uuid.uuid4().hex[:8]}" - - # Limit filename length (e.g., 100 characters), preserving the extension. - max_len = 100 - if len(sanitized) > max_len: - name_part, ext_part = os.path.splitext(sanitized) - # Ensure extension is not overly long itself; common extensions are short. - ext_part = ext_part[:10] # Limit extension length just in case. - name_part = name_part[ - : max_len - len(ext_part) - 1 - ] # -1 for the dot if ext exists - sanitized = name_part + ext_part - logger.warning( - f"Original filename '{base_filename}' was truncated to '{sanitized}' due to length limits." - ) - - if not sanitized: # Should not happen with previous checks, but as a failsafe. - return f"final_fallback_name_{uuid.uuid4().hex[:8]}" - - return sanitized - - -# --- Constants for Text Processing --- -# Set of common abbreviations to help with sentence splitting. -ABBREVIATIONS: Set[str] = { - "mr.", - "mrs.", - "ms.", - "dr.", - "prof.", - "rev.", - "hon.", - "st.", - "etc.", - "e.g.", - "i.e.", - "vs.", - "approx.", - "apt.", - "dept.", - "fig.", - "gen.", - "gov.", - "inc.", - "jr.", - "sr.", - "ltd.", - "no.", - "p.", - "pp.", - "vol.", - "op.", - "cit.", - "ca.", - "cf.", - "ed.", - "esp.", - "et.", - "al.", - "ibid.", - "id.", - "inf.", - "sup.", - "viz.", - "sc.", - "fl.", - "d.", - "b.", - "r.", - "c.", - "v.", - "u.s.", - "u.k.", - "a.m.", - "p.m.", - "a.d.", - "b.c.", -} - -# Common titles that might appear without a period if cleaned by other means first. -TITLES_NO_PERIOD: Set[str] = { - "mr", - "mrs", - "ms", - "dr", - "prof", - "rev", - "hon", - "st", - "sgt", - "capt", - "lt", - "col", - "gen", -} - -# Regex patterns (pre-compiled for efficiency in text processing). -NUMBER_DOT_NUMBER_PATTERN = re.compile( - r"(? Optional[bytes]: - """ - Encodes a NumPy audio array into the specified format (Opus or WAV) in memory. - Can resample the audio to a target sample rate before encoding if specified. - - Args: - audio_array: NumPy array containing audio data (expected as float32, range [-1, 1]). - sample_rate: Sample rate of the input audio data. - output_format: Desired output format ('opus', 'wav' or 'mp3'). - target_sample_rate: Optional target sample rate to resample to before encoding. - - Returns: - Bytes object containing the encoded audio, or None if encoding fails. - """ - if audio_array is None or audio_array.size == 0: - logger.warning("encode_audio received empty or None audio array.") - return None - - # Ensure audio is float32 for consistent processing. - if audio_array.dtype != np.float32: - if np.issubdtype(audio_array.dtype, np.integer): - max_val = np.iinfo(audio_array.dtype).max - audio_array = audio_array.astype(np.float32) / max_val - else: # Fallback for other types, assuming they might be float64 or similar - audio_array = audio_array.astype(np.float32) - logger.debug(f"Converted audio array to float32 for encoding.") - - # Ensure audio is mono if it's (samples, 1) - if audio_array.ndim == 2 and audio_array.shape[1] == 1: - audio_array = audio_array.squeeze(axis=1) - logger.debug( - "Squeezed audio array from (samples, 1) to (samples,) for encoding." - ) - elif ( - audio_array.ndim > 1 - ): # Multi-channel not directly supported by simple encoding path, attempt to take first channel - logger.warning( - f"Multi-channel audio (shape: {audio_array.shape}) provided to encode_audio. Using only the first channel." - ) - audio_array = audio_array[:, 0] - - # Resample if target_sample_rate is provided and different from current sample_rate - if ( - target_sample_rate is not None - and target_sample_rate != sample_rate - and LIBROSA_AVAILABLE - ): - try: - logger.info( - f"Resampling audio from {sample_rate}Hz to {target_sample_rate}Hz using Librosa." - ) - audio_array = librosa.resample( - y=audio_array, orig_sr=sample_rate, target_sr=target_sample_rate - ) - sample_rate = ( - target_sample_rate # Update sample_rate for subsequent encoding - ) - except Exception as e_resample: - logger.error( - f"Error resampling audio to {target_sample_rate}Hz: {e_resample}. Proceeding with original sample rate {sample_rate}.", - exc_info=True, - ) - elif target_sample_rate is not None and target_sample_rate != sample_rate: - logger.warning( - f"Librosa not available. Cannot resample audio from {sample_rate}Hz to {target_sample_rate}Hz. " - f"Proceeding with original sample rate for encoding." - ) - - start_time = time.time() - output_buffer = io.BytesIO() - - try: - audio_to_write = audio_array - rate_to_write = sample_rate - - if output_format == "opus": - OPUS_SUPPORTED_RATES = {8000, 12000, 16000, 24000, 48000} - TARGET_OPUS_RATE = 48000 # Preferred Opus rate. - - if rate_to_write not in OPUS_SUPPORTED_RATES: - if LIBROSA_AVAILABLE: - logger.warning( - f"Current sample rate {rate_to_write}Hz not directly supported by Opus. " - f"Resampling to {TARGET_OPUS_RATE}Hz using Librosa for Opus encoding." - ) - audio_to_write = librosa.resample( - y=audio_array, orig_sr=rate_to_write, target_sr=TARGET_OPUS_RATE - ) - rate_to_write = TARGET_OPUS_RATE - else: - logger.error( - f"Librosa not available. Cannot resample audio from {rate_to_write}Hz for Opus encoding. " - f"Opus encoding may fail or produce poor quality." - ) - # Proceed with current rate, soundfile might handle it or fail. - sf.write( - output_buffer, - audio_to_write, - rate_to_write, - format="ogg", - subtype="opus", - ) - - elif output_format == "wav": - # WAV typically uses int16 for broader compatibility. - # Clip audio to [-1.0, 1.0] before converting to int16 to prevent overflow. - audio_clipped = np.clip(audio_array, -1.0, 1.0) - audio_int16 = (audio_clipped * 32767).astype(np.int16) - audio_to_write = audio_int16 # Use the int16 version for WAV - sf.write( - output_buffer, - audio_to_write, - rate_to_write, - format="wav", - subtype="pcm_16", - ) - - elif output_format == "mp3": - audio_clipped = np.clip(audio_array, -1.0, 1.0) - audio_int16 = (audio_clipped * 32767).astype(np.int16) - audio_segment = AudioSegment( - audio_int16.tobytes(), - frame_rate=sample_rate, - sample_width=2, - channels=1, - ) - audio_segment.export(output_buffer, format="mp3") - - else: - logger.error( - f"Unsupported output format requested for encoding: {output_format}" - ) - return None - - encoded_bytes = output_buffer.getvalue() - end_time = time.time() - logger.info( - f"Encoded {len(encoded_bytes)} bytes to '{output_format}' at {rate_to_write}Hz in {end_time - start_time:.3f} seconds." - ) - return encoded_bytes - - except ImportError as ie_sf: # Specifically for soundfile import issues - logger.critical( - f"The 'soundfile' library or its dependency (libsndfile) is not installed or found. " - f"Audio encoding/saving is not possible. Please install it. Error: {ie_sf}" - ) - return None - except Exception as e: - logger.error(f"Error encoding audio to '{output_format}': {e}", exc_info=True) - return None - - -def save_audio_to_file( - audio_array: np.ndarray, sample_rate: int, file_path_str: str -) -> bool: - """ - Saves a NumPy audio array to a WAV file. - - Args: - audio_array: NumPy array containing audio data (float32, range [-1, 1]). - sample_rate: Sample rate of the audio data. - file_path_str: String path to save the WAV file. - - Returns: - True if saving was successful, False otherwise. - """ - if audio_array is None or audio_array.size == 0: - logger.warning("save_audio_to_file received empty or None audio array.") - return False - - file_path = Path(file_path_str) - if file_path.suffix.lower() != ".wav": - logger.warning( - f"File path '{file_path_str}' does not end with .wav. Appending .wav extension." - ) - file_path = file_path.with_suffix(".wav") - - start_time = time.time() - try: - # Ensure output directory exists. - file_path.parent.mkdir(parents=True, exist_ok=True) - - # Prepare audio for WAV (int16, clipped). - if ( - audio_array.dtype != np.float32 - ): # Ensure float32 before potential scaling to int16 - if np.issubdtype(audio_array.dtype, np.integer): - max_val = np.iinfo(audio_array.dtype).max - audio_array = audio_array.astype(np.float32) / max_val - else: - audio_array = audio_array.astype(np.float32) - - audio_clipped = np.clip(audio_array, -1.0, 1.0) - audio_int16 = (audio_clipped * 32767).astype(np.int16) - - sf.write( - str(file_path), audio_int16, sample_rate, format="wav", subtype="pcm_16" - ) - end_time = time.time() - logger.info( - f"Saved WAV file to {file_path} in {end_time - start_time:.3f} seconds." - ) - return True - except ImportError: - logger.critical("SoundFile library not found. Cannot save audio.") - return False - except Exception as e: - logger.error(f"Error saving WAV file to {file_path}: {e}", exc_info=True) - return False - - -def save_audio_tensor_to_file( - audio_tensor: torch.Tensor, - sample_rate: int, - file_path_str: str, - output_format: str = "wav", -) -> bool: - """ - Saves a PyTorch audio tensor to a file using torchaudio. - - Args: - audio_tensor: PyTorch tensor containing audio data. - sample_rate: Sample rate of the audio data. - file_path_str: String path to save the audio file. - output_format: Desired output format (passed to torchaudio.save). - - Returns: - True if saving was successful, False otherwise. - """ - if audio_tensor is None or audio_tensor.numel() == 0: - logger.warning("save_audio_tensor_to_file received empty or None audio tensor.") - return False - - file_path = Path(file_path_str) - start_time = time.time() - try: - file_path.parent.mkdir(parents=True, exist_ok=True) - # torchaudio.save expects tensor on CPU. - audio_tensor_cpu = audio_tensor.cpu() - # Ensure tensor is 2D (channels, samples) for torchaudio.save. - if audio_tensor_cpu.ndim == 1: - audio_tensor_cpu = audio_tensor_cpu.unsqueeze(0) - - torchaudio.save( - str(file_path), audio_tensor_cpu, sample_rate, format=output_format - ) - end_time = time.time() - logger.info( - f"Saved audio tensor to {file_path} (format: {output_format}) in {end_time - start_time:.3f} seconds." - ) - return True - except Exception as e: - logger.error(f"Error saving audio tensor to {file_path}: {e}", exc_info=True) - return False - - -# --- Audio Manipulation Utilities --- -def apply_speed_factor( - audio_tensor: torch.Tensor, sample_rate: int, speed_factor: float -) -> Tuple[torch.Tensor, int]: - """ - Applies a speed factor to an audio tensor. - Uses librosa.effects.time_stretch if available for pitch preservation. - Falls back to simple resampling via torchaudio.transforms.Resample if librosa is not available, - which will alter pitch. - - Args: - audio_tensor: Input audio waveform (PyTorch tensor, expected mono). - sample_rate: Sample rate of the input audio. - speed_factor: Desired speed factor (e.g., 1.0 is normal, 1.5 is faster, 0.5 is slower). - - Returns: - A tuple of the speed-adjusted audio tensor and its sample rate (which remains unchanged). - Returns the original tensor and sample rate if speed_factor is 1.0 or if adjustment fails. - """ - if speed_factor == 1.0: - return audio_tensor, sample_rate - if speed_factor <= 0: - logger.warning( - f"Invalid speed_factor {speed_factor}. Must be positive. Returning original audio." - ) - return audio_tensor, sample_rate - - audio_tensor_cpu = audio_tensor.cpu() - # Ensure tensor is 1D mono for librosa and consistent handling - if audio_tensor_cpu.ndim == 2: - if audio_tensor_cpu.shape[0] == 1: - audio_tensor_cpu = audio_tensor_cpu.squeeze(0) - elif audio_tensor_cpu.shape[1] == 1: - audio_tensor_cpu = audio_tensor_cpu.squeeze(1) - else: # True stereo or multi-channel - logger.warning( - f"apply_speed_factor received multi-channel audio (shape {audio_tensor_cpu.shape}). Using first channel only." - ) - audio_tensor_cpu = audio_tensor_cpu[0, :] - - if audio_tensor_cpu.ndim != 1: - logger.error( - f"apply_speed_factor: audio_tensor_cpu is not 1D after processing (shape {audio_tensor_cpu.shape}). Returning original audio." - ) - return audio_tensor, sample_rate - - if LIBROSA_AVAILABLE: - try: - audio_np = audio_tensor_cpu.numpy() - # librosa.effects.time_stretch changes duration, not sample rate directly. - # The 'rate' parameter in time_stretch is equivalent to speed_factor. - stretched_audio_np = librosa.effects.time_stretch( - y=audio_np, rate=speed_factor - ) - speed_adjusted_tensor = torch.from_numpy(stretched_audio_np) - logger.info( - f"Applied speed factor {speed_factor} using librosa.effects.time_stretch. Original SR: {sample_rate}" - ) - return speed_adjusted_tensor, sample_rate # Sample rate is preserved - except Exception as e_librosa: - logger.error( - f"Failed to apply speed factor {speed_factor} using librosa: {e_librosa}. " - f"Falling back to basic resampling (pitch will change).", - exc_info=True, - ) - # Fallback to simple resampling (changes pitch) - try: - new_sample_rate_for_speedup = int(sample_rate / speed_factor) - resampler = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=new_sample_rate_for_speedup - ) - # Resample to new_sample_rate_for_speedup to change duration, then resample back to original SR - # This is effectively what sox 'speed' does, but 'tempo' is better (which librosa does) - # For simplicity in fallback, just resample and note pitch change - # To actually change speed without changing sample rate and preserving pitch using *only* torchaudio is more complex - # and typically involves phase vocoder or similar, which is beyond a simple fallback. - # The torchaudio.functional.pitch_shift and then torchaudio.functional.speed is one way, - # but librosa is simpler. - # Given the instruction "Fallback to original audio" if librosa not available or fails, we'll stick to that. - # Original plan: "If Librosa is not available, log a warning and return the original audio" - logger.warning( - f"Librosa failed for speed factor. Returning original audio as primary fallback." - ) - return audio_tensor, sample_rate - - except Exception as e_resample_fallback: - logger.error( - f"Fallback resampling for speed factor {speed_factor} also failed: {e_resample_fallback}. Returning original audio.", - exc_info=True, - ) - return audio_tensor, sample_rate - - else: # Librosa not available - logger.warning( - f"Librosa not available for pitch-preserving speed adjustment (factor: {speed_factor}). " - f"Returning original audio. Install librosa for this feature." - ) - return audio_tensor, sample_rate - - -def trim_lead_trail_silence( - audio_array: np.ndarray, - sample_rate: int, - silence_threshold_db: float = -40.0, - min_silence_duration_ms: int = 100, - padding_ms: int = 50, -) -> np.ndarray: - """ - Trims silence from the beginning and end of a NumPy audio array using a dB threshold. - - Args: - audio_array: NumPy array (float32) of the audio. - sample_rate: Sample rate of the audio. - silence_threshold_db: Silence threshold in dBFS. Segments below this are considered silent. - min_silence_duration_ms: Minimum duration of silence to be trimmed (ms). - padding_ms: Padding to leave at the start/end after trimming (ms). - - Returns: - Trimmed NumPy audio array. Returns original if no significant silence is found or on error. - """ - if audio_array is None or audio_array.size == 0: - return audio_array - - try: - if not LIBROSA_AVAILABLE: - logger.warning("Librosa not available, skipping silence trimming.") - return audio_array - - top_db_threshold = abs(silence_threshold_db) - - frame_length = 2048 - hop_length = 512 - - trimmed_audio, index = librosa.effects.trim( - y=audio_array, - top_db=top_db_threshold, - frame_length=frame_length, - hop_length=hop_length, - ) - - start_sample, end_sample = index[0], index[1] - - padding_samples = int((padding_ms / 1000.0) * sample_rate) - final_start = max(0, start_sample - padding_samples) - final_end = min(len(audio_array), end_sample + padding_samples) - - if final_end > final_start: # Ensure the slice is valid - # Check if significant trimming occurred - original_length = len(audio_array) - trimmed_length_with_padding = final_end - final_start - # Heuristic: if length changed by more than just padding, or if original silence was more than min_duration - # For simplicity, if librosa.effects.trim found *any* indices different from [0, original_length], - # it means some trimming potential was identified. - if index[0] > 0 or index[1] < original_length: - logger.debug( - f"Silence trimmed: original samples {original_length}, new effective samples {trimmed_length_with_padding} (indices before padding: {index})" - ) - return audio_array[final_start:final_end] - - logger.debug( - "No significant leading/trailing silence found to trim, or result would be empty." - ) - return audio_array - - except Exception as e: - logger.error(f"Error during silence trimming: {e}", exc_info=True) - return audio_array - - -def fix_internal_silence( - audio_array: np.ndarray, - sample_rate: int, - silence_threshold_db: float = -40.0, - min_silence_to_fix_ms: int = 700, - max_allowed_silence_ms: int = 300, -) -> np.ndarray: - """ - Reduces long internal silences in a NumPy audio array to a specified maximum duration. - Uses Librosa to split by silence. - - Args: - audio_array: NumPy array (float32) of the audio. - sample_rate: Sample rate of the audio. - silence_threshold_db: Silence threshold in dBFS. - min_silence_to_fix_ms: Minimum duration of an internal silence to be shortened (ms). - max_allowed_silence_ms: Target maximum duration for long silences (ms). - - Returns: - NumPy audio array with long internal silences shortened. Original if no fix needed or on error. - """ - if audio_array is None or audio_array.size == 0: - return audio_array - - try: - if not LIBROSA_AVAILABLE: - logger.warning("Librosa not available, skipping internal silence fixing.") - return audio_array - - top_db_threshold = abs(silence_threshold_db) - min_silence_len_samples = int((min_silence_to_fix_ms / 1000.0) * sample_rate) - max_silence_samples_to_keep = int( - (max_allowed_silence_ms / 1000.0) * sample_rate - ) - - non_silent_intervals = librosa.effects.split( - y=audio_array, - top_db=top_db_threshold, - frame_length=2048, # Can be tuned - hop_length=512, # Can be tuned - ) - - if len(non_silent_intervals) <= 1: - logger.debug("No significant internal silences found to fix.") - return audio_array - - fixed_audio_parts = [] - last_nonsilent_end = 0 - - for i, (start_sample, end_sample) in enumerate(non_silent_intervals): - silence_duration_samples = start_sample - last_nonsilent_end - if silence_duration_samples > 0: - if silence_duration_samples >= min_silence_len_samples: - silence_to_add = audio_array[ - last_nonsilent_end : last_nonsilent_end - + max_silence_samples_to_keep - ] - fixed_audio_parts.append(silence_to_add) - logger.debug( - f"Shortened internal silence from {silence_duration_samples} to {max_silence_samples_to_keep} samples." - ) - else: - fixed_audio_parts.append( - audio_array[last_nonsilent_end:start_sample] - ) - fixed_audio_parts.append(audio_array[start_sample:end_sample]) - last_nonsilent_end = end_sample - - # Handle potential silence after the very last non-silent segment - # This part is tricky as librosa.effects.split only gives non-silent parts. - # The trim_lead_trail_silence should handle overall trailing silence. - # This function focuses on *between* non-silent segments. - if last_nonsilent_end < len(audio_array): - trailing_segment = audio_array[last_nonsilent_end:] - # Check if this trailing segment is mostly silence and long enough to shorten - # For simplicity, we'll assume trim_lead_trail_silence handles the very end. - # Or, we could append it if it's short, or shorten it if it's long silence. - # To avoid over-complication here, let's just append what's left. - # The primary goal is internal silences. - # However, if the last "non_silent_interval" was short and followed by a long silence, - # that silence needs to be handled here too. - silence_duration_samples = len(audio_array) - last_nonsilent_end - if silence_duration_samples > 0: - if silence_duration_samples >= min_silence_len_samples: - fixed_audio_parts.append( - audio_array[ - last_nonsilent_end : last_nonsilent_end - + max_silence_samples_to_keep - ] - ) - logger.debug( - f"Shortened trailing silence from {silence_duration_samples} to {max_silence_samples_to_keep} samples." - ) - else: - fixed_audio_parts.append(trailing_segment) - - if not fixed_audio_parts: # Should not happen if non_silent_intervals > 1 - logger.warning( - "Internal silence fixing resulted in no audio parts; returning original." - ) - return audio_array - - return np.concatenate(fixed_audio_parts) - - except Exception as e: - logger.error(f"Error during internal silence fixing: {e}", exc_info=True) - return audio_array - - -def remove_long_unvoiced_segments( - audio_array: np.ndarray, - sample_rate: int, - min_unvoiced_duration_ms: int = 300, - pitch_floor: float = 75.0, - pitch_ceiling: float = 600.0, -) -> np.ndarray: - """ - Removes segments from a NumPy audio array that are unvoiced for longer than - the specified duration, using Parselmouth for pitch analysis. - - Args: - audio_array: NumPy array (float32) of the audio. - sample_rate: Sample rate of the audio. - min_unvoiced_duration_ms: Minimum duration (ms) of an unvoiced segment to be removed. - pitch_floor: Minimum pitch (Hz) to consider for voicing. - pitch_ceiling: Maximum pitch (Hz) to consider for voicing. - - Returns: - NumPy audio array with long unvoiced segments removed. Original if Parselmouth not available or on error. - """ - if not PARSELMOUTH_AVAILABLE: - logger.warning("Parselmouth not available, skipping unvoiced segment removal.") - return audio_array - if audio_array is None or audio_array.size == 0: - return audio_array - - try: - sound = parselmouth.Sound( - audio_array.astype(np.float64), sampling_frequency=sample_rate - ) - pitch = sound.to_pitch(pitch_floor=pitch_floor, pitch_ceiling=pitch_ceiling) - voiced_unvoiced = pitch.get_VoicedVoicelessUnvoiced() - - segments_to_keep = [] - current_segment_start_sample = 0 - min_unvoiced_samples = int((min_unvoiced_duration_ms / 1000.0) * sample_rate) - - for i in range(len(voiced_unvoiced.time_intervals)): - interval_start_time, interval_end_time, is_voiced_str = ( - voiced_unvoiced.time_intervals[i] - ) - is_voiced = is_voiced_str == "voiced" - - interval_start_sample = int(interval_start_time * sample_rate) - interval_end_sample = int(interval_end_time * sample_rate) - interval_duration_samples = interval_end_sample - interval_start_sample - - if is_voiced: - segments_to_keep.append( - audio_array[current_segment_start_sample:interval_end_sample] - ) - current_segment_start_sample = interval_end_sample - else: # Unvoiced segment - if interval_duration_samples < min_unvoiced_samples: - segments_to_keep.append( - audio_array[current_segment_start_sample:interval_end_sample] - ) - current_segment_start_sample = interval_end_sample - else: - logger.debug( - f"Removing long unvoiced segment from {interval_start_time:.2f}s to {interval_end_time:.2f}s." - ) - # Append the audio *before* this long unvoiced segment (if any) - if interval_start_sample > current_segment_start_sample: - segments_to_keep.append( - audio_array[ - current_segment_start_sample:interval_start_sample - ] - ) - current_segment_start_sample = interval_end_sample - - if current_segment_start_sample < len(audio_array): - segments_to_keep.append(audio_array[current_segment_start_sample:]) - - if not segments_to_keep: - logger.warning( - "Unvoiced segment removal resulted in empty audio; returning original." - ) - return audio_array - - return np.concatenate(segments_to_keep) - - except Exception as e: - logger.error(f"Error during unvoiced segment removal: {e}", exc_info=True) - return audio_array - - -# --- Text Processing Utilities --- -def _is_valid_sentence_end(text: str, period_index: int) -> bool: - """ - Checks if a period at a given index in the text is likely a valid sentence terminator, - rather than part of an abbreviation, number, or version string. - """ - word_start_before_period = period_index - 1 - scan_limit = max(0, period_index - 10) - while ( - word_start_before_period >= scan_limit - and not text[word_start_before_period].isspace() - ): - word_start_before_period -= 1 - word_before_period = text[word_start_before_period + 1 : period_index + 1].lower() - if word_before_period in ABBREVIATIONS: - return False - - context_start = max(0, period_index - 10) - context_end = min(len(text), period_index + 10) - context_segment = text[context_start:context_end] - relative_period_index_in_context = period_index - context_start - - for pattern in [NUMBER_DOT_NUMBER_PATTERN, VERSION_PATTERN]: - for match in pattern.finditer(context_segment): - if match.start() <= relative_period_index_in_context < match.end(): - is_last_char_of_numeric_match = ( - relative_period_index_in_context == match.end() - 1 - ) - is_followed_by_space_or_eos = ( - period_index + 1 == len(text) or text[period_index + 1].isspace() - ) - if not (is_last_char_of_numeric_match and is_followed_by_space_or_eos): - return False - return True - - -def _split_text_by_punctuation(text: str) -> List[str]: - """ - Splits text into sentences based on common punctuation marks (.!?), - while trying to avoid splitting on periods used in abbreviations or numbers. - """ - sentences: List[str] = [] - last_split_index = 0 - text_length = len(text) - - for match in POTENTIAL_END_PATTERN.finditer(text): - punctuation_char_index = match.start(1) - punctuation_char = text[punctuation_char_index] - slice_end_after_punctuation = match.start(1) + 1 + len(match.group(2) or "") - - if punctuation_char in ["!", "?"]: - current_sentence_text = text[ - last_split_index:slice_end_after_punctuation - ].strip() - if current_sentence_text: - sentences.append(current_sentence_text) - last_split_index = match.end() - continue - - if punctuation_char == ".": - if ( - punctuation_char_index > 0 and text[punctuation_char_index - 1] == "." - ) or ( - punctuation_char_index < text_length - 1 - and text[punctuation_char_index + 1] == "." - ): - continue - - if _is_valid_sentence_end(text, punctuation_char_index): - current_sentence_text = text[ - last_split_index:slice_end_after_punctuation - ].strip() - if current_sentence_text: - sentences.append(current_sentence_text) - last_split_index = match.end() - - remaining_text_segment = text[last_split_index:].strip() - if remaining_text_segment: - sentences.append(remaining_text_segment) - - sentences = [s for s in sentences if s] - if not sentences and text.strip(): - return [text.strip()] - return sentences - - -def split_into_sentences(text: str) -> List[str]: - """ - Splits a given text into sentences. Handles normalization of line breaks - and considers bullet points as potential sentence separators. - This is the primary entry point for sentence splitting. - """ - if not text or text.isspace(): - return [] - - text = text.replace("\r\n", "\n").replace("\r", "\n") - bullet_point_matches = list(BULLET_POINT_PATTERN.finditer(text)) - - if bullet_point_matches: - logger.debug("Bullet points detected in text; splitting by bullet items.") - processed_sentences: List[str] = [] - current_position = 0 - for i, bullet_match in enumerate(bullet_point_matches): - bullet_actual_start_index = bullet_match.start() - if i == 0 and bullet_actual_start_index > current_position: - pre_bullet_segment = text[ - current_position:bullet_actual_start_index - ].strip() - if pre_bullet_segment: - processed_sentences.extend( - s for s in _split_text_by_punctuation(pre_bullet_segment) if s - ) - - next_bullet_start_index = ( - bullet_point_matches[i + 1].start() - if i + 1 < len(bullet_point_matches) - else len(text) - ) - bullet_item_segment = text[ - bullet_actual_start_index:next_bullet_start_index - ].strip() - if bullet_item_segment: - processed_sentences.append(bullet_item_segment) - current_position = next_bullet_start_index - - if current_position < len(text): - post_bullet_segment = text[current_position:].strip() - if post_bullet_segment: - processed_sentences.extend( - s for s in _split_text_by_punctuation(post_bullet_segment) if s - ) - return [s for s in processed_sentences if s] - else: - logger.debug( - "No bullet points detected; using punctuation-based sentence splitting." - ) - return _split_text_by_punctuation(text) - - -def _preprocess_and_segment_text(full_text: str) -> List[Tuple[Optional[str], str]]: - """ - Internal helper to segment text by non-verbal cues (e.g., (laughs)) and then - further split those segments into sentences. - Assigns a placeholder "tag" (here, None or empty string) as this system is single-speaker. - The tuple structure (tag, sentence) is maintained for compatibility with chunking logic - that might expect it, even if the tag itself isn't used for speaker differentiation. - - Args: - full_text: The complete input text. - - Returns: - A list of tuples, where each tuple is (placeholder_tag, sentence_text). - """ - if not full_text or full_text.isspace(): - return [] - - placeholder_tag: Optional[str] = None - segmented_with_tags: List[Tuple[Optional[str], str]] = [] - parts_and_cues = NON_VERBAL_CUE_PATTERN.split(full_text) - - for part in parts_and_cues: - if not part or part.isspace(): - continue - if NON_VERBAL_CUE_PATTERN.fullmatch(part): - segmented_with_tags.append((placeholder_tag, part.strip())) - else: - sentences_from_part = split_into_sentences(part.strip()) - for sentence in sentences_from_part: - if sentence: - segmented_with_tags.append((placeholder_tag, sentence)) - - if not segmented_with_tags and full_text.strip(): - segmented_with_tags.append((placeholder_tag, full_text.strip())) - - logger.debug( - f"Preprocessed text into {len(segmented_with_tags)} segments/sentences." - ) - return segmented_with_tags - - -def chunk_text_by_sentences( - full_text: str, - chunk_size: int, -) -> List[str]: - """ - Chunks text into manageable pieces for TTS processing, respecting sentence boundaries - and a maximum chunk character length. Designed for single-speaker text, but maintains - a structure that can handle segments (like non-verbal cues) separately. - - Args: - full_text: The complete text to be chunked. - chunk_size: The desired maximum character length for each chunk. - Sentences longer than this will form their own chunk. - - Returns: - A list of text chunks, ready for TTS. - """ - if not full_text or full_text.isspace(): - return [] - if chunk_size <= 0: - chunk_size = float("inf") - - processed_segments = _preprocess_and_segment_text(full_text) - if not processed_segments: - return [] - - text_chunks: List[str] = [] - current_chunk_sentences: List[str] = [] - current_chunk_length = 0 - - for ( - _, - segment_text, - ) in processed_segments: - segment_len = len(segment_text) - - if not current_chunk_sentences: - current_chunk_sentences.append(segment_text) - current_chunk_length = segment_len - elif current_chunk_length + 1 + segment_len <= chunk_size: - current_chunk_sentences.append(segment_text) - current_chunk_length += 1 + segment_len - else: - if current_chunk_sentences: - text_chunks.append(" ".join(current_chunk_sentences)) - current_chunk_sentences = [segment_text] - current_chunk_length = segment_len - - if current_chunk_length > chunk_size and len(current_chunk_sentences) == 1: - logger.info( - f"A single segment (length {current_chunk_length}) exceeds chunk_size {chunk_size}. " - f"It will form its own chunk." - ) - text_chunks.append(" ".join(current_chunk_sentences)) - current_chunk_sentences = [] - current_chunk_length = 0 - - if current_chunk_sentences: - text_chunks.append(" ".join(current_chunk_sentences)) - - text_chunks = [chunk for chunk in text_chunks if chunk.strip()] - - if not text_chunks and full_text.strip(): - logger.warning( - "Text chunking resulted in zero chunks despite non-empty input. Returning full text as one chunk." - ) - return [full_text.strip()] - - logger.info(f"Text chunking complete. Generated {len(text_chunks)} chunk(s).") - return text_chunks - - -# --- Performance Monitoring Utility --- -class PerformanceMonitor: - """ - A simple helper class for recording and reporting elapsed time for different - stages of an operation. Useful for debugging performance bottlenecks. - """ - - def __init__( - self, enabled: bool = True, logger_instance: Optional[logging.Logger] = None - ): - self.enabled: bool = enabled - self.logger = ( - logger_instance - if logger_instance is not None - else logging.getLogger(__name__) - ) - self.start_time: float = 0.0 - self.events: List[Tuple[str, float]] = [] - if self.enabled: - self.start_time = time.monotonic() - self.events.append(("Monitoring Started", self.start_time)) - - def record(self, event_name: str): - if not self.enabled: - return - self.events.append((event_name, time.monotonic())) - - def report(self, log_level: int = logging.DEBUG) -> str: - if not self.enabled or not self.events: - return "Performance monitoring was disabled or no events recorded." - - report_lines = ["Performance Report:"] - last_event_time = self.events[0][1] - - for i in range(1, len(self.events)): - event_name, timestamp = self.events[i] - prev_event_name, _ = self.events[i - 1] - duration_since_last = timestamp - last_event_time - duration_since_start = timestamp - self.start_time - report_lines.append( - f" - Event: '{event_name}' (after '{prev_event_name}') " - f"took {duration_since_last:.4f}s. Total elapsed: {duration_since_start:.4f}s" - ) - last_event_time = timestamp - - total_duration = self.events[-1][1] - self.start_time - report_lines.append(f"Total Monitored Duration: {total_duration:.4f}s") - full_report_str = "\n".join(report_lines) - - if self.logger: - self.logger.log(log_level, full_report_str) - return full_report_str +# utils.py +# Utility functions for the TTS server application. +# This module includes functions for audio processing, text manipulation, +# file system operations, and performance monitoring. + +import os +import logging +import re +import time +import io +import uuid +from pathlib import Path +from typing import Optional, Tuple, Dict, Any, Set, List +from pydub import AudioSegment + +import numpy as np +import soundfile as sf +import torchaudio # For saving PyTorch tensors and potentially speed adjustment. +import torch + +# Configuration manager to get paths dynamically. +# Assumes config.py and its config_manager are in the same directory or accessible via PYTHONPATH. +from config import config_manager + +# Optional import for librosa (for audio resampling, e.g., Opus encoding and time stretching) +try: + import librosa + + LIBROSA_AVAILABLE = True + logger = logging.getLogger( + __name__ + ) # Initialize logger here if librosa is available + logger.info( + "Librosa library found and will be used for audio resampling and time stretching." + ) +except ImportError: + LIBROSA_AVAILABLE = False + logger = logging.getLogger(__name__) + logger.warning( + "Librosa library not found. Advanced audio resampling features (e.g., for Opus encoding) " + "and pitch-preserving speed adjustment will be limited. Speed adjustment will fall back to basic method if enabled." + ) + +# Optional import for Parselmouth (for unvoiced segment detection) +try: + import parselmouth + + PARSELMOUTH_AVAILABLE = True + logger.info( + "Parselmouth library found and will be used for unvoiced segment removal if enabled." + ) +except ImportError: + PARSELMOUTH_AVAILABLE = False + logger.warning( + "Parselmouth library not found. Unvoiced segment removal feature will be disabled." + ) + + +# --- Filename Sanitization --- +def sanitize_filename(filename: str) -> str: + """ + Removes potentially unsafe characters and path components from a filename + to make it safe for use in file paths. Replaces unsafe sequences with underscores. + + Args: + filename: The original filename string. + + Returns: + A sanitized filename string, ensuring it's not empty and reasonably short. + """ + if not filename: + # Generate a unique name if the input is empty. + return f"unnamed_file_{uuid.uuid4().hex[:8]}" + + # Remove directory separators and leading/trailing whitespace. + base_filename = Path(filename).name.strip() + if not base_filename: + return f"empty_basename_{uuid.uuid4().hex[:8]}" + + # Define a set of allowed characters (alphanumeric, underscore, hyphen, dot, space). + # Spaces will be replaced by underscores later. + safe_chars = set( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._- " + ) + sanitized_list = [] + last_char_was_underscore = False + + for char in base_filename: + if char in safe_chars: + # Replace spaces with underscores. + sanitized_list.append("_" if char == " " else char) + last_char_was_underscore = char == " " + elif not last_char_was_underscore: + # Replace any disallowed character sequence with a single underscore. + sanitized_list.append("_") + last_char_was_underscore = True + + sanitized = "".join(sanitized_list).strip("_") + + # Prevent names starting with multiple dots or consisting only of dots/underscores. + if not sanitized or sanitized.lstrip("._") == "": + return f"sanitized_file_{uuid.uuid4().hex[:8]}" + + # Limit filename length (e.g., 100 characters), preserving the extension. + max_len = 100 + if len(sanitized) > max_len: + name_part, ext_part = os.path.splitext(sanitized) + # Ensure extension is not overly long itself; common extensions are short. + ext_part = ext_part[:10] # Limit extension length just in case. + name_part = name_part[ + : max_len - len(ext_part) - 1 + ] # -1 for the dot if ext exists + sanitized = name_part + ext_part + logger.warning( + f"Original filename '{base_filename}' was truncated to '{sanitized}' due to length limits." + ) + + if not sanitized: # Should not happen with previous checks, but as a failsafe. + return f"final_fallback_name_{uuid.uuid4().hex[:8]}" + + return sanitized + + +# --- Constants for Text Processing --- +# Set of common abbreviations to help with sentence splitting. +ABBREVIATIONS: Set[str] = { + "mr.", + "mrs.", + "ms.", + "dr.", + "prof.", + "rev.", + "hon.", + "st.", + "etc.", + "e.g.", + "i.e.", + "vs.", + "approx.", + "apt.", + "dept.", + "fig.", + "gen.", + "gov.", + "inc.", + "jr.", + "sr.", + "ltd.", + "no.", + "p.", + "pp.", + "vol.", + "op.", + "cit.", + "ca.", + "cf.", + "ed.", + "esp.", + "et.", + "al.", + "ibid.", + "id.", + "inf.", + "sup.", + "viz.", + "sc.", + "fl.", + "d.", + "b.", + "r.", + "c.", + "v.", + "u.s.", + "u.k.", + "a.m.", + "p.m.", + "a.d.", + "b.c.", +} + +# Common titles that might appear without a period if cleaned by other means first. +TITLES_NO_PERIOD: Set[str] = { + "mr", + "mrs", + "ms", + "dr", + "prof", + "rev", + "hon", + "st", + "sgt", + "capt", + "lt", + "col", + "gen", +} + +# Regex patterns (pre-compiled for efficiency in text processing). +NUMBER_DOT_NUMBER_PATTERN = re.compile( + r"(? Optional[bytes]: + """ + Encodes a NumPy audio array into the specified format (Opus or WAV) in memory. + Can resample the audio to a target sample rate before encoding if specified. + + Args: + audio_array: NumPy array containing audio data (expected as float32, range [-1, 1]). + sample_rate: Sample rate of the input audio data. + output_format: Desired output format ('opus', 'wav' or 'mp3'). + target_sample_rate: Optional target sample rate to resample to before encoding. + + Returns: + Bytes object containing the encoded audio, or None if encoding fails. + """ + if audio_array is None or audio_array.size == 0: + logger.warning("encode_audio received empty or None audio array.") + return None + + # Ensure audio is float32 for consistent processing. + if audio_array.dtype != np.float32: + if np.issubdtype(audio_array.dtype, np.integer): + max_val = np.iinfo(audio_array.dtype).max + audio_array = audio_array.astype(np.float32) / max_val + else: # Fallback for other types, assuming they might be float64 or similar + audio_array = audio_array.astype(np.float32) + logger.debug(f"Converted audio array to float32 for encoding.") + + # Ensure audio is mono if it's (samples, 1) + if audio_array.ndim == 2 and audio_array.shape[1] == 1: + audio_array = audio_array.squeeze(axis=1) + logger.debug( + "Squeezed audio array from (samples, 1) to (samples,) for encoding." + ) + elif ( + audio_array.ndim > 1 + ): # Multi-channel not directly supported by simple encoding path, attempt to take first channel + logger.warning( + f"Multi-channel audio (shape: {audio_array.shape}) provided to encode_audio. Using only the first channel." + ) + audio_array = audio_array[:, 0] + + # Resample if target_sample_rate is provided and different from current sample_rate + if ( + target_sample_rate is not None + and target_sample_rate != sample_rate + and LIBROSA_AVAILABLE + ): + try: + logger.info( + f"Resampling audio from {sample_rate}Hz to {target_sample_rate}Hz using Librosa." + ) + audio_array = librosa.resample( + y=audio_array, orig_sr=sample_rate, target_sr=target_sample_rate + ) + sample_rate = ( + target_sample_rate # Update sample_rate for subsequent encoding + ) + except Exception as e_resample: + logger.error( + f"Error resampling audio to {target_sample_rate}Hz: {e_resample}. Proceeding with original sample rate {sample_rate}.", + exc_info=True, + ) + elif target_sample_rate is not None and target_sample_rate != sample_rate: + logger.warning( + f"Librosa not available. Cannot resample audio from {sample_rate}Hz to {target_sample_rate}Hz. " + f"Proceeding with original sample rate for encoding." + ) + + start_time = time.time() + output_buffer = io.BytesIO() + + try: + audio_to_write = audio_array + rate_to_write = sample_rate + + if output_format == "opus": + OPUS_SUPPORTED_RATES = {8000, 12000, 16000, 24000, 48000} + TARGET_OPUS_RATE = 48000 # Preferred Opus rate. + + if rate_to_write not in OPUS_SUPPORTED_RATES: + if LIBROSA_AVAILABLE: + logger.warning( + f"Current sample rate {rate_to_write}Hz not directly supported by Opus. " + f"Resampling to {TARGET_OPUS_RATE}Hz using Librosa for Opus encoding." + ) + audio_to_write = librosa.resample( + y=audio_array, orig_sr=rate_to_write, target_sr=TARGET_OPUS_RATE + ) + rate_to_write = TARGET_OPUS_RATE + else: + logger.error( + f"Librosa not available. Cannot resample audio from {rate_to_write}Hz for Opus encoding. " + f"Opus encoding may fail or produce poor quality." + ) + # Proceed with current rate, soundfile might handle it or fail. + sf.write( + output_buffer, + audio_to_write, + rate_to_write, + format="ogg", + subtype="opus", + ) + + elif output_format == "wav": + # WAV typically uses int16 for broader compatibility. + # Clip audio to [-1.0, 1.0] before converting to int16 to prevent overflow. + audio_clipped = np.clip(audio_array, -1.0, 1.0) + audio_int16 = (audio_clipped * 32767).astype(np.int16) + audio_to_write = audio_int16 # Use the int16 version for WAV + sf.write( + output_buffer, + audio_to_write, + rate_to_write, + format="wav", + subtype="pcm_16", + ) + + elif output_format == "mp3": + audio_clipped = np.clip(audio_array, -1.0, 1.0) + audio_int16 = (audio_clipped * 32767).astype(np.int16) + audio_segment = AudioSegment( + audio_int16.tobytes(), + frame_rate=sample_rate, + sample_width=2, + channels=1, + ) + audio_segment.export(output_buffer, format="mp3") + + else: + logger.error( + f"Unsupported output format requested for encoding: {output_format}" + ) + return None + + encoded_bytes = output_buffer.getvalue() + end_time = time.time() + logger.info( + f"Encoded {len(encoded_bytes)} bytes to '{output_format}' at {rate_to_write}Hz in {end_time - start_time:.3f} seconds." + ) + return encoded_bytes + + except ImportError as ie_sf: # Specifically for soundfile import issues + logger.critical( + f"The 'soundfile' library or its dependency (libsndfile) is not installed or found. " + f"Audio encoding/saving is not possible. Please install it. Error: {ie_sf}" + ) + return None + except Exception as e: + logger.error(f"Error encoding audio to '{output_format}': {e}", exc_info=True) + return None + + +def save_audio_to_file( + audio_array: np.ndarray, sample_rate: int, file_path_str: str +) -> bool: + """ + Saves a NumPy audio array to a WAV file. + + Args: + audio_array: NumPy array containing audio data (float32, range [-1, 1]). + sample_rate: Sample rate of the audio data. + file_path_str: String path to save the WAV file. + + Returns: + True if saving was successful, False otherwise. + """ + if audio_array is None or audio_array.size == 0: + logger.warning("save_audio_to_file received empty or None audio array.") + return False + + file_path = Path(file_path_str) + if file_path.suffix.lower() != ".wav": + logger.warning( + f"File path '{file_path_str}' does not end with .wav. Appending .wav extension." + ) + file_path = file_path.with_suffix(".wav") + + start_time = time.time() + try: + # Ensure output directory exists. + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Prepare audio for WAV (int16, clipped). + if ( + audio_array.dtype != np.float32 + ): # Ensure float32 before potential scaling to int16 + if np.issubdtype(audio_array.dtype, np.integer): + max_val = np.iinfo(audio_array.dtype).max + audio_array = audio_array.astype(np.float32) / max_val + else: + audio_array = audio_array.astype(np.float32) + + audio_clipped = np.clip(audio_array, -1.0, 1.0) + audio_int16 = (audio_clipped * 32767).astype(np.int16) + + sf.write( + str(file_path), audio_int16, sample_rate, format="wav", subtype="pcm_16" + ) + end_time = time.time() + logger.info( + f"Saved WAV file to {file_path} in {end_time - start_time:.3f} seconds." + ) + return True + except ImportError: + logger.critical("SoundFile library not found. Cannot save audio.") + return False + except Exception as e: + logger.error(f"Error saving WAV file to {file_path}: {e}", exc_info=True) + return False + + +def save_audio_tensor_to_file( + audio_tensor: torch.Tensor, + sample_rate: int, + file_path_str: str, + output_format: str = "wav", +) -> bool: + """ + Saves a PyTorch audio tensor to a file using torchaudio. + + Args: + audio_tensor: PyTorch tensor containing audio data. + sample_rate: Sample rate of the audio data. + file_path_str: String path to save the audio file. + output_format: Desired output format (passed to torchaudio.save). + + Returns: + True if saving was successful, False otherwise. + """ + if audio_tensor is None or audio_tensor.numel() == 0: + logger.warning("save_audio_tensor_to_file received empty or None audio tensor.") + return False + + file_path = Path(file_path_str) + start_time = time.time() + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + # torchaudio.save expects tensor on CPU. + audio_tensor_cpu = audio_tensor.cpu() + # Ensure tensor is 2D (channels, samples) for torchaudio.save. + if audio_tensor_cpu.ndim == 1: + audio_tensor_cpu = audio_tensor_cpu.unsqueeze(0) + + torchaudio.save( + str(file_path), audio_tensor_cpu, sample_rate, format=output_format + ) + end_time = time.time() + logger.info( + f"Saved audio tensor to {file_path} (format: {output_format}) in {end_time - start_time:.3f} seconds." + ) + return True + except Exception as e: + logger.error(f"Error saving audio tensor to {file_path}: {e}", exc_info=True) + return False + + +# --- Audio Manipulation Utilities --- +def apply_speed_factor( + audio_tensor: torch.Tensor, sample_rate: int, speed_factor: float +) -> Tuple[torch.Tensor, int]: + """ + Applies a speed factor to an audio tensor. + Uses librosa.effects.time_stretch if available for pitch preservation. + Falls back to simple resampling via torchaudio.transforms.Resample if librosa is not available, + which will alter pitch. + + Args: + audio_tensor: Input audio waveform (PyTorch tensor, expected mono). + sample_rate: Sample rate of the input audio. + speed_factor: Desired speed factor (e.g., 1.0 is normal, 1.5 is faster, 0.5 is slower). + + Returns: + A tuple of the speed-adjusted audio tensor and its sample rate (which remains unchanged). + Returns the original tensor and sample rate if speed_factor is 1.0 or if adjustment fails. + """ + if speed_factor == 1.0: + return audio_tensor, sample_rate + if speed_factor <= 0: + logger.warning( + f"Invalid speed_factor {speed_factor}. Must be positive. Returning original audio." + ) + return audio_tensor, sample_rate + + audio_tensor_cpu = audio_tensor.cpu() + # Ensure tensor is 1D mono for librosa and consistent handling + if audio_tensor_cpu.ndim == 2: + if audio_tensor_cpu.shape[0] == 1: + audio_tensor_cpu = audio_tensor_cpu.squeeze(0) + elif audio_tensor_cpu.shape[1] == 1: + audio_tensor_cpu = audio_tensor_cpu.squeeze(1) + else: # True stereo or multi-channel + logger.warning( + f"apply_speed_factor received multi-channel audio (shape {audio_tensor_cpu.shape}). Using first channel only." + ) + audio_tensor_cpu = audio_tensor_cpu[0, :] + + if audio_tensor_cpu.ndim != 1: + logger.error( + f"apply_speed_factor: audio_tensor_cpu is not 1D after processing (shape {audio_tensor_cpu.shape}). Returning original audio." + ) + return audio_tensor, sample_rate + + if LIBROSA_AVAILABLE: + try: + audio_np = audio_tensor_cpu.numpy() + # librosa.effects.time_stretch changes duration, not sample rate directly. + # The 'rate' parameter in time_stretch is equivalent to speed_factor. + stretched_audio_np = librosa.effects.time_stretch( + y=audio_np, rate=speed_factor + ) + speed_adjusted_tensor = torch.from_numpy(stretched_audio_np) + logger.info( + f"Applied speed factor {speed_factor} using librosa.effects.time_stretch. Original SR: {sample_rate}" + ) + return speed_adjusted_tensor, sample_rate # Sample rate is preserved + except Exception as e_librosa: + logger.error( + f"Failed to apply speed factor {speed_factor} using librosa: {e_librosa}. " + f"Falling back to basic resampling (pitch will change).", + exc_info=True, + ) + # Fallback to simple resampling (changes pitch) + try: + new_sample_rate_for_speedup = int(sample_rate / speed_factor) + resampler = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=new_sample_rate_for_speedup + ) + # Resample to new_sample_rate_for_speedup to change duration, then resample back to original SR + # This is effectively what sox 'speed' does, but 'tempo' is better (which librosa does) + # For simplicity in fallback, just resample and note pitch change + # To actually change speed without changing sample rate and preserving pitch using *only* torchaudio is more complex + # and typically involves phase vocoder or similar, which is beyond a simple fallback. + # The torchaudio.functional.pitch_shift and then torchaudio.functional.speed is one way, + # but librosa is simpler. + # Given the instruction "Fallback to original audio" if librosa not available or fails, we'll stick to that. + # Original plan: "If Librosa is not available, log a warning and return the original audio" + logger.warning( + f"Librosa failed for speed factor. Returning original audio as primary fallback." + ) + return audio_tensor, sample_rate + + except Exception as e_resample_fallback: + logger.error( + f"Fallback resampling for speed factor {speed_factor} also failed: {e_resample_fallback}. Returning original audio.", + exc_info=True, + ) + return audio_tensor, sample_rate + + else: # Librosa not available + logger.warning( + f"Librosa not available for pitch-preserving speed adjustment (factor: {speed_factor}). " + f"Returning original audio. Install librosa for this feature." + ) + return audio_tensor, sample_rate + + +def trim_lead_trail_silence( + audio_array: np.ndarray, + sample_rate: int, + silence_threshold_db: float = -40.0, + min_silence_duration_ms: int = 100, + padding_ms: int = 50, +) -> np.ndarray: + """ + Trims silence from the beginning and end of a NumPy audio array using a dB threshold. + + Args: + audio_array: NumPy array (float32) of the audio. + sample_rate: Sample rate of the audio. + silence_threshold_db: Silence threshold in dBFS. Segments below this are considered silent. + min_silence_duration_ms: Minimum duration of silence to be trimmed (ms). + padding_ms: Padding to leave at the start/end after trimming (ms). + + Returns: + Trimmed NumPy audio array. Returns original if no significant silence is found or on error. + """ + if audio_array is None or audio_array.size == 0: + return audio_array + + try: + if not LIBROSA_AVAILABLE: + logger.warning("Librosa not available, skipping silence trimming.") + return audio_array + + top_db_threshold = abs(silence_threshold_db) + + frame_length = 2048 + hop_length = 512 + + trimmed_audio, index = librosa.effects.trim( + y=audio_array, + top_db=top_db_threshold, + frame_length=frame_length, + hop_length=hop_length, + ) + + start_sample, end_sample = index[0], index[1] + + padding_samples = int((padding_ms / 1000.0) * sample_rate) + final_start = max(0, start_sample - padding_samples) + final_end = min(len(audio_array), end_sample + padding_samples) + + if final_end > final_start: # Ensure the slice is valid + # Check if significant trimming occurred + original_length = len(audio_array) + trimmed_length_with_padding = final_end - final_start + # Heuristic: if length changed by more than just padding, or if original silence was more than min_duration + # For simplicity, if librosa.effects.trim found *any* indices different from [0, original_length], + # it means some trimming potential was identified. + if index[0] > 0 or index[1] < original_length: + logger.debug( + f"Silence trimmed: original samples {original_length}, new effective samples {trimmed_length_with_padding} (indices before padding: {index})" + ) + return audio_array[final_start:final_end] + + logger.debug( + "No significant leading/trailing silence found to trim, or result would be empty." + ) + return audio_array + + except Exception as e: + logger.error(f"Error during silence trimming: {e}", exc_info=True) + return audio_array + + +def fix_internal_silence( + audio_array: np.ndarray, + sample_rate: int, + silence_threshold_db: float = -40.0, + min_silence_to_fix_ms: int = 700, + max_allowed_silence_ms: int = 300, +) -> np.ndarray: + """ + Reduces long internal silences in a NumPy audio array to a specified maximum duration. + Uses Librosa to split by silence. + + Args: + audio_array: NumPy array (float32) of the audio. + sample_rate: Sample rate of the audio. + silence_threshold_db: Silence threshold in dBFS. + min_silence_to_fix_ms: Minimum duration of an internal silence to be shortened (ms). + max_allowed_silence_ms: Target maximum duration for long silences (ms). + + Returns: + NumPy audio array with long internal silences shortened. Original if no fix needed or on error. + """ + if audio_array is None or audio_array.size == 0: + return audio_array + + try: + if not LIBROSA_AVAILABLE: + logger.warning("Librosa not available, skipping internal silence fixing.") + return audio_array + + top_db_threshold = abs(silence_threshold_db) + min_silence_len_samples = int((min_silence_to_fix_ms / 1000.0) * sample_rate) + max_silence_samples_to_keep = int( + (max_allowed_silence_ms / 1000.0) * sample_rate + ) + + non_silent_intervals = librosa.effects.split( + y=audio_array, + top_db=top_db_threshold, + frame_length=2048, # Can be tuned + hop_length=512, # Can be tuned + ) + + if len(non_silent_intervals) <= 1: + logger.debug("No significant internal silences found to fix.") + return audio_array + + fixed_audio_parts = [] + last_nonsilent_end = 0 + + for i, (start_sample, end_sample) in enumerate(non_silent_intervals): + silence_duration_samples = start_sample - last_nonsilent_end + if silence_duration_samples > 0: + if silence_duration_samples >= min_silence_len_samples: + silence_to_add = audio_array[ + last_nonsilent_end : last_nonsilent_end + + max_silence_samples_to_keep + ] + fixed_audio_parts.append(silence_to_add) + logger.debug( + f"Shortened internal silence from {silence_duration_samples} to {max_silence_samples_to_keep} samples." + ) + else: + fixed_audio_parts.append( + audio_array[last_nonsilent_end:start_sample] + ) + fixed_audio_parts.append(audio_array[start_sample:end_sample]) + last_nonsilent_end = end_sample + + # Handle potential silence after the very last non-silent segment + # This part is tricky as librosa.effects.split only gives non-silent parts. + # The trim_lead_trail_silence should handle overall trailing silence. + # This function focuses on *between* non-silent segments. + if last_nonsilent_end < len(audio_array): + trailing_segment = audio_array[last_nonsilent_end:] + # Check if this trailing segment is mostly silence and long enough to shorten + # For simplicity, we'll assume trim_lead_trail_silence handles the very end. + # Or, we could append it if it's short, or shorten it if it's long silence. + # To avoid over-complication here, let's just append what's left. + # The primary goal is internal silences. + # However, if the last "non_silent_interval" was short and followed by a long silence, + # that silence needs to be handled here too. + silence_duration_samples = len(audio_array) - last_nonsilent_end + if silence_duration_samples > 0: + if silence_duration_samples >= min_silence_len_samples: + fixed_audio_parts.append( + audio_array[ + last_nonsilent_end : last_nonsilent_end + + max_silence_samples_to_keep + ] + ) + logger.debug( + f"Shortened trailing silence from {silence_duration_samples} to {max_silence_samples_to_keep} samples." + ) + else: + fixed_audio_parts.append(trailing_segment) + + if not fixed_audio_parts: # Should not happen if non_silent_intervals > 1 + logger.warning( + "Internal silence fixing resulted in no audio parts; returning original." + ) + return audio_array + + return np.concatenate(fixed_audio_parts) + + except Exception as e: + logger.error(f"Error during internal silence fixing: {e}", exc_info=True) + return audio_array + + +def remove_long_unvoiced_segments( + audio_array: np.ndarray, + sample_rate: int, + min_unvoiced_duration_ms: int = 300, + pitch_floor: float = 75.0, + pitch_ceiling: float = 600.0, +) -> np.ndarray: + """ + Removes segments from a NumPy audio array that are unvoiced for longer than + the specified duration, using Parselmouth for pitch analysis. + + Args: + audio_array: NumPy array (float32) of the audio. + sample_rate: Sample rate of the audio. + min_unvoiced_duration_ms: Minimum duration (ms) of an unvoiced segment to be removed. + pitch_floor: Minimum pitch (Hz) to consider for voicing. + pitch_ceiling: Maximum pitch (Hz) to consider for voicing. + + Returns: + NumPy audio array with long unvoiced segments removed. Original if Parselmouth not available or on error. + """ + if not PARSELMOUTH_AVAILABLE: + logger.warning("Parselmouth not available, skipping unvoiced segment removal.") + return audio_array + if audio_array is None or audio_array.size == 0: + return audio_array + + try: + sound = parselmouth.Sound( + audio_array.astype(np.float64), sampling_frequency=sample_rate + ) + pitch = sound.to_pitch(pitch_floor=pitch_floor, pitch_ceiling=pitch_ceiling) + voiced_unvoiced = pitch.get_VoicedVoicelessUnvoiced() + + segments_to_keep = [] + current_segment_start_sample = 0 + min_unvoiced_samples = int((min_unvoiced_duration_ms / 1000.0) * sample_rate) + + for i in range(len(voiced_unvoiced.time_intervals)): + interval_start_time, interval_end_time, is_voiced_str = ( + voiced_unvoiced.time_intervals[i] + ) + is_voiced = is_voiced_str == "voiced" + + interval_start_sample = int(interval_start_time * sample_rate) + interval_end_sample = int(interval_end_time * sample_rate) + interval_duration_samples = interval_end_sample - interval_start_sample + + if is_voiced: + segments_to_keep.append( + audio_array[current_segment_start_sample:interval_end_sample] + ) + current_segment_start_sample = interval_end_sample + else: # Unvoiced segment + if interval_duration_samples < min_unvoiced_samples: + segments_to_keep.append( + audio_array[current_segment_start_sample:interval_end_sample] + ) + current_segment_start_sample = interval_end_sample + else: + logger.debug( + f"Removing long unvoiced segment from {interval_start_time:.2f}s to {interval_end_time:.2f}s." + ) + # Append the audio *before* this long unvoiced segment (if any) + if interval_start_sample > current_segment_start_sample: + segments_to_keep.append( + audio_array[ + current_segment_start_sample:interval_start_sample + ] + ) + current_segment_start_sample = interval_end_sample + + if current_segment_start_sample < len(audio_array): + segments_to_keep.append(audio_array[current_segment_start_sample:]) + + if not segments_to_keep: + logger.warning( + "Unvoiced segment removal resulted in empty audio; returning original." + ) + return audio_array + + return np.concatenate(segments_to_keep) + + except Exception as e: + logger.error(f"Error during unvoiced segment removal: {e}", exc_info=True) + return audio_array + + +# --- Text Processing Utilities --- +def _is_valid_sentence_end(text: str, period_index: int) -> bool: + """ + Checks if a period at a given index in the text is likely a valid sentence terminator, + rather than part of an abbreviation, number, or version string. + """ + word_start_before_period = period_index - 1 + scan_limit = max(0, period_index - 10) + while ( + word_start_before_period >= scan_limit + and not text[word_start_before_period].isspace() + ): + word_start_before_period -= 1 + word_before_period = text[word_start_before_period + 1 : period_index + 1].lower() + if word_before_period in ABBREVIATIONS: + return False + + context_start = max(0, period_index - 10) + context_end = min(len(text), period_index + 10) + context_segment = text[context_start:context_end] + relative_period_index_in_context = period_index - context_start + + for pattern in [NUMBER_DOT_NUMBER_PATTERN, VERSION_PATTERN]: + for match in pattern.finditer(context_segment): + if match.start() <= relative_period_index_in_context < match.end(): + is_last_char_of_numeric_match = ( + relative_period_index_in_context == match.end() - 1 + ) + is_followed_by_space_or_eos = ( + period_index + 1 == len(text) or text[period_index + 1].isspace() + ) + if not (is_last_char_of_numeric_match and is_followed_by_space_or_eos): + return False + return True + + +def _split_text_by_punctuation(text: str) -> List[str]: + """ + Splits text into sentences based on common punctuation marks (.!?), + while trying to avoid splitting on periods used in abbreviations or numbers. + """ + sentences: List[str] = [] + last_split_index = 0 + text_length = len(text) + + for match in POTENTIAL_END_PATTERN.finditer(text): + punctuation_char_index = match.start(1) + punctuation_char = text[punctuation_char_index] + slice_end_after_punctuation = match.start(1) + 1 + len(match.group(2) or "") + + if punctuation_char in ["!", "?"]: + current_sentence_text = text[ + last_split_index:slice_end_after_punctuation + ].strip() + if current_sentence_text: + sentences.append(current_sentence_text) + last_split_index = match.end() + continue + + if punctuation_char == ".": + if ( + punctuation_char_index > 0 and text[punctuation_char_index - 1] == "." + ) or ( + punctuation_char_index < text_length - 1 + and text[punctuation_char_index + 1] == "." + ): + continue + + if _is_valid_sentence_end(text, punctuation_char_index): + current_sentence_text = text[ + last_split_index:slice_end_after_punctuation + ].strip() + if current_sentence_text: + sentences.append(current_sentence_text) + last_split_index = match.end() + + remaining_text_segment = text[last_split_index:].strip() + if remaining_text_segment: + sentences.append(remaining_text_segment) + + sentences = [s for s in sentences if s] + if not sentences and text.strip(): + return [text.strip()] + return sentences + + +def split_into_sentences(text: str) -> List[str]: + """ + Splits a given text into sentences. Handles normalization of line breaks + and considers bullet points as potential sentence separators. + This is the primary entry point for sentence splitting. + """ + if not text or text.isspace(): + return [] + + text = text.replace("\r\n", "\n").replace("\r", "\n") + bullet_point_matches = list(BULLET_POINT_PATTERN.finditer(text)) + + if bullet_point_matches: + logger.debug("Bullet points detected in text; splitting by bullet items.") + processed_sentences: List[str] = [] + current_position = 0 + for i, bullet_match in enumerate(bullet_point_matches): + bullet_actual_start_index = bullet_match.start() + if i == 0 and bullet_actual_start_index > current_position: + pre_bullet_segment = text[ + current_position:bullet_actual_start_index + ].strip() + if pre_bullet_segment: + processed_sentences.extend( + s for s in _split_text_by_punctuation(pre_bullet_segment) if s + ) + + next_bullet_start_index = ( + bullet_point_matches[i + 1].start() + if i + 1 < len(bullet_point_matches) + else len(text) + ) + bullet_item_segment = text[ + bullet_actual_start_index:next_bullet_start_index + ].strip() + if bullet_item_segment: + processed_sentences.append(bullet_item_segment) + current_position = next_bullet_start_index + + if current_position < len(text): + post_bullet_segment = text[current_position:].strip() + if post_bullet_segment: + processed_sentences.extend( + s for s in _split_text_by_punctuation(post_bullet_segment) if s + ) + return [s for s in processed_sentences if s] + else: + logger.debug( + "No bullet points detected; using punctuation-based sentence splitting." + ) + return _split_text_by_punctuation(text) + + +def _preprocess_and_segment_text(full_text: str) -> List[Tuple[Optional[str], str]]: + """ + Internal helper to segment text by non-verbal cues (e.g., (laughs)) and then + further split those segments into sentences. + Assigns a placeholder "tag" (here, None or empty string) as this system is single-speaker. + The tuple structure (tag, sentence) is maintained for compatibility with chunking logic + that might expect it, even if the tag itself isn't used for speaker differentiation. + + Args: + full_text: The complete input text. + + Returns: + A list of tuples, where each tuple is (placeholder_tag, sentence_text). + """ + if not full_text or full_text.isspace(): + return [] + + placeholder_tag: Optional[str] = None + segmented_with_tags: List[Tuple[Optional[str], str]] = [] + parts_and_cues = NON_VERBAL_CUE_PATTERN.split(full_text) + + for part in parts_and_cues: + if not part or part.isspace(): + continue + if NON_VERBAL_CUE_PATTERN.fullmatch(part): + segmented_with_tags.append((placeholder_tag, part.strip())) + else: + sentences_from_part = split_into_sentences(part.strip()) + for sentence in sentences_from_part: + if sentence: + segmented_with_tags.append((placeholder_tag, sentence)) + + if not segmented_with_tags and full_text.strip(): + segmented_with_tags.append((placeholder_tag, full_text.strip())) + + logger.debug( + f"Preprocessed text into {len(segmented_with_tags)} segments/sentences." + ) + return segmented_with_tags + + +def chunk_text_by_sentences( + full_text: str, + chunk_size: int, +) -> List[str]: + """ + Chunks text into manageable pieces for TTS processing, respecting sentence boundaries + and a maximum chunk character length. Designed for single-speaker text, but maintains + a structure that can handle segments (like non-verbal cues) separately. + + Args: + full_text: The complete text to be chunked. + chunk_size: The desired maximum character length for each chunk. + Sentences longer than this will form their own chunk. + + Returns: + A list of text chunks, ready for TTS. + """ + if not full_text or full_text.isspace(): + return [] + if chunk_size <= 0: + chunk_size = float("inf") + + processed_segments = _preprocess_and_segment_text(full_text) + if not processed_segments: + return [] + + text_chunks: List[str] = [] + current_chunk_sentences: List[str] = [] + current_chunk_length = 0 + + for ( + _, + segment_text, + ) in processed_segments: + segment_len = len(segment_text) + + if not current_chunk_sentences: + current_chunk_sentences.append(segment_text) + current_chunk_length = segment_len + elif current_chunk_length + 1 + segment_len <= chunk_size: + current_chunk_sentences.append(segment_text) + current_chunk_length += 1 + segment_len + else: + if current_chunk_sentences: + text_chunks.append(" ".join(current_chunk_sentences)) + current_chunk_sentences = [segment_text] + current_chunk_length = segment_len + + if current_chunk_length > chunk_size and len(current_chunk_sentences) == 1: + logger.info( + f"A single segment (length {current_chunk_length}) exceeds chunk_size {chunk_size}. " + f"It will form its own chunk." + ) + text_chunks.append(" ".join(current_chunk_sentences)) + current_chunk_sentences = [] + current_chunk_length = 0 + + if current_chunk_sentences: + text_chunks.append(" ".join(current_chunk_sentences)) + + text_chunks = [chunk for chunk in text_chunks if chunk.strip()] + + if not text_chunks and full_text.strip(): + logger.warning( + "Text chunking resulted in zero chunks despite non-empty input. Returning full text as one chunk." + ) + return [full_text.strip()] + + logger.info(f"Text chunking complete. Generated {len(text_chunks)} chunk(s).") + return text_chunks + + +# --- Performance Monitoring Utility --- +class PerformanceMonitor: + """ + A simple helper class for recording and reporting elapsed time for different + stages of an operation. Useful for debugging performance bottlenecks. + """ + + def __init__( + self, enabled: bool = True, logger_instance: Optional[logging.Logger] = None + ): + self.enabled: bool = enabled + self.logger = ( + logger_instance + if logger_instance is not None + else logging.getLogger(__name__) + ) + self.start_time: float = 0.0 + self.events: List[Tuple[str, float]] = [] + if self.enabled: + self.start_time = time.monotonic() + self.events.append(("Monitoring Started", self.start_time)) + + def record(self, event_name: str): + if not self.enabled: + return + self.events.append((event_name, time.monotonic())) + + def report(self, log_level: int = logging.DEBUG) -> str: + if not self.enabled or not self.events: + return "Performance monitoring was disabled or no events recorded." + + report_lines = ["Performance Report:"] + last_event_time = self.events[0][1] + + for i in range(1, len(self.events)): + event_name, timestamp = self.events[i] + prev_event_name, _ = self.events[i - 1] + duration_since_last = timestamp - last_event_time + duration_since_start = timestamp - self.start_time + report_lines.append( + f" - Event: '{event_name}' (after '{prev_event_name}') " + f"took {duration_since_last:.4f}s. Total elapsed: {duration_since_start:.4f}s" + ) + last_event_time = timestamp + + total_duration = self.events[-1][1] - self.start_time + report_lines.append(f"Total Monitored Duration: {total_duration:.4f}s") + full_report_str = "\n".join(report_lines) + + if self.logger: + self.logger.log(log_level, full_report_str) + return full_report_str From e8d95f9c48b82c1bbeae86714d3ceabfa0339f12 Mon Sep 17 00:00:00 2001 From: Benjamin Kobjolke Date: Wed, 18 Mar 2026 07:56:38 +0100 Subject: [PATCH 4/4] FEATURE (setup): add GPU installation script and Windows bat files - add install_gpu.bat to create venv and install CUDA dependencies - add activate_environment.bat and start.bat for Windows usage - remove torch, torchaudio, onnxruntime-gpu from requirements-nvidia.txt as they are installed separately by install_gpu.bat with correct CUDA index --- activate_environment.bat | 1 + install_gpu.bat | 41 ++++++++++++++++++++++++++++++++++++++++ start.bat | 5 +++++ 3 files changed, 47 insertions(+) create mode 100644 activate_environment.bat create mode 100644 install_gpu.bat create mode 100644 start.bat diff --git a/activate_environment.bat b/activate_environment.bat new file mode 100644 index 0000000..6a8e4a5 --- /dev/null +++ b/activate_environment.bat @@ -0,0 +1 @@ +%~dp0venv\Scripts\activate.bat diff --git a/install_gpu.bat b/install_gpu.bat new file mode 100644 index 0000000..4ff7c75 --- /dev/null +++ b/install_gpu.bat @@ -0,0 +1,41 @@ +@echo off +title Kitten-TTS GPU Installation + +echo Creating virtual environment with Python 3.11... +uv venv --python 3.11 --seed venv +if errorlevel 1 ( + echo Failed to create virtual environment. + pause + exit /b 1 +) + +echo. +echo Step 1: Installing GPU-enabled ONNX Runtime... +venv\Scripts\python.exe -m pip install onnxruntime-gpu +if errorlevel 1 ( + echo Failed to install onnxruntime-gpu. + pause + exit /b 1 +) + +echo. +echo Step 2: Installing PyTorch with CUDA 12.1 support... +venv\Scripts\python.exe -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 +if errorlevel 1 ( + echo Failed to install PyTorch. + pause + exit /b 1 +) + +echo. +echo Step 3: Installing remaining dependencies... +venv\Scripts\python.exe -m pip install -r requirements-nvidia.txt +if errorlevel 1 ( + echo Failed to install requirements. + pause + exit /b 1 +) + +echo. +echo Installation complete! +pause diff --git a/start.bat b/start.bat new file mode 100644 index 0000000..4cfe5a0 --- /dev/null +++ b/start.bat @@ -0,0 +1,5 @@ +@echo off +title Kitten-TTS Server +call activate_environment.bat + +python server.py \ No newline at end of file