-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Reduce code duplication in audio collection + some small fixes #15587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4779d8d
14b2824
31b411a
cb18386
4dd2376
51995ac
f19c102
80167fe
2fc475e
9ff0f91
e6dc162
c967065
e39f5db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from abc import ABC, abstractmethod | ||
| from typing import Dict, List, Optional, Union | ||
|
|
||
| import einops | ||
| import hydra | ||
| import librosa | ||
| import soundfile as sf | ||
|
|
@@ -50,7 +51,9 @@ class AudioToAudioModel(ModelPT, ABC): | |
| def __init__(self, cfg: DictConfig, trainer: Trainer = None): | ||
| super().__init__(cfg=cfg, trainer=trainer) | ||
|
|
||
| self.sample_rate = self._cfg.sample_rate | ||
| self._setup_loss() | ||
| self.setup_optimization_flags() | ||
|
|
||
| def _setup_loss(self): | ||
| """Setup loss for this model.""" | ||
|
|
@@ -130,6 +133,55 @@ def _setup_metrics(self, tag: str = 'val'): | |
| 'Setup metrics for %s, dataloader %d: %s', tag, dataloader_idx, ', '.join(metrics_dataloader_idx) | ||
| ) | ||
|
|
||
| def _parse_batch(self, batch): | ||
| """Parse a batch into input signal, target signal, and input length. | ||
|
|
||
| Handles both dict-style (lhotse) and tuple-style (AudioToTargetDataset) | ||
| batches, and ensures signals are in multi-channel format (B, C, T). | ||
|
|
||
| Returns: | ||
| Tuple of (input_signal, target_signal, input_length). | ||
| """ | ||
| if isinstance(batch, dict): | ||
| # Lhotse dataloaders produce dict batches | ||
| input_signal = batch['input_signal'] | ||
| input_length = batch['input_length'] | ||
| target_signal = batch['target_signal'] | ||
| else: | ||
| # Standard audio datasets produce tuple batches | ||
| input_signal, input_length, target_signal, _ = batch | ||
|
|
||
| if input_signal.ndim == 2: | ||
| input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') | ||
| if target_signal.ndim == 2: | ||
| target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') | ||
|
|
||
| return input_signal, target_signal, input_length | ||
|
|
||
| @abstractmethod | ||
| def _compute_train_loss(self, input_signal, target_signal, input_length): | ||
| """Compute training loss from parsed batch signals. | ||
|
|
||
| Args: | ||
| input_signal: input audio tensor (B, C, T) | ||
| target_signal: target audio tensor (B, C, T) | ||
| input_length: length of each example in the batch (B,) | ||
|
|
||
| Returns: | ||
| Scalar loss tensor. | ||
| """ | ||
| pass | ||
|
|
||
| def training_step(self, batch, batch_idx): | ||
| input_signal, target_signal, input_length = self._parse_batch(batch) | ||
| loss = self._compute_train_loss(input_signal, target_signal, input_length) | ||
|
|
||
| self.log('train_loss', loss) | ||
| self.log('learning_rate', self._optimizer.param_groups[0]['lr']) | ||
| self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) | ||
|
|
||
| return loss | ||
|
|
||
| @abstractmethod | ||
| def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): | ||
| pass | ||
|
|
@@ -313,6 +365,23 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade | |
| temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) | ||
| return temporary_dataloader | ||
|
|
||
| def _normalize(self, signal: torch.Tensor): | ||
| """Normalize signal so its peak amplitude is 1. | ||
|
|
||
| Args: | ||
| signal: tensor with shape (B, C, T) | ||
|
|
||
| Returns: | ||
| Tuple of (normalized_signal, norm_scale). Pass norm_scale to | ||
| _denormalize to restore the original scale. | ||
| """ | ||
| norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True) | ||
| return signal / (norm_scale + self.eps), norm_scale | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the returned norm_scale include
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch I'd prefer an explicit add in the reverse operation instead of baking it into |
||
|
|
||
| def _denormalize(self, signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor: | ||
| """Restore original scale after _normalize.""" | ||
| return signal * (norm_scale + self.eps) | ||
|
|
||
| @staticmethod | ||
| def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: | ||
| """Trim or pad the output to match the batch length. | ||
|
|
@@ -467,11 +536,10 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': | |
| return list_of_models | ||
|
|
||
| def setup_optimization_flags(self): | ||
| """ | ||
| Utility method that must be explicitly called by the subclass in order to support optional optimization flags. | ||
| This method is the only valid place to access self.cfg prior to DDP training occurs. | ||
| """Setup optional optimization flags from the model config. | ||
|
|
||
| The subclass may chose not to support this method, therefore all variables here must be checked via hasattr() | ||
| Called automatically during __init__. This is the only valid place | ||
| to access self.cfg prior to DDP training. | ||
| """ | ||
| # Skip update if nan/inf grads appear on any rank. | ||
| self._skip_nan_grad = False | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.