Skip to content

Commit 42da55b

Browse files
committed
Fix model input data type
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
1 parent e0af62d commit 42da55b

3 files changed

Lines changed: 8 additions & 8 deletions

File tree

nemo/collections/common/data/lhotse/cutset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ def convert_cut_fn(cut: Cut) -> Cut:
11051105
assert new_cut.recording is old_target_audio, f"{new_cut.id}: recording object not swapped"
11061106
assert new_cut.target_audio is old_recording, f"{new_cut.id}: target_audio object not swapped"
11071107

1108-
new_cut.formatter = "s2s_duplex_reverse_role"
1108+
new_cut.task = "s2s_duplex_reverse_role"
11091109
return new_cut
11101110

11111111
cuts = cuts.map(convert_cut_fn)

nemo/collections/speechlm2/models/duplex_ear_tts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def get_codec_silence_frame_last_one(self):
131131
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
132132

133133
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
134-
sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len)
134+
sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len)
135135
return sil_codes[0, -1]
136136

137137
def get_codec_silence_frame(self):
@@ -142,7 +142,7 @@ def get_codec_silence_frame(self):
142142
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
143143

144144
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
145-
sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C]
145+
sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len) # [1, T, C]
146146
sil_codes = sil_codes[0] # [T, C]
147147

148148
# Convert each frame (C tokens) into a tuple
@@ -328,7 +328,7 @@ def prepare_inputs(self, batch: dict):
328328
target_audio, target_audio_lens, self.target_samples_per_frame, 1
329329
)
330330
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
331-
target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens)
331+
target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_lens)
332332

333333
with fp32_precision():
334334
target_len = target_codes.shape[1]
@@ -1013,7 +1013,7 @@ def set_init_inputs(self, speaker_audio=None, speaker_audio_lens=None, system_pr
10131013
[target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device
10141014
)
10151015
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
1016-
code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len)
1016+
code, _ = self.audio_codec.encode(target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_len)
10171017

10181018
# get context hidden
10191019
if self.cfg.tts_config.context_hidden_size is not None:
@@ -1683,7 +1683,7 @@ def setup_audio_codec(model):
16831683
p.requires_grad = False
16841684

16851685
model.audio_codec.eval()
1686-
model.audio_codec.to(model.device) # force codec to run in the same device as the main model
1686+
model.audio_codec.to(model.device) # force codec to run in the same device as the main model
16871687

16881688
assert callable(model.tts_model.set_rvq_embs)
16891689

nemo/collections/speechlm2/modules/ear_tts_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ def depthsum_embedding(self, code: Tensor) -> Tensor:
11491149
_, v, h = self.rvq_embs.size()
11501150
device = code.device
11511151

1152-
ret = torch.zeros((b, t, h), device=device)
1152+
ret = torch.zeros((b, t, h), device=device, dtype=self.rvq_embs.dtype)
11531153
embs = F.pad(self.rvq_embs, [0, 0, 0, 1])
11541154
for i in range(d):
11551155
emb = embs[i]
@@ -1203,7 +1203,7 @@ def _prepare_conditioning(
12031203
asr_speech_tokens_emb: Tensor | None,
12041204
) -> Tensor:
12051205
"""Computes the final conditioning tensor by combining all sources."""
1206-
cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device)
1206+
cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device, dtype=self.rvq_embs.dtype)
12071207

12081208
if self.embed_context is not None and context_hidden_state is not None:
12091209
cond = cond + self.embed_context(context_hidden_state)

0 commit comments

Comments
 (0)