From bee1da6945a496d12e28e394f4cc3e52aebaef5a Mon Sep 17 00:00:00 2001 From: stanley1208 Date: Sat, 28 Mar 2026 22:50:59 -0700 Subject: [PATCH] fix: clamp shift perturbation for short audio instead of silently skipping Signed-off-by: stanley1208 Made-with: Cursor --- .../asr/parts/preprocessing/perturb.py | 8 ++-- .../asr/test_preprocessing_segment.py | 47 ++++++++++++++++++- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/parts/preprocessing/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py index 6f1704d07f02..cbd5e3b86bef 100644 --- a/nemo/collections/asr/parts/preprocessing/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -438,10 +438,12 @@ def __init__(self, min_shift_ms=-5.0, max_shift_ms=5.0, rng=None): def perturb(self, data): shift_ms = random.uniform(self._min_shift_ms, self._max_shift_ms) - if abs(shift_ms) / 1000 > data.duration: - # TODO: do something smarter than just ignore this condition - return + max_shift_ms = data.duration * 1000 + if abs(shift_ms) > max_shift_ms: + shift_ms = max(-max_shift_ms, min(shift_ms, max_shift_ms)) shift_samples = int(shift_ms * data.sample_rate // 1000) + if shift_samples == 0: + return # logging.debug("shift: %s", shift_samples) if shift_samples < 0: data._samples[-shift_samples:] = data._samples[:shift_samples] diff --git a/tests/collections/asr/test_preprocessing_segment.py b/tests/collections/asr/test_preprocessing_segment.py index 9f6144bad017..ca8d668836e7 100644 --- a/tests/collections/asr/test_preprocessing_segment.py +++ b/tests/collections/asr/test_preprocessing_segment.py @@ -22,7 +22,7 @@ import pytest import soundfile as sf -from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, SilencePerturbation +from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, ShiftPerturbation, SilencePerturbation from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, select_channels @@ -485,3 +485,48 @@ def test_audio_segment_trim_match(self, tmpdir, target_sr): # Test assert audio_segment_1 == audio_segment_2, f'trim setup {trim_setup}, loaded segments not matching' + + +class TestShiftPerturbation: + sample_rate = 16000 + + def _make_audio_segment(self, duration_sec=1.0): + """Create a simple AudioSegment with a sine wave for testing.""" + num_samples = int(duration_sec * self.sample_rate) + t = np.linspace(0, duration_sec, num_samples, dtype=np.float32) + samples = np.sin(2 * np.pi * 440 * t) + return AudioSegment(samples=samples, sample_rate=self.sample_rate) + + def test_shift_perturbation_normal(self): + """Shift perturbation modifies audio when shift is within duration.""" + perturb = ShiftPerturbation(min_shift_ms=-5.0, max_shift_ms=5.0) + segment = self._make_audio_segment(duration_sec=1.0) + original = segment.samples.copy() + perturb.perturb(segment) + assert segment.samples.shape == original.shape, "Audio length should not change" + + def test_shift_perturbation_short_audio_not_skipped(self): + """Shift perturbation should clamp and apply shift for short audio, not silently skip.""" + perturb = ShiftPerturbation(min_shift_ms=10.0, max_shift_ms=20.0) + duration_sec = 0.005 # 5ms — shorter than min_shift_ms + segment = self._make_audio_segment(duration_sec=duration_sec) + original = segment.samples.copy() + perturb.perturb(segment) + assert segment.samples.shape == original.shape, "Audio length should not change" + + @pytest.mark.parametrize("duration_sec", [0.001, 0.01, 0.1, 1.0]) + def test_shift_perturbation_preserves_length(self, duration_sec): + """Audio length must be preserved regardless of duration.""" + perturb = ShiftPerturbation(min_shift_ms=-50.0, max_shift_ms=50.0) + segment = self._make_audio_segment(duration_sec=duration_sec) + original_len = len(segment.samples) + perturb.perturb(segment) + assert len(segment.samples) == original_len, "Shift perturbation must preserve audio length" + + def test_shift_perturbation_zero_shift(self): + """When min and max shift are both 0, audio should be unchanged.""" + perturb = ShiftPerturbation(min_shift_ms=0.0, max_shift_ms=0.0) + segment = self._make_audio_segment(duration_sec=0.5) + original = segment.samples.copy() + perturb.perturb(segment) + np.testing.assert_array_equal(segment.samples, original, "Zero shift should not modify audio")