@@ -52,7 +52,7 @@ impl Default for PipelineConfig {
5252 default_lang : Lang :: Ru ,
5353 crossfade_ms : DEFAULT_CROSSFADE_MS ,
5454 chunk_tokens : 10 ,
55- max_seq_len : 4096 , // Max sequence length for audio generation
55+ max_seq_len : 400 , // Max sequence length for audio generation (~33 seconds at 12Hz)
5656 default_speaker : None ,
5757 is_custom_voice : false ,
5858 }
@@ -749,50 +749,30 @@ impl TtsPipeline {
749749 info ! ( "min_new_tokens={} (matching Python SDK)" , min_tokens) ;
750750
751751 // ========== COMPUTE trailing_text_hidden ==========
752- // Python SDK (modeling_qwen3_tts.py:2230-2232):
753- // trailing_text_hidden = torch.cat((self.talker.text_projection(
754- // self.talker.get_text_embeddings()(input_id[:, 4:-5])
755- // ), tts_eos_embed), dim=1)
752+ // In non_streaming_mode (which we use), ALL text tokens are already in prefill.
753+ // Python SDK (modeling_qwen3_tts.py:2227):
754+ // trailing_text_hidden = tts_pad_embed
756755 //
757- // This is: text_projection(text_tokens[1:]) concatenated with tts_eos_embed
758- // The first text token goes into prefill, remaining are for trailing conditioning
756+ // This is just a single tts_pad embedding that gets used for ALL generation steps.
757+ // The text conditioning comes from the prefill, not from trailing_text_hidden.
759758 //
760- // For text "Hello world" with tokens [15339, 1917]:
761- // - text_tokens[0] = 15339 goes into prefill (combined with codec embeddings)
762- // - text_tokens[1:] = [1917] + tts_eos becomes trailing_text_hidden
763- //
764- // During generation step i:
765- // - if i < len(trailing_text_hidden): use trailing_text_hidden[i]
766- // - else: use tts_pad_embed
767- // trailing_text_hidden: text conditioning for generation steps
768- // Each step uses trailing_text_hidden[step] until exhausted, then uses tts_pad_embed
769- let trailing_text_hidden = if text_tokens. len ( ) > 1 {
770- // Build trailing tokens: text_tokens[1:] + tts_eos
771- let mut trailing_tokens: Vec < u32 > = text_tokens[ 1 ..] . to_vec ( ) ;
772- trailing_tokens. push ( st. tts_eos_token_id ) ;
773-
774- let trailing_tensor = Tensor :: new ( trailing_tokens. as_slice ( ) , & self . device )
775- . map_err ( |e| TtsError :: inference ( format ! ( "tensor creation failed: {e}" ) ) ) ?
776- . unsqueeze ( 0 )
777- . map_err ( |e| TtsError :: inference ( format ! ( "unsqueeze failed: {e}" ) ) ) ?;
778-
779- // Get text embeddings with projection
780- let trailing_embed = model
781- . get_text_embedding ( & trailing_tensor)
782- . map_err ( |e| TtsError :: inference ( format ! ( "trailing text embed failed: {e}" ) ) ) ?;
783-
784- info ! (
785- "Built trailing_text_hidden from {} tokens (text[1:] + eos), shape: {:?}" ,
786- trailing_tokens. len( ) ,
787- trailing_embed. dims( )
788- ) ;
759+ // IMPORTANT: In non_streaming_mode, trailing_text_hidden is NOT the remaining text!
760+ // It's just tts_pad because all text is already encoded in the prefill.
761+ let tts_pad_tensor = Tensor :: new ( & [ st. tts_pad_token_id ] , & self . device )
762+ . map_err ( |e| TtsError :: inference ( format ! ( "tensor creation failed: {e}" ) ) ) ?
763+ . unsqueeze ( 0 )
764+ . map_err ( |e| TtsError :: inference ( format ! ( "unsqueeze failed: {e}" ) ) ) ?;
789765
790- Some ( trailing_embed)
791- } else {
792- // Only one text token, no trailing hidden needed
793- info ! ( "Single text token, no trailing_text_hidden" ) ;
794- None
795- } ;
766+ let tts_pad_embed = model
767+ . get_text_embedding ( & tts_pad_tensor)
768+ . map_err ( |e| TtsError :: inference ( format ! ( "tts_pad embed failed: {e}" ) ) ) ?;
769+
770+ info ! (
771+ "non_streaming_mode: trailing_text_hidden = tts_pad_embed only, shape: {:?}" ,
772+ tts_pad_embed. dims( )
773+ ) ;
774+
775+ let trailing_text_hidden = Some ( tts_pad_embed) ;
796776
797777 // Use new method with CodePredictor if available
798778 // This correctly sums all 16 codebook embeddings at each generation step
@@ -1247,37 +1227,34 @@ impl TtsPipeline {
12471227 "Acoustic tokens generated"
12481228 ) ;
12491229
1250- // 4. Filter out special tokens before decoding
1251- // Codec special tokens (2148-2157) are control tokens, not audio data:
1252- // - 2148: codec_pad_id
1253- // - 2149: codec_bos_id
1254- // - 2150: codec_eos_id
1255- // - 2154: codec_think_id
1256- // - 2155: codec_nothink_id
1257- // - 2156: codec_think_bos_id
1258- // - 2157: codec_think_eos_id
1259- // Audio tokens are in range 0-2047 (codec vocab_size)
1260- let filtered_tokens: Vec < u32 > = acoustic_tokens
1261- . iter ( )
1262- . filter ( |& & t| t < 2048 ) // valid audio tokens
1263- . copied ( )
1264- . collect ( ) ;
1230+ // 4. acoustic_tokens is in interleaved format: [c0_f0, c1_f0, ..., c15_f0, c0_f1, ...]
1231+ // Convert to multi-codebook format: Vec<Vec<u32>> where each inner Vec is one codebook
1232+ const NUM_CODEBOOKS : usize = 16 ;
1233+ let num_frames = acoustic_tokens. len ( ) / NUM_CODEBOOKS ;
12651234
1266- if filtered_tokens. is_empty ( ) {
1267- return Err ( TtsError :: inference (
1268- "no valid audio tokens generated (all were special tokens)" . to_string ( ) ,
1269- ) ) ;
1235+ if num_frames == 0 {
1236+ return Err ( TtsError :: inference ( "no audio frames generated" . to_string ( ) ) ) ;
1237+ }
1238+
1239+ // Reshape interleaved to [num_codebooks][num_frames]
1240+ let mut multi_tokens: Vec < Vec < u32 > > = vec ! [ Vec :: with_capacity( num_frames) ; NUM_CODEBOOKS ] ;
1241+ for frame_idx in 0 ..num_frames {
1242+ for cb_idx in 0 ..NUM_CODEBOOKS {
1243+ let token = acoustic_tokens[ frame_idx * NUM_CODEBOOKS + cb_idx] ;
1244+ // Clamp tokens to valid range (0-2047)
1245+ let clamped = if token >= 2048 { 0 } else { token } ;
1246+ multi_tokens[ cb_idx] . push ( clamped) ;
1247+ }
12701248 }
12711249
12721250 debug ! (
1273- original = acoustic_tokens. len( ) ,
1274- filtered = filtered_tokens. len( ) ,
1275- removed = acoustic_tokens. len( ) - filtered_tokens. len( ) ,
1276- "Filtered special tokens"
1251+ num_frames = num_frames,
1252+ codebooks = NUM_CODEBOOKS ,
1253+ "Converted interleaved to multi-codebook format"
12771254 ) ;
12781255
1279- // 5. Decode to audio
1280- let audio = self . decode_audio ( & filtered_tokens ) ?;
1256+ // 5. Decode using multi-codebook decoder
1257+ let audio = self . codec . decode_multi ( & multi_tokens ) ?;
12811258 debug ! (
12821259 samples = audio. num_samples( ) ,
12831260 duration_ms = audio. duration_ms( ) ,
0 commit comments