diff --git a/R/glex.R b/R/glex.R index 734ecbc..f9bc961 100644 --- a/R/glex.R +++ b/R/glex.R @@ -64,6 +64,17 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { predictors = features ) # class(ret) <- c("glex", "rpf_components", class(ret)) + ret$m <- cbind(ret$intercept, ret$m) + colnames(ret$m)[1] <- "intercept" + if ( + object$mode == "regression" && + is.numeric(max_interaction) && + max_interaction < object$params$max_interaction + ) { + preds <- predict(object, x) + ret$m[, "rest"] <- preds$.pred - rowSums(ret$m) + } + ret } @@ -97,7 +108,9 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, if (!requireNamespace("xgboost", quietly = TRUE)) { stop("xgboost needs to be installed: install.packages(\"xgboost\")") } - + if (!is.matrix(x)) { + x <- as.matrix(x) + } # If max_interaction is not specified, we set it to the max_depth param of the xgb model. # If max_depth is not defined in xgb, we assume its default of 6. xgb_max_depth <- ifelse(is.null(object$params$max_depth), 6L, object$params$max_depth) @@ -115,7 +128,17 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, # Calculate components res <- calc_components(trees, x, max_interaction, features, probFunction) res$intercept <- res$intercept + 0.5 - + res$m[, "intercept"] <- res$intercept + + if ( + "objective" %in% names(object$params) && + startsWith(object$params$objective, "reg:") && + is.numeric(max_interaction) && + max_interaction < xgb_max_depth + ) { + preds <- predict(object, x) + res$m[, "rest"] <- preds - rowSums(res$m) + } # Return components res } @@ -163,7 +186,9 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob if (is.null(object$forest$num.samples.nodes)) { stop("ranger needs to be called with node.stats=TRUE for glex.") } - + if (!is.matrix(x)) { + x <- as.matrix(x) + } # If max_interaction is not specified, we set it to the max.depth param of the ranger model. # If max.depth is not defined in ranger, we assume 6 as in xgboost. rf_max_depth <- ifelse((is.null(object$max.depth) || object$max.depth == 0), 6L, object$max.depth) @@ -194,6 +219,14 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob res$m <- res$m / object$num.trees res$intercept <- res$intercept / object$num.trees + if ( + object$treetype == "Regression" && + is.numeric(max_interaction) && + max_interaction < rf_max_depth + ) { + preds <- predict(object, x) + res$m[, "rest"] <- preds$predictions - rowSums(res$m) + } # Return components res } @@ -403,7 +436,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction = } else { m_all <- foreach(j = idx, .combine = "+") %do% tree_fun(j) } - + colnames(m_all)[1] <- "intercept" d <- get_degree(colnames(m_all)) # Overall feature effect is sum of all elements where feature is involved @@ -423,7 +456,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction = # Return shap values, decomposition and intercept ret <- list( shap = data.table::setDT(as.data.frame(shap)), - m = data.table::setDT(as.data.frame(m_all[, -1])), + m = data.table::setDT(as.data.frame(m_all)), intercept = unique(m_all[, 1]), x = data.table::setDT(as.data.frame(x)) ) diff --git a/R/glex_vi.R b/R/glex_vi.R index aabfc83..53b87f5 100644 --- a/R/glex_vi.R +++ b/R/glex_vi.R @@ -28,38 +28,40 @@ #' set.seed(1) #' # Random Planted Forest ----- #' if (requireNamespace("randomPlantedForest", quietly = TRUE)) { -#' library(randomPlantedForest) +#' library(randomPlantedForest) #' -#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) +#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) #' -#' glex_rpf <- glex(rp, mtcars[27:32, ]) +#' glex_rpf <- glex(rp, mtcars[27:32, ]) #' -#' # All terms -#' vi_rpf <- glex_vi(glex_rpf) +#' # All terms +#' vi_rpf <- glex_vi(glex_rpf) #' -#' library(ggplot2) -#' # Filter to contributions greater 0.05 on the scale of the target -#' autoplot(vi_rpf, threshold = 0.05) -#' # Summarize by degree of interaction -#' autoplot(vi_rpf, by_degree = TRUE) -#' # Filter by relative contributions greater 0.1% -#' autoplot(vi_rpf, scale = "relative", threshold = 0.001) +#' library(ggplot2) +#' # Filter to contributions greater 0.05 on the scale of the target +#' autoplot(vi_rpf, threshold = 0.05) +#' # Summarize by degree of interaction +#' autoplot(vi_rpf, by_degree = TRUE) +#' # Filter by relative contributions greater 0.1% +#' autoplot(vi_rpf, scale = "relative", threshold = 0.001) #' } #' #' # xgboost ----- #' if (requireNamespace("xgboost", quietly = TRUE)) { -#' library(xgboost) -#' x <- as.matrix(mtcars[, -1]) -#' y <- mtcars$mpg -#' xg <- xgboost(data = x[1:26, ], label = y[1:26], -#' params = list(max_depth = 4, eta = .1), -#' nrounds = 10, verbose = 0) -#' glex_xgb <- glex(xg, x[27:32, ]) -#' vi_xgb <- glex_vi(glex_xgb) +#' library(xgboost) +#' x <- as.matrix(mtcars[, -1]) +#' y <- mtcars$mpg +#' xg <- xgboost( +#' data = x[1:26, ], label = y[1:26], +#' params = list(max_depth = 4, eta = .1), +#' nrounds = 10, verbose = 0 +#' ) +#' glex_xgb <- glex(xg, x[27:32, ]) +#' vi_xgb <- glex_vi(glex_xgb) #' -#' library(ggplot2) -#' autoplot(vi_xgb) -#' autoplot(vi_xgb, by_degree = TRUE) +#' library(ggplot2) +#' autoplot(vi_xgb) +#' autoplot(vi_xgb, by_degree = TRUE) #' } glex_vi <- function(object, ...) { checkmate::assert_class(object, classes = "glex") @@ -67,7 +69,7 @@ glex_vi <- function(object, ...) { # FIXME: data.table NSE warnings term <- degree <- m <- m_rel <- NULL - m_long <- melt_m(object$m, object$target_levels) + m_long <- melt_m(object$m[, -c("intercept")], object$target_levels) if (!is.null(object$target_levels)) { vars_by <- c("term", "class") diff --git a/R/print.R b/R/print.R index fc95907..2a336cc 100644 --- a/R/print.R +++ b/R/print.R @@ -11,21 +11,25 @@ #' @examples #' # Random Planted Forest ----- #' if (requireNamespace("randomPlantedForest", quietly = TRUE)) { -#' library(randomPlantedForest) -#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) +#' library(randomPlantedForest) +#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) #' -#' glex(rp, mtcars[27:32, ]) +#' glex(rp, mtcars[27:32, ]) #' } print.glex <- function(x, ...) { - n <- nrow(x$x) - n_m <- ncol(x$m) + n_m <- ncol(x$m) - 1 max_deg <- max(get_degree(names(x$m))) - max_deg_lab <- switch(as.character(max_deg), "1" = "degree", "degrees") + max_deg_lab <- switch(as.character(max_deg), + "1" = "degree", + "degrees" + ) cat("glex object of subclass", class(x)[[2]], "\n") - cat("Explaining predictions of", n, "observations with", n_m, - "terms of up to", max_deg, max_deg_lab) + cat( + "Explaining predictions of", n, "observations with", n_m, + "terms of up to", max_deg, max_deg_lab + ) cat("\n\n") str(x, list.len = 5) } diff --git a/R/utils-components-reshaping.R b/R/utils-components-reshaping.R index 8d52fef..8995ac1 100644 --- a/R/utils-components-reshaping.R +++ b/R/utils-components-reshaping.R @@ -27,8 +27,10 @@ melt_m <- function(m, levels = NULL) { term <- NULL # We need an id.var for melt to avoid a warning, but don't want to modify m permanently m[, ".id" := .I] - m_long <- data.table::melt(m, id.vars = ".id", value.name = "m", - variable.name = "term", variable.factor = FALSE) + m_long <- data.table::melt(m, + id.vars = ".id", value.name = "m", + variable.name = "term", variable.factor = FALSE + ) # clean up that temporary id column again while modifying by reference m[, ".id" := NULL] @@ -45,7 +47,7 @@ reshape_m_multiclass <- function(object) { checkmate::assert_character(object$target_levels, min.len = 2) mlong <- melt_m(object$m, object$target_levels) - data.table::dcast(mlong, .id + class ~ term, value.var = "m") + data.table::dcast(mlong, .id + class ~ term, value.var = "m") } #' Helper to split multiclass terms of format __class: @@ -62,4 +64,3 @@ split_names <- function(mn, split_string = "__class:", target_index = 2) { unlist(strsplit(x, split = split_string, fixed = TRUE))[target_index] }, character(1), USE.NAMES = FALSE) } - diff --git a/README.Rmd b/README.Rmd index ee044a9..7fd28c0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -85,10 +85,11 @@ rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg -xg <- xgboost(data = x[1:26, ], label = y[1:26], - params = list(max_depth = 3, eta = .1), - nrounds = 30, verbose = 0) - +xg <- xgboost( + data = x[1:26, ], label = y[1:26], + params = list(max_depth = 3, eta = .1), + nrounds = 30, verbose = 0 +) ``` Using the model objects and a dataset to explain (such as a test set in this case), we can create `glex` objects @@ -106,8 +107,8 @@ Both `m` and `shap` satisfy the property that their sums (per observation) toget ```{r comp-sum} # Calculating sum of components and sum of SHAP values -sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept -sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept +sum_m_rpf <- rowSums(glex_rpf$m) +sum_m_xgb <- rowSums(glex_xgb$m) sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept # Model predictions @@ -137,10 +138,10 @@ vi_xgb[1:5, c("degree", "term", "m")] The output additionally contains the degree of interaction, which can be used for filtering and aggregating. Here we filter for terms with contributions above a `threshold` of `0.05` to get a more compact plot, with terms below the threshold aggregated into one labelled "Remaining terms": ```{r glex_vi-plot, fig.width=12} -p_vi1 <- autoplot(vi_rpf, threshold = .05) + +p_vi1 <- autoplot(vi_rpf, threshold = .05) + labs(title = NULL, subtitle = "RPF") -p_vi2 <- autoplot(vi_xgb, threshold = .05) + +p_vi2 <- autoplot(vi_xgb, threshold = .05) + labs(title = NULL, subtitle = "XGBoost") p_vi1 + p_vi2 + @@ -150,14 +151,14 @@ p_vi1 + p_vi2 + We can also sum values within each degree of interaction for a more aggregated view, which can be useful as it allows us to judge interactions above a certain degree to not be particularly relevant for a given model. ```{r glex_vi-plot-by-degree, fig.height=5} -p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) + +p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) + labs(title = NULL, subtitle = "RPF") -p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) + +p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) + labs(title = NULL, subtitle = "XGBoost") p_vi1 + p_vi2 + - plot_annotation(title = "Variable importance scores by degree") + plot_annotation(title = "Variable importance scores by degree") ``` ### Feature Effects @@ -168,13 +169,13 @@ We can also plot prediction components against observed feature values, which ad p1 <- autoplot(glex_rpf, "hp") + labs(subtitle = "RPF") p2 <- autoplot(glex_xgb, "hp") + labs(subtitle = "XGBoost") -p1 + p2 + +p1 + p2 + plot_annotation(title = "Main effect for 'hp'") p1 <- autoplot(glex_rpf, c("hp", "wt")) + labs(subtitle = "RPF") p2 <- autoplot(glex_xgb, c("hp", "wt")) + labs(subtitle = "XGBoost") -p1 + p2 + +p1 + p2 + plot_annotation(title = "Two-way effects for 'hp' and 'wt'") ``` @@ -192,9 +193,9 @@ Finally, we can explore the prediction for a single observation by displaying it For compactness, we only plot one feature and collapse all interaction terms above the second degree into one as their combined effect is very small. ```{r glex_explain} -p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) + +p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) + labs(tag = "RPF") -p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) + +p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) + labs(tag = "XGBoost") p1 + p2 & theme(plot.tag.position = "bottom") diff --git a/README.md b/README.md index 1a39b5d..688a80d 100644 --- a/README.md +++ b/README.md @@ -105,8 +105,8 @@ prediction for each observation: ``` r # Calculating sum of components and sum of SHAP values -sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept -sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept +sum_m_rpf <- rowSums(glex_rpf$m) +sum_m_xgb <- rowSums(glex_xgb$m) sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept # Model predictions diff --git a/tests/testthat/_snaps/print-glex.md b/tests/testthat/_snaps/print-glex.md index 1f00a77..65e1a22 100644 --- a/tests/testthat/_snaps/print-glex.md +++ b/tests/testthat/_snaps/print-glex.md @@ -3,16 +3,17 @@ Code out Output - [1] "glex object of subclass rpf_components " - [2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree" - [3] "" - [4] "List of 3" - [5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" - [6] " ..$ cyl: num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..." - [7] " ..- attr(*, \".internal.selfref\")= " - [8] " $ intercept: num 20.2" - [9] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" - [10] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..." - [11] " ..- attr(*, \".internal.selfref\")= " - [12] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\"" + [1] "glex object of subclass rpf_components " + [2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree" + [3] "" + [4] "List of 3" + [5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 2 variables:" + [6] " ..$ intercept: num [1:32] 20.2 20.2 20.2 20.2 20.2 ..." + [7] " ..$ cyl : num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..." + [8] " ..- attr(*, \".internal.selfref\")= " + [9] " $ intercept: num 20.2" + [10] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" + [11] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..." + [12] " ..- attr(*, \".internal.selfref\")= " + [13] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\"" diff --git a/tests/testthat/test-fastpd-equals-empirical.R b/tests/testthat/test-fastpd-equals-empirical.R index d708130..7adff17 100644 --- a/tests/testthat/test-fastpd-equals-empirical.R +++ b/tests/testthat/test-fastpd-equals-empirical.R @@ -15,7 +15,6 @@ test_that("FastPD equals empirical leaf weighting", { empirical_leaf_weighting <- glex(rf, x, probFunction = "empirical") expect_equal(fastpd$m, empirical_leaf_weighting$m) - expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept }) test_that("FastPD equals empirical leaf weighting for lower interactions", { diff --git a/tests/testthat/test-glex-xgboost.R b/tests/testthat/test-glex-xgboost.R index 1392e5e..3c7307a 100644 --- a/tests/testthat/test-glex-xgboost.R +++ b/tests/testthat/test-glex-xgboost.R @@ -34,7 +34,7 @@ test_that("features argument only calculates for given feautures", { x <- as.matrix(mtcars[, -1]) xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) glexb <- glex(xg, x, features = c("cyl", "disp")) - expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp")) + expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp", "cyl:disp")) }) test_that("features argument results in same values as without", { @@ -50,7 +50,7 @@ test_that("features argument works together with max_interaction", { x <- as.matrix(mtcars[, -1]) xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1) - expect_equal(colnames(glexb$m), c("cyl", "disp")) + expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp")) }) test_that("Prediction is approx. same as sum of decomposition + intercept, xgboost", { @@ -59,7 +59,7 @@ test_that("Prediction is approx. same as sum of decomposition + intercept, xgboo pred_train <- predict(xg, x) res_train <- glex(xg, x) - expect_equal(res_train$intercept + rowSums(res_train$m), + expect_equal(rowSums(res_train$m), pred_train, tolerance = 1e-5 ) diff --git a/tests/testthat/test-mtcars-ranger.R b/tests/testthat/test-mtcars-ranger.R index b6dc3d4..7bc7536 100644 --- a/tests/testthat/test-mtcars-ranger.R +++ b/tests/testthat/test-mtcars-ranger.R @@ -5,9 +5,11 @@ x_test <- as.matrix(mtcars[27:32, -1]) y_train <- mtcars$mpg[1:26] y_test <- mtcars$mpg[27:32] -# xgboost -rf <- ranger(x = x_train, y = y_train, node.stats = TRUE, - num.trees = 5, max.depth = 4) +# ranger +rf <- ranger( + x = x_train, y = y_train, node.stats = TRUE, + num.trees = 5, max.depth = 4 +) pred_train <- predict(rf, x_train)$predictions pred_test <- predict(rf, x_test)$predictions @@ -17,24 +19,28 @@ res_test <- glex(rf, x_test, max_interaction = 4) test_that("Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) diff --git a/tests/testthat/test-mtcars.R b/tests/testthat/test-mtcars.R index 948d48d..d0c3db2 100644 --- a/tests/testthat/test-mtcars.R +++ b/tests/testthat/test-mtcars.R @@ -14,24 +14,28 @@ res_test <- glex(xg, x_test) test_that("Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) diff --git a/tests/testthat/test-sim.R b/tests/testthat/test-sim.R index de110fd..a7ee4c7 100644 --- a/tests/testthat/test-sim.R +++ b/tests/testthat/test-sim.R @@ -3,9 +3,11 @@ n <- 100 p <- 2 beta <- c(1, 1) beta0 <- 0 -x <- matrix(rnorm(n = n * p), ncol = p, - dimnames = list(NULL, paste0('x', seq_len(p)))) -lp <- x %*% beta + beta0 + 2*x[, 1] * x[, 2] +x <- matrix(rnorm(n = n * p), + ncol = p, + dimnames = list(NULL, paste0("x", seq_len(p))) +) +lp <- x %*% beta + beta0 + 2 * x[, 1] * x[, 2] y <- lp + rnorm(n) x_train <- x[1:50, ] @@ -37,50 +39,58 @@ resbin_test <- glex(xgbin, x_test) test_that("regr: Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(resbin_train$intercept + rowSums(resbin_train$shap), - predbin_train, - tolerance = 1e-5) + predbin_train, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(resbin_test$intercept + rowSums(resbin_test$shap), - predbin_test, - tolerance = 1e-5) + predbin_test, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("classif: Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(resbin_train$intercept + rowSums(resbin_train$m), - predbin_train, - tolerance = 1e-5) + expect_equal(rowSums(resbin_train$m), + predbin_train, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(resbin_test$intercept + rowSums(resbin_test$m), - predbin_test, - tolerance = 1e-5) + expect_equal(rowSums(resbin_test$m), + predbin_test, + tolerance = 1e-5 + ) }) test_that("Shap is computed correctly with overlapping colnames", { @@ -88,10 +98,12 @@ test_that("Shap is computed correctly with overlapping colnames", { p <- 2 beta <- c(1, 1) beta0 <- 0 - x <- matrix(rnorm(n = n * p), ncol = p, - dimnames = list(NULL, paste0('x', seq_len(p)))) + x <- matrix(rnorm(n = n * p), + ncol = p, + dimnames = list(NULL, paste0("x", seq_len(p))) + ) - lp <- x %*% beta + beta0 + 2*x[, 1] * x[, 2] + lp <- x %*% beta + beta0 + 2 * x[, 1] * x[, 2] y <- lp + rnorm(n) # xgboost @@ -113,5 +125,4 @@ test_that("Shap is computed correctly with overlapping colnames", { # We only check values for equality, colnames will of course differ expect_equal(unname(overlapping_names), unname(unique_names)) - })