Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions R/glex.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' Defaults to using all possible interactions available in the model.\cr
#' For [`xgboost`][xgboost::xgb.train], this defaults to the `max_depth` parameter of the model fit.\cr
#' If not set in `xgboost`, the default value of `6` is assumed.
#' @param features Vector of column names in x to calculate components for. If \code{NULL}, all features are used.
#' @param ... Further arguments passed to methods.
#'
#' @return Decomposition of the regression or classification function.
Expand All @@ -26,7 +27,7 @@
#' with `:` separating interaction terms as one would specify in a [`formula`] interface.
#' * `intercept`: Intercept term, the expected value of the prediction.
#' @export
glex <- function(object, x, max_interaction = NULL, ...) {
glex <- function(object, x, max_interaction = NULL, features = NULL, ...) {
UseMethod("glex")
}

Expand All @@ -51,15 +52,15 @@ glex.default <- function(object, ...) {
#' glex_rpf <- glex(rp, mtcars[27:32, ])
#' str(glex_rpf, list.len = 5)
#' }
glex.rpf <- function(object, x, max_interaction = NULL, ...) {
glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) {
if (!requireNamespace("randomPlantedForest", quietly = TRUE)) {
stop(paste0("randomPlantedForest needs to be installed: ",
"remotes::install_github(\"PlantedML/randomPlantedForest\")"))
}

ret <- randomPlantedForest::predict_components(
object = object, new_data = x, max_interaction = max_interaction,
predictors = NULL
predictors = features
)
# class(ret) <- c("glex", "rpf_components", class(ret))
ret
Expand Down Expand Up @@ -91,7 +92,7 @@ glex.rpf <- function(object, x, max_interaction = NULL, ...) {
#' glex(xg, x[27:32, ])
#' }
#' }
glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) {
glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, ...) {

if (!requireNamespace("xgboost", quietly = TRUE)) {
stop("xgboost needs to be installed: install.packages(\"xgboost\")")
Expand Down Expand Up @@ -125,12 +126,21 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) {
}

# Convert features to numerics (leaf = 0)
trees[, Feature_num := as.numeric(factor(Feature, levels = c("Leaf", colnames(x)))) - 1]

# All subsets S (that appear in any of the trees)
all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) {
subsets(trees[Tree == tree & Feature_num > 0, sort(unique(as.integer(Feature_num)))])
})))
trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L]

if (is.null(features)) {
# All subsets S (that appear in any of the trees)
all_S <- unique(do.call(c,lapply(0:max(trees$Tree), function(tree) {
subsets(trees[Tree == tree & Feature_num > 0, sort(unique(Feature_num))])
})))
} else {
# All subsets with supplied features
if (!all(features %in% colnames(x))) {
stop("All selected features have to be column names of x.")
}
features_num <- as.integer(factor(features, levels = c("Leaf", colnames(x)))) - 1L
all_S <- subsets(sort(unique(features_num)))
}

# Keep only those with not more than max_interaction involved features
d <- lengths(all_S)
Expand All @@ -141,7 +151,7 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) {
# Calculate matrix
tree_info <- trees[Tree == tree, ]

T <- setdiff(tree_info[, sort(unique(Feature_num))], 0)
T <- setdiff(tree_info[, sort(unique(Feature_num))], 0L)
U <- subsets(T)
mat <- recurse(x, tree_info$Feature_num, tree_info$Split, tree_info$Yes, tree_info$No,
tree_info$Quality, tree_info$Cover, U, 0)
Expand All @@ -151,14 +161,12 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, ...) {

# Init m matrix
m_all <- matrix(0, nrow = nrow(x), ncol = length(all_S))
#browser()
colnames(m_all) <- vapply(all_S, function(s) {
paste(sort(colnames(x)[s]), collapse = ":")
}, FUN.VALUE = character(1))

# Calculate contribution, use only subsets with not more than max_interaction involved features
d <- lengths(U)
for (S in U[d <= max_interaction]) {
# Calculate contribution, use only selected features and subsets with not more than max_interaction involved features
for (S in intersect(U, all_S)) {
colname <- paste(sort(colnames(x)[S]), collapse = ":")
if (nchar(colname) == 0) {
colnum <- 1
Expand Down
8 changes: 5 additions & 3 deletions man/glex.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-glex-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,27 @@ test_that("max_interaction respects xgb's max_depth", {
expect_equal(max_degree, 7)
expect_error(glex(xg, x, max_interaction = 8))
})

test_that("features argument only calculates for given feautures", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glexb <- glex(xg, x, features = c("cyl", "disp"))
expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp"))
})

test_that("features argument results in same values as without", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glexb1 <- glex(xg, x, features = c("cyl", "disp"))
glexb2 <- glex(xg, x)
cols <- c("cyl", "disp", "cyl:disp")
expect_equal(glexb1$m[, ..cols], glexb2$m[, ..cols])
})

test_that("features argument works together with max_interaction", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1)
expect_equal(colnames(glexb$m), c("cyl", "disp"))
})