diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index 8e4c51e3..796da576 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -376,7 +376,14 @@ def link_audio_dataset( first: int = 0, last: int | None = None, ) -> None: - """Link the ``SpectroData`` of the ``SpectroDataset`` to the ``AudioData`` of the ``AudioDataset``. + """Link the ``SpectroDataset`` to the ``AudioDataset``. + + The ``SpectroData`` of the ``SpectroDataset`` will be linked to + the ``AudioData`` of the ``AudioDataset``. + + There should be in the ``AudioDataset`` an ``AudioData`` that + have the same ``begin`` and ``end`` than each of the ``SpectroData`` + of the ``SpectroDataset``. Parameters ---------- @@ -388,23 +395,16 @@ def link_audio_dataset( Index of the last ``SpectroData`` and ``AudioData`` to link. """ - if len(audio_dataset.data) != len(self.data): - msg = ( - "The audio dataset doesn't contain the same number of data" - " as the spectro dataset." - ) - raise ValueError(msg) - last = len(self.data) if last is None else last - for sd, ad in list( - zip( - sorted(self.data, key=lambda d: (d.begin, d.end)), - sorted(audio_dataset.data, key=lambda d: (d.begin, d.end)), - strict=False, - ), - )[first:last]: - sd.link_audio_data(ad) + ad_dict = {(ad.begin, ad.end): ad for ad in audio_dataset.data} + + for sd in self.data[first:last]: + key = (sd.begin, sd.end) + if key not in ad_dict: + msg = f"No AudioData found for SpectroData {sd}" + raise ValueError(msg) + sd.link_audio_data(ad_dict[key]) def update_json_audio_data(self, first: int, last: int) -> None: """Update the serialized ``json`` file with the spectro data from first to last. diff --git a/tests/test_spectro.py b/tests/test_spectro.py index fc23e14c..6fe76ae0 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import gc from contextlib import nullcontext from pathlib import Path @@ -727,163 +728,328 @@ def test_link_audio_data( @pytest.mark.parametrize( ( - "audio_files", - "ads1_data_duration", - "ads2_data_duration", - "ads2_sample_rate", + "audio_data_params", + "spectro_data_params", "start_index", "stop_index", + "expected_relinked_data_idxs", "expected_exception", ), [ pytest.param( - { - "duration": 1, - "sample_rate": 1_024, - "nb_files": 1, - "date_begin": pd.Timestamp("2024-01-01 12:00:00"), - }, - Timedelta(seconds=0.1), - Timedelta(seconds=0.1), - 1_024, + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], None, None, + [0, 1], nullcontext(), id="default_indexes_is_full_dataset", ), pytest.param( - { - "duration": 1, - "sample_rate": 1_024, - "nb_files": 1, - "date_begin": pd.Timestamp("2024-01-01 12:00:00"), - }, - Timedelta(seconds=0.1), - Timedelta(seconds=0.1), - 1_024, + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:02"), + Timestamp("1994-02-27 00:00:03"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:02"), + Timestamp("1994-02-27 00:00:03"), + 100.0, + ), + ], + 1, 2, - 6, + [1], nullcontext(), id="link_a_part_of_the_data", ), pytest.param( - { - "duration": 1, - "sample_rate": 1_024, - "nb_files": 1, - "date_begin": pd.Timestamp("2024-01-01 12:00:00"), - }, - Timedelta(seconds=0.1), - Timedelta(seconds=0.1), - 2_048, + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 150.0, + ), + ], None, None, + [], pytest.raises( ValueError, - match="The sample rate of the audio data doesn't match.", + match=r"The sample rate of the audio data doesn't match.", ), id="different_sample_rate", ), pytest.param( - { - "duration": 1, - "sample_rate": 1_024, - "nb_files": 1, - "date_begin": pd.Timestamp("2024-01-01 12:00:00"), - }, - Timedelta(seconds=0.1), - Timedelta(seconds=0.5), - 1_024, + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], None, None, + [0], + nullcontext(), + id="fewer_spectro_data_should_be_ok", + ), + pytest.param( + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:02"), + Timestamp("1994-02-27 00:00:03"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + None, + None, + [], pytest.raises( ValueError, - match="The audio dataset doesn't contain the same number of data as the" - " spectro dataset.", + match=rf"No AudioData found for SpectroData " + rf"{ + Timestamp('1994-02-27 00:00:01').strftime( + TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED + ) + }", ), - id="different_number_of_data", + id="not_found_spectrodata_should_raise", ), pytest.param( - { - "duration": 1, - "sample_rate": 1_024, - "nb_files": 1, - "date_begin": pd.Timestamp("2024-01-01 12:00:00"), - }, - Timedelta(seconds=0.1), - Timedelta(seconds=0.101), - 1_024, + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:03"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], None, None, - pytest.raises(ValueError, match="The end of the audio data doesn't match."), - id="different_end_of_first_data", + [], + pytest.raises( + ValueError, + match=rf"No AudioData found for SpectroData " + rf"{ + Timestamp('1994-02-27 00:00:01').strftime( + TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED + ) + }", + ), + id="found_begin_but_not_end_should_raise", + ), + pytest.param( + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + None, + None, + [], + pytest.raises( + ValueError, + match=rf"No AudioData found for SpectroData " + rf"{ + Timestamp('1994-02-27 00:00:00').strftime( + TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED + ) + }", + ), + id="found_end_but_not_begin_should_raise", + ), + pytest.param( + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ], + [ + ( + Timestamp("1994-02-27 00:00:00"), + Timestamp("1994-02-27 00:00:01"), + 100.0, + ), + ( + Timestamp("1994-02-27 00:00:01"), + Timestamp("1994-02-27 00:00:02"), + 100.0, + ), + ], + None, + 1, # Excludes the sd that doesn't have an ad counterpart in ads + [0], + nullcontext(), + id="missing_audiodata_counterparts_of_excluded_spectrodata_shouldnt_raise", ), ], - indirect=["audio_files"], ) def test_link_audio_dataset( - audio_files: pytest.fixture, - tmp_path: pytest.fixture, - ads1_data_duration: Timedelta, - ads2_data_duration: Timedelta, - ads2_sample_rate: float, + patch_audio_data: None, + audio_data_params: list[ + tuple[Timestamp, Timestamp, float] + ], # begin, end, sample_rate + spectro_data_params: list[ + tuple[Timestamp, Timestamp, float] + ], # begin, end, sample_rate start_index: int, stop_index: int, - expected_exception: type[Exception], + expected_relinked_data_idxs: list[int], + expected_exception: contextlib.AbstractContextManager, ) -> None: - ads1 = AudioDataset.from_folder( - tmp_path, - strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED, - data_duration=ads1_data_duration, + ads_origin = AudioDataset( + [ + AudioData( + begin=sd_params[0], + end=sd_params[1], + sample_rate=sd_params[2], + mocked_value=[1.0], + ) + for sd_params in spectro_data_params + ], ) - ads2 = AudioDataset.from_folder( - tmp_path, - strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED, - data_duration=ads2_data_duration, + + ads_dest = AudioDataset( + [ + AudioData( + begin=ad_params[0], + end=ad_params[1], + sample_rate=ad_params[2], + mocked_value=[1.0], + ) + for ad_params in audio_data_params + ], ) - ads2.sample_rate = ads2_sample_rate sds = SpectroDataset.from_audio_dataset( - ads1, - fft=ShortTimeFFT(hamming(128), 128, ads1.sample_rate), + audio_dataset=ads_origin, + fft=ShortTimeFFT(hamming(16), 16, ads_origin.sample_rate), ) - with expected_exception as e: - assert sds.link_audio_dataset(ads2, first=start_index, last=stop_index) == e + origin_ids = {id(ad) for ad in ads_origin.data} + dest_ids = {id(ad) for ad in ads_dest.data} - if type(expected_exception) is not nullcontext: - return + assert all(id(sd.audio_data) in origin_ids for sd in sds.data) + assert not any(id(sd.audio_data) in dest_ids for sd in sds.data) - start_index = 0 if start_index is None else start_index - stop_index = len(ads1.data) if stop_index is None else stop_index - - for idx, sd in enumerate(sds.data): - if idx in range(start_index, stop_index): - assert sd.audio_data is not ads1.data[idx] - assert sd.audio_data is ads2.data[idx] - else: - assert sd.audio_data is ads1.data[idx] - assert sd.audio_data is not ads2.data[idx] - - # linking should fail if the length of the audio datasets differ: - ads_err = AudioDataset( - [*ads2.data, ads2.data[0]], - ) # Adding one data to the destination ads - with pytest.raises( - ValueError, - match=r"The audio dataset doesn't contain the same number of data as the " - "spectro dataset.", - ): - sds.link_audio_dataset(ads_err) + with expected_exception: + sds.link_audio_dataset( + audio_dataset=ads_dest, + first=start_index, + last=stop_index, + ) - # linking should fail if any of the data can't be linked - ads_err = AudioDataset(ads1.data) - ads1.data[-1].sample_rate = ads2_sample_rate * 0.5 - with pytest.raises( - ValueError, - match=r"The sample rate of the audio data doesn't match.", - ): - sds.link_audio_dataset(ads_err) + relinked_sd = [sds.data[idx] for idx in expected_relinked_data_idxs] + not_relinked_sd = [sd for sd in sds.data if sd not in relinked_sd] + + assert all(id(sd.audio_data) in origin_ids for sd in not_relinked_sd) + assert not any(id(sd.audio_data) in dest_ids for sd in not_relinked_sd) + assert all(id(sd.audio_data) in dest_ids for sd in relinked_sd) + assert not any(id(sd.audio_data) in origin_ids for sd in relinked_sd) @pytest.mark.parametrize(