Skip to content

Commit 94f451e

Browse files
committed
fix: исправить декодирование multi-codebook и оптимизировать память
- Исправить trailing_text_hidden для non_streaming_mode (использовать tts_pad) - Конвертировать interleaved токены в multi-codebook формат для декодера - Хранить только последний hidden state вместо всех (экономия памяти) - Добавить fade-out 100ms для устранения шума в конце аудио - Уменьшить max_seq_len до 400 (~33 сек при 12Hz) - Убрать задержку при скрытии overlay инициализации - Увеличить высоту окна приложения
1 parent dbcd2c8 commit 94f451e

5 files changed

Lines changed: 74 additions & 99 deletions

File tree

crates/acoustic-model/src/model.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,8 @@ impl Model {
565565
let mut sampler = Sampler::new(sampling_config.clone());
566566
let mut generated_zeroth: Vec<u32> = Vec::with_capacity(max_new_tokens);
567567
let mut all_frames: Vec<Vec<u32>> = Vec::with_capacity(max_new_tokens);
568-
let mut all_hidden_states: Vec<Tensor> = Vec::with_capacity(max_new_tokens);
568+
// Note: We only keep the last hidden state to save memory
569+
let mut last_hidden_state: Option<Tensor> = None;
569570

570571
// Suppress tokens: special tokens (2048-3071) except EOS
571572
let suppress_start = 2048u32;
@@ -626,7 +627,7 @@ impl Model {
626627
}
627628

628629
generated_zeroth.push(current_zeroth_token);
629-
all_hidden_states.push(hidden_states.i((.., seq_len - 1..seq_len, ..))?.clone());
630+
last_hidden_state = Some(hidden_states.i((.., seq_len - 1..seq_len, ..))?.clone());
630631

631632
// Predict residual codebooks for first token using CodePredictor
632633
let first_hidden = hidden_states.i((.., seq_len - 1..seq_len, ..))?;
@@ -735,14 +736,17 @@ impl Model {
735736
// Sample next zeroth token
736737
let next_zeroth = sampler.sample(&logits_vec);
737738

738-
// Log progress
739-
if step % 50 == 0 || next_zeroth >= 2048 {
739+
// Log progress every 10 steps or on special events
740+
let is_eos = Some(next_zeroth) == eos_token_id;
741+
if step % 10 == 0 || next_zeroth >= 2048 || is_eos {
740742
info!(
741-
"Step {}: zeroth={}, eos={:?}, is_eos={}, generated={}",
743+
"Step {}/{}: zeroth={}, eos={:?}, is_eos={}, min_reached={}, generated={}",
742744
step,
745+
max_new_tokens,
743746
next_zeroth,
744747
eos_token_id,
745-
Some(next_zeroth) == eos_token_id,
748+
is_eos,
749+
generated_zeroth.len() >= min_new_tokens,
746750
generated_zeroth.len()
747751
);
748752
}
@@ -758,8 +762,8 @@ impl Model {
758762
break;
759763
}
760764

761-
// Store hidden state
762-
all_hidden_states.push(hidden_states.clone());
765+
// Keep only last hidden state to save memory
766+
last_hidden_state = Some(hidden_states.clone());
763767

764768
generated_zeroth.push(next_zeroth);
765769

@@ -792,14 +796,13 @@ impl Model {
792796
all_frames.len()
793797
);
794798

795-
// Concatenate hidden states
796-
let concatenated_hidden = if all_hidden_states.is_empty() {
797-
Tensor::zeros((1, 0, self.config.hidden_size), DType::F32, &self.device)?
798-
} else {
799-
Tensor::cat(&all_hidden_states, 1)?
800-
};
799+
// Return last hidden state only (or empty tensor)
800+
let final_hidden = last_hidden_state.unwrap_or_else(|| {
801+
Tensor::zeros((1, 1, self.config.hidden_size), DType::F32, &self.device)
802+
.expect("Failed to create empty hidden state")
803+
});
801804

802-
Ok((generated_zeroth, all_frames, concatenated_hidden))
805+
Ok((generated_zeroth, all_frames, final_hidden))
803806
}
804807

805808
/// Suppress special tokens in logits (set to -inf) except for EOS token.

crates/runtime/src/pipeline.rs

Lines changed: 44 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl Default for PipelineConfig {
5252
default_lang: Lang::Ru,
5353
crossfade_ms: DEFAULT_CROSSFADE_MS,
5454
chunk_tokens: 10,
55-
max_seq_len: 4096, // Max sequence length for audio generation
55+
max_seq_len: 400, // Max sequence length for audio generation (~33 seconds at 12Hz)
5656
default_speaker: None,
5757
is_custom_voice: false,
5858
}
@@ -749,50 +749,30 @@ impl TtsPipeline {
749749
info!("min_new_tokens={} (matching Python SDK)", min_tokens);
750750

751751
// ========== COMPUTE trailing_text_hidden ==========
752-
// Python SDK (modeling_qwen3_tts.py:2230-2232):
753-
// trailing_text_hidden = torch.cat((self.talker.text_projection(
754-
// self.talker.get_text_embeddings()(input_id[:, 4:-5])
755-
// ), tts_eos_embed), dim=1)
752+
// In non_streaming_mode (which we use), ALL text tokens are already in prefill.
753+
// Python SDK (modeling_qwen3_tts.py:2227):
754+
// trailing_text_hidden = tts_pad_embed
756755
//
757-
// This is: text_projection(text_tokens[1:]) concatenated with tts_eos_embed
758-
// The first text token goes into prefill, remaining are for trailing conditioning
756+
// This is just a single tts_pad embedding that gets used for ALL generation steps.
757+
// The text conditioning comes from the prefill, not from trailing_text_hidden.
759758
//
760-
// For text "Hello world" with tokens [15339, 1917]:
761-
// - text_tokens[0] = 15339 goes into prefill (combined with codec embeddings)
762-
// - text_tokens[1:] = [1917] + tts_eos becomes trailing_text_hidden
763-
//
764-
// During generation step i:
765-
// - if i < len(trailing_text_hidden): use trailing_text_hidden[i]
766-
// - else: use tts_pad_embed
767-
// trailing_text_hidden: text conditioning for generation steps
768-
// Each step uses trailing_text_hidden[step] until exhausted, then uses tts_pad_embed
769-
let trailing_text_hidden = if text_tokens.len() > 1 {
770-
// Build trailing tokens: text_tokens[1:] + tts_eos
771-
let mut trailing_tokens: Vec<u32> = text_tokens[1..].to_vec();
772-
trailing_tokens.push(st.tts_eos_token_id);
773-
774-
let trailing_tensor = Tensor::new(trailing_tokens.as_slice(), &self.device)
775-
.map_err(|e| TtsError::inference(format!("tensor creation failed: {e}")))?
776-
.unsqueeze(0)
777-
.map_err(|e| TtsError::inference(format!("unsqueeze failed: {e}")))?;
778-
779-
// Get text embeddings with projection
780-
let trailing_embed = model
781-
.get_text_embedding(&trailing_tensor)
782-
.map_err(|e| TtsError::inference(format!("trailing text embed failed: {e}")))?;
783-
784-
info!(
785-
"Built trailing_text_hidden from {} tokens (text[1:] + eos), shape: {:?}",
786-
trailing_tokens.len(),
787-
trailing_embed.dims()
788-
);
759+
// IMPORTANT: In non_streaming_mode, trailing_text_hidden is NOT the remaining text!
760+
// It's just tts_pad because all text is already encoded in the prefill.
761+
let tts_pad_tensor = Tensor::new(&[st.tts_pad_token_id], &self.device)
762+
.map_err(|e| TtsError::inference(format!("tensor creation failed: {e}")))?
763+
.unsqueeze(0)
764+
.map_err(|e| TtsError::inference(format!("unsqueeze failed: {e}")))?;
789765

790-
Some(trailing_embed)
791-
} else {
792-
// Only one text token, no trailing hidden needed
793-
info!("Single text token, no trailing_text_hidden");
794-
None
795-
};
766+
let tts_pad_embed = model
767+
.get_text_embedding(&tts_pad_tensor)
768+
.map_err(|e| TtsError::inference(format!("tts_pad embed failed: {e}")))?;
769+
770+
info!(
771+
"non_streaming_mode: trailing_text_hidden = tts_pad_embed only, shape: {:?}",
772+
tts_pad_embed.dims()
773+
);
774+
775+
let trailing_text_hidden = Some(tts_pad_embed);
796776

797777
// Use new method with CodePredictor if available
798778
// This correctly sums all 16 codebook embeddings at each generation step
@@ -1247,37 +1227,34 @@ impl TtsPipeline {
12471227
"Acoustic tokens generated"
12481228
);
12491229

1250-
// 4. Filter out special tokens before decoding
1251-
// Codec special tokens (2148-2157) are control tokens, not audio data:
1252-
// - 2148: codec_pad_id
1253-
// - 2149: codec_bos_id
1254-
// - 2150: codec_eos_id
1255-
// - 2154: codec_think_id
1256-
// - 2155: codec_nothink_id
1257-
// - 2156: codec_think_bos_id
1258-
// - 2157: codec_think_eos_id
1259-
// Audio tokens are in range 0-2047 (codec vocab_size)
1260-
let filtered_tokens: Vec<u32> = acoustic_tokens
1261-
.iter()
1262-
.filter(|&&t| t < 2048) // valid audio tokens
1263-
.copied()
1264-
.collect();
1230+
// 4. acoustic_tokens is in interleaved format: [c0_f0, c1_f0, ..., c15_f0, c0_f1, ...]
1231+
// Convert to multi-codebook format: Vec<Vec<u32>> where each inner Vec is one codebook
1232+
const NUM_CODEBOOKS: usize = 16;
1233+
let num_frames = acoustic_tokens.len() / NUM_CODEBOOKS;
12651234

1266-
if filtered_tokens.is_empty() {
1267-
return Err(TtsError::inference(
1268-
"no valid audio tokens generated (all were special tokens)".to_string(),
1269-
));
1235+
if num_frames == 0 {
1236+
return Err(TtsError::inference("no audio frames generated".to_string()));
1237+
}
1238+
1239+
// Reshape interleaved to [num_codebooks][num_frames]
1240+
let mut multi_tokens: Vec<Vec<u32>> = vec![Vec::with_capacity(num_frames); NUM_CODEBOOKS];
1241+
for frame_idx in 0..num_frames {
1242+
for cb_idx in 0..NUM_CODEBOOKS {
1243+
let token = acoustic_tokens[frame_idx * NUM_CODEBOOKS + cb_idx];
1244+
// Clamp tokens to valid range (0-2047)
1245+
let clamped = if token >= 2048 { 0 } else { token };
1246+
multi_tokens[cb_idx].push(clamped);
1247+
}
12701248
}
12711249

12721250
debug!(
1273-
original = acoustic_tokens.len(),
1274-
filtered = filtered_tokens.len(),
1275-
removed = acoustic_tokens.len() - filtered_tokens.len(),
1276-
"Filtered special tokens"
1251+
num_frames = num_frames,
1252+
codebooks = NUM_CODEBOOKS,
1253+
"Converted interleaved to multi-codebook format"
12771254
);
12781255

1279-
// 5. Decode to audio
1280-
let audio = self.decode_audio(&filtered_tokens)?;
1256+
// 5. Decode using multi-codebook decoder
1257+
let audio = self.codec.decode_multi(&multi_tokens)?;
12811258
debug!(
12821259
samples = audio.num_samples(),
12831260
duration_ms = audio.duration_ms(),

crates/tts-app/src/tts.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tauri::State;
1111
use tokio::sync::Mutex;
1212
use tracing::{error, info};
1313

14-
use audio_codec_12hz::apply_fade_in;
14+
use audio_codec_12hz::{apply_fade_in, apply_fade_out};
1515
use runtime::TtsPipeline;
1616
use tts_core::Lang;
1717

@@ -200,10 +200,11 @@ pub async fn speak(
200200
let sample_rate = audio.sample_rate;
201201
let num_samples = audio.num_samples();
202202

203-
// Apply fade-in to remove artifacts
203+
// Apply fade-in/fade-out to remove artifacts at beginning and end
204204
// audio.pcm is Arc<[f32]>, need to clone to mutable Vec
205205
let mut samples: Vec<f32> = audio.pcm.to_vec();
206-
apply_fade_in(&mut samples, 50.0, sample_rate);
206+
apply_fade_in(&mut samples, 20.0, sample_rate); // 20ms fade-in
207+
apply_fade_out(&mut samples, 100.0, sample_rate); // 100ms fade-out (longer to remove end noise)
207208

208209
// Create WAV in memory
209210
let wav_data = create_wav_buffer(&samples, sample_rate);

crates/tts-app/tauri.conf.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
{
1313
"title": "Qwen3-TTS",
1414
"width": 600,
15-
"height": 700,
15+
"height": 900,
1616
"resizable": true,
1717
"center": true,
1818
"minWidth": 400,
19-
"minHeight": 500
19+
"minHeight": 600
2020
}
2121
],
2222
"security": {

crates/tts-app/ui/index.html

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,13 @@ <h1>Qwen3-TTS</h1>
450450
console.log('Setting isInitialized = true');
451451
isInitialized = true;
452452

453-
// Show success message briefly
454-
showInitStatus('Model loaded! Starting...', 'success');
453+
// Hide overlay immediately
454+
console.log('Hiding initOverlay immediately');
455+
initOverlay.style.display = 'none';
456+
initOverlay.classList.add('hidden');
455457

456-
// Small delay to let user see success, then hide overlay
457-
setTimeout(() => {
458-
console.log('Hiding initOverlay, element:', initOverlay);
459-
initOverlay.style.display = 'none'; // Direct style as backup
460-
initOverlay.classList.add('hidden');
461-
console.log('initOverlay classes after:', initOverlay.className);
462-
463-
speakBtn.disabled = false;
464-
console.log('Init complete!');
465-
}, 500);
458+
speakBtn.disabled = false;
459+
console.log('Init complete! Main UI should be visible now.');
466460

467461
} catch (e) {
468462
console.error('initTTS error:', e);

0 commit comments

Comments
 (0)