Skip to content

Implement formatter to ensure Nemotron VoiceChat speech decoder reproducibility, speed up training and support half precision inference#15583

Open
Edresson wants to merge 18 commits intoNVIDIA-NeMo:mainfrom
Edresson:main_april
Open

Implement formatter to ensure Nemotron VoiceChat speech decoder reproducibility, speed up training and support half precision inference#15583
Edresson wants to merge 18 commits intoNVIDIA-NeMo:mainfrom
Edresson:main_april

Conversation

@Edresson
Copy link
Copy Markdown
Collaborator

@Edresson Edresson commented Apr 6, 2026

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

This PR improves the Nemotron VoiceChat speech decoder's reproducibility and training performance through the following changes:

  • Adds Data Formatter: Implements the s2s_duplex_reverse_role formatter to reliably swap speaker roles and audio streams.

  • Fixes Device Mismatch: Forces the RVQ audio codec to instantiate on the main model's device during setup, resolving CPU/GPU DDP synchronization crashes on true precision.

  • Accelerates Training: Changes the default DDP strategy to find_unused_parameters: false to reduce overhead and speed up the training loop.

  • Support half precision inference: Changes evaluation script to support half precision inference and update it to use torch dataloader for simplicity.

Collection: [Note which collection this PR will affect]
SpeechLM2



@data_type_parser(["s2s_duplex_reverse_role"])
def read_s2s_duplex_reverse_role(config) -> Tuple[CutSet, bool]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit test coverage

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added 3 new unit test covering all 3 duplex formatters (all the duplex formatters was not covered by unit tests). Let me know if you think it is good enough.


def convert_cut_fn(cut: Cut) -> Cut:
"""Convert a single cut by swapping supervisions and audio streams."""
new_cut = fastcopy(cut)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use copy.copy() or deepcopy() instead, fastcopy is a very shallow copy, when you modify supervisions later, they will be modified on the original object too.

you can keep using fastcopy() if you construct a new list of supervisions, then fastcopy(cut, supervisions=[...])

Copy link
Copy Markdown
Collaborator Author

@Edresson Edresson Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I also updated magpietts/tts data formatter.

Signed-off-by: Edresson <Edresson@users.noreply.github.com>
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
)
if cfg.get("keep_codec_original_dtype", True):
model.tts_model.to(dtype=target_dtype)
model.on_train_epoch_start() # ensures that codec is in the right precision
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like training details leaking into inference; could we create a method model.setup_precision() on NemotronVoicechat?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move ensures_codec_target_dtype that is called inside of on_train_epoch_start() to a method inside of the Duplex EARTTS class, that way we can call it directly without calling on_train_epoch_start()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

out_path,
wav,
samplerate=model.target_sample_rate,
if cfg.get("debug_dtype", False) and batch_id == 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can debug logic be moved to a separate function and invoked here for better readability / to avoid inflating the inference loop size?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, minor comments left, good work

…sion

Signed-off-by: Edresson Casanova <edresson1@gmail.com>
Signed-off-by: Edresson <Edresson@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants