Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions R/glex.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) {
predictors = features
)
# class(ret) <- c("glex", "rpf_components", class(ret))
ret$m <- cbind(ret$intercept, ret$m)
colnames(ret$m)[1] <- "intercept"
if (
object$mode == "regression" &&
is.numeric(max_interaction) &&
max_interaction < object$params$max_interaction
) {
preds <- predict(object, x)
ret$m[, "rest"] <- preds$.pred - rowSums(ret$m)
}

ret
}

Expand Down Expand Up @@ -97,7 +108,9 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL,
if (!requireNamespace("xgboost", quietly = TRUE)) {
stop("xgboost needs to be installed: install.packages(\"xgboost\")")
}

if (!is.matrix(x)) {
x <- as.matrix(x)
}
# If max_interaction is not specified, we set it to the max_depth param of the xgb model.
# If max_depth is not defined in xgb, we assume its default of 6.
xgb_max_depth <- ifelse(is.null(object$params$max_depth), 6L, object$params$max_depth)
Expand All @@ -115,7 +128,17 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL,
# Calculate components
res <- calc_components(trees, x, max_interaction, features, probFunction)
res$intercept <- res$intercept + 0.5

res$m[, "intercept"] <- res$intercept

if (
"objective" %in% names(object$params) &&
startsWith(object$params$objective, "reg:") &&
is.numeric(max_interaction) &&
max_interaction < xgb_max_depth
) {
preds <- predict(object, x)
res$m[, "rest"] <- preds - rowSums(res$m)
}
# Return components
res
}
Expand Down Expand Up @@ -163,7 +186,9 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob
if (is.null(object$forest$num.samples.nodes)) {
stop("ranger needs to be called with node.stats=TRUE for glex.")
}

if (!is.matrix(x)) {
x <- as.matrix(x)
}
# If max_interaction is not specified, we set it to the max.depth param of the ranger model.
# If max.depth is not defined in ranger, we assume 6 as in xgboost.
rf_max_depth <- ifelse((is.null(object$max.depth) || object$max.depth == 0), 6L, object$max.depth)
Expand Down Expand Up @@ -194,6 +219,14 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob
res$m <- res$m / object$num.trees
res$intercept <- res$intercept / object$num.trees

if (
object$treetype == "Regression" &&
is.numeric(max_interaction) &&
max_interaction < rf_max_depth
) {
preds <- predict(object, x)
res$m[, "rest"] <- preds$predictions - rowSums(res$m)
}
# Return components
res
}
Expand Down Expand Up @@ -403,7 +436,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction =
} else {
m_all <- foreach(j = idx, .combine = "+") %do% tree_fun(j)
}

colnames(m_all)[1] <- "intercept"
d <- get_degree(colnames(m_all))

# Overall feature effect is sum of all elements where feature is involved
Expand All @@ -423,7 +456,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction =
# Return shap values, decomposition and intercept
ret <- list(
shap = data.table::setDT(as.data.frame(shap)),
m = data.table::setDT(as.data.frame(m_all[, -1])),
m = data.table::setDT(as.data.frame(m_all)),
intercept = unique(m_all[, 1]),
x = data.table::setDT(as.data.frame(x))
)
Expand Down
50 changes: 26 additions & 24 deletions R/glex_vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,48 @@
#' set.seed(1)
#' # Random Planted Forest -----
#' if (requireNamespace("randomPlantedForest", quietly = TRUE)) {
#' library(randomPlantedForest)
#' library(randomPlantedForest)
#'
#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3)
#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3)
#'
#' glex_rpf <- glex(rp, mtcars[27:32, ])
#' glex_rpf <- glex(rp, mtcars[27:32, ])
#'
#' # All terms
#' vi_rpf <- glex_vi(glex_rpf)
#' # All terms
#' vi_rpf <- glex_vi(glex_rpf)
#'
#' library(ggplot2)
#' # Filter to contributions greater 0.05 on the scale of the target
#' autoplot(vi_rpf, threshold = 0.05)
#' # Summarize by degree of interaction
#' autoplot(vi_rpf, by_degree = TRUE)
#' # Filter by relative contributions greater 0.1%
#' autoplot(vi_rpf, scale = "relative", threshold = 0.001)
#' library(ggplot2)
#' # Filter to contributions greater 0.05 on the scale of the target
#' autoplot(vi_rpf, threshold = 0.05)
#' # Summarize by degree of interaction
#' autoplot(vi_rpf, by_degree = TRUE)
#' # Filter by relative contributions greater 0.1%
#' autoplot(vi_rpf, scale = "relative", threshold = 0.001)
#' }
#'
#' # xgboost -----
#' if (requireNamespace("xgboost", quietly = TRUE)) {
#' library(xgboost)
#' x <- as.matrix(mtcars[, -1])
#' y <- mtcars$mpg
#' xg <- xgboost(data = x[1:26, ], label = y[1:26],
#' params = list(max_depth = 4, eta = .1),
#' nrounds = 10, verbose = 0)
#' glex_xgb <- glex(xg, x[27:32, ])
#' vi_xgb <- glex_vi(glex_xgb)
#' library(xgboost)
#' x <- as.matrix(mtcars[, -1])
#' y <- mtcars$mpg
#' xg <- xgboost(
#' data = x[1:26, ], label = y[1:26],
#' params = list(max_depth = 4, eta = .1),
#' nrounds = 10, verbose = 0
#' )
#' glex_xgb <- glex(xg, x[27:32, ])
#' vi_xgb <- glex_vi(glex_xgb)
#'
#' library(ggplot2)
#' autoplot(vi_xgb)
#' autoplot(vi_xgb, by_degree = TRUE)
#' library(ggplot2)
#' autoplot(vi_xgb)
#' autoplot(vi_xgb, by_degree = TRUE)
#' }
glex_vi <- function(object, ...) {
checkmate::assert_class(object, classes = "glex")

# FIXME: data.table NSE warnings
term <- degree <- m <- m_rel <- NULL

m_long <- melt_m(object$m, object$target_levels)
m_long <- melt_m(object$m[, -c("intercept")], object$target_levels)

if (!is.null(object$target_levels)) {
vars_by <- c("term", "class")
Expand Down
20 changes: 12 additions & 8 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,25 @@
#' @examples
#' # Random Planted Forest -----
#' if (requireNamespace("randomPlantedForest", quietly = TRUE)) {
#' library(randomPlantedForest)
#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2)
#' library(randomPlantedForest)
#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2)
#'
#' glex(rp, mtcars[27:32, ])
#' glex(rp, mtcars[27:32, ])
#' }
print.glex <- function(x, ...) {

n <- nrow(x$x)
n_m <- ncol(x$m)
n_m <- ncol(x$m) - 1
max_deg <- max(get_degree(names(x$m)))
max_deg_lab <- switch(as.character(max_deg), "1" = "degree", "degrees")
max_deg_lab <- switch(as.character(max_deg),
"1" = "degree",
"degrees"
)

cat("glex object of subclass", class(x)[[2]], "\n")
cat("Explaining predictions of", n, "observations with", n_m,
"terms of up to", max_deg, max_deg_lab)
cat(
"Explaining predictions of", n, "observations with", n_m,
"terms of up to", max_deg, max_deg_lab
)
cat("\n\n")
str(x, list.len = 5)
}
9 changes: 5 additions & 4 deletions R/utils-components-reshaping.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ melt_m <- function(m, levels = NULL) {
term <- NULL
# We need an id.var for melt to avoid a warning, but don't want to modify m permanently
m[, ".id" := .I]
m_long <- data.table::melt(m, id.vars = ".id", value.name = "m",
variable.name = "term", variable.factor = FALSE)
m_long <- data.table::melt(m,
id.vars = ".id", value.name = "m",
variable.name = "term", variable.factor = FALSE
)
# clean up that temporary id column again while modifying by reference
m[, ".id" := NULL]

Expand All @@ -45,7 +47,7 @@ reshape_m_multiclass <- function(object) {
checkmate::assert_character(object$target_levels, min.len = 2)

mlong <- melt_m(object$m, object$target_levels)
data.table::dcast(mlong, .id + class ~ term, value.var = "m")
data.table::dcast(mlong, .id + class ~ term, value.var = "m")
}

#' Helper to split multiclass terms of format <term>__class:<class>
Expand All @@ -62,4 +64,3 @@ split_names <- function(mn, split_string = "__class:", target_index = 2) {
unlist(strsplit(x, split = split_string, fixed = TRUE))[target_index]
}, character(1), USE.NAMES = FALSE)
}

31 changes: 16 additions & 15 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3)

x <- as.matrix(mtcars[, -1])
y <- mtcars$mpg
xg <- xgboost(data = x[1:26, ], label = y[1:26],
params = list(max_depth = 3, eta = .1),
nrounds = 30, verbose = 0)

xg <- xgboost(
data = x[1:26, ], label = y[1:26],
params = list(max_depth = 3, eta = .1),
nrounds = 30, verbose = 0
)
```

Using the model objects and a dataset to explain (such as a test set in this case), we can create `glex` objects
Expand All @@ -106,8 +107,8 @@ Both `m` and `shap` satisfy the property that their sums (per observation) toget

```{r comp-sum}
# Calculating sum of components and sum of SHAP values
sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept
sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept
sum_m_rpf <- rowSums(glex_rpf$m)
sum_m_xgb <- rowSums(glex_xgb$m)
sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept

# Model predictions
Expand Down Expand Up @@ -137,10 +138,10 @@ vi_xgb[1:5, c("degree", "term", "m")]
The output additionally contains the degree of interaction, which can be used for filtering and aggregating. Here we filter for terms with contributions above a `threshold` of `0.05` to get a more compact plot, with terms below the threshold aggregated into one labelled "Remaining terms":

```{r glex_vi-plot, fig.width=12}
p_vi1 <- autoplot(vi_rpf, threshold = .05) +
p_vi1 <- autoplot(vi_rpf, threshold = .05) +
labs(title = NULL, subtitle = "RPF")

p_vi2 <- autoplot(vi_xgb, threshold = .05) +
p_vi2 <- autoplot(vi_xgb, threshold = .05) +
labs(title = NULL, subtitle = "XGBoost")

p_vi1 + p_vi2 +
Expand All @@ -150,14 +151,14 @@ p_vi1 + p_vi2 +
We can also sum values within each degree of interaction for a more aggregated view, which can be useful as it allows us to judge interactions above a certain degree to not be particularly relevant for a given model.

```{r glex_vi-plot-by-degree, fig.height=5}
p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) +
p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) +
labs(title = NULL, subtitle = "RPF")

p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) +
p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) +
labs(title = NULL, subtitle = "XGBoost")

p_vi1 + p_vi2 +
plot_annotation(title = "Variable importance scores by degree")
plot_annotation(title = "Variable importance scores by degree")
```

### Feature Effects
Expand All @@ -168,13 +169,13 @@ We can also plot prediction components against observed feature values, which ad
p1 <- autoplot(glex_rpf, "hp") + labs(subtitle = "RPF")
p2 <- autoplot(glex_xgb, "hp") + labs(subtitle = "XGBoost")

p1 + p2 +
p1 + p2 +
plot_annotation(title = "Main effect for 'hp'")

p1 <- autoplot(glex_rpf, c("hp", "wt")) + labs(subtitle = "RPF")
p2 <- autoplot(glex_xgb, c("hp", "wt")) + labs(subtitle = "XGBoost")

p1 + p2 +
p1 + p2 +
plot_annotation(title = "Two-way effects for 'hp' and 'wt'")
```

Expand All @@ -192,9 +193,9 @@ Finally, we can explore the prediction for a single observation by displaying it
For compactness, we only plot one feature and collapse all interaction terms above the second degree into one as their combined effect is very small.

```{r glex_explain}
p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) +
p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) +
labs(tag = "RPF")
p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) +
p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) +
labs(tag = "XGBoost")

p1 + p2 & theme(plot.tag.position = "bottom")
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ prediction for each observation:

``` r
# Calculating sum of components and sum of SHAP values
sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept
sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept
sum_m_rpf <- rowSums(glex_rpf$m)
sum_m_xgb <- rowSums(glex_xgb$m)
sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept

# Model predictions
Expand Down
25 changes: 13 additions & 12 deletions tests/testthat/_snaps/print-glex.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
Code
out
Output
[1] "glex object of subclass rpf_components "
[2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree"
[3] ""
[4] "List of 3"
[5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:"
[6] " ..$ cyl: num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..."
[7] " ..- attr(*, \".internal.selfref\")=<externalptr> "
[8] " $ intercept: num 20.2"
[9] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:"
[10] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..."
[11] " ..- attr(*, \".internal.selfref\")=<externalptr> "
[12] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\""
[1] "glex object of subclass rpf_components "
[2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree"
[3] ""
[4] "List of 3"
[5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 2 variables:"
[6] " ..$ intercept: num [1:32] 20.2 20.2 20.2 20.2 20.2 ..."
[7] " ..$ cyl : num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..."
[8] " ..- attr(*, \".internal.selfref\")=<externalptr> "
[9] " $ intercept: num 20.2"
[10] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:"
[11] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..."
[12] " ..- attr(*, \".internal.selfref\")=<externalptr> "
[13] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\""

1 change: 0 additions & 1 deletion tests/testthat/test-fastpd-equals-empirical.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ test_that("FastPD equals empirical leaf weighting", {
empirical_leaf_weighting <- glex(rf, x, probFunction = "empirical")

expect_equal(fastpd$m, empirical_leaf_weighting$m)
expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept
})

test_that("FastPD equals empirical leaf weighting for lower interactions", {
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-glex-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ test_that("features argument only calculates for given feautures", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glexb <- glex(xg, x, features = c("cyl", "disp"))
expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp"))
expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp", "cyl:disp"))
})

test_that("features argument results in same values as without", {
Expand All @@ -50,7 +50,7 @@ test_that("features argument works together with max_interaction", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1)
expect_equal(colnames(glexb$m), c("cyl", "disp"))
expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp"))
})

test_that("Prediction is approx. same as sum of decomposition + intercept, xgboost", {
Expand All @@ -59,7 +59,7 @@ test_that("Prediction is approx. same as sum of decomposition + intercept, xgboo
pred_train <- predict(xg, x)

res_train <- glex(xg, x)
expect_equal(res_train$intercept + rowSums(res_train$m),
expect_equal(rowSums(res_train$m),
pred_train,
tolerance = 1e-5
)
Expand Down
Loading