diff --git a/test data/multi_cuvette_test_data.KD b/test data/multi_cuvette_test_data.KD new file mode 100644 index 0000000..349efee Binary files /dev/null and b/test data/multi_cuvette_test_data.KD differ diff --git a/test data/multi_cuvette_test_data_corrupted.KD b/test data/multi_cuvette_test_data_corrupted.KD new file mode 100644 index 0000000..e44f6b3 Binary files /dev/null and b/test data/multi_cuvette_test_data_corrupted.KD differ diff --git a/tests/test_import_kd.py b/tests/test_import_kd.py new file mode 100644 index 0000000..51e59b9 --- /dev/null +++ b/tests/test_import_kd.py @@ -0,0 +1,54 @@ +"""Tests for import_kd module.""" + +import warnings +from pathlib import Path + +import pytest + +from uv_pro.io.import_kd import KDFile + + +TEST_DATA_DIR = Path(__file__).parent.parent / "test data" + + +class TestKDFileCorruptionDetection: + """Test detection of corrupted .KD files with non-monotonic time values.""" + + def test_valid_multi_cuvette_file_no_warning(self): + """Test that valid multi-cuvette file does not generate a warning.""" + file_path = TEST_DATA_DIR / "multi_cuvette_test_data.KD" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kd = KDFile(file_path) + + # Filter for UserWarnings about corruption + corruption_warnings = [ + warning for warning in w + if issubclass(warning.category, UserWarning) + and "corrupted" in str(warning.message).lower() + ] + + assert len(corruption_warnings) == 0, ( + f"Expected no corruption warnings for valid file, " + f"but got: {[str(w.message) for w in corruption_warnings]}" + ) + + def test_corrupted_multi_cuvette_file_generates_warning(self): + """Test that corrupted multi-cuvette file generates a warning.""" + file_path = TEST_DATA_DIR / "multi_cuvette_test_data_corrupted.KD" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kd = KDFile(file_path) + + # Filter for UserWarnings about corruption + corruption_warnings = [ + warning for warning in w + if issubclass(warning.category, UserWarning) + and "corrupted" in str(warning.message).lower() + ] + + assert len(corruption_warnings) >= 1, ( + "Expected at least one corruption warning for corrupted file" + ) diff --git a/uv_pro/io/import_kd.py b/uv_pro/io/import_kd.py index 9eb5c24..cfac619 100644 --- a/uv_pro/io/import_kd.py +++ b/uv_pro/io/import_kd.py @@ -8,6 +8,7 @@ """ import struct +import warnings from pathlib import Path import pandas as pd @@ -39,6 +40,8 @@ class KDFile: The time values that each spectrum was captured. cycle_time : int or None The cycle time value (in seconds) for the experiment. + samples_cell : :class:`pandas.Series` or None + The cell/cuvette identifier for each spectrum. """ absorbance_data_header = { @@ -55,6 +58,10 @@ class KDFile: b'\x65\x00\x77\x00\x00\x00\x00\x00', 'spacing': 24, } + samples_cell_header = { + 'header': b'\x52\x00\x65\x00\x67\x00\x4e\x00\x61\x00\x6d\x00\x65\x00', + 'spacing': 18, + } def __init__( self, path: Path, spectrometer_range: tuple[int, int] = (190, 1100) @@ -82,7 +89,7 @@ def __init__( ) self.spectrum_bytes_length = self._get_spectrum_bytes_length() self.file_bytes = self._read_binary() - self.spectra, self.spectra_times, self.cycle_time = self.parse_kd() + self.spectra, self.spectra_times, self.cycle_time, self.samples_cell = self.parse_kd() def _get_spectrum_bytes_length(self) -> int: """8 hex chars per wavelength.""" @@ -97,7 +104,7 @@ def _read_binary(self) -> bytes: return file_bytes - def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int]: + def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int, pd.Series | None]: """ Parse a .KD file and extract data. @@ -114,11 +121,20 @@ def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int]: The time values that each spectrum was captured. cycle_time : int The cycle time (in seconds) for the UV-vis experiment. + samples_cell : :class:`pandas.Series` or None + The cell/cuvette identifier for each spectrum (for multi-cuvette files). """ cycle_time = self._handle_cycletime() spectra_times = self._handle_spectratimes() spectra = self._handle_spectra(spectra_times) - return spectra, spectra_times, cycle_time + samples_cell = self._handle_samples_cell() + + # Validate and fix corrupt timepoints + spectra, spectra_times, samples_cell = self._validate_and_fix_data( + spectra, spectra_times, samples_cell + ) + + return spectra, spectra_times, cycle_time, samples_cell def _handle_spectra(self, spectra_times: pd.Series) -> pd.DataFrame: def _spectra_dataframe(spectra: list, spectra_times: pd.Series) -> pd.DataFrame: @@ -142,6 +158,93 @@ def _handle_spectratimes(self) -> pd.Series: raise Exception('Error parsing file. No spectra times found.') + def _validate_and_fix_data( + self, + spectra: pd.DataFrame, + spectra_times: pd.Series, + samples_cell: pd.Series | None, + ) -> tuple[pd.DataFrame, pd.Series, pd.Series | None]: + """ + Validate that time values are monotonically increasing within each cuvette. + + If non-increasing time values are found (indicating file corruption), + issue a warning and remove the corrupt timepoints and their spectra. + + Parameters + ---------- + spectra : pd.DataFrame + The spectra data with time as columns. + spectra_times : pd.Series + The time values for each spectrum. + samples_cell : pd.Series or None + The cell/cuvette identifier for each spectrum. + + Returns + ------- + tuple[pd.DataFrame, pd.Series, pd.Series | None] + The validated/fixed spectra, times, and samples_cell data. + """ + # Build a working dataframe by transposing spectra (rows become spectra, columns become wavelengths) + work_df = spectra.T.reset_index() + work_df.rename(columns={'Time (s)': 'Time_s'}, inplace=True) + cell_values = samples_cell if samples_cell is not None else pd.Series(['SAMPLES_CELL_1'] * len(work_df)) + work_df.insert(0, 'sample', cell_values.values) + + def find_valid_rows(group: pd.DataFrame) -> pd.DataFrame: + """Find rows where times are monotonically increasing.""" + times = group['Time_s'].values + valid_mask = [True] * len(times) + + # Find reset points where time decreases + for i in range(1, len(times)): + if times[i] < times[i - 1]: + # Mark all previous times >= current time as invalid + for j in range(i): + if times[j] >= times[i]: + valid_mask[j] = False + + return group[valid_mask] + + # Apply validation per cell group - get valid indices + valid_indices = work_df.groupby('sample', group_keys=False).apply( + find_valid_rows, include_groups=False + ).index + + # Check if any data was removed + removed_count = len(work_df) - len(valid_indices) + if removed_count > 0: + removed_indices = set(work_df.index) - set(valid_indices) + removed_times = work_df.loc[list(removed_indices), 'Time_s'].tolist() + + warnings.warn( + f"Potentially corrupted .KD file detected: {self.path}. " + f"Time values are not monotonically increasing.", + UserWarning + ) + warnings.warn( + f"Removed {removed_count} corrupt timepoint(s) " + f"at indices {sorted(removed_indices)} with time values {removed_times}. " + f"These timepoints and their corresponding spectra have been excluded.", + UserWarning + ) + + # Filter work_df using valid indices + clean_df = work_df.loc[valid_indices] + + # Rebuild the clean data + clean_spectra_times = pd.Series(clean_df['Time_s'].values, name='Time (s)') + clean_samples_cell = pd.Series(clean_df['sample'].values, name='Cell') if samples_cell is not None else None + + # Rebuild spectra dataframe (transpose back) + wavelength_cols = [col for col in clean_df.columns if col not in ['sample', 'Time_s']] + clean_spectra = clean_df[wavelength_cols].T + clean_spectra.index = pd.Index(self.wavelength_range, name='Wavelength (nm)') + clean_spectra.columns = clean_spectra_times + + return clean_spectra, clean_spectra_times, clean_samples_cell + + return spectra, spectra_times, samples_cell + def _handle_cycletime(self) -> int | None: try: return int( @@ -150,6 +253,29 @@ def _handle_cycletime(self) -> int | None: except TypeError: return None + def _handle_samples_cell(self) -> pd.Series | None: + data_header = KDFile.samples_cell_header['header'] + spacing = KDFile.samples_cell_header['spacing'] + + header_idx = self.file_bytes.find(data_header) + if header_idx == -1: + return None + + samples_cell = [] + position = header_idx + spacing + + while True: + cell_name = self._parse_samples_cell(position) + if cell_name is None or not cell_name.startswith('SAMPLES_CELL'): + break + samples_cell.append(cell_name) + position += 30 # 2-byte prefix + 28-byte cell name (14 chars * 2) + + if samples_cell: + return pd.Series(samples_cell, name='Cell') + + return None + def _extract_data(self, header: dict, parse_func: callable) -> list: data_list = [] position = 0 @@ -181,3 +307,11 @@ def _parse_spectratimes(self, data_start: int) -> float: def _parse_cycletime(self, data_start: int) -> int: return int(struct.unpack_from(' str | None: + try: + data_end = data_start + 28 # 14 chars * 2 bytes per char + cell_name_bytes = self.file_bytes[data_start:data_end] + return cell_name_bytes.decode('utf-16-le').rstrip('\x00') + except (IndexError, UnicodeDecodeError): + return None