Implement formatter to ensure Nemotron VoiceChat speech decoder reproducibility, speed up training and support half precision inference#15583
Conversation
|
|
||
|
|
||
| @data_type_parser(["s2s_duplex_reverse_role"]) | ||
| def read_s2s_duplex_reverse_role(config) -> Tuple[CutSet, bool]: |
There was a problem hiding this comment.
I added 3 new unit test covering all 3 duplex formatters (all the duplex formatters was not covered by unit tests). Let me know if you think it is good enough.
|
|
||
| def convert_cut_fn(cut: Cut) -> Cut: | ||
| """Convert a single cut by swapping supervisions and audio streams.""" | ||
| new_cut = fastcopy(cut) |
There was a problem hiding this comment.
use copy.copy() or deepcopy() instead, fastcopy is a very shallow copy, when you modify supervisions later, they will be modified on the original object too.
you can keep using fastcopy() if you construct a new list of supervisions, then fastcopy(cut, supervisions=[...])
There was a problem hiding this comment.
Done. I also updated magpietts/tts data formatter.
Signed-off-by: Edresson <Edresson@users.noreply.github.com>
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
| ) | ||
| if cfg.get("keep_codec_original_dtype", True): | ||
| model.tts_model.to(dtype=target_dtype) | ||
| model.on_train_epoch_start() # ensures that codec is in the right precision |
There was a problem hiding this comment.
I don't like training details leaking into inference; could we create a method model.setup_precision() on NemotronVoicechat?
There was a problem hiding this comment.
I will move ensures_codec_target_dtype that is called inside of on_train_epoch_start() to a method inside of the Duplex EARTTS class, that way we can call it directly without calling on_train_epoch_start()
| out_path, | ||
| wav, | ||
| samplerate=model.target_sample_rate, | ||
| if cfg.get("debug_dtype", False) and batch_id == 0: |
There was a problem hiding this comment.
Can debug logic be moved to a separate function and invoked here for better readability / to avoid inflating the inference loop size?
pzelasko
left a comment
There was a problem hiding this comment.
Thanks, minor comments left, good work
…sion Signed-off-by: Edresson Casanova <edresson1@gmail.com>
Signed-off-by: Edresson <Edresson@users.noreply.github.com>
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
This PR improves the Nemotron VoiceChat speech decoder's reproducibility and training performance through the following changes:
Adds Data Formatter: Implements the
s2s_duplex_reverse_roleformatter to reliably swap speaker roles and audio streams.Fixes Device Mismatch: Forces the RVQ audio codec to instantiate on the main model's device during setup, resolving CPU/GPU DDP synchronization crashes on true precision.
Accelerates Training: Changes the default DDP strategy to find_unused_parameters: false to reduce overhead and speed up the training loop.
Support half precision inference: Changes evaluation script to support half precision inference and update it to use torch dataloader for simplicity.
Collection: [Note which collection this PR will affect]
SpeechLM2