Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions src/contraqctor/qc/harp/lickety_split.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import asdict, dataclass

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
Expand All @@ -8,6 +10,20 @@
from .harp_device import HarpDeviceTypeTestSuite


@dataclass(frozen=True)
class _LickDurationMetrics:
"""Data class for metrics related to lick duration analysis."""
mean: float | None
std: float | None
percent_violations: float
total_licks: int
num_violations: int
max: float | None
min: float | None
num_long_violations: int
num_short_violations: int


class HarpLicketySplitTestSuite(HarpDeviceTypeTestSuite):
"""Test suite for Harp Lickety Split devices.

Expand Down Expand Up @@ -121,7 +137,25 @@ def test_lick_duration(self):
"""Tests for licks that are shorter than the expected duration."""
limits = (0.015, 1) # in seconds
lick = self._get_distinct_from_channel(self.data, self._target_channel)
first_lick_onset = lick[lick == 1].index[0]
lick_onsets = lick[lick == 1]
if len(lick_onsets) == 0:
metrics = _LickDurationMetrics(
mean=None,
std=None,
percent_violations=0.0,
total_licks=0,
num_violations=0,
max=None,
min=None,
num_long_violations=0,
num_short_violations=0,
)
metrics_dict = asdict(metrics)
return self.fail_test(
metrics_dict, "No lick onsets detected; unable to evaluate lick duration.", context=metrics_dict
)

first_lick_onset = lick_onsets.index[0]
lick = lick[first_lick_onset:]
lick_durations = lick[lick == 0].index - lick[lick == 1].index

Expand All @@ -141,26 +175,27 @@ def test_lick_duration(self):
long.set_ylabel("Count")

fig.tight_layout()
context = ContextExportableObj.as_context(fig)

metrics = {}
metrics["mean"] = np.mean(lick_durations) if len(lick_durations) > 0 else None
metrics["std"] = np.std(lick_durations) if len(lick_durations) > 0 else None
metrics["percent_violations"] = (
np.sum((lick_durations < limits[0]) | (lick_durations > limits[1])) / len(lick_durations)
if len(lick_durations) > 0
else 0.0
metrics = _LickDurationMetrics(
mean=float(np.mean(lick_durations)) if len(lick_durations) > 0 else None,
std=float(np.std(lick_durations)) if len(lick_durations) > 0 else None,
percent_violations=(
float(np.sum((lick_durations < limits[0]) | (lick_durations > limits[1])) / len(lick_durations))
if len(lick_durations) > 0
else 0.0
),
total_licks=len(lick_durations),
num_violations=int(np.sum((lick_durations < limits[0]) | (lick_durations > limits[1]))),
max=float(np.max(lick_durations)) if len(lick_durations) > 0 else None,
min=float(np.min(lick_durations)) if len(lick_durations) > 0 else None,
num_long_violations=int(np.sum(lick_durations > limits[1])),
num_short_violations=int(np.sum(lick_durations < limits[0])),
)
metrics["total_licks"] = len(lick_durations)
metrics["num_violations"] = int(np.sum((lick_durations < limits[0]) | (lick_durations > limits[1])))
metrics["max"] = np.max(lick_durations) if len(lick_durations) > 0 else None
metrics["min"] = np.min(lick_durations) if len(lick_durations) > 0 else None
metrics["num_long_violations"] = int(np.sum(lick_durations > limits[1]))
metrics["num_short_violations"] = int(np.sum(lick_durations < limits[0]))
context.update(metrics)

if metrics["num_long_violations"] > 0:
return self.warn_test(metrics, "Long lick duration violations detected.", context=context)
if metrics["percent_violations"] > 0.05:
return self.warn_test(metrics, "High number of lick duration violations (>5%).", context=context)
return self.pass_test(metrics, "Lick duration distribution within expected range.", context=context)
metrics_dict = asdict(metrics)
context = {**ContextExportableObj.as_context(fig), **metrics_dict}

if metrics.num_long_violations > 0:
return self.warn_test(metrics_dict, "Long lick duration violations detected.", context=context)
if metrics.percent_violations > 0.05:
return self.warn_test(metrics_dict, "High number of lick duration violations (>5%).", context=context)
return self.pass_test(metrics_dict, "Lick duration distribution within expected range.", context=context)
7 changes: 7 additions & 0 deletions tests/test_qc/harp/test_lickety_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,10 @@ def test_lick_duration(self, mock_lickety_split_device, mock_lickety_split_devic
assert result.status == Status.WARNING
assert result.message is not None
assert result.context is not None

def test_lick_duration_no_licks(self, mock_lickety_split_device_low_rate):
suite = HarpLicketySplitTestSuite(mock_lickety_split_device_low_rate)
result = suite.test_lick_duration()
assert result.status == Status.FAILED
assert result.message is not None
assert result.context is not None
Loading
Loading