From 7b20bc3d291fe7c7f8b83b292efe11170cacd8be Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 8 Jan 2024 09:29:23 +0100 Subject: [PATCH 1/6] create subsets based on supplied features --- R/glex.R | 19 +++++++++++++------ man/glex.Rd | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/R/glex.R b/R/glex.R index 4b437d6..8a9fd61 100644 --- a/R/glex.R +++ b/R/glex.R @@ -91,7 +91,7 @@ glex.rpf <- function(object, x, max_interaction = NULL, ...) { #' glex(xg, x[27:32, ]) #' } #' } -glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { +glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, ...) { if (!requireNamespace("xgboost", quietly = TRUE)) { stop("xgboost needs to be installed: install.packages(\"xgboost\")") @@ -126,11 +126,18 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { # Convert features to numerics (leaf = 0) trees[, Feature_num := as.numeric(factor(Feature, levels = c("Leaf", colnames(x)))) - 1] - - # All subsets S (that appear in any of the trees) - all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) { - subsets(trees[Tree == tree & Feature_num > 0, sort(unique(as.integer(Feature_num)))]) - }))) + + if (is.null(features)) { + # All subsets S (that appear in any of the trees) + all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) { + subsets(trees[Tree == tree & Feature_num > 0, sort(unique(as.integer(Feature_num)))]) + }))) + } else { + # All subsets with supplied features + # TODO: Check colnames + features_num <- as.numeric(factor(features, levels = c("Leaf", colnames(x)))) - 1 + all_S <- subsets(sort(unique(as.integer(features_num)))) + } # Keep only those with not more than max_interaction involved features d <- lengths(all_S) diff --git a/man/glex.Rd b/man/glex.Rd index 2021586..82746cd 100644 --- a/man/glex.Rd +++ b/man/glex.Rd @@ -10,7 +10,7 @@ glex(object, x, max_interaction = NULL, ...) \method{glex}{rpf}(object, x, max_interaction = NULL, ...) -\method{glex}{xgb.Booster}(object, x, max_interaction = NULL, ...) +\method{glex}{xgb.Booster}(object, x, max_interaction = NULL, features = NULL, ...) } \arguments{ \item{object}{Model to be explained, either of class \code{xgb.Booster} or \code{rpf}.} From b4f6a08f18bdd324eae6ab5ba2b9c532eac3e1d4 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 11 Jan 2024 20:57:49 +0100 Subject: [PATCH 2/6] add check for feature names --- R/glex.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/glex.R b/R/glex.R index 8a9fd61..0c46f8f 100644 --- a/R/glex.R +++ b/R/glex.R @@ -134,7 +134,9 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, }))) } else { # All subsets with supplied features - # TODO: Check colnames + if (!all(features %in% colnames(x))) { + stop("All selected features have to be column names of x.") + } features_num <- as.numeric(factor(features, levels = c("Leaf", colnames(x)))) - 1 all_S <- subsets(sort(unique(as.integer(features_num)))) } From a9b402924531166dea19c0873c4cb3b656ecc7b2 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 11 Jan 2024 21:24:07 +0100 Subject: [PATCH 3/6] call contribute only for selected features --- R/glex.R | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/R/glex.R b/R/glex.R index 0c46f8f..b13f8cc 100644 --- a/R/glex.R +++ b/R/glex.R @@ -160,14 +160,12 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, # Init m matrix m_all <- matrix(0, nrow = nrow(x), ncol = length(all_S)) - #browser() colnames(m_all) <- vapply(all_S, function(s) { paste(sort(colnames(x)[s]), collapse = ":") }, FUN.VALUE = character(1)) - # Calculate contribution, use only subsets with not more than max_interaction involved features - d <- lengths(U) - for (S in U[d <= max_interaction]) { + # Calculate contribution, use only selected features and subsets with not more than max_interaction involved features + for (S in intersect(U, all_S)) { colname <- paste(sort(colnames(x)[S]), collapse = ":") if (nchar(colname) == 0) { colnum <- 1 From f9919691ffa2ef0cec3c12e9c71f68662a0f9b10 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 11 Jan 2024 21:43:21 +0100 Subject: [PATCH 4/6] fix numeric/integer comparison --- R/glex.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/glex.R b/R/glex.R index b13f8cc..8ee2a01 100644 --- a/R/glex.R +++ b/R/glex.R @@ -125,20 +125,20 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, } # Convert features to numerics (leaf = 0) - trees[, Feature_num := as.numeric(factor(Feature, levels = c("Leaf", colnames(x)))) - 1] + trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L] if (is.null(features)) { # All subsets S (that appear in any of the trees) all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) { - subsets(trees[Tree == tree & Feature_num > 0, sort(unique(as.integer(Feature_num)))]) + subsets(trees[Tree == tree & Feature_num > 0, sort(unique(Feature_num))]) }))) } else { # All subsets with supplied features if (!all(features %in% colnames(x))) { stop("All selected features have to be column names of x.") } - features_num <- as.numeric(factor(features, levels = c("Leaf", colnames(x)))) - 1 - all_S <- subsets(sort(unique(as.integer(features_num)))) + features_num <- as.integer(factor(features, levels = c("Leaf", colnames(x)))) - 1L + all_S <- subsets(sort(unique(features_num))) } # Keep only those with not more than max_interaction involved features @@ -150,7 +150,7 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, # Calculate matrix tree_info <- trees[Tree == tree, ] - T <- setdiff(tree_info[, sort(unique(Feature_num))], 0) + T <- setdiff(tree_info[, sort(unique(Feature_num))], 0L) U <- subsets(T) mat <- recurse(x, tree_info$Feature_num, tree_info$Split, tree_info$Yes, tree_info$No, tree_info$Quality, tree_info$Cover, U, 0) From 7133f0424af105cdae46ee3a26bcdb93ceec07e2 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 11 Jan 2024 22:00:21 +0100 Subject: [PATCH 5/6] also add features argument for RPF and add docs --- R/glex.R | 7 ++++--- man/glex.Rd | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/R/glex.R b/R/glex.R index 8ee2a01..44fccfc 100644 --- a/R/glex.R +++ b/R/glex.R @@ -15,6 +15,7 @@ #' Defaults to using all possible interactions available in the model.\cr #' For [`xgboost`][xgboost::xgb.train], this defaults to the `max_depth` parameter of the model fit.\cr #' If not set in `xgboost`, the default value of `6` is assumed. +#' @param features Vector of column names in x to calculate components for. If \code{NULL}, all features are used. #' @param ... Further arguments passed to methods. #' #' @return Decomposition of the regression or classification function. @@ -26,7 +27,7 @@ #' with `:` separating interaction terms as one would specify in a [`formula`] interface. #' * `intercept`: Intercept term, the expected value of the prediction. #' @export -glex <- function(object, x, max_interaction = NULL, ...) { +glex <- function(object, x, max_interaction = NULL, features = NULL, ...) { UseMethod("glex") } @@ -51,7 +52,7 @@ glex.default <- function(object, ...) { #' glex_rpf <- glex(rp, mtcars[27:32, ]) #' str(glex_rpf, list.len = 5) #' } -glex.rpf <- function(object, x, max_interaction = NULL, ...) { +glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { if (!requireNamespace("randomPlantedForest", quietly = TRUE)) { stop(paste0("randomPlantedForest needs to be installed: ", "remotes::install_github(\"PlantedML/randomPlantedForest\")")) @@ -59,7 +60,7 @@ glex.rpf <- function(object, x, max_interaction = NULL, ...) { ret <- randomPlantedForest::predict_components( object = object, new_data = x, max_interaction = max_interaction, - predictors = NULL + predictors = features ) # class(ret) <- c("glex", "rpf_components", class(ret)) ret diff --git a/man/glex.Rd b/man/glex.Rd index 82746cd..a1def79 100644 --- a/man/glex.Rd +++ b/man/glex.Rd @@ -6,9 +6,9 @@ \alias{glex.xgb.Booster} \title{Global explanations for tree-based models.} \usage{ -glex(object, x, max_interaction = NULL, ...) +glex(object, x, max_interaction = NULL, features = NULL, ...) -\method{glex}{rpf}(object, x, max_interaction = NULL, ...) +\method{glex}{rpf}(object, x, max_interaction = NULL, features = NULL, ...) \method{glex}{xgb.Booster}(object, x, max_interaction = NULL, features = NULL, ...) } @@ -23,6 +23,8 @@ Defaults to using all possible interactions available in the model.\cr For \code{\link[xgboost:xgb.train]{xgboost}}, this defaults to the \code{max_depth} parameter of the model fit.\cr If not set in \code{xgboost}, the default value of \code{6} is assumed.} +\item{features}{Vector of column names in x to calculate components for. If \code{NULL}, all features are used.} + \item{...}{Further arguments passed to methods.} } \value{ From 9eabc2c623a766e3f78de75590e753c72c979f47 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 11 Jan 2024 22:08:21 +0100 Subject: [PATCH 6/6] add tests for features argument --- tests/testthat/test-glex-xgboost.R | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/testthat/test-glex-xgboost.R b/tests/testthat/test-glex-xgboost.R index 8eb0201..0f88ff9 100644 --- a/tests/testthat/test-glex-xgboost.R +++ b/tests/testthat/test-glex-xgboost.R @@ -30,3 +30,27 @@ test_that("max_interaction respects xgb's max_depth", { expect_equal(max_degree, 7) expect_error(glex(xg, x, max_interaction = 8)) }) + +test_that("features argument only calculates for given feautures", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb <- glex(xg, x, features = c("cyl", "disp")) + expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp")) +}) + +test_that("features argument results in same values as without", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb1 <- glex(xg, x, features = c("cyl", "disp")) + glexb2 <- glex(xg, x) + cols <- c("cyl", "disp", "cyl:disp") + expect_equal(glexb1$m[, ..cols], glexb2$m[, ..cols]) +}) + +test_that("features argument works together with max_interaction", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1) + expect_equal(colnames(glexb$m), c("cyl", "disp")) +}) +