Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test data/multi_cuvette_test_data.KD
Binary file not shown.
Binary file added test data/multi_cuvette_test_data_corrupted.KD
Binary file not shown.
54 changes: 54 additions & 0 deletions tests/test_import_kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Tests for import_kd module."""

import warnings
from pathlib import Path

import pytest

from uv_pro.io.import_kd import KDFile


TEST_DATA_DIR = Path(__file__).parent.parent / "test data"


class TestKDFileCorruptionDetection:
"""Test detection of corrupted .KD files with non-monotonic time values."""

def test_valid_multi_cuvette_file_no_warning(self):
"""Test that valid multi-cuvette file does not generate a warning."""
file_path = TEST_DATA_DIR / "multi_cuvette_test_data.KD"

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
kd = KDFile(file_path)

# Filter for UserWarnings about corruption
corruption_warnings = [
warning for warning in w
if issubclass(warning.category, UserWarning)
and "corrupted" in str(warning.message).lower()
]

assert len(corruption_warnings) == 0, (
f"Expected no corruption warnings for valid file, "
f"but got: {[str(w.message) for w in corruption_warnings]}"
)

def test_corrupted_multi_cuvette_file_generates_warning(self):
"""Test that corrupted multi-cuvette file generates a warning."""
file_path = TEST_DATA_DIR / "multi_cuvette_test_data_corrupted.KD"

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
kd = KDFile(file_path)

# Filter for UserWarnings about corruption
corruption_warnings = [
warning for warning in w
if issubclass(warning.category, UserWarning)
and "corrupted" in str(warning.message).lower()
]

assert len(corruption_warnings) >= 1, (
"Expected at least one corruption warning for corrupted file"
)
140 changes: 137 additions & 3 deletions uv_pro/io/import_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import struct
import warnings
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -39,6 +40,8 @@ class KDFile:
The time values that each spectrum was captured.
cycle_time : int or None
The cycle time value (in seconds) for the experiment.
samples_cell : :class:`pandas.Series` or None
The cell/cuvette identifier for each spectrum.
"""

absorbance_data_header = {
Expand All @@ -55,6 +58,10 @@ class KDFile:
b'\x65\x00\x77\x00\x00\x00\x00\x00',
'spacing': 24,
}
samples_cell_header = {
'header': b'\x52\x00\x65\x00\x67\x00\x4e\x00\x61\x00\x6d\x00\x65\x00',
'spacing': 18,
}

def __init__(
self, path: Path, spectrometer_range: tuple[int, int] = (190, 1100)
Expand Down Expand Up @@ -82,7 +89,7 @@ def __init__(
)
self.spectrum_bytes_length = self._get_spectrum_bytes_length()
self.file_bytes = self._read_binary()
self.spectra, self.spectra_times, self.cycle_time = self.parse_kd()
self.spectra, self.spectra_times, self.cycle_time, self.samples_cell = self.parse_kd()

def _get_spectrum_bytes_length(self) -> int:
"""8 hex chars per wavelength."""
Expand All @@ -97,7 +104,7 @@ def _read_binary(self) -> bytes:

return file_bytes

def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int]:
def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int, pd.Series | None]:
"""
Parse a .KD file and extract data.

Expand All @@ -114,11 +121,20 @@ def parse_kd(self) -> tuple[pd.DataFrame, pd.Series, int]:
The time values that each spectrum was captured.
cycle_time : int
The cycle time (in seconds) for the UV-vis experiment.
samples_cell : :class:`pandas.Series` or None
The cell/cuvette identifier for each spectrum (for multi-cuvette files).
"""
cycle_time = self._handle_cycletime()
spectra_times = self._handle_spectratimes()
spectra = self._handle_spectra(spectra_times)
return spectra, spectra_times, cycle_time
samples_cell = self._handle_samples_cell()

# Validate and fix corrupt timepoints
spectra, spectra_times, samples_cell = self._validate_and_fix_data(
spectra, spectra_times, samples_cell
)

return spectra, spectra_times, cycle_time, samples_cell

def _handle_spectra(self, spectra_times: pd.Series) -> pd.DataFrame:
def _spectra_dataframe(spectra: list, spectra_times: pd.Series) -> pd.DataFrame:
Expand All @@ -142,6 +158,93 @@ def _handle_spectratimes(self) -> pd.Series:

raise Exception('Error parsing file. No spectra times found.')

def _validate_and_fix_data(
self,
spectra: pd.DataFrame,
spectra_times: pd.Series,
samples_cell: pd.Series | None,
) -> tuple[pd.DataFrame, pd.Series, pd.Series | None]:
"""
Validate that time values are monotonically increasing within each cuvette.

If non-increasing time values are found (indicating file corruption),
issue a warning and remove the corrupt timepoints and their spectra.

Parameters
----------
spectra : pd.DataFrame
The spectra data with time as columns.
spectra_times : pd.Series
The time values for each spectrum.
samples_cell : pd.Series or None
The cell/cuvette identifier for each spectrum.

Returns
-------
tuple[pd.DataFrame, pd.Series, pd.Series | None]
The validated/fixed spectra, times, and samples_cell data.
"""
# Build a working dataframe by transposing spectra (rows become spectra, columns become wavelengths)
work_df = spectra.T.reset_index()
work_df.rename(columns={'Time (s)': 'Time_s'}, inplace=True)
cell_values = samples_cell if samples_cell is not None else pd.Series(['SAMPLES_CELL_1'] * len(work_df))
work_df.insert(0, 'sample', cell_values.values)

def find_valid_rows(group: pd.DataFrame) -> pd.DataFrame:
"""Find rows where times are monotonically increasing."""
times = group['Time_s'].values
valid_mask = [True] * len(times)

# Find reset points where time decreases
for i in range(1, len(times)):
if times[i] < times[i - 1]:
# Mark all previous times >= current time as invalid
for j in range(i):
if times[j] >= times[i]:
valid_mask[j] = False

return group[valid_mask]

# Apply validation per cell group - get valid indices
valid_indices = work_df.groupby('sample', group_keys=False).apply(
find_valid_rows, include_groups=False
).index

# Check if any data was removed
removed_count = len(work_df) - len(valid_indices)
if removed_count > 0:
removed_indices = set(work_df.index) - set(valid_indices)
removed_times = work_df.loc[list(removed_indices), 'Time_s'].tolist()

warnings.warn(
f"Potentially corrupted .KD file detected: {self.path}. "
f"Time values are not monotonically increasing.",
UserWarning
)
warnings.warn(
f"Removed {removed_count} corrupt timepoint(s) "
f"at indices {sorted(removed_indices)} with time values {removed_times}. "
f"These timepoints and their corresponding spectra have been excluded.",
UserWarning
)

# Filter work_df using valid indices
clean_df = work_df.loc[valid_indices]

# Rebuild the clean data
clean_spectra_times = pd.Series(clean_df['Time_s'].values, name='Time (s)')
clean_samples_cell = pd.Series(clean_df['sample'].values, name='Cell') if samples_cell is not None else None

# Rebuild spectra dataframe (transpose back)
wavelength_cols = [col for col in clean_df.columns if col not in ['sample', 'Time_s']]
clean_spectra = clean_df[wavelength_cols].T
clean_spectra.index = pd.Index(self.wavelength_range, name='Wavelength (nm)')
clean_spectra.columns = clean_spectra_times

return clean_spectra, clean_spectra_times, clean_samples_cell

return spectra, spectra_times, samples_cell

def _handle_cycletime(self) -> int | None:
try:
return int(
Expand All @@ -150,6 +253,29 @@ def _handle_cycletime(self) -> int | None:
except TypeError:
return None

def _handle_samples_cell(self) -> pd.Series | None:
data_header = KDFile.samples_cell_header['header']
spacing = KDFile.samples_cell_header['spacing']

header_idx = self.file_bytes.find(data_header)
if header_idx == -1:
return None

samples_cell = []
position = header_idx + spacing

while True:
cell_name = self._parse_samples_cell(position)
if cell_name is None or not cell_name.startswith('SAMPLES_CELL'):
break
samples_cell.append(cell_name)
position += 30 # 2-byte prefix + 28-byte cell name (14 chars * 2)

if samples_cell:
return pd.Series(samples_cell, name='Cell')

return None

def _extract_data(self, header: dict, parse_func: callable) -> list:
data_list = []
position = 0
Expand Down Expand Up @@ -181,3 +307,11 @@ def _parse_spectratimes(self, data_start: int) -> float:

def _parse_cycletime(self, data_start: int) -> int:
return int(struct.unpack_from('<d', self.file_bytes, data_start)[0])

def _parse_samples_cell(self, data_start: int) -> str | None:
try:
data_end = data_start + 28 # 14 chars * 2 bytes per char
cell_name_bytes = self.file_bytes[data_start:data_end]
return cell_name_bytes.decode('utf-16-le').rstrip('\x00')
except (IndexError, UnicodeDecodeError):
return None