Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 62 additions & 48 deletions src/fpcm_detector.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
"""
fpcm_detector.py
----------------
Minimal dependency‑light implementation of the Fast Parametric Curve
Matching (FPCM) (Kleeva et al., 2022) for MNE‑Python Raw objects.

Authors : Daria Kleeva, Alexei Ossadtchi
Email: dkleeva@gmail.com
"""
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from typing import Dict, List, Tuple
import mne


from tqdm import tqdm

# =============================================================================
# CONSTANTS
# =============================================================================
PEAK_HW_MS = 35.0
WAVE_HW_MS = 60.0
WAVE_POWER = 3.0
BCK_COEFF = 3.0
ERR_PEAK_TH = 0.3
ERR_WAVE_TH = 0.7
HIT_THRESHOLD = 4

SLOPES_HEIGHTS_RATIO = 0.3
SLOPE_TOLERANCE = 2
APEX_WAVE_RATIO1 = 2
APEX_WAVE_RATIO2 = 5

SEPARATION_INTERVAL = 40 #ms
# --------------------------------------------------------------------
# 1. Spline model
# --------------------------------------------------------------------
Expand Down Expand Up @@ -52,7 +60,7 @@ def _build_model(
t2 = np.arange(-wave_hw_s, wave_hw_s + 1, 1, dtype=float)

# Segment design matrices
A1 = np.c_[t1_l, np.ones_like(t1_l)] # left slope
A1 = np.c_[t1_l, np.ones_like(t1_l)] # left slope
A2 = np.c_[t1_r, np.ones_like(t1_r)] # right slope
A3 = np.c_[np.abs(t2) ** wave_power, np.ones_like(t2)] # slow wave

Expand Down Expand Up @@ -127,8 +135,7 @@ def _convolve_filters(X: np.ndarray, piBp: np.ndarray) -> np.ndarray:
- C[5]: intercept of the slow wave segment
"""
flipped_model = np.fliplr(piBp)
C = np.vstack([np.convolve(f, X) for f in flipped_model])
C = C[:,:len(X)]
C = np.vstack([np.convolve(f, X, mode='same') for f in flipped_model])
return C
Comment thread
Poncharm marked this conversation as resolved.

# --------------------------------------------------------------------
Expand All @@ -140,8 +147,10 @@ def _apply_predicates(
bkg_left: np.ndarray,
bkg_right: np.ndarray,
flat_peak: np.ndarray,
bkg_coeff: float,
slope_tol: float = 2.0
bkg_coeff: float = BCK_COEFF,
slope_tol: float = SLOPE_TOLERANCE,
apex_wave_ratio1: float = APEX_WAVE_RATIO1,
apex_wave_ratio2: float = APEX_WAVE_RATIO2,
) -> np.ndarray:
"""
Apply a set of logical rules to identify spike-like waveforms.
Expand Down Expand Up @@ -175,7 +184,7 @@ def _apply_predicates(
spike_above = (HC > 0) & (C[4,:] > 0)
wave_below = (C[5,:]<0)

ratio_ok_neg = (np.abs(HC) > 2*np.abs(C[5,:])) & (np.abs(HC) < 5*np.abs(C[5,:]))
ratio_ok_neg = (np.abs(HC) > apex_wave_ratio1*np.abs(C[5,:])) & (np.abs(HC) < apex_wave_ratio2*np.abs(C[5,:]))
ratio_ok_pos = ratio_ok_neg.copy()

slopes_equal = np.abs(np.abs(C[0,:]/C[2,:]) - 1) < slope_tol
Expand All @@ -191,13 +200,14 @@ def _apply_predicates(
def detect_spikes_fpcm(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Low code quality found in detect_spikes_fpcm - 21% (low-code-quality)


ExplanationThe quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

  • Reduce the function length by extracting pieces of functionality out into
    their own functions. This is the most important thing you can do - ideally a
    function should be less than 10 lines.
  • Reduce nesting, perhaps by introducing guard clauses to return early.
  • Ensure that variables are tightly scoped, so that code using related concepts
    sits together within the function rather than being scattered.

raw: mne.io.BaseRaw,
*,
peak_hw_ms: float = 35.0,
wave_hw_ms: float = 60.0,
wave_power: int = 3,
bkg_coeff: float = 3.0,
err_peak_th: float = 0.3,
err_wave_th: float = 0.7,
hit_threshold: int = 4,
peak_hw_ms: float = PEAK_HW_MS,
wave_hw_ms: float = WAVE_HW_MS,
wave_power: int = WAVE_POWER,
bkg_coeff: float = BCK_COEFF,
err_peak_th: float = ERR_PEAK_TH,
err_wave_th: float = ERR_WAVE_TH,
hit_threshold: int = HIT_THRESHOLD,
slopes_heights_ratio: float = SLOPES_HEIGHTS_RATIO,
) -> Dict[str, object]:
"""
Run FPCM spike detection on an MNE‑Raw object.
Expand Down Expand Up @@ -249,7 +259,7 @@ def detect_spikes_fpcm(

box = np.ones(T) / T

for ch in range(n_ch):
for ch in tqdm(range(n_ch)):
x = data[ch,:]
C = _convolve_filters(x, piBp) # (6, n_times)
coeffs.append(C)
Expand All @@ -259,49 +269,54 @@ def detect_spikes_fpcm(
a0 = Bp[0,:] @ C
b0 = Bp[peak_hw_s,:] @ C
c0 = Bp[2*peak_hw_s,:] @ C
flat = np.abs((a0 - b0) / (c0 - b0)) - 1 < 0.3

inst_power = np.convolve(box, x**2)
inst_power = inst_power[:len(x)]
bkg_left = np.sqrt(np.concatenate([np.zeros(T), inst_power[:-T]]))
bkg_right = np.sqrt(np.concatenate([inst_power[T:], np.zeros(T)]))

flat = np.abs((a0 - b0) / (c0 - b0)) - 1 < slopes_heights_ratio

inst_power = np.convolve(box, x**2, mode='same')
bkg_left = np.sqrt(np.roll(inst_power, shift=T))
bkg_left[:T//2] = 0
bkg_right = np.sqrt(np.roll(inst_power, shift=-T))
bkg_right[-T//2:] = 0

morph_ok = _apply_predicates(C, hc, bkg_left, bkg_right, flat, bkg_coeff)

cand_idx = np.where(morph_ok)[0]
for t in cand_idx:
if t < T:
continue
rng = slice(t-T+1, t+1)
synth = Bp @ C[:, t]
if T % 2 == 0:
st, en = t - T // 2, t + T // 2
else:
st, en = t - T // 2, t + T // 2 + 1

synth = Bp @ C[:, t]
synthdic[t] = synth

err_peak[ch, t] = np.linalg.norm(x[rng][:2*peak_hw_s] - synth[:2*peak_hw_s]) \
/ np.linalg.norm(x[rng][:2*peak_hw_s])
err_wave[ch, t] = np.linalg.norm(x[rng][2*peak_hw_s:] - synth[2*peak_hw_s:]) \
/ np.linalg.norm(x[rng][2*peak_hw_s:])
if st >= 0 and en <= len(x):
rng = slice(st, en)
err_peak[ch, t] = np.linalg.norm(x[rng][:2*peak_hw_s] - synth[:2*peak_hw_s]) \
/ np.linalg.norm(x[rng][:2*peak_hw_s])
err_wave[ch, t] = np.linalg.norm(x[rng][2*peak_hw_s:] - synth[2*peak_hw_s:]) \
/ np.linalg.norm(x[rng][2*peak_hw_s:])
else:
err_peak[ch, t] = 1
err_wave[ch, t] = 1

synth_all.append(synthdic)
good = (err_peak[ch] < err_peak_th) & (err_wave[ch] < err_wave_th)
hit_mask[ch] = morph_ok & good


acc_hits = hit_mask.sum(0)
spike_centres = np.where(acc_hits >= hit_threshold)[0]

min_separation_samp = int(round(0.040 * sfreq))
min_separation_samp = int(round(SEPARATION_INTERVAL / 1000 * sfreq))
suppressed = []
last_t = -np.inf
for t in sorted(spike_centres):
if t - last_t >= min_separation_samp:
suppressed.append(t)
last_t = t

# unique_peaks = np.array(sorted(set(spike_centres)))
unique_peaks = np.array(suppressed)
hits_selection=hit_mask[:,unique_peaks]
unique_peaks = unique_peaks-T+peak_hw_s

hits_selection = hit_mask[:,unique_peaks]
unique_peaks = unique_peaks-T//2+peak_hw_s

events = np.column_stack([
unique_peaks + raw.first_samp,
Expand All @@ -320,4 +335,3 @@ def detect_spikes_fpcm(
peak_hw_s = peak_hw_s,
wave_hw_s = wave_hw_s
)