diff --git a/R/glex.R b/R/glex.R index 4b437d6..44fccfc 100644 --- a/R/glex.R +++ b/R/glex.R @@ -15,6 +15,7 @@ #' Defaults to using all possible interactions available in the model.\cr #' For [`xgboost`][xgboost::xgb.train], this defaults to the `max_depth` parameter of the model fit.\cr #' If not set in `xgboost`, the default value of `6` is assumed. +#' @param features Vector of column names in x to calculate components for. If \code{NULL}, all features are used. #' @param ... Further arguments passed to methods. #' #' @return Decomposition of the regression or classification function. @@ -26,7 +27,7 @@ #' with `:` separating interaction terms as one would specify in a [`formula`] interface. #' * `intercept`: Intercept term, the expected value of the prediction. #' @export -glex <- function(object, x, max_interaction = NULL, ...) { +glex <- function(object, x, max_interaction = NULL, features = NULL, ...) { UseMethod("glex") } @@ -51,7 +52,7 @@ glex.default <- function(object, ...) { #' glex_rpf <- glex(rp, mtcars[27:32, ]) #' str(glex_rpf, list.len = 5) #' } -glex.rpf <- function(object, x, max_interaction = NULL, ...) { +glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { if (!requireNamespace("randomPlantedForest", quietly = TRUE)) { stop(paste0("randomPlantedForest needs to be installed: ", "remotes::install_github(\"PlantedML/randomPlantedForest\")")) @@ -59,7 +60,7 @@ glex.rpf <- function(object, x, max_interaction = NULL, ...) { ret <- randomPlantedForest::predict_components( object = object, new_data = x, max_interaction = max_interaction, - predictors = NULL + predictors = features ) # class(ret) <- c("glex", "rpf_components", class(ret)) ret @@ -91,7 +92,7 @@ glex.rpf <- function(object, x, max_interaction = NULL, ...) { #' glex(xg, x[27:32, ]) #' } #' } -glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { +glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, ...) { if (!requireNamespace("xgboost", quietly = TRUE)) { stop("xgboost needs to be installed: install.packages(\"xgboost\")") @@ -125,12 +126,21 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { } # Convert features to numerics (leaf = 0) - trees[, Feature_num := as.numeric(factor(Feature, levels = c("Leaf", colnames(x)))) - 1] - - # All subsets S (that appear in any of the trees) - all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) { - subsets(trees[Tree == tree & Feature_num > 0, sort(unique(as.integer(Feature_num)))]) - }))) + trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L] + + if (is.null(features)) { + # All subsets S (that appear in any of the trees) + all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) { + subsets(trees[Tree == tree & Feature_num > 0, sort(unique(Feature_num))]) + }))) + } else { + # All subsets with supplied features + if (!all(features %in% colnames(x))) { + stop("All selected features have to be column names of x.") + } + features_num <- as.integer(factor(features, levels = c("Leaf", colnames(x)))) - 1L + all_S <- subsets(sort(unique(features_num))) + } # Keep only those with not more than max_interaction involved features d <- lengths(all_S) @@ -141,7 +151,7 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { # Calculate matrix tree_info <- trees[Tree == tree, ] - T <- setdiff(tree_info[, sort(unique(Feature_num))], 0) + T <- setdiff(tree_info[, sort(unique(Feature_num))], 0L) U <- subsets(T) mat <- recurse(x, tree_info$Feature_num, tree_info$Split, tree_info$Yes, tree_info$No, tree_info$Quality, tree_info$Cover, U, 0) @@ -151,14 +161,12 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) { # Init m matrix m_all <- matrix(0, nrow = nrow(x), ncol = length(all_S)) - #browser() colnames(m_all) <- vapply(all_S, function(s) { paste(sort(colnames(x)[s]), collapse = ":") }, FUN.VALUE = character(1)) - # Calculate contribution, use only subsets with not more than max_interaction involved features - d <- lengths(U) - for (S in U[d <= max_interaction]) { + # Calculate contribution, use only selected features and subsets with not more than max_interaction involved features + for (S in intersect(U, all_S)) { colname <- paste(sort(colnames(x)[S]), collapse = ":") if (nchar(colname) == 0) { colnum <- 1 diff --git a/man/glex.Rd b/man/glex.Rd index 2021586..a1def79 100644 --- a/man/glex.Rd +++ b/man/glex.Rd @@ -6,11 +6,11 @@ \alias{glex.xgb.Booster} \title{Global explanations for tree-based models.} \usage{ -glex(object, x, max_interaction = NULL, ...) +glex(object, x, max_interaction = NULL, features = NULL, ...) -\method{glex}{rpf}(object, x, max_interaction = NULL, ...) +\method{glex}{rpf}(object, x, max_interaction = NULL, features = NULL, ...) -\method{glex}{xgb.Booster}(object, x, max_interaction = NULL, ...) +\method{glex}{xgb.Booster}(object, x, max_interaction = NULL, features = NULL, ...) } \arguments{ \item{object}{Model to be explained, either of class \code{xgb.Booster} or \code{rpf}.} @@ -23,6 +23,8 @@ Defaults to using all possible interactions available in the model.\cr For \code{\link[xgboost:xgb.train]{xgboost}}, this defaults to the \code{max_depth} parameter of the model fit.\cr If not set in \code{xgboost}, the default value of \code{6} is assumed.} +\item{features}{Vector of column names in x to calculate components for. If \code{NULL}, all features are used.} + \item{...}{Further arguments passed to methods.} } \value{ diff --git a/tests/testthat/test-glex-xgboost.R b/tests/testthat/test-glex-xgboost.R index 8eb0201..0f88ff9 100644 --- a/tests/testthat/test-glex-xgboost.R +++ b/tests/testthat/test-glex-xgboost.R @@ -30,3 +30,27 @@ test_that("max_interaction respects xgb's max_depth", { expect_equal(max_degree, 7) expect_error(glex(xg, x, max_interaction = 8)) }) + +test_that("features argument only calculates for given feautures", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb <- glex(xg, x, features = c("cyl", "disp")) + expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp")) +}) + +test_that("features argument results in same values as without", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb1 <- glex(xg, x, features = c("cyl", "disp")) + glexb2 <- glex(xg, x) + cols <- c("cyl", "disp", "cyl:disp") + expect_equal(glexb1$m[, ..cols], glexb2$m[, ..cols]) +}) + +test_that("features argument works together with max_interaction", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1) + expect_equal(colnames(glexb$m), c("cyl", "disp")) +}) +