@@ -131,7 +131,7 @@ def get_codec_silence_frame_last_one(self):
131131 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
132132
133133 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
134- sil_codes , sil_codes_lens = self .audio_codec .encode (audio .unsqueeze (1 ), audio_len )
134+ sil_codes , sil_codes_lens = self .audio_codec .encode (audio .unsqueeze (1 ). to ( self . audio_codec_run_dtype ) , audio_len )
135135 return sil_codes [0 , - 1 ]
136136
137137 def get_codec_silence_frame (self ):
@@ -142,7 +142,7 @@ def get_codec_silence_frame(self):
142142 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
143143
144144 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
145- sil_codes , _ = self .audio_codec .encode (audio .unsqueeze (1 ), audio_len ) # [1, T, C]
145+ sil_codes , _ = self .audio_codec .encode (audio .unsqueeze (1 ). to ( self . audio_codec_run_dtype ) , audio_len ) # [1, T, C]
146146 sil_codes = sil_codes [0 ] # [T, C]
147147
148148 # Convert each frame (C tokens) into a tuple
@@ -328,7 +328,7 @@ def prepare_inputs(self, batch: dict):
328328 target_audio , target_audio_lens , self .target_samples_per_frame , 1
329329 )
330330 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
331- target_codes , target_codes_lens = self .audio_codec .encode (target_audio .unsqueeze (1 ), target_audio_lens )
331+ target_codes , target_codes_lens = self .audio_codec .encode (target_audio .unsqueeze (1 ). to ( self . audio_codec_run_dtype ) , target_audio_lens )
332332
333333 with fp32_precision ():
334334 target_len = target_codes .shape [1 ]
@@ -1013,7 +1013,7 @@ def set_init_inputs(self, speaker_audio=None, speaker_audio_lens=None, system_pr
10131013 [target_audio .size (- 1 )] * target_audio .size (0 ), dtype = torch .long , device = self .device
10141014 )
10151015 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
1016- code , _ = self .audio_codec .encode (target_audio .unsqueeze (1 ), target_audio_len )
1016+ code , _ = self .audio_codec .encode (target_audio .unsqueeze (1 ). to ( self . audio_codec_run_dtype ) , target_audio_len )
10171017
10181018 # get context hidden
10191019 if self .cfg .tts_config .context_hidden_size is not None :
@@ -1683,7 +1683,7 @@ def setup_audio_codec(model):
16831683 p .requires_grad = False
16841684
16851685 model .audio_codec .eval ()
1686- model .audio_codec .to (model .device ) # force codec to run in the same device as the main model
1686+ model .audio_codec .to (model .device ) # force codec to run in the same device as the main model
16871687
16881688 assert callable (model .tts_model .set_rvq_embs )
16891689
0 commit comments