Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
- uses: astral-sh/setup-uv@v7
- run: uv pip install --quiet --system .[doc]
- run: gedai sys-info --developer
- name: Download data
run: |
python -c 'import mne; mne.datasets.fetch_fsaverage()'
- run: make -C doc html
- name: Prune sphinx environment
run: rm -R ./doc/_build/html/.doctrees
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ jobs:
- uses: astral-sh/setup-uv@v7
- run: uv pip install --quiet --system .[test]
- run: gedai sys-info --developer
- name: Download testing data
run: |
python -c 'import mne; mne.datasets.testing.data_path(verbose=True)'
- run: pytest gedai --cov=gedai --cov-report=xml --cov-config=pyproject.toml
- uses: codecov/codecov-action@v5
with:
Expand Down Expand Up @@ -63,6 +66,9 @@ jobs:
uv pip install --quiet --system .[test]
uv pip install --quiet --system --upgrade --prerelease allow --only-binary :all: -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy
- run: gedai sys-info --developer
- name: Download testing data
run: |
python -c 'import mne; mne.datasets.testing.data_path(verbose=True)'
- run: pytest gedai --cov=gedai --cov-report=xml --cov-config=pyproject.toml
- uses: codecov/codecov-action@v5
with:
Expand Down
78 changes: 70 additions & 8 deletions gedai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,83 @@
from __future__ import annotations
import os
import warnings

from typing import TYPE_CHECKING
import pytest
from mne import set_log_level as set_log_level_mne

from .utils.logs import logger
from gedai import set_log_level
from gedai.utils.logs import logger

if TYPE_CHECKING:
import pytest


def pytest_configure(config: pytest.Config) -> None:
def pytest_configure(config):
"""Configure pytest options."""
config.addinivalue_line("usefixtures", "matplotlib_config")

warnings_lines = r"""
error::
# We use matplotlib agg backend to avoid any window to pop up during tests.
ignore:Matplotlib is currently using agg:UserWarning
# Pytest internals
ignore:Use setlocale.*instead:DeprecationWarning
ignore:datetime\.datetime\.utcnow.*is deprecated.*:DeprecationWarning
ignore:datetime\.datetime\.utcfromtimestamp.*is deprecated.*:DeprecationWarning
# Joblib
ignore:ast\.Num is deprecated.*:DeprecationWarning
ignore:Attribute n is deprecated.*:DeprecationWarning
# MNE
ignore:Python 3.14 will, by default, filter extracted tar.*:DeprecationWarning
"""
for warning_line in warnings_lines.split("\n"):
warning_line = warning_line.strip()
if warning_line and not warning_line.startswith("#"):
config.addinivalue_line("filterwarnings", warning_line)
# setup logging

logger.propagate = True
set_log_level_mne("WARNING")
set_log_level("WARNING")


@pytest.fixture(scope="session")
def matplotlib_config():
"""Configure matplotlib for viz tests."""
import matplotlib
from matplotlib import cbook

# Allow for easy interactive debugging with a call like:
#
# $ PYCROSTATES_MPL_TESTING_BACKEND=Qt5Agg pytest mne/viz/tests/test_raw.py -k annotation -x --pdb # noqa: E501
#
try:
want = os.environ["PYCROSTATES_MPL_TESTING_BACKEND"]
except KeyError:
want = "agg" # don't pop up windows
with warnings.catch_warnings(record=True): # ignore warning
warnings.filterwarnings("ignore")
matplotlib.use(want, force=True)
import matplotlib.pyplot as plt

assert plt.get_backend() == want
# overwrite some params that can horribly slow down tests that
# users might have changed locally (but should not otherwise affect
# functionality)
plt.ioff()
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.raise_window"] = False

# Make sure that we always reraise exceptions in handlers
orig = cbook.CallbackRegistry

class CallbackRegistryReraise(orig):
def __init__(self, exception_handler=None, signals=None):
super().__init__(exception_handler)

cbook.CallbackRegistry = CallbackRegistryReraise


@pytest.fixture(autouse=True)
def close_all():
"""Close all matplotlib plots, regardless of test status."""
# This adds < 1 µS in local testing
import matplotlib.pyplot as plt

yield
plt.close("all")
Binary file not shown.
Binary file added gedai/data/fsavLEADFIELD_4_GEDAI-cov.fif
Binary file not shown.
Binary file removed gedai/data/fsavLEADFIELD_4_GEDAI.mat
Binary file not shown.
201 changes: 102 additions & 99 deletions gedai/gedai/covariances.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,112 @@
import h5py
import os

import mne
import numpy as np
import sklearn.metrics

from ..utils._checks import _check_type


def _ensure_cov(reference_cov):
_check_type(reference_cov, (str, mne.Covariance), "reference_cov")
if isinstance(reference_cov, str):
if reference_cov == "leadfield":
reference_cov = mne.read_cov(
os.path.join(
os.path.dirname(__file__), "../data/fsavLEADFIELD_4_GEDAI-cov.fif"
)
)
else:
raise ValueError(
"Reference covariance must be 'leadfield'"
f"got '{reference_cov}' instead."
)
return reference_cov


def _pick_cov(cov, ch_names):
cov_ch_names = cov.ch_names

picks_cov = []
picks_ch_names = []
for cov_name in cov_ch_names:
for ch_name in ch_names:
if ch_name.lower() == cov_name.lower():
picks_cov.append(cov_name)
picks_ch_names.append(ch_name)
break
if len(picks_cov) == 0:
raise ValueError(
"No matching channel names found between inst and cov.\n"
f"Available channels in covariance are {cov_ch_names}.\n"
f"but instance has channels {ch_names}."
)
if len(picks_cov) < len(ch_names):
raise ValueError(
"Only a subset of channels in the instance are present"
" in the covariance.\n"
f"Use inst.pick_channels({picks_ch_names}) to select only the channels"
f" that are in the covariance or provide a covariance that contains"
f" all channels in the instance."
)
cov = cov.copy().pick_channels(picks_cov)
# Update the channel names in the covariance to match those in the instance
cov.update(names=ch_names)
return cov


def _compute_distance_cov(raw):
ch_positions = [raw.info["chs"][i]["loc"][:3] for i in range(raw.info["nchan"])]
ch_distance_matrix = sklearn.metrics.pairwise_distances(
ch_positions, metric="euclidean"
def compute_covariance_from_forward(forward):
"""Compute covariance matrix from the leadfield of a forward solution.

Parameters
----------
forward : instance of mne.Forward
The forward solution from which to compute the covariance matrix.

Returns
-------
cov : instance of mne.Covariance
The computed covariance matrix.
"""
_check_type(forward, (mne.Forward,), "forward")

data = forward["sol"]["data"] @ forward["sol"]["data"].T
ch_names = forward["info"]["ch_names"]
bads = forward["info"]["bads"]
nfree = len(ch_names) # TODO: fix
cov = mne.Covariance(
data, names=ch_names, bads=bads, projs=[], nfree=nfree, verbose=None
)
cov = 1 - ch_distance_matrix
return cov


def _compute_refcov(inst, mat):
inst_ch_names = inst.info["ch_names"]

with h5py.File(mat, "r") as f:
leadfield_data = f["leadfield4GEDAI"]
# ch_names
leadfield_channel_data = leadfield_data["electrodes"]
leadfield_ch_names = [
f[ref[0]][()].tobytes().decode("utf-16le").lower()
for ref in leadfield_channel_data["Name"]
]
# leadfield matrix
leadfield_gain_matrix = leadfield_data["gram_matrix_avref"]
leadfield_gain_matrix = np.array(leadfield_gain_matrix).T

# Two-pass matching: exact first, then substring
ch_indices = []
ch_names = []
matched_inst_indices = set()
match_types = [] # Track match quality for logging

# Pass 1: Exact matching (case-insensitive)
for inst_idx, inst_ch_name in enumerate(inst_ch_names):
for leadfield_ch_index, leadfield_ch_name in enumerate(leadfield_ch_names):
if inst_ch_name.lower() == leadfield_ch_name.lower():
ch_indices.append(leadfield_ch_index)
ch_names.append(leadfield_ch_name)
matched_inst_indices.add(inst_idx)
match_types.append("exact")
break # Move to next inst channel after finding exact match

# Pass 2: Substring matching for unmatched channels
for inst_idx, inst_ch_name in enumerate(inst_ch_names):
if inst_idx in matched_inst_indices:
continue # Already matched exactly

inst_lower = inst_ch_name.lower()
best_match = None
best_match_length = 0

for leadfield_ch_index, leadfield_ch_name in enumerate(leadfield_ch_names):
leadfield_lower = leadfield_ch_name.lower()

# Check if leadfield name is substring of inst name
# or inst name is substring of leadfield name
if leadfield_lower in inst_lower or inst_lower in leadfield_lower:
# Prefer longer matches to avoid false positives
match_length = min(len(leadfield_lower), len(inst_lower))
if match_length > best_match_length:
best_match = leadfield_ch_index
best_match_length = match_length

if best_match is not None:
ch_indices.append(best_match)
ch_names.append(leadfield_ch_names[best_match])
matched_inst_indices.add(inst_idx)
match_types.append("substring")

# Validation and warnings
n_inst_channels = len(inst_ch_names)
n_matched = len(ch_indices)

if n_matched == 0:
raise ValueError(
f"No electrode matches found between data and leadfield "
f"template.\n"
f"Your channels: {inst_ch_names[:10]}\n"
f"Leadfield channels: {leadfield_ch_names[:10]}\n"
f"Please check that your electrode names follow standard "
f"conventions (e.g., Fp1, Fp2, F3, F4)."
)
def compute_covariance_from_channel_positions(info):
"""Compute covariance matrix from channel positions.

# Always warn if any channels didn't match
if n_matched < n_inst_channels:
import warnings

unmatched = [
inst_ch_names[i]
for i in range(n_inst_channels)
if i not in matched_inst_indices
]
n_exact = match_types.count("exact")
n_substring = match_types.count("substring")

warnings.warn(
f"Electrode matching: {n_matched}/{n_inst_channels} channels "
f"matched ({n_exact} exact, {n_substring} substring). "
f"Unmatched channels ({len(unmatched)}): "
f"{unmatched}",
UserWarning,
stacklevel=2,
)
Parameters
----------
info : instance of mne.Info
The info structure containing channel information.

refCOV = leadfield_gain_matrix[np.ix_(ch_indices, ch_indices)]
return (refCOV, ch_names)
Returns
-------
cov : instance of mne.Covariance
The computed covariance matrix.
"""
ch_positions = [info["chs"][i]["loc"][:3] for i in range(info["nchan"])]
ch_distance_matrix = sklearn.metrics.pairwise_distances(
ch_positions, metric="euclidean"
)
nonzero = ch_distance_matrix[ch_distance_matrix > 0]
ell = np.median(nonzero) if nonzero.size else 1.0
sigma2 = 1.0
eps = 1e-6

data = sigma2 * np.exp(-(ch_distance_matrix**2) / (2 * ell**2))
data += eps * np.eye(data.shape[0])

ch_names = info["ch_names"]
bads = info["bads"]
nfree = len(ch_names) # TODO: fix
cov = mne.Covariance(data, ch_names, bads, nfree=nfree, projs=[], verbose=None)
return cov
Loading
Loading