Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions nemo/collections/audio/models/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union

import einops
import hydra
import librosa
import soundfile as sf
Expand Down Expand Up @@ -50,7 +51,9 @@ class AudioToAudioModel(ModelPT, ABC):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)

self.sample_rate = self._cfg.sample_rate
self._setup_loss()
self.setup_optimization_flags()

def _setup_loss(self):
"""Setup loss for this model."""
Expand Down Expand Up @@ -130,6 +133,55 @@ def _setup_metrics(self, tag: str = 'val'):
'Setup metrics for %s, dataloader %d: %s', tag, dataloader_idx, ', '.join(metrics_dataloader_idx)
)

def _parse_batch(self, batch):
"""Parse a batch into input signal, target signal, and input length.

Handles both dict-style (lhotse) and tuple-style (AudioToTargetDataset)
batches, and ensures signals are in multi-channel format (B, C, T).

Returns:
Tuple of (input_signal, target_signal, input_length).
"""
if isinstance(batch, dict):
# Lhotse dataloaders produce dict batches
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
# Standard audio datasets produce tuple batches
input_signal, input_length, target_signal, _ = batch

if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

return input_signal, target_signal, input_length

@abstractmethod
def _compute_train_loss(self, input_signal, target_signal, input_length):
"""Compute training loss from parsed batch signals.

Args:
input_signal: input audio tensor (B, C, T)
target_signal: target audio tensor (B, C, T)
input_length: length of each example in the batch (B,)

Returns:
Scalar loss tensor.
"""
pass

def training_step(self, batch, batch_idx):
input_signal, target_signal, input_length = self._parse_batch(batch)
loss = self._compute_train_loss(input_signal, target_signal, input_length)

self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

return loss

@abstractmethod
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
pass
Expand Down Expand Up @@ -313,6 +365,23 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade
temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_dataloader

def _normalize(self, signal: torch.Tensor):
"""Normalize signal so its peak amplitude is 1.

Args:
signal: tensor with shape (B, C, T)

Returns:
Tuple of (normalized_signal, norm_scale). Pass norm_scale to
_denormalize to restore the original scale.
"""
norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True)
return signal / (norm_scale + self.eps), norm_scale
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the returned norm_scale include + self.eps for identical reverse operation?

Copy link
Copy Markdown
Collaborator Author

@racoiaws racoiaws Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

I'd prefer an explicit add in the reverse operation instead of baking it into norm_scale


def _denormalize(self, signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor:
"""Restore original scale after _normalize."""
return signal * (norm_scale + self.eps)

@staticmethod
def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor:
"""Trim or pad the output to match the batch length.
Expand Down Expand Up @@ -467,11 +536,10 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
return list_of_models

def setup_optimization_flags(self):
"""
Utility method that must be explicitly called by the subclass in order to support optional optimization flags.
This method is the only valid place to access self.cfg prior to DDP training occurs.
"""Setup optional optimization flags from the model config.

The subclass may chose not to support this method, therefore all variables here must be checked via hasattr()
Called automatically during __init__. This is the only valid place
to access self.cfg prior to DDP training.
"""
# Skip update if nan/inf grads appear on any rank.
self._skip_nan_grad = False
Expand Down
Loading
Loading