diff --git a/DESCRIPTION b/DESCRIPTION index b1bcbcb..0e37c8a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: glex Title: Global Explanations for Tree-Based Models -Version: 0.5.0 +Version: 0.5.1 Authors@R: c( person(c("Marvin", "N."), "Wright", , "cran@wrig.de", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-8542-6291")), diff --git a/NEWS.md b/NEWS.md index 4ddb438..5f94ca8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# glex 0.5.1 + +* Fix path-dependent algorithm by computing the proper covers manually +* Allow `glex()` to accept data frames as input + + # glex 0.5.0 * Optimize FastPD to be able to handle more features using bitmask represenation (#29) diff --git a/R/glex.R b/R/glex.R index 0d884e6..6b385d6 100644 --- a/R/glex.R +++ b/R/glex.R @@ -88,7 +88,8 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { #' params = list(max_depth = 4, eta = .1), #' nrounds = 10, verbose = 0) #' glex(xg, x[27:32, ]) -#' +#' glex(xg, mtcars[27:32, ]) +#' #' \dontrun{ #' # Parallel execution #' doParallel::registerDoParallel() @@ -96,6 +97,12 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { #' } #' } glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ...) { + if (!is.matrix(x)) { + if (is.data.frame(x) && any(!sapply(x, is.numeric))) { + stop("Input 'x' contains non-numeric columns. Please ensure all columns are numeric or convert them appropriately (e.g., using model.matrix) to match model training data before calling glex.") + } + x <- as.matrix(x) + } if (!requireNamespace("xgboost", quietly = TRUE)) { stop("xgboost needs to be installed: install.packages(\"xgboost\")") } @@ -146,7 +153,8 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, #' num.trees = 5, max.depth = 3, #' node.stats = TRUE) #' glex(rf, x[27:32, ]) -#' +#' glex(rf, mtcars[27:32, ]) +#' #' \dontrun{ #' # Parallel execution #' doParallel::registerDoParallel() @@ -154,7 +162,12 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, #' } #' } glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ...) { - + if (!is.matrix(x)) { + if (is.data.frame(x) && any(!sapply(x, is.numeric))) { + stop("Input 'x' contains non-numeric columns. Please ensure all columns are numeric or convert them appropriately (e.g., using model.matrix) to match model training data before calling glex.") + } + x <- as.matrix(x) + } # To avoid data.table check issues terminal <- NULL splitvarName <- NULL @@ -209,7 +222,7 @@ tree_fun_path_dependent <- function(tree, trees, x, all_S, max_interaction) { # Prepare tree_info for C++ function tree_info <- trees[get("Tree") == tree, ] tree_info[, "Feature" := get("Feature_num") - 1L] # Adjust to 0-based for C++ bitmasks - to_select <- c("Feature", "Split", "Yes", "No", "Quality", "Cover") + to_select <- c("Feature", "Split", "Yes", "No", "Quality") tree_mat <- tree_info[, to_select, with = FALSE] # Use with=FALSE to avoid data.table check issues tree_mat[is.na(tree_mat)] <- -1L # Use -1 for leaf nodes tree_mat <- as.matrix(tree_mat) diff --git a/src/path_dependent_explain_tree.cpp b/src/path_dependent_explain_tree.cpp index d6d4e84..a4db517 100644 --- a/src/path_dependent_explain_tree.cpp +++ b/src/path_dependent_explain_tree.cpp @@ -13,22 +13,100 @@ void contributeFastPDBitmask( unsigned int colnum, unsigned int t_size); +namespace +{ + std::vector calculate_data_driven_covers( + const NumericMatrix &x, + const NumericMatrix &tree, + bool is_weak_inequality) + { + unsigned int n_nodes = tree.nrow(); + unsigned int n_samples = x.nrow(); + std::vector covers(n_nodes, 0); + + if (n_samples == 0 || n_nodes == 0) + { + return covers; + } + + for (unsigned int i = 0; i < n_samples; ++i) + { + // Cannot take x.row(i) directly if x is pass-by-reference-to-const in some Rcpp contexts + // but NumericMatrix::ConstRow is fine. Let's assume x is accessible. + // To be safe with potential non-const x, we can copy or access elements. + // For now, direct access pattern as in original recursePathDependent implies x is readable. + + unsigned int current_node_idx = 0; // Start at root + + while (current_node_idx < n_nodes) + { // Ensure current_node_idx is valid + covers[current_node_idx]++; + NumericMatrix::ConstRow node_data = tree.row(current_node_idx); + const int feature_idx = node_data[Index::FEATURE]; + + if (feature_idx == -1) + { // Leaf node + break; + } + + const double split_val = node_data[Index::SPLIT]; + // Ensure feature_idx is valid for x.ncol() + if (static_cast(feature_idx) >= x.ncol() || feature_idx < 0) + { + // This would be an error: tree uses a feature not in x + // Or handle more gracefully, e.g., Rcpp::stop or break path. + // For now, assume valid feature_idx as per original code's direct use. + break; + } + const double obs_val = x(i, feature_idx); + + bool goes_yes; + if (is_weak_inequality) + { + goes_yes = glex::WeakComparison::compare(obs_val, split_val); + } + else + { + goes_yes = glex::StrictComparison::compare(obs_val, split_val); + } + + if (goes_yes) + { + current_node_idx = node_data[Index::YES]; + } + else + { + current_node_idx = node_data[Index::NO]; + } + // Basic check for child validity, though tree structure should be correct + if (current_node_idx >= n_nodes && feature_idx != -1) + { + // Invalid child index from tree, stop this path. + // This indicates a malformed tree. + break; + } + } + } + return covers; + } +} // anonymous namespace + template -NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree, std::vector &U, unsigned int node) +NumericMatrix recursePathDependent( + NumericMatrix &x, + const NumericMatrix &tree, + std::vector &U, + unsigned int node, + const std::vector &node_covers) // Added node_covers { NumericMatrix::ConstRow current_node = tree(node, _); const int current_feature = current_node[Index::FEATURE]; - // Dimensions const unsigned int n = x.nrow(); const unsigned int n_subsets = U.size(); - - // We'll create an output matrix [n, n_subsets]. - // Using no_init to skip zero fill if we plan to overwrite everything. NumericMatrix mat(no_init(n, n_subsets)); - // If leaf, just return value if (current_feature == -1) { double pred = current_node[Index::QUALITY]; @@ -36,16 +114,12 @@ NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree, } else { - // Call both children, they give a matrix each of all obs and subsets const unsigned int yes = current_node[Index::YES]; const unsigned int no = current_node[Index::NO]; const double split = current_node[Index::SPLIT]; - const double cover_yes = tree(yes, Index::COVER); - const double cover_no = tree(no, Index::COVER); - const double cover_node = current_node[Index::COVER]; - NumericMatrix mat_yes = recursePathDependent(x, tree, U, yes); - NumericMatrix mat_no = recursePathDependent(x, tree, U, no); + NumericMatrix mat_yes = recursePathDependent(x, tree, U, yes, node_covers); + NumericMatrix mat_no = recursePathDependent(x, tree, U, no, node_covers); for (unsigned int j = 0; j < U.size(); ++j) { @@ -53,17 +127,38 @@ NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree, const double *col_no = &mat_no(0, j); double *col_out = &mat(0, j); - // Use bitmask to check if feature is out (not in subset U[j]) if (!hasFeature(U[j], current_feature)) { + double cover_at_current_node = static_cast(node_covers[node]); + double weight_yes = 0.5; + double weight_no = 0.5; + + if (cover_at_current_node > 0) + { + // Ensure child indices 'yes' and 'no' are valid for node_covers + // This should hold if tree structure and node_covers calculation are correct + double cover_at_yes_child = (yes < node_covers.size()) ? static_cast(node_covers[yes]) : 0.0; + double cover_at_no_child = (no < node_covers.size()) ? static_cast(node_covers[no]) : 0.0; + + weight_yes = cover_at_yes_child / cover_at_current_node; + weight_no = cover_at_no_child / cover_at_current_node; + + // Small correction if sum is not exactly 1 due to all samples going one way and 0 the other + // for a child that might not be further split (leaf). + // The covers should reflect actual flow. If node_covers[yes] + node_covers[no] != node_covers[node] + // it means some samples terminated or tree is structured unexpectedly (e.g. missing values not handled here) + // For now, simple division is fine. + } + // If cover_at_current_node is 0, weights remain 0.5/0.5. + // This means if no samples from x reach this node, we average predictions from children. + for (unsigned int i = 0; i < n; ++i) { - col_out[i] = cover_yes / cover_node * col_yes[i] + cover_no / cover_node * col_no[i]; + col_out[i] = weight_yes * col_yes[i] + weight_no * col_no[i]; } } else { - // For subsets where feature is in, split to left/right for (unsigned int i = 0; i < n; ++i) { if (Comparison::compare(x(i, current_feature), split)) @@ -78,8 +173,6 @@ NumericMatrix recursePathDependent(NumericMatrix &x, const NumericMatrix &tree, } } } - - // Return combined matrix return (mat); } @@ -91,36 +184,42 @@ NumericMatrix explainTreePathDependent( unsigned int max_interaction, bool is_weak_inequality) { - // Get feature count for proper mask sizing unsigned int n_features = x.ncol(); unsigned int n_nodes = tree.nrow(); - // Determine all features encountered in the tree nodes FeatureMask all_encountered = createEmptyMask(n_features); for (unsigned int i = 0; i < n_nodes; ++i) { - NumericMatrix::ConstRow current_node = tree(i, _); - const int current_feature = current_node[Index::FEATURE]; - if (current_feature != -1) - { // -1 indicates a leaf node - all_encountered = setFeature(all_encountered, current_feature); // Adjust to 0-based + NumericMatrix::ConstRow current_node_row = tree(i, _); // Renamed to avoid conflict + const int feature_val = current_node_row[Index::FEATURE]; // Renamed + if (feature_val != -1) + { + all_encountered = setFeature(all_encountered, feature_val); } } - // Generate all subsets U of all_encountered features up to max_interaction std::vector U = get_all_subsets_of_mask(all_encountered, max_interaction); - // Explain/expectation/marginalization step using the optimized recursePathDependent - NumericMatrix mat = is_weak_inequality ? recursePathDependent(x, tree, U, 0) : recursePathDependent(x, tree, U, 0); + // Calculate data-driven covers based on input x and tree structure + std::vector node_covers = calculate_data_driven_covers(x, tree, is_weak_inequality); + + NumericMatrix mat; + if (is_weak_inequality) + { + mat = recursePathDependent(x, tree, U, 0, node_covers); + } + else + { + mat = recursePathDependent(x, tree, U, 0, node_covers); + } - // Process to_explain_list and calculate m_all - std::vector> to_explain; // List of S'es to explain + std::vector> to_explain; unsigned int to_explain_size = to_explain_list.size(); NumericMatrix m_all = NumericMatrix(x.nrow(), to_explain_size); CharacterVector m_all_col_names = CharacterVector(to_explain_size); CharacterVector x_col_names = colnames(x); - std::set needToComputePDfunctionsFor; // The set of coords we need to compute PD functions of + std::set needToComputePDfunctionsFor; for (int S_idx = 0; S_idx < to_explain_size; S_idx++) { std::set S = std::set( @@ -128,7 +227,7 @@ NumericMatrix explainTreePathDependent( as(to_explain_list[S_idx]).end()); to_explain.push_back(S); - int k = S.size(); // Declare k here + int k = S.size(); if (k != 0) { auto it = S.begin(); @@ -139,9 +238,7 @@ NumericMatrix explainTreePathDependent( m_all_col_names[S_idx] = oss.str(); } - // Calculate contribution for S using inclusion-exclusion - // Iterate over all subsets V of S (using bitmask logic) - FeatureMask S_mask = setToBitmask(S); // S_set is 0-based + FeatureMask S_mask = setToBitmask(S); bool is_subset = false; for (const auto &U_mask : U) { diff --git a/tests/testthat/test-glex-dataframe.R b/tests/testthat/test-glex-dataframe.R new file mode 100644 index 0000000..02abea7 --- /dev/null +++ b/tests/testthat/test-glex-dataframe.R @@ -0,0 +1,18 @@ +library(xgboost) +library(ranger) + +test_that("Explaining XGBoost model with dataframe input works", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) + glex_matrix <- glex(xg, x) + glex_dataframe <- glex(xg, mtcars) + expect_equal(glex_matrix$m, glex_dataframe$m) +}) + +test_that("Explaining Ranger model with dataframe input works", { + x <- mtcars[, -1] + rf <- ranger(mpg ~ ., data = mtcars, node.stats = TRUE, num.trees = 10) + glex_matrix <- glex(rf, x) + glex_dataframe <- glex(rf, mtcars) + expect_equal(glex_matrix$m, glex_dataframe$m) +})