Skip to content
Open

Dev #13

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions results/plot_movie_concepts.html
Git LFS file not shown
Git LFS file not shown
7 changes: 4 additions & 3 deletions src/brain_decoding/config/save_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@

config.data.result_path = str(RESULT_PATH)
config.data.spike_path = str(DATA_PATH)
config.data.lfp_path = "undefined"
config.data.lfp_path = None
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
config.data.spike_data_mode = "notch CAR-quant-neg"
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
config.data.spike_data_mode = "notch"
config.data.spike_data_mode_inference = "notch"
config.data.spike_data_sd = [3.5]
config.data.spike_data_sd_inference = 3.5
config.data.model_aggregate_type = "sum"
Expand All @@ -68,4 +68,5 @@
config.data.movie_sampling_rate = 30
config.data.filter_low_occurrence_samples = True

# TO DO: fix pydantic export json error.
# config.export_config(CONFIG_FILE_PATH)
2 changes: 1 addition & 1 deletion src/brain_decoding/hp_search2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def pipeline(config):
patient_list = ["566"]
sd_list = [[4, 3.5]]
# data_list = ['notch CAR4.5', 'notch CAR3.5', 'notch CAR4.5', 'notch CAR4', 'notch CAR3.5', 'notch CAR3.5']
data_list = ["notch CAR-quant-neg"]
data_list = ["notch"]
for patient, sd, dd in zip(patient_list, sd_list, data_list):
for data_type in ["clusterless"]:
root_path = os.path.dirname(os.path.abspath(__file__))
Expand Down
6 changes: 2 additions & 4 deletions src/brain_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from brain_decoding.config.save_config import config
from brain_decoding.param.base_param import device
from brain_decoding.trainer import Trainer
from brain_decoding.utils.analysis import concept_frequency
from brain_decoding.utils.initializer import initialize_dataloaders, initialize_evaluator, initialize_model

# torch.autograd.set_detect_anomaly(True)
Expand All @@ -29,6 +28,7 @@
def set_config(
config_file: Union[str, Path, PipelineConfig],
patient_id: int,
experiment_name: str,
train_phases: Union[List[str], str],
test_phases: Union[List[str], str],
spike_data_sd: Union[List[float], float, None] = None,
Expand All @@ -54,9 +54,7 @@ def set_config(
config = PipelineConfig.read_config(config_file)

config.experiment["patient"] = patient_id
# config.experiment.name = "8concepts"
config.experiment.name = "twilight_merged"
# config.experiment.name = "twilight_vs_24"
config.experiment.name = experiment_name

config.experiment.train_phases = train_phases
config.experiment.ensure_list("train_phases")
Expand Down
12 changes: 6 additions & 6 deletions src/brain_decoding/main_augment2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def pipeline(config):
sd_list = [[3.5, 4], [3.5, 4], [3.5, 4], [3.5, 4]] # , [4, 3, 5], [4, 3, 5]]
# data_list = ['notch CAR4.5', 'notch CAR3.5', 'notch CAR4.5', 'notch CAR4', 'notch CAR3.5', 'notch CAR3.5']
data_list = [
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch",
"notch",
"notch",
"notch",
"notch",
"notch",
]
early_stop = [100, 100, 100, 50, 50, 75]
for patient, sd, dd in zip(patient_list, sd_list, data_list):
Expand Down
2 changes: 1 addition & 1 deletion src/brain_decoding/param/param_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"i728": 500,
}
# LABELS = ['LosAngeles', 'BombAttacks', 'Whitehouse', 'CIA/FBI', 'Hostage', 'Handcuff', 'Jack', 'Chloe', 'Bill', 'A. Fayed', 'A. Amar', 'President']
LABELS = [
MOVIE24_LABELS = [
"WhiteHouse",
"CIA",
"Hostage",
Expand Down
12 changes: 6 additions & 6 deletions src/brain_decoding/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,12 +1141,12 @@ def check_avg_score(config, phase="recall1"):
sd_list = [3.5, 3.5, 3.5, 3.5]
# data_list = ['notch CAR4.5', 'notch CAR3.5', 'notch CAR4.5', 'notch CAR4', 'notch CAR3.5', 'notch CAR3.5']
data_list = [
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch",
"notch",
"notch",
"notch",
"notch",
"notch",
]
early_stop = [100, 100, 100, 50, 50, 75]
for p, sd, dd in zip(patient_list, sd_list, data_list):
Expand Down
65 changes: 43 additions & 22 deletions src/brain_decoding/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -163,7 +163,7 @@ def prediction_curve(
axes[i].set_ylim([y_min, y_max])
axes[i].set_title(labels[i], fontsize=14)

if sleep_score:
if sleep_score is not None:
# Assign a unique color for each unique sleep stage
unique_stages = sleep_score["Score"].unique()
stage_colors = sb.color_palette("Set2", len(unique_stages))
Expand Down Expand Up @@ -196,18 +196,23 @@ def prediction_curve(


def stage_box_plot(
predictions: np.ndarray, sleep_score: pd.DataFrame, labels: List[str], save_figure_name: str
predictions: np.ndarray,
sleep_score: pd.DataFrame,
labels: List[str],
save_figure_name: str,
prediction_thresh: Optional[float],
) -> None:
"""
Plot violin plots with swarms overlaid for each sleep stage, with a separate subplot for each label.
Plot box plots with swarms overlaid for each sleep stage, with a separate subplot for each label.
Limit the number of swarm points per stage for performance improvement and add stage length to the label.

Parameters:
- predictions (np.ndarray): n by m array of predictions.
- sleep_score (pd.DataFrame): n by 2 DataFrame with sleep stage (column 0) and start index (column 1).
- labels (List[str]): List of labels for each prediction column.
- save_figure_name (str): The file path to save the plot.
- sampling_rate (int): The sampling rate of the data (default is 4 Hz).
- prediction_thresh Optional[float]: select concepts (column of predictions) with mean prediction larger than the
threshold.

Returns:
- None: The function saves the figure with subplots to the specified output file.
Expand All @@ -217,6 +222,10 @@ def stage_box_plot(
warnings.warn("sleep_score is None!")
return

if prediction_thresh:
predictions = predictions[:, np.nanmean(predictions, axis=0) > prediction_thresh]
labels = labels[np.nanmean(predictions, axis=0) > prediction_thresh]

n_samples, n_labels = predictions.shape

# Create subplots for each label (column of predictions)
Expand Down Expand Up @@ -387,13 +396,13 @@ def correlation_heatmap(data: np.ndarray, column_labels: List[str], save_figure_
def correlation_heatmap_by_stage(
predictions: np.ndarray[float], labels: List[str], sleep_score: pd.DataFrame, result_path: str
) -> None:
for i, (stage_label, start_index, end_index) in enumerate(
for i, (i_stage, stage_label, start_index, end_index) in enumerate(
sleep_stage_iterator(sleep_score, predictions.shape[0], SLEEP_STAGE_THRESH)
):
predictions_stage = predictions[start_index:end_index, :]
stage_label = stage_label.replace("/", "")
file_extension = os.path.splitext(result_path)[1]
figure_name = result_path.replace(file_extension, f"_{i}_{stage_label}{file_extension}")
figure_name = result_path.replace(file_extension, f"_{i_stage}_{stage_label}{file_extension}")
correlation_heatmap(predictions_stage, labels, figure_name)


Expand All @@ -416,7 +425,7 @@ def multi_facet_correlation_heatmap(
warnings.warn("sleep score is None")
return

sleep_stages, correlation_matrix = get_correlation_matrix_by_stage(predictions, sleep_score)
sleep_stages, correlation_matrix, sleep_stages_index = get_correlation_matrix_by_stage(predictions, sleep_score)

num_label, _, num_stage = correlation_matrix.shape
if len(sleep_stages) != num_stage:
Expand All @@ -430,7 +439,8 @@ def multi_facet_correlation_heatmap(

# Organize plots in a grid with 5 heatmaps per row with additional histgram
num_rows = num_stage // 5 + 1 # Calculate the number of rows needed (5 per row)
fig, axes = plt.subplots(num_rows, 5, figsize=(25, 5 * num_rows), constrained_layout=True)
# fig, axes = plt.subplots(num_rows, 5, figsize=(25, 5 * num_rows), constrained_layout=True)
fig, axes = plt.subplots(num_rows, 5, figsize=(25, 5 * num_rows))

# Flatten axes for easier iteration if there are multiple rows
if num_rows > 1:
Expand Down Expand Up @@ -458,6 +468,7 @@ def multi_facet_correlation_heatmap(
xticklabels=labels,
yticklabels=labels,
ax=ax,
square=True,
)

# Annotate only the lower triangle
Expand All @@ -472,7 +483,11 @@ def multi_facet_correlation_heatmap(
color="black" if correlation_matrix[i, j, stage] < 0 else "white",
)

ax.set_title(f"Stage {stage + 1} - {sleep_stages[stage]}")
ax.set_title(
f"Stage {sleep_stages_index[stage]}: {sleep_stages[stage]}", fontdict={"fontsize": 13, "weight": 500}
)
ax.tick_params(axis="x", rotation=45, labelsize=12)
ax.tick_params(axis="y", rotation=45, labelsize=12)

# Collect correlation values by unique label
for i in range(num_label):
Expand All @@ -491,24 +506,28 @@ def multi_facet_correlation_heatmap(
edgecolor=None,
)

ax_hist.legend(title="")
ax_hist.set_title(f"Histogram of Correlations")
ax_hist.set_aspect(0.35)
ax_hist.legend(title="", frameon=False, handletextpad=-6, handlelength=6.5, loc="best", fontsize=10)
ax_hist.set_title(f"Histogram of Correlations", fontdict={"fontsize": 13, "weight": 500})
ax_hist.set_xlabel("Correlation Value")
ax_hist.set_ylabel("Density")

# Remove any empty subplots
for idx in range(num_stage + 1, num_rows * 5):
fig.delaxes(axes[idx])

plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.tight_layout()
plt.savefig(result_path, bbox_inches="tight")
# plt.savefig(result_path, bbox_inches="tight")
plt.savefig(result_path)
plt.show()


def get_correlation_matrix_by_stage(
predictions: np.ndarray[float],
sleep_score: pd.DataFrame,
) -> Tuple[List[str], np.ndarray[float, Any]]:
) -> Tuple[List[str], np.ndarray[float, Any], List[int]]:
"""
Calculate correlation matrices for different sleep stages.

Expand All @@ -517,21 +536,23 @@ def get_correlation_matrix_by_stage(
sleep_score (pd.DataFrame): DataFrame containing sleep stage labels and start/end indices.

Returns:
Tuple[List[str], np.ndarray]: A list of stage labels and a 3D numpy array of correlation matrices.
Tuple[List[str], np.ndarray, List[int]]: A list of stage labels and a 3D numpy array of correlation matrices.
"""

correlation_matrices = []
sleep_stages = []
for i, (stage_label, start_index, end_index) in enumerate(
sleep_stages_index = []
for i, (i_stage, stage_label, start_index, end_index) in enumerate(
sleep_stage_iterator(sleep_score, predictions.shape[0], SLEEP_STAGE_THRESH)
):
predictions_stage = predictions[start_index:end_index, :]
sleep_stages.append(stage_label.replace("/", "-"))
sleep_stages_index.append(i_stage)
corr_matrix = np.corrcoef(predictions_stage, rowvar=False)
correlation_matrices.append(corr_matrix[:, :, np.newaxis])

correlation_matrix = np.concatenate(correlation_matrices, axis=2)
return sleep_stages, correlation_matrix
return sleep_stages, correlation_matrix, sleep_stages_index


def prediction_heatmap(predictions: np.ndarray[float], events: Events, title: str, file_path: str):
Expand Down Expand Up @@ -587,6 +608,7 @@ def smooth_columns(data: np.ndarray[float], window_size: int = 5) -> np.ndarray[
def combine_continuous_scores(df: pd.DataFrame) -> pd.DataFrame:
"""
Combine rows with continuous same values in the 'Score' column and keep the first value in the 'start_index' column.
This is used to combine continuous sleep stages in sleep score.

Parameters:
- df (pd.DataFrame): A DataFrame with 'Score' and 'start_index' columns.
Expand Down Expand Up @@ -617,7 +639,7 @@ def sleep_stage_iterator(sleep_score: pd.DataFrame, last_index: int, duration_th
next_start_index = last_index

if next_start_index - start_index > duration_thresh * PREDICTION_FS:
yield stage_label, start_index, next_start_index
yield i, stage_label, start_index, next_start_index


def prediction_iterator(
Expand All @@ -640,21 +662,20 @@ def prediction_iterator(
Iterator[Tuple[str, np.ndarray]]: An iterator that yields a tuple with a label and the corresponding
slice of the prediction array.
"""
for i, (label, start_index, end_index) in enumerate(
for i, (i_stage, label, start_index, end_index) in enumerate(
sleep_stage_iterator(sleep_score, len(prediction), length_thresh)
):
stage_data = prediction[start_index:end_index]
stage_data = stage_data[stage_data > value_thresh] # Filter values greater than 0.5
# Calculate stage length (duration in seconds)
stage_length = (end_index - start_index) / PREDICTION_FS
stage_label = f"Stage: {i} ({stage_length:.1f} sec)"
stage_label = f"Stage: {i_stage} ({stage_length:.1f} sec)"

# Overwrite combined_df each time to save memory
yield {
"Stage": [stage_label] * len(stage_data),
"Value(>.5)": stage_data,
"Label": [label] * len(stage_data),
"Stage Label": [sleep_score.iloc[i]["Score"]] * len(stage_data),
"Stage Label": [sleep_score.iloc[i_stage]["Score"]] * len(stage_data),
}


Expand Down
4 changes: 3 additions & 1 deletion src/brain_decoding/utils/check_free_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from scipy.stats import f, gmean, mannwhitneyu, multivariate_normal, ttest_1samp, ttest_ind, ttest_rel, wilcoxon
from sklearn.mixture import GaussianMixture

from brain_decoding.param.param_data import LABELS
from brain_decoding.param.param_data import MOVIE24_LABELS

LABELS = MOVIE24_LABELS


def hl_envelopes_idx(s, dmin=1, dmax=1, split=False):
Expand Down
4 changes: 3 additions & 1 deletion src/brain_decoding/utils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brain_decoding.config.config import PipelineConfig
from brain_decoding.config.file_path import PATIENTS_FILE_PATH, SURROGATE_FILE_PATH
from brain_decoding.dataloader.patients import Experiment, load_patients
from brain_decoding.param.param_data import LABELS
from brain_decoding.param.param_data import MOVIE24_LABELS
from brain_decoding.utils.check_free_recall import (
find_area_above_threshold_yyding,
find_target_activation_indices,
Expand All @@ -23,6 +23,8 @@
ttest_rel,
)

LABELS = MOVIE24_LABELS


class Permutate:
def __init__(self, config: PipelineConfig, phase: Union[str, List[str]], epoch, phase_length=Dict[str, float]):
Expand Down
12 changes: 6 additions & 6 deletions src/scripts/pipeline_decode_sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
sd_list = [4, 4, 3.5, 4, 4, 3.5]
# data_list = ['notch CAR4.5', 'notch CAR3.5', 'notch CAR4.5', 'notch CAR4', 'notch CAR3.5', 'notch CAR3.5']
data_list = [
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch CAR-quant-neg",
"notch",
"notch",
"notch",
"notch",
"notch",
"notch",
]
early_stop = [100, 100, 100, 50, 50, 75]

Expand Down
Loading