Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: glex
Title: Global Explanations for Tree-Based Models
Version: 0.5.0
Version: 0.5.1
Authors@R: c(
person(c("Marvin", "N."), "Wright", , "cran@wrig.de", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-8542-6291")),
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# glex 0.5.1

* Fix path-dependent algorithm by computing the proper covers manually
* Allow `glex()` to accept data frames as input


# glex 0.5.0

* Optimize FastPD to be able to handle more features using bitmask represenation (#29)
Expand Down
21 changes: 17 additions & 4 deletions R/glex.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,21 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) {
#' params = list(max_depth = 4, eta = .1),
#' nrounds = 10, verbose = 0)
#' glex(xg, x[27:32, ])
#'
#' glex(xg, mtcars[27:32, ])
#'
#' \dontrun{
#' # Parallel execution
#' doParallel::registerDoParallel()
#' glex(xg, x[27:32, ])
#' }
#' }
glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ...) {
if (!is.matrix(x)) {
if (is.data.frame(x) && any(!sapply(x, is.numeric))) {
stop("Input 'x' contains non-numeric columns. Please ensure all columns are numeric or convert them appropriately (e.g., using model.matrix) to match model training data before calling glex.")
}
x <- as.matrix(x)
}
if (!requireNamespace("xgboost", quietly = TRUE)) {
stop("xgboost needs to be installed: install.packages(\"xgboost\")")
}
Expand Down Expand Up @@ -146,15 +153,21 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL,
#' num.trees = 5, max.depth = 3,
#' node.stats = TRUE)
#' glex(rf, x[27:32, ])
#'
#' glex(rf, mtcars[27:32, ])
#'
#' \dontrun{
#' # Parallel execution
#' doParallel::registerDoParallel()
#' glex(rf, x[27:32, ])
#' }
#' }
glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ...) {

if (!is.matrix(x)) {
if (is.data.frame(x) && any(!sapply(x, is.numeric))) {
stop("Input 'x' contains non-numeric columns. Please ensure all columns are numeric or convert them appropriately (e.g., using model.matrix) to match model training data before calling glex.")
}
x <- as.matrix(x)
}
# To avoid data.table check issues
terminal <- NULL
splitvarName <- NULL
Expand Down Expand Up @@ -209,7 +222,7 @@ tree_fun_path_dependent <- function(tree, trees, x, all_S, max_interaction) {
# Prepare tree_info for C++ function
tree_info <- trees[get("Tree") == tree, ]
tree_info[, "Feature" := get("Feature_num") - 1L] # Adjust to 0-based for C++ bitmasks
to_select <- c("Feature", "Split", "Yes", "No", "Quality", "Cover")
to_select <- c("Feature", "Split", "Yes", "No", "Quality")
tree_mat <- tree_info[, to_select, with = FALSE] # Use with=FALSE to avoid data.table check issues
tree_mat[is.na(tree_mat)] <- -1L # Use -1 for leaf nodes
tree_mat <- as.matrix(tree_mat)
Expand Down
165 changes: 131 additions & 34 deletions src/path_dependent_explain_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,152 @@ void contributeFastPDBitmask(
unsigned int colnum,
unsigned int t_size);

namespace
{
std::vector<unsigned int> calculate_data_driven_covers(
const NumericMatrix &x,
const NumericMatrix &tree,
bool is_weak_inequality)
{
unsigned int n_nodes = tree.nrow();
unsigned int n_samples = x.nrow();
std::vector<unsigned int> covers(n_nodes, 0);

if (n_samples == 0 || n_nodes == 0)
{
return covers;
}

for (unsigned int i = 0; i < n_samples; ++i)
{
// Cannot take x.row(i) directly if x is pass-by-reference-to-const in some Rcpp contexts
// but NumericMatrix::ConstRow is fine. Let's assume x is accessible.
// To be safe with potential non-const x, we can copy or access elements.
// For now, direct access pattern as in original recursePathDependent implies x is readable.

unsigned int current_node_idx = 0; // Start at root

while (current_node_idx < n_nodes)
{ // Ensure current_node_idx is valid
covers[current_node_idx]++;
NumericMatrix::ConstRow node_data = tree.row(current_node_idx);
const int feature_idx = node_data[Index::FEATURE];

if (feature_idx == -1)
{ // Leaf node
break;
}

const double split_val = node_data[Index::SPLIT];
// Ensure feature_idx is valid for x.ncol()
if (static_cast<unsigned int>(feature_idx) >= x.ncol() || feature_idx < 0)
{
// This would be an error: tree uses a feature not in x
// Or handle more gracefully, e.g., Rcpp::stop or break path.
// For now, assume valid feature_idx as per original code's direct use.
break;
}
const double obs_val = x(i, feature_idx);

bool goes_yes;
if (is_weak_inequality)
{
goes_yes = glex::WeakComparison::compare(obs_val, split_val);
}
else
{
goes_yes = glex::StrictComparison::compare(obs_val, split_val);
}

if (goes_yes)
{
current_node_idx = node_data[Index::YES];
}
else
{
current_node_idx = node_data[Index::NO];
}
// Basic check for child validity, though tree structure should be correct
if (current_node_idx >= n_nodes && feature_idx != -1)
{
// Invalid child index from tree, stop this path.
// This indicates a malformed tree.
break;
}
}
}
return covers;
}
} // anonymous namespace

template <typename Comparison>
NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree, std::vector<FeatureMask> &U, unsigned int node)
NumericMatrix recursePathDependent(
NumericMatrix &x,
const NumericMatrix &tree,
std::vector<FeatureMask> &U,
unsigned int node,
const std::vector<unsigned int> &node_covers) // Added node_covers
{

NumericMatrix::ConstRow current_node = tree(node, _);
const int current_feature = current_node[Index::FEATURE];

// Dimensions
const unsigned int n = x.nrow();
const unsigned int n_subsets = U.size();

// We'll create an output matrix [n, n_subsets].
// Using no_init to skip zero fill if we plan to overwrite everything.
NumericMatrix mat(no_init(n, n_subsets));

// If leaf, just return value
if (current_feature == -1)
{
double pred = current_node[Index::QUALITY];
std::fill(mat.begin(), mat.end(), pred);
}
else
{
// Call both children, they give a matrix each of all obs and subsets
const unsigned int yes = current_node[Index::YES];
const unsigned int no = current_node[Index::NO];
const double split = current_node[Index::SPLIT];
const double cover_yes = tree(yes, Index::COVER);
const double cover_no = tree(no, Index::COVER);
const double cover_node = current_node[Index::COVER];

NumericMatrix mat_yes = recursePathDependent<Comparison>(x, tree, U, yes);
NumericMatrix mat_no = recursePathDependent<Comparison>(x, tree, U, no);
NumericMatrix mat_yes = recursePathDependent<Comparison>(x, tree, U, yes, node_covers);
NumericMatrix mat_no = recursePathDependent<Comparison>(x, tree, U, no, node_covers);

for (unsigned int j = 0; j < U.size(); ++j)
{
const double *col_yes = &mat_yes(0, j);
const double *col_no = &mat_no(0, j);
double *col_out = &mat(0, j);

// Use bitmask to check if feature is out (not in subset U[j])
if (!hasFeature(U[j], current_feature))
{
double cover_at_current_node = static_cast<double>(node_covers[node]);
double weight_yes = 0.5;
double weight_no = 0.5;

if (cover_at_current_node > 0)
{
// Ensure child indices 'yes' and 'no' are valid for node_covers
// This should hold if tree structure and node_covers calculation are correct
double cover_at_yes_child = (yes < node_covers.size()) ? static_cast<double>(node_covers[yes]) : 0.0;
double cover_at_no_child = (no < node_covers.size()) ? static_cast<double>(node_covers[no]) : 0.0;

weight_yes = cover_at_yes_child / cover_at_current_node;
weight_no = cover_at_no_child / cover_at_current_node;

// Small correction if sum is not exactly 1 due to all samples going one way and 0 the other
// for a child that might not be further split (leaf).
// The covers should reflect actual flow. If node_covers[yes] + node_covers[no] != node_covers[node]
// it means some samples terminated or tree is structured unexpectedly (e.g. missing values not handled here)
// For now, simple division is fine.
}
// If cover_at_current_node is 0, weights remain 0.5/0.5.
// This means if no samples from x reach this node, we average predictions from children.

for (unsigned int i = 0; i < n; ++i)
{
col_out[i] = cover_yes / cover_node * col_yes[i] + cover_no / cover_node * col_no[i];
col_out[i] = weight_yes * col_yes[i] + weight_no * col_no[i];
}
}
else
{
// For subsets where feature is in, split to left/right
for (unsigned int i = 0; i < n; ++i)
{
if (Comparison::compare(x(i, current_feature), split))
Expand All @@ -78,8 +173,6 @@ NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree,
}
}
}

// Return combined matrix
return (mat);
}

Expand All @@ -91,44 +184,50 @@ NumericMatrix explainTreePathDependent(
unsigned int max_interaction,
bool is_weak_inequality)
{
// Get feature count for proper mask sizing
unsigned int n_features = x.ncol();
unsigned int n_nodes = tree.nrow();

// Determine all features encountered in the tree nodes
FeatureMask all_encountered = createEmptyMask(n_features);
for (unsigned int i = 0; i < n_nodes; ++i)
{
NumericMatrix::ConstRow current_node = tree(i, _);
const int current_feature = current_node[Index::FEATURE];
if (current_feature != -1)
{ // -1 indicates a leaf node
all_encountered = setFeature(all_encountered, current_feature); // Adjust to 0-based
NumericMatrix::ConstRow current_node_row = tree(i, _); // Renamed to avoid conflict
const int feature_val = current_node_row[Index::FEATURE]; // Renamed
if (feature_val != -1)
{
all_encountered = setFeature(all_encountered, feature_val);
}
}

// Generate all subsets U of all_encountered features up to max_interaction
std::vector<FeatureMask> U = get_all_subsets_of_mask(all_encountered, max_interaction);

// Explain/expectation/marginalization step using the optimized recursePathDependent
NumericMatrix mat = is_weak_inequality ? recursePathDependent<glex::WeakComparison>(x, tree, U, 0) : recursePathDependent<glex::StrictComparison>(x, tree, U, 0);
// Calculate data-driven covers based on input x and tree structure
std::vector<unsigned int> node_covers = calculate_data_driven_covers(x, tree, is_weak_inequality);

NumericMatrix mat;
if (is_weak_inequality)
{
mat = recursePathDependent<glex::WeakComparison>(x, tree, U, 0, node_covers);
}
else
{
mat = recursePathDependent<glex::StrictComparison>(x, tree, U, 0, node_covers);
}

// Process to_explain_list and calculate m_all
std::vector<std::set<unsigned int>> to_explain; // List of S'es to explain
std::vector<std::set<unsigned int>> to_explain;
unsigned int to_explain_size = to_explain_list.size();
NumericMatrix m_all = NumericMatrix(x.nrow(), to_explain_size);
CharacterVector m_all_col_names = CharacterVector(to_explain_size);
CharacterVector x_col_names = colnames(x);

std::set<unsigned int> needToComputePDfunctionsFor; // The set of coords we need to compute PD functions of
std::set<unsigned int> needToComputePDfunctionsFor;
for (int S_idx = 0; S_idx < to_explain_size; S_idx++)
{
std::set<unsigned int> S = std::set<unsigned int>(
as<IntegerVector>(to_explain_list[S_idx]).begin(),
as<IntegerVector>(to_explain_list[S_idx]).end());
to_explain.push_back(S);

int k = S.size(); // Declare k here
int k = S.size();
if (k != 0)
{
auto it = S.begin();
Expand All @@ -139,9 +238,7 @@ NumericMatrix explainTreePathDependent(
m_all_col_names[S_idx] = oss.str();
}

// Calculate contribution for S using inclusion-exclusion
// Iterate over all subsets V of S (using bitmask logic)
FeatureMask S_mask = setToBitmask(S); // S_set is 0-based
FeatureMask S_mask = setToBitmask(S);
bool is_subset = false;
for (const auto &U_mask : U)
{
Expand Down
18 changes: 18 additions & 0 deletions tests/testthat/test-glex-dataframe.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
library(xgboost)
library(ranger)

test_that("Explaining XGBoost model with dataframe input works", {
x <- as.matrix(mtcars[, -1])
xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0)
glex_matrix <- glex(xg, x)
glex_dataframe <- glex(xg, mtcars)
expect_equal(glex_matrix$m, glex_dataframe$m)
})

test_that("Explaining Ranger model with dataframe input works", {
x <- mtcars[, -1]
rf <- ranger(mpg ~ ., data = mtcars, node.stats = TRUE, num.trees = 10)
glex_matrix <- glex(rf, x)
glex_dataframe <- glex(rf, mtcars)
expect_equal(glex_matrix$m, glex_dataframe$m)
})
Loading