From 4779d8d351c41265bcf350b520b413464e702777 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 05:20:50 -0700 Subject: [PATCH 01/13] Simplify SchroedingerBridge _step to return scalar loss Move component loss logging (train_loss_encoded, train_loss_time) into _step itself, so it returns a plain scalar like all other models. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/enhancement.py | 23 +++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 02c449a5b2b8..b4d031858db5 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -1100,8 +1100,6 @@ def forward(self, input_signal, input_length=None): }, output_types={ "loss": NeuralType(None, LossType()), - "loss_encoded": NeuralType(None, LossType()), - "loss_time": NeuralType(None, LossType()), }, ) def _step(self, target_signal, input_signal, input_length=None): @@ -1164,7 +1162,6 @@ def _step(self, target_signal, input_signal, input_length=None): if self.loss is not None: # Single loss in the encoded domain loss = self.loss(estimate=estimate, target=target_enc, input_length=estimate_len) - loss_encoded = loss_time = None else: # Weighted loss between encoded and time domain loss = 0.0 @@ -1173,10 +1170,9 @@ def _step(self, target_signal, input_signal, input_length=None): if self.loss_encoded is not None: # Loss between the estimate and the target in the encoded domain loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len) + self.log('train_loss_encoded', loss_encoded) # Weighting loss += self.loss_encoded_weight * loss_encoded - else: - loss_encoded = None # Loss in the time domain if self.loss_time is not None: @@ -1193,14 +1189,13 @@ def _step(self, target_signal, input_signal, input_length=None): loss_time = self.loss_time( estimate=estimate_signal, target=target_signal, input_length=input_length ) + self.log('train_loss_time', loss_time) # Weighting loss += self.loss_time_weight * loss_time - else: - loss_time = None else: raise NotImplementedError(f'Output type {self.estimator_output} is not implemented') - return loss, loss_encoded, loss_time + return loss # PTL-specific methods def training_step(self, batch, batch_idx): @@ -1220,21 +1215,13 @@ def training_step(self, batch, batch_idx): target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Calculate the loss - loss, loss_encoded, loss_time = self._step( - target_signal=target_signal, input_signal=input_signal, input_length=input_length - ) + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) # Logs self.log('train_loss', loss) self.log('learning_rate', self._optimizer.param_groups[0]['lr']) self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - if loss_encoded is not None: - self.log('train_loss_encoded', loss_encoded) - - if loss_time is not None: - self.log('train_loss_time', loss_time) - return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): @@ -1254,7 +1241,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Calculate loss - loss, *_ = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) # Update metrics update_metrics = False From 14b282443d80b7258d4d6bbceff6c67e03123f61 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 05:31:18 -0700 Subject: [PATCH 02/13] Extract _parse_batch helper into AudioToAudioModel base class Replace duplicated batch parsing and 2D-to-3D reshape logic across all 6 audio model subclasses with a single _parse_batch method on the base class. FlowMatchingAudioToAudioModel overrides it to allow missing target_signal for SSL pretraining. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- .../audio/models/audio_to_audio.py | 26 +++ nemo/collections/audio/models/enhancement.py | 153 +++--------------- nemo/collections/audio/models/maxine/bnr.py | 25 +-- 3 files changed, 48 insertions(+), 156 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 28109f27b7f2..a6d1a806c422 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union +import einops import hydra import librosa import soundfile as sf @@ -130,6 +131,31 @@ def _setup_metrics(self, tag: str = 'val'): 'Setup metrics for %s, dataloader %d: %s', tag, dataloader_idx, ', '.join(metrics_dataloader_idx) ) + def _parse_batch(self, batch): + """Parse a batch into input signal, target signal, and input length. + + Handles both dict-style (lhotse) and tuple-style (AudioToTargetDataset) + batches, and ensures signals are in multi-channel format (B, C, T). + + Returns: + Tuple of (input_signal, target_signal, input_length). + """ + if isinstance(batch, dict): + # Lhotse dataloaders produce dict batches + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + # Standard audio datasets produce tuple batches + input_signal, input_length, target_signal, _ = batch + + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + return input_signal, target_signal, input_length + @abstractmethod def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): pass diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index b4d031858db5..254d3e8c0ea1 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -135,20 +135,7 @@ def forward(self, input_signal, input_length=None): # PTL-specific methods def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Apply channel augmentation if self.training and self.channel_augmentation is not None: @@ -169,20 +156,7 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) @@ -296,20 +270,7 @@ def forward(self, input_signal, input_length=None): # PTL-specific methods def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Estimate the signal output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) @@ -325,25 +286,12 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Estimate the signal output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - # Prepare output + # Calculate the loss loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) # Update metrics @@ -537,20 +485,7 @@ def _step(self, target_signal, input_signal, input_length=None): # PTL-specific methods def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate the loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) @@ -563,20 +498,7 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) @@ -857,21 +779,25 @@ def _step(self, target_signal, input_signal, input_length=None): return self.loss(estimate=estimate, target=loss_target, input_length=input_enc_len) - # PTL-specific methods - def training_step(self, batch, batch_idx): + def _parse_batch(self, batch): + """Override to allow missing target_signal for SSL pretraining.""" if isinstance(batch, dict): - # lhotse batches are dictionaries input_signal = batch['input_signal'] input_length = batch['input_length'] target_signal = batch.get('target_signal', input_signal.clone()) else: input_signal, input_length, target_signal, _ = batch - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, "B T -> B 1 T") + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, "B T -> B 1 T") + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + return input_signal, target_signal, input_length + + # PTL-specific methods + def training_step(self, batch, batch_idx): + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate the loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) @@ -884,20 +810,7 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch.get('target_signal', input_signal.clone()) - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss loss = self._step( @@ -1199,20 +1112,7 @@ def _step(self, target_signal, input_signal, input_length=None): # PTL-specific methods def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate the loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) @@ -1225,20 +1125,7 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) diff --git a/nemo/collections/audio/models/maxine/bnr.py b/nemo/collections/audio/models/maxine/bnr.py index 6cff8ead9c87..f0ba346448ac 100644 --- a/nemo/collections/audio/models/maxine/bnr.py +++ b/nemo/collections/audio/models/maxine/bnr.py @@ -26,7 +26,6 @@ from typing import Dict, Optional -import einops import lightning.pytorch as plt import torch import torch.nn as nn @@ -229,17 +228,7 @@ def forward(self, input_signal): return output def training_step(self, batch, batch_idx): - if isinstance(batch, dict): - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) predicted_audio = self.forward(input_signal=input_signal) @@ -252,17 +241,7 @@ def training_step(self, batch, batch_idx): return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - if isinstance(batch, dict): - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Process input processed_signal = self(input_signal=input_signal) From 31b411a616b51457054c2099995429c8ae7f5361 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 05:33:55 -0700 Subject: [PATCH 03/13] Move training_step into AudioToAudioModel base class Add abstract _compute_train_loss method that each subclass implements with its model-specific loss computation. The base class training_step handles batch parsing, logging, and return. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- .../audio/models/audio_to_audio.py | 24 ++++++ nemo/collections/audio/models/enhancement.py | 80 +++---------------- nemo/collections/audio/models/maxine/bnr.py | 13 +-- 3 files changed, 36 insertions(+), 81 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index a6d1a806c422..7981959db2d2 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -156,6 +156,30 @@ def _parse_batch(self, batch): return input_signal, target_signal, input_length + @abstractmethod + def _compute_train_loss(self, input_signal, target_signal, input_length): + """Compute training loss from parsed batch signals. + + Args: + input_signal: input audio tensor (B, C, T) + target_signal: target audio tensor (B, C, T) + input_length: length of each example in the batch (B,) + + Returns: + Scalar loss tensor. + """ + ... + + def training_step(self, batch, batch_idx): + input_signal, target_signal, input_length = self._parse_batch(batch) + loss = self._compute_train_loss(input_signal, target_signal, input_length) + + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + @abstractmethod def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): pass diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 254d3e8c0ea1..ff412a5cd571 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -133,27 +133,13 @@ def forward(self, input_signal, input_length=None): processed = self.match_batch_length(input=processed, batch_length=batch_length) return processed, processed_length - # PTL-specific methods - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - + def _compute_train_loss(self, input_signal, target_signal, input_length): # Apply channel augmentation if self.training and self.channel_augmentation is not None: input_signal = self.channel_augmentation(input=input_signal) - # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - - # Calculate the loss - loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - # Return loss - return loss + return self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) @@ -268,22 +254,9 @@ def forward(self, input_signal, input_length=None): output = self.match_batch_length(input=output, batch_length=batch_length) return output, output_length - # PTL-specific methods - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - - # Estimate the signal + def _compute_train_loss(self, input_signal, target_signal, input_length): output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - - # Calculate the loss - loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + return self.loss(estimate=output_signal, target=target_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) @@ -483,19 +456,8 @@ def _step(self, target_signal, input_signal, input_length=None): return loss - # PTL-specific methods - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - - # Calculate the loss - loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) @@ -795,19 +757,8 @@ def _parse_batch(self, batch): return input_signal, target_signal, input_length - # PTL-specific methods - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - - # Calculate the loss - loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) @@ -1110,19 +1061,8 @@ def _step(self, target_signal, input_signal, input_length=None): return loss - # PTL-specific methods - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - - # Calculate the loss - loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) diff --git a/nemo/collections/audio/models/maxine/bnr.py b/nemo/collections/audio/models/maxine/bnr.py index f0ba346448ac..c63c95fd7e6f 100644 --- a/nemo/collections/audio/models/maxine/bnr.py +++ b/nemo/collections/audio/models/maxine/bnr.py @@ -227,18 +227,9 @@ def forward(self, input_signal): return output - def training_step(self, batch, batch_idx): - input_signal, target_signal, input_length = self._parse_batch(batch) - + def _compute_train_loss(self, input_signal, target_signal, input_length): predicted_audio = self.forward(input_signal=input_signal) - - loss = self.loss(target=target_signal, estimate=predicted_audio, input_length=input_length) - - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + return self.loss(target=target_signal, estimate=predicted_audio, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, target_signal, input_length = self._parse_batch(batch) From cb1838679f4f7bb3c1195a6204b5f9cc881dff37 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 05:47:34 -0700 Subject: [PATCH 04/13] Guard SB component loss logging with self.training check _step is called from both training and evaluation. The train_loss_encoded and train_loss_time logs should only fire during training. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/enhancement.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index ff412a5cd571..06badf25575a 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -1034,7 +1034,8 @@ def _step(self, target_signal, input_signal, input_length=None): if self.loss_encoded is not None: # Loss between the estimate and the target in the encoded domain loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len) - self.log('train_loss_encoded', loss_encoded) + if self.training: + self.log('train_loss_encoded', loss_encoded) # Weighting loss += self.loss_encoded_weight * loss_encoded @@ -1053,7 +1054,8 @@ def _step(self, target_signal, input_signal, input_length=None): loss_time = self.loss_time( estimate=estimate_signal, target=target_signal, input_length=input_length ) - self.log('train_loss_time', loss_time) + if self.training: + self.log('train_loss_time', loss_time) # Weighting loss += self.loss_time_weight * loss_time else: From 4dd23761b62e49d67e16cafc417801f3ecf04a5b Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 05:51:35 -0700 Subject: [PATCH 05/13] Move sample_rate and setup_optimization_flags to base AudioToAudioModel.__init__ Both are set identically by all 6 subclasses. setup_optimization_flags only reads self._cfg, so it is safe to call before subclass-specific module initialization. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- .../audio/models/audio_to_audio.py | 2 ++ nemo/collections/audio/models/enhancement.py | 20 ------------------- nemo/collections/audio/models/maxine/bnr.py | 4 ---- 3 files changed, 2 insertions(+), 24 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 7981959db2d2..0ebc223eb743 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -51,7 +51,9 @@ class AudioToAudioModel(ModelPT, ABC): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate self._setup_loss() + self.setup_optimization_flags() def _setup_loss(self): """Setup loss for this model.""" diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 06badf25575a..f11a156ccd2a 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -52,7 +52,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.world_size = trainer.world_size super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.encoder) @@ -75,9 +74,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): logging.debug('Channel augmentation not used') self.channel_augmentation = None - # Setup optional Optimization flags - self.setup_optimization_flags() - @property def input_types(self) -> Dict[str, NeuralType]: return { @@ -182,7 +178,6 @@ class PredictiveAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -197,9 +192,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tnormalize_input: %s', self.normalize_input) logging.debug('\teps: %s', self.eps) @@ -293,7 +285,6 @@ class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -328,9 +319,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tnormalize_input: %s', self.normalize_input) logging.debug('\teps: %s', self.eps) @@ -518,7 +506,6 @@ class FlowMatchingAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -561,9 +548,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Regularization self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None)) logging.debug('\tp_cond: %s', self.p_cond) @@ -825,7 +809,6 @@ class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -880,9 +863,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\testimator_output: %s', self.estimator_output) logging.debug('\tnormalize_input: %s', self.normalize_input) diff --git a/nemo/collections/audio/models/maxine/bnr.py b/nemo/collections/audio/models/maxine/bnr.py index c63c95fd7e6f..7e869321caba 100644 --- a/nemo/collections/audio/models/maxine/bnr.py +++ b/nemo/collections/audio/models/maxine/bnr.py @@ -161,10 +161,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.world_size = trainer.world_size super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate - - # Setup optional Optimization flags - self.setup_optimization_flags() self.seasr = _Seasr(self.sample_rate) if ( From 51995ac2973eb2550392899df9030351172eb1d7 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:00:31 -0700 Subject: [PATCH 06/13] Remove redundant world_size init from EncMaskDec and BNR2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ModelPT.__init__ calls set_trainer → set_world_size before any data loader setup, so the pre-super assignment is always overwritten before it can be read. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/enhancement.py | 6 ------ nemo/collections/audio/models/maxine/bnr.py | 4 ---- 2 files changed, 10 deletions(-) diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index f11a156ccd2a..1451cd12739f 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -45,12 +45,6 @@ class EncMaskDecAudioToAudioModel(AudioToAudioModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): - # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable - # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 - self.world_size = 1 - if trainer is not None: - self.world_size = trainer.world_size - super().__init__(cfg=cfg, trainer=trainer) # Setup processing modules diff --git a/nemo/collections/audio/models/maxine/bnr.py b/nemo/collections/audio/models/maxine/bnr.py index 7e869321caba..bca29de65cc7 100644 --- a/nemo/collections/audio/models/maxine/bnr.py +++ b/nemo/collections/audio/models/maxine/bnr.py @@ -156,10 +156,6 @@ class BNR2(AudioToAudioModel): """Implementation of the BNR 2 model""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): - self.world_size = 1 - if trainer is not None: - self.world_size = trainer.world_size - super().__init__(cfg=cfg, trainer=trainer) self.seasr = _Seasr(self.sample_rate) From f19c10241c36e9ea5cf6ea7703135ec3ff75d303 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:00:48 -0700 Subject: [PATCH 07/13] Use self.from_config_dict in EncMaskDecAudioToAudioModel Consistent with all other audio model subclasses which use self.from_config_dict rather than the concrete class name. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/enhancement.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 1451cd12739f..6f2b96083872 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -48,14 +48,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) # Setup processing modules - self.encoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.encoder) - self.mask_estimator = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_estimator) - self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor) - self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder) + self.encoder = self.from_config_dict(self._cfg.encoder) + self.mask_estimator = self.from_config_dict(self._cfg.mask_estimator) + self.mask_processor = self.from_config_dict(self._cfg.mask_processor) + self.decoder = self.from_config_dict(self._cfg.decoder) if 'mixture_consistency' in self._cfg: logging.debug('Using mixture consistency') - self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency) + self.mixture_consistency = self.from_config_dict(self._cfg.mixture_consistency) else: logging.debug('Mixture consistency not used') self.mixture_consistency = None @@ -63,7 +63,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup augmentation if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None: logging.debug('Using channel augmentation') - self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment) + self.channel_augmentation = self.from_config_dict(self.cfg.channel_augment) else: logging.debug('Channel augmentation not used') self.channel_augmentation = None From 80167fe0ea90b95d3c9d31054c3ea5b64e063eb8 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:01:03 -0700 Subject: [PATCH 08/13] Update setup_optimization_flags docstring Now called from base __init__, no longer requires explicit subclass call. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/audio_to_audio.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 0ebc223eb743..3164d80b3ca1 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -519,11 +519,10 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': return list_of_models def setup_optimization_flags(self): - """ - Utility method that must be explicitly called by the subclass in order to support optional optimization flags. - This method is the only valid place to access self.cfg prior to DDP training occurs. + """Setup optional optimization flags from the model config. - The subclass may chose not to support this method, therefore all variables here must be checked via hasattr() + Called automatically during __init__. This is the only valid place + to access self.cfg prior to DDP training. """ # Skip update if nan/inf grads appear on any rank. self._skip_nan_grad = False From 2fc475ea6067e9960770d82301eee9087fe9882c Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:03:52 -0700 Subject: [PATCH 09/13] Extract _normalize/_denormalize helpers into base class Replace repeated normalize/denormalize boilerplate across 4 forward() and 3 _step() methods with calls to shared helpers on AudioToAudioModel. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- .../audio/models/audio_to_audio.py | 18 +++++++ nemo/collections/audio/models/enhancement.py | 50 ++++--------------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 3164d80b3ca1..599eb1b30ab8 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -365,6 +365,24 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_dataloader + def _normalize(self, signal: torch.Tensor) -> tuple: + """Normalize signal so its peak amplitude is 1. + + Args: + signal: tensor with shape (B, C, T) + + Returns: + Tuple of (normalized_signal, norm_scale). Pass norm_scale to + _denormalize to restore the original scale. + """ + norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True) + return signal / (norm_scale + self.eps), norm_scale + + @staticmethod + def _denormalize(signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor: + """Restore original scale after _normalize.""" + return signal * norm_scale + @staticmethod def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: """Trim or pad the output to match the batch length. diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 6f2b96083872..ebd49b36d23f 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -218,10 +218,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -233,8 +230,7 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=estimated, input_length=estimated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -353,10 +349,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -370,8 +363,7 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -396,11 +388,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input @@ -627,10 +615,7 @@ def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=F batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -655,8 +640,7 @@ def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=F output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -677,11 +661,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input @@ -905,10 +885,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -922,8 +899,7 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -947,11 +923,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input From 9ff0f919d38ca2c98043afba1cccf8e0834ec7c8 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:08:02 -0700 Subject: [PATCH 10/13] Remove misleading -> tuple annotation from _normalize Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/audio_to_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 599eb1b30ab8..f96635446e81 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -365,7 +365,7 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_dataloader - def _normalize(self, signal: torch.Tensor) -> tuple: + def _normalize(self, signal: torch.Tensor): """Normalize signal so its peak amplitude is 1. Args: From e6dc162ed4a6c1ad7915fb917f42a004cd024bdc Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:19:39 -0700 Subject: [PATCH 11/13] Fix CodeQL warning: use pass instead of ... in abstract method Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/audio_to_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index f96635446e81..a0f599627f2c 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -170,7 +170,7 @@ def _compute_train_loss(self, input_signal, target_signal, input_length): Returns: Scalar loss tensor. """ - ... + pass def training_step(self, batch, batch_idx): input_signal, target_signal, input_length = self._parse_batch(batch) From c967065a062fcb607d0c4b4bd525f4ae432ad880 Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 06:37:08 -0700 Subject: [PATCH 12/13] Fix SB test that calls _step outside Lightning training loop The test calls _step directly, which now logs component losses via self.log. Disable logging in this test since there is no active Lightning loop context. Also update to use _parse_batch and the scalar return from _step. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- .../test_audio_models_schroedinger_bridge.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index 6d4a228092b3..fe61bf7bf5bf 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -265,23 +265,14 @@ def test_forward_infer(self, schroedinger_bridge_model_ncsn, batch_size, sample_ def test_training_step(self, schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset): model, _ = schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset model = model.train() + # _step calls self.log for component losses, which requires a Lightning loop context. + # Disable logging since we're calling _step directly outside the training loop. + model.log = lambda *args, **kwargs: None for batch in itertools.islice(model._train_dl, 2): - # start boilerplate from SchroedingerBridgeAudioToAudioModel.training_step - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch.get('target_signal', input_signal) - else: - input_signal, input_length, target_signal, _ = batch - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - # end boilerplate - - loss, _, _ = model._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + input_signal, target_signal, input_length = model._parse_batch(batch) + + loss = model._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) loss.backward() def test_model_training(self, schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset): From e39f5db8e333836ef490d210cc2d66b906bec11e Mon Sep 17 00:00:00 2001 From: Roman Korostik Date: Tue, 7 Apr 2026 07:48:30 -0700 Subject: [PATCH 13/13] Fix _denormalize to be proper inverse of _normalize _normalize divides by (norm_scale + eps), so _denormalize should multiply by (norm_scale + eps) to recover the original signal. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Roman Korostik --- nemo/collections/audio/models/audio_to_audio.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index a0f599627f2c..6cf875677b9a 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -378,10 +378,9 @@ def _normalize(self, signal: torch.Tensor): norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True) return signal / (norm_scale + self.eps), norm_scale - @staticmethod - def _denormalize(signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor: + def _denormalize(self, signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor: """Restore original scale after _normalize.""" - return signal * norm_scale + return signal * (norm_scale + self.eps) @staticmethod def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: