diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 28109f27b7f2..6cf875677b9a 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union +import einops import hydra import librosa import soundfile as sf @@ -50,7 +51,9 @@ class AudioToAudioModel(ModelPT, ABC): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate self._setup_loss() + self.setup_optimization_flags() def _setup_loss(self): """Setup loss for this model.""" @@ -130,6 +133,55 @@ def _setup_metrics(self, tag: str = 'val'): 'Setup metrics for %s, dataloader %d: %s', tag, dataloader_idx, ', '.join(metrics_dataloader_idx) ) + def _parse_batch(self, batch): + """Parse a batch into input signal, target signal, and input length. + + Handles both dict-style (lhotse) and tuple-style (AudioToTargetDataset) + batches, and ensures signals are in multi-channel format (B, C, T). + + Returns: + Tuple of (input_signal, target_signal, input_length). + """ + if isinstance(batch, dict): + # Lhotse dataloaders produce dict batches + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + # Standard audio datasets produce tuple batches + input_signal, input_length, target_signal, _ = batch + + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + return input_signal, target_signal, input_length + + @abstractmethod + def _compute_train_loss(self, input_signal, target_signal, input_length): + """Compute training loss from parsed batch signals. + + Args: + input_signal: input audio tensor (B, C, T) + target_signal: target audio tensor (B, C, T) + input_length: length of each example in the batch (B,) + + Returns: + Scalar loss tensor. + """ + pass + + def training_step(self, batch, batch_idx): + input_signal, target_signal, input_length = self._parse_batch(batch) + loss = self._compute_train_loss(input_signal, target_signal, input_length) + + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + @abstractmethod def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): pass @@ -313,6 +365,23 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_dataloader + def _normalize(self, signal: torch.Tensor): + """Normalize signal so its peak amplitude is 1. + + Args: + signal: tensor with shape (B, C, T) + + Returns: + Tuple of (normalized_signal, norm_scale). Pass norm_scale to + _denormalize to restore the original scale. + """ + norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True) + return signal / (norm_scale + self.eps), norm_scale + + def _denormalize(self, signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor: + """Restore original scale after _normalize.""" + return signal * (norm_scale + self.eps) + @staticmethod def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: """Trim or pad the output to match the batch length. @@ -467,11 +536,10 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': return list_of_models def setup_optimization_flags(self): - """ - Utility method that must be explicitly called by the subclass in order to support optional optimization flags. - This method is the only valid place to access self.cfg prior to DDP training occurs. + """Setup optional optimization flags from the model config. - The subclass may chose not to support this method, therefore all variables here must be checked via hasattr() + Called automatically during __init__. This is the only valid place + to access self.cfg prior to DDP training. """ # Skip update if nan/inf grads appear on any rank. self._skip_nan_grad = False diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 02c449a5b2b8..ebd49b36d23f 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -45,24 +45,17 @@ class EncMaskDecAudioToAudioModel(AudioToAudioModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): - # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable - # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 - self.world_size = 1 - if trainer is not None: - self.world_size = trainer.world_size - super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules - self.encoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.encoder) - self.mask_estimator = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_estimator) - self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor) - self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder) + self.encoder = self.from_config_dict(self._cfg.encoder) + self.mask_estimator = self.from_config_dict(self._cfg.mask_estimator) + self.mask_processor = self.from_config_dict(self._cfg.mask_processor) + self.decoder = self.from_config_dict(self._cfg.decoder) if 'mixture_consistency' in self._cfg: logging.debug('Using mixture consistency') - self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency) + self.mixture_consistency = self.from_config_dict(self._cfg.mixture_consistency) else: logging.debug('Mixture consistency not used') self.mixture_consistency = None @@ -70,14 +63,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup augmentation if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None: logging.debug('Using channel augmentation') - self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment) + self.channel_augmentation = self.from_config_dict(self.cfg.channel_augment) else: logging.debug('Channel augmentation not used') self.channel_augmentation = None - # Setup optional Optimization flags - self.setup_optimization_flags() - @property def input_types(self) -> Dict[str, NeuralType]: return { @@ -133,56 +123,16 @@ def forward(self, input_signal, input_length=None): processed = self.match_batch_length(input=processed, batch_length=batch_length) return processed, processed_length - # PTL-specific methods - def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - + def _compute_train_loss(self, input_signal, target_signal, input_length): # Apply channel augmentation if self.training and self.channel_augmentation is not None: input_signal = self.channel_augmentation(input=input_signal) - # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - - # Calculate the loss - loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - # Return loss - return loss + return self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) @@ -222,7 +172,6 @@ class PredictiveAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -237,9 +186,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tnormalize_input: %s', self.normalize_input) logging.debug('\teps: %s', self.eps) @@ -272,10 +218,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -287,63 +230,23 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=estimated, input_length=estimated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) return output, output_length - # PTL-specific methods - def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - - # Estimate the signal + def _compute_train_loss(self, input_signal, target_signal, input_length): output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - - # Calculate the loss - loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + return self.loss(estimate=output_signal, target=target_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Estimate the signal output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - # Prepare output + # Calculate the loss loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) # Update metrics @@ -372,7 +275,6 @@ class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -407,9 +309,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tnormalize_input: %s', self.normalize_input) logging.debug('\teps: %s', self.eps) @@ -450,10 +349,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -467,8 +363,7 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -493,11 +388,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input @@ -535,48 +426,11 @@ def _step(self, target_signal, input_signal, input_length=None): return loss - # PTL-specific methods - def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - - # Calculate the loss - loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) @@ -634,7 +488,6 @@ class FlowMatchingAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -677,9 +530,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Regularization self.eps = self._cfg.get('eps', 1e-8) - # Setup optional Optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None)) logging.debug('\tp_cond: %s', self.p_cond) @@ -765,10 +615,7 @@ def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=F batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -793,8 +640,7 @@ def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=F output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -815,11 +661,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input @@ -857,47 +699,27 @@ def _step(self, target_signal, input_signal, input_length=None): return self.loss(estimate=estimate, target=loss_target, input_length=input_enc_len) - # PTL-specific methods - def training_step(self, batch, batch_idx): + def _parse_batch(self, batch): + """Override to allow missing target_signal for SSL pretraining.""" if isinstance(batch, dict): - # lhotse batches are dictionaries input_signal = batch['input_signal'] input_length = batch['input_length'] target_signal = batch.get('target_signal', input_signal.clone()) else: input_signal, input_length, target_signal, _ = batch - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, "B T -> B 1 T") + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, "B T -> B 1 T") + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - # Calculate the loss - loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + return input_signal, target_signal, input_length - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch.get('target_signal', input_signal.clone()) - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss loss = self._step( @@ -961,7 +783,6 @@ class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate # Setup processing modules self.encoder = self.from_config_dict(self._cfg.encoder) @@ -1016,9 +837,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Term added to the denominator to improve numerical stability self.eps = self._cfg.get('eps', 1e-8) - # Setup optional optimization flags - self.setup_optimization_flags() - logging.debug('Initialized %s', self.__class__.__name__) logging.debug('\testimator_output: %s', self.estimator_output) logging.debug('\tnormalize_input: %s', self.normalize_input) @@ -1067,10 +885,7 @@ def forward(self, input_signal, input_length=None): batch_length = input_signal.size(-1) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) + input_signal, norm_scale = self._normalize(input_signal) # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) @@ -1084,8 +899,7 @@ def forward(self, input_signal, input_length=None): output, output_length = self.decoder(input=generated, input_length=generated_length) if self.normalize_input: - # rescale to the original scale - output = output * norm_scale + output = self._denormalize(output, norm_scale) # Trim or pad the estimated signal to match input length output = self.match_batch_length(input=output, batch_length=batch_length) @@ -1100,8 +914,6 @@ def forward(self, input_signal, input_length=None): }, output_types={ "loss": NeuralType(None, LossType()), - "loss_encoded": NeuralType(None, LossType()), - "loss_time": NeuralType(None, LossType()), }, ) def _step(self, target_signal, input_signal, input_length=None): @@ -1111,11 +923,7 @@ def _step(self, target_signal, input_signal, input_length=None): batch_size = target_signal.size(0) if self.normalize_input: - # max for each example in the batch - norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) - # scale input signal - input_signal = input_signal / (norm_scale + self.eps) - # scale the target signal + input_signal, norm_scale = self._normalize(input_signal) target_signal = target_signal / (norm_scale + self.eps) # Apply encoder to both target and the input @@ -1164,7 +972,6 @@ def _step(self, target_signal, input_signal, input_length=None): if self.loss is not None: # Single loss in the encoded domain loss = self.loss(estimate=estimate, target=target_enc, input_length=estimate_len) - loss_encoded = loss_time = None else: # Weighted loss between encoded and time domain loss = 0.0 @@ -1173,10 +980,10 @@ def _step(self, target_signal, input_signal, input_length=None): if self.loss_encoded is not None: # Loss between the estimate and the target in the encoded domain loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len) + if self.training: + self.log('train_loss_encoded', loss_encoded) # Weighting loss += self.loss_encoded_weight * loss_encoded - else: - loss_encoded = None # Loss in the time domain if self.loss_time is not None: @@ -1193,68 +1000,23 @@ def _step(self, target_signal, input_signal, input_length=None): loss_time = self.loss_time( estimate=estimate_signal, target=target_signal, input_length=input_length ) + if self.training: + self.log('train_loss_time', loss_time) # Weighting loss += self.loss_time_weight * loss_time - else: - loss_time = None else: raise NotImplementedError(f'Output type {self.estimator_output} is not implemented') - return loss, loss_encoded, loss_time - - # PTL-specific methods - def training_step(self, batch, batch_idx): - - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - - # Calculate the loss - loss, loss_encoded, loss_time = self._step( - target_signal=target_signal, input_signal=input_signal, input_length=input_length - ) - - # Logs - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - if loss_encoded is not None: - self.log('train_loss_encoded', loss_encoded) - - if loss_time is not None: - self.log('train_loss_time', loss_time) - return loss - def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + def _compute_train_loss(self, input_signal, target_signal, input_length): + return self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + input_signal, target_signal, input_length = self._parse_batch(batch) # Calculate loss - loss, *_ = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) # Update metrics update_metrics = False diff --git a/nemo/collections/audio/models/maxine/bnr.py b/nemo/collections/audio/models/maxine/bnr.py index 6cff8ead9c87..bca29de65cc7 100644 --- a/nemo/collections/audio/models/maxine/bnr.py +++ b/nemo/collections/audio/models/maxine/bnr.py @@ -26,7 +26,6 @@ from typing import Dict, Optional -import einops import lightning.pytorch as plt import torch import torch.nn as nn @@ -157,15 +156,7 @@ class BNR2(AudioToAudioModel): """Implementation of the BNR 2 model""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): - self.world_size = 1 - if trainer is not None: - self.world_size = trainer.world_size - super().__init__(cfg=cfg, trainer=trainer) - self.sample_rate = self._cfg.sample_rate - - # Setup optional Optimization flags - self.setup_optimization_flags() self.seasr = _Seasr(self.sample_rate) if ( @@ -228,41 +219,12 @@ def forward(self, input_signal): return output - def training_step(self, batch, batch_idx): - if isinstance(batch, dict): - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - + def _compute_train_loss(self, input_signal, target_signal, input_length): predicted_audio = self.forward(input_signal=input_signal) - - loss = self.loss(target=target_signal, estimate=predicted_audio, input_length=input_length) - - self.log('train_loss', loss) - self.log('learning_rate', self._optimizer.param_groups[0]['lr']) - self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) - - return loss + return self.loss(target=target_signal, estimate=predicted_audio, input_length=input_length) def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): - if isinstance(batch, dict): - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch['target_signal'] - else: - input_signal, input_length, target_signal, _ = batch - - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + input_signal, target_signal, input_length = self._parse_batch(batch) # Process input processed_signal = self(input_signal=input_signal) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index 6d4a228092b3..fe61bf7bf5bf 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -265,23 +265,14 @@ def test_forward_infer(self, schroedinger_bridge_model_ncsn, batch_size, sample_ def test_training_step(self, schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset): model, _ = schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset model = model.train() + # _step calls self.log for component losses, which requires a Lightning loop context. + # Disable logging since we're calling _step directly outside the training loop. + model.log = lambda *args, **kwargs: None for batch in itertools.islice(model._train_dl, 2): - # start boilerplate from SchroedingerBridgeAudioToAudioModel.training_step - if isinstance(batch, dict): - # lhotse batches are dictionaries - input_signal = batch['input_signal'] - input_length = batch['input_length'] - target_signal = batch.get('target_signal', input_signal) - else: - input_signal, input_length, target_signal, _ = batch - if input_signal.ndim == 2: - input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') - if target_signal.ndim == 2: - target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') - # end boilerplate - - loss, _, _ = model._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + input_signal, target_signal, input_length = model._parse_batch(batch) + + loss = model._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) loss.backward() def test_model_training(self, schroedinger_bridge_model_ncsn_with_trainer_and_mock_dataset):