A LoRA / QLoRA finetuning pipeline for MLX TTS models on Apple Silicon. Supports Qwen3-TTS, CSM, and PersonaPlex, with two distinct training strategies.
Use this when: you want the model to speak a new language or accent (e.g. Hindi). Voice identity still comes from a reference audio clip at inference — you are teaching the language, not baking in a specific voice.
python scripts/train.py --config configs/qwen3_tts_hindi.yamlHow it works:
- Applies LoRA adapters to all attention + MLP layers inside the
talkertransformer - Loss:
main_codec_loss + 0.3 × sub_talker_loss(teacher-forced codec prediction) - Does not use
ref_audioduring training — speaker identity is irrelevant - After training: load the adapter at inference with any ref audio to clone any voice in the new language
Good for: Hindi, regional languages, new accents, domain-specific speech styles
Use this when: you want to permanently bake a specific person's voice into the model so it speaks as that person without needing a reference clip at inference.
# Step 1 — train
python scripts/train.py --config configs/qwen3_tts_speaker.yaml
# Step 2 — bake the speaker identity into codec_embedding[3000]
python scripts/bake_speaker_embedding.py \
--config configs/qwen3_tts_speaker.yaml \
--checkpoint checkpoints/qwen3-speaker/checkpoint-final \
--output checkpoints/qwen3-speaker/custom_voice_modelHow it works (mirrors official sft_12hz.py):
- Every training sample must have a
ref_audiofield pointing to a clip of the target speaker - During each forward pass,
model.speaker_encoderextracts a speaker embedding from the ref audio mel spectrogram - The speaker embedding is injected as a single positional token into the codec prefix (between the think/lang section and
[pad, bos]) — matches the officialsft_12hz.pypositional injection approach - LoRA adapters learn to generate codec tokens conditioned on that speaker identity
- After training,
bake_speaker_embedding.pyaverages the speaker embeddings across the training set and writes the mean intotalker.model.codec_embedding.weight[3000]— the reserved custom-voice slot — then patchesconfig.jsonwithtts_model_type: custom_voice
Good for: cloning a single specific voice, podcast/audiobook voice replication, personal TTS
| Language Adaptation | Speaker Voice Cloning | |
|---|---|---|
| Config | qwen3_tts_hindi.yaml |
qwen3_tts_speaker.yaml |
| model_type | qwen3_tts |
qwen3_tts_speaker |
| ref_audio in data | Not required | Required for every sample |
| Epochs | 10–15 | 3 (more risks overfitting) |
| LoRA rank | 8 | 16 |
| Effective batch | 32 | 16 |
| Post-training step | None | bake_speaker_embedding.py |
| At inference | Needs ref audio to pick a voice | Works without ref audio |
You can combine both: train with the language config first, then run the speaker config on top using the language-adapted checkpoint as the starting point.
# Install dependencies
pip install mlx-audio soundfile scipy datasets transformers gradio pyyaml safetensors
# Verify pipeline works (no data needed)
python scripts/train.py --config configs/qwen3_tts_hindi.yaml --smoke-test# 1. Download Hindi dataset
python scripts/prepare_hindi_dataset.py --source hf --output data/hindi
# 2. Pre-tokenize audio to codec IDs (run once — saves .codec.npy files)
python scripts/preprocess_dataset.py --input data/hindi/train.jsonl data/hindi/val.jsonl
# 3. Train
python scripts/train.py --config configs/qwen3_tts_hindi.yaml
# 4. Demo — compare before/after with a reference voice
python scripts/demo.py --adapter checkpoints/qwen3-hindi/checkpoint-bestYour JSONL must have ref_audio on every line, all pointing to the same target speaker:
{"audio": "data/speaker/clip1.wav", "text": "Hello world", "ref_audio": "data/speaker/ref.wav"}
{"audio": "data/speaker/clip2.wav", "text": "How are you", "ref_audio": "data/speaker/ref.wav"}# 1. Pre-tokenize audio
python scripts/preprocess_dataset.py \
--input data/speaker/train.jsonl data/speaker/val.jsonl
# 2. Train (3 epochs recommended)
python scripts/train.py --config configs/qwen3_tts_speaker.yaml
# 3. Bake speaker embedding into the model
python scripts/bake_speaker_embedding.py \
--config configs/qwen3_tts_speaker.yaml \
--checkpoint checkpoints/qwen3-speaker/checkpoint-final \
--output checkpoints/qwen3-speaker/custom_voice_model
# 4. Use the baked model at inference (no ref_audio needed)
# The output dir contains adapters.safetensors + speaker_embedding.npymlx-audio-train/
│
├── scripts/
│ ├── train.py # Main finetuning entry point
│ ├── bake_speaker_embedding.py # Post-training: write speaker → codec_embedding[3000]
│ ├── preprocess_dataset.py # Pre-tokenize audio → .codec.npy files
│ ├── prepare_hindi_dataset.py # Download & format Hindi TTS data
│ └── demo.py # Gradio voice-cloning demo
│
├── configs/
│ ├── qwen3_tts_hindi.yaml # Pipeline 1: language adaptation
│ ├── qwen3_tts_speaker.yaml # Pipeline 2: speaker voice cloning
│ └── qwen3_tts_multilingual.yaml # Pipeline 3: multilingual (hi/ta/te/kn/mr)
│
├── train/
│ ├── lora.py # LoRA / QLoRA layer implementations
│ ├── trainer.py # Training loop (AdamW, grad accum, checkpointing)
│ └── losses/
│ └── codec_loss.py # qwen3_tts_loss, qwen3_tts_speaker_loss, csm_loss
│
├── data/
│ ├── audio_utils.py # Audio I/O, resampling, mel spectrogram
│ ├── base_dataset.py # JSONL dataset + BatchIterator (prefetch, length-sort)
│ └── processors/
│ ├── qwen3_tts.py # audio→codec tokens, text→IDs, ref_mel extraction
│ └── csm.py # CSM: Mimi RVQ codes + LLaMA tokenizer
│
└── checkpoints/ # Saved LoRA adapters
python scripts/train.py --config CONFIG [OPTIONS]
Options:
--smoke-test Run 5 steps with dummy data (no dataset needed)
--resume PATH Resume from a checkpoint directory
--lora-rank N Override LoRA rank
--lr FLOAT Override learning rate
--epochs N Override num_epochs
--max-steps N Stop after N optimizer stepspython scripts/bake_speaker_embedding.py \
--config CONFIG_YAML \
--checkpoint CKPT_DIR \
--output OUTPUT_DIR \
[--slot N] \ # codec_embedding row to use (default 3000)
[--fuse-lora] # merge LoRA into base weightsPre-tokenizes audio to .codec.npy files once, so training skips the speech tokenizer entirely (major speedup + avoids OOM).
python scripts/preprocess_dataset.py --input data/hindi/train.jsonl [data/hindi/val.jsonl ...]LoRALinear— wrapsnn.Linear(full-precision base)QLoRALinear— wrapsnn.QuantizedLinear(quantized base + bf16 LoRA delta)apply_lora(model, config)— recursive in-place patching; scoped totalkerfor Qwen3-TTSget_trainable_params(model)— returns onlylora_a/lora_btensors (avoids 40+ GB gradient memory)save_adapters / load_adapters— tiny safetensors checkpoint (23 MB for rank-8, 45 MB for rank-16)
- AdamW + gradient accumulation + gradient norm clipping
- Cosine / linear / constant LR schedule with warmup
- Per-step and per-epoch checkpointing (LoRA adapters only)
- Batched gradient norm computation (single GPU sync per optimizer step)
| Function | Pipeline | Description |
|---|---|---|
qwen3_tts_loss |
Language Adaptation | main_loss + 0.3 × sub_talker_loss; no speaker conditioning |
qwen3_tts_speaker_loss |
Speaker Cloning | Same loss + speaker embedding injected from ref_mel; all 16 codec levels as input context |
csm_loss |
CSM | Cross-entropy on first codebook head |
No PyTorch dependency — uses soundfile + scipy.signal.
load_audio(path, target_sr)— load + resample any audio filenormalize_loudness(audio)— RMS normalization to −23 dBFStrim_silence(audio)— energy-based silence trimmingmel_spectrogram(audio, sr)— log-mel spectrogram matching Qwen3-TTS speaker encoder params (n_fft=1024, n_mels=128, hop=256)
sort_by_length=True— groups similar-length sequences to minimise padding wasteprefetch=2— loads the next N batches in a background thread, overlapping CPU I/O with GPU compute
| Key | Value | Notes |
|---|---|---|
model.model_type |
qwen3_tts |
|
lora.rank |
8 |
|
trainer.num_epochs |
15 |
|
trainer.batch_size |
2 |
M-series Metal pressure |
trainer.grad_accumulation |
16 |
effective batch = 32 |
trainer.learning_rate |
2e-5 |
|
trainer.label_smoothing |
0.1 |
helps with character diversity |
processor.include_ref_mel |
false |
not needed for language training |
| Key | Value | Notes |
|---|---|---|
model.model_type |
qwen3_tts_speaker |
enables speaker loss |
lora.rank |
16 |
more capacity for speaker detail |
trainer.num_epochs |
3 |
more epochs → overfitting risk |
trainer.grad_accumulation |
8 |
effective batch = 16 |
trainer.learning_rate |
2e-5 |
matches official |
trainer.label_smoothing |
0.0 |
sharp speaker predictions |
processor.include_ref_mel |
true |
enables mel → speaker_encoder path |
processor.speaker_name |
custom_speaker |
written into config.json by bake script |
| Model | model_type |
Status |
|---|---|---|
| Qwen3-TTS 0.6B Base 8bit | qwen3_tts / qwen3_tts_speaker |
Working |
| Qwen3-TTS 1.7B Base 8bit | qwen3_tts / qwen3_tts_speaker |
Same pipeline, change model_id |
| PersonaPlex 7B | personaplex |
Working |
| CSM (Sesame) | csm |
Processor + loss implemented |
| Kokoro | — | Planned |
| Chatterbox | — | Planned |
PersonaPlex is a 7B full-duplex speech model (simultaneous ASR + TTS). This repo supports LoRA finetuning it for Hindi language adaptation.
# 1. Prepare paired Hindi data (manifest.json + tokens/*.npz)
python scripts/prepare_personaplex_dataset.py
# 2. Train
python scripts/train.py --config configs/personaplex_hindi.yaml
# 3. Run full-duplex inference (in the personaplex-mlx repo)
python -m personaplex_mlx.local_web \
--adapter-weights /path/to/checkpoints/personaplex-hindi/checkpoint-best/adapters.npz| Component | Trainable | Size | Notes |
|---|---|---|---|
Transformer attention (in_proj, out_proj) |
LoRA A/B only | ~29 MB | rank-16 |
Depformer bridge (linear_in) |
LoRA A/B only | small | bridges transformer→depformer; must match training at inference |
Audio embeddings (audio_embs) |
Full weights | ~268 MB | 16 codebooks × 2049 tokens |
Vocab projection (text_linear) |
Frozen | 0 MB | see note below |
Output norm (out_norm) |
Full weights | tiny |
The text_linear layer is a (32000 × 4096) = 131M parameter vocabulary projection. Training it on a small dataset (~1K samples) causes severe overfitting — the model memorises training text patterns and produces repetitive garbage (commas, byte tokens) at inference.
Always set freeze_text_linear: true in the config unless you have 10K+ samples:
lora:
rank: 16
alpha: 16.0
dropout: 0.0
train_depformer: false
freeze_text_linear: true # freeze 131M vocab projection — prevents overfitting on small datasetsWith freeze_text_linear: true, adapter size drops from ~560 MB to ~297 MB and inference quality is significantly better on small datasets.
| Samples | Expected quality |
|---|---|
| ~1K | Overfits; fails on unseen Hindi phonemes |
| ~3–5K | Reasonable Hindi phoneme coverage |
| ~10K+ | Solid generalisation |
Use scripts/download_indicvoices.py or scripts/download_common_voice.py to pull more Hindi data.
The official sft_12hz.py is full fine-tuning on CUDA with PyTorch + HuggingFace Accelerate. This repo is LoRA / QLoRA on Apple Silicon with MLX.
| Official | This repo | |
|---|---|---|
| Framework | PyTorch + Accelerate | MLX |
| Training mode | Full fine-tune (all weights) | LoRA (1–2% of weights) |
| Precision | bfloat16 | 8-bit quantised base + bf16 LoRA delta |
| Effective batch | 8 (bs=2 × accum=4) | 16–32 |
| Epochs | 3 | 3 (speaker) / 10–15 (language) |
| Speaker injection | Position 6 in dual-channel format | Positional token in codec prefix (matches official) |
| All 16 codec levels | Yes (additive input embeddings) | No (first codebook only as input; code_predictor is an auxiliary loss head) |
| Post-training step | Bake speaker to codec_embedding[3000] |
bake_speaker_embedding.py (same) |
| Checkpoint size | Full model (~1.2 GB) | Adapters only (~46 MB) |