From 32ab696e32d0f867c07359d9a578ac5d91b1e17c Mon Sep 17 00:00:00 2001 From: bbuchsbaum Date: Sun, 18 May 2025 07:17:37 -0400 Subject: [PATCH] test: add reliability weight check --- tests/testthat/test-contrast_rsa_model.R | 86 ++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/testthat/test-contrast_rsa_model.R b/tests/testthat/test-contrast_rsa_model.R index 0dfb6df8..c9a776d0 100644 --- a/tests/testthat/test-contrast_rsa_model.R +++ b/tests/testthat/test-contrast_rsa_model.R @@ -750,3 +750,89 @@ test_that("contrast_rsa_model output metrics are internally consistent", { beta_delta_rel <- run_metric("beta_delta_reliable", reliability = TRUE) expect_equal(beta_delta_rel, beta_delta_vec * rho_const, tolerance = 1e-6) }) + +test_that("beta_delta_reliable uses reliability weights from fold deltas", { + dset <- mock_mvpa_dataset_train(n_samples = 8, n_cond = 4, n_blocks = 2, n_voxels = 1) + colnames(dset$train_data) <- "V1" + mvpa_des <- dset$design + + C_custom <- matrix(c( + 1, 0, + 0, 1, + 0, 0, + 0, 0 + ), nrow = 4, ncol = 2, byrow = TRUE, + dimnames = list(levels(mvpa_des$Y), c("Con1", "Con2"))) + + ms_des <- msreve_design(mvpa_des, C_custom) + + model_spec <- contrast_rsa_model( + dataset = dset, + design = ms_des, + regression_type = "pearson", + output_metric = c("beta_delta_reliable"), + calc_reliability = TRUE, + check_collinearity = FALSE + ) + + cv_spec <- mock_cv_spec_s3(mvpa_des) + sl_data <- dset$train_data + sl_info <- list(center_local_id = 1, center_global_id = 1, radius = 0, n_voxels = 1) + + fold_estimates_mock <- array(0, dim = c(4, 1, 2), + dimnames = list(levels(mvpa_des$Y), "V1", c("Fold1", "Fold2"))) + fold_estimates_mock[, 1, 1] <- c(1, 0, 0, 0) + fold_estimates_mock[, 1, 2] <- c(2, 0, 0, 0) + mean_estimate_mock <- apply(fold_estimates_mock, c(1, 2), mean) + + Delta1 <- t(fold_estimates_mock[, , 1]) %*% C_custom + Delta2 <- t(fold_estimates_mock[, , 2]) %*% C_custom + deltas <- rbind(Delta1[1, ], Delta2[1, ]) + mean_delta <- c(0, 0) + M2_delta <- c(0, 0) + valid_folds <- 0 + for (i in seq_len(nrow(deltas))) { + delta_fold <- deltas[i, ] + if (anyNA(delta_fold)) next + valid_folds <- valid_folds + 1 + delta_diff <- delta_fold - mean_delta + mean_delta <- mean_delta + delta_diff / valid_folds + M2_delta <- M2_delta + delta_diff * (delta_fold - mean_delta) + } + rho_expected <- rep(1, ncol(C_custom)) + if (valid_folds > 1) { + var_delta <- M2_delta / (valid_folds - 1) + sigma2_noise_param <- (valid_folds - 1) * var_delta + denom <- var_delta + sigma2_noise_param + rho_expected <- ifelse(denom < 1e-10, 1, sigma2_noise_param / denom) + rho_expected[is.na(rho_expected)] <- 0 + } else if (valid_folds == 1) { + rho_expected[M2_delta == 0] <- 1 + rho_expected[M2_delta != 0] <- 0 + } else { + rho_expected <- rep(0, ncol(C_custom)) + } + + beta_mock <- c(2, 1) + + result <- with_mocked_bindings( + compute_crossvalidated_means_sl = function(...) { + list(mean_estimate = mean_estimate_mock, fold_estimates = fold_estimates_mock) + }, + run_cor = function(dvec, obj) { + setNames(beta_mock, colnames(obj$design$model_mat)) + }, + .package = "rMVPA", + { + train_model.contrast_rsa_model(model_spec, sl_data, sl_info, cv_spec) + } + ) + + Delta_sl <- t(mean_estimate_mock) %*% C_custom + delta_vc_sl <- Delta_sl[1, ] + beta_delta_expected <- beta_mock * delta_vc_sl + + expect_equal(result$beta_delta_reliable, + beta_delta_expected * rho_expected, + tolerance = 1e-6) +})