From c67724974bd1b3091e79ce277674ba2fb37c0771 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 11:29:08 +0100 Subject: [PATCH 01/17] Split FastPD expectation into new file --- src/fastpd.cpp | 137 -------------------------------- src/fastpd_expectation.cpp | 157 +++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 137 deletions(-) create mode 100644 src/fastpd_expectation.cpp diff --git a/src/fastpd.cpp b/src/fastpd.cpp index 6512ba9..dd4fd44 100644 --- a/src/fastpd.cpp +++ b/src/fastpd.cpp @@ -210,132 +210,6 @@ LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) return leaf_data; } -double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) -{ - std::stack to_process; - to_process.push(0); - - double result = 0; - std::set to_explain_set(to_explain.begin(), to_explain.end()); - - while (!to_process.empty()) - { - int node_idx = to_process.top(); - to_process.pop(); - - NumericMatrix::Row current_node = tree(node_idx, _); - int current_feature = current_node[Index::FEATURE]; - double split = current_node[Index::SPLIT]; - - if (current_feature == -1) - { - std::set to_marginalize = {}; - if (!to_explain_set.empty()) - std::set_intersection( - to_explain_set.begin(), to_explain_set.end(), - leaf_data.encountered[node_idx].begin(), leaf_data.encountered[node_idx].end(), - std::inserter(to_marginalize, to_marginalize.begin())); - - double p = leaf_data.leafProbs[node_idx][to_marginalize]; - result += tree(node_idx, Index::QUALITY) * p; - continue; - } - - if (std::find(to_explain.begin(), to_explain.end(), current_feature) != to_explain.end()) - { - if (x[current_feature] <= split) - { - to_process.push(current_node[Index::YES]); - } - else - { - to_process.push(current_node[Index::NO]); - } - } - else - { - to_process.push(current_node[Index::YES]); - to_process.push(current_node[Index::NO]); - } - } - - return result; -} - -double augmentExpectationXgboost(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) -{ - std::stack to_process; - to_process.push(0); - - double result = 0; - std::set to_explain_set(to_explain.begin(), to_explain.end()); - - while (!to_process.empty()) - { - int node_idx = to_process.top(); - to_process.pop(); - - NumericMatrix::Row current_node = tree(node_idx, _); - int current_feature = current_node[Index::FEATURE]; - double split = current_node[Index::SPLIT]; - - if (current_feature == -1) - { - std::set to_marginalize = {}; - if (!to_explain_set.empty()) - std::set_intersection( - to_explain_set.begin(), to_explain_set.end(), - leaf_data.encountered[node_idx].begin(), leaf_data.encountered[node_idx].end(), - std::inserter(to_marginalize, to_marginalize.begin())); - - double p = leaf_data.leafProbs[node_idx][to_marginalize]; - result += tree(node_idx, Index::QUALITY) * p; - continue; - } - - if (std::find(to_explain.begin(), to_explain.end(), current_feature) != to_explain.end()) - { - if (x[current_feature] < split) - { - to_process.push(current_node[Index::YES]); - } - else - { - to_process.push(current_node[Index::NO]); - } - } - else - { - to_process.push(current_node[Index::YES]); - to_process.push(current_node[Index::NO]); - } - } - - return result; -} - -// [[Rcpp::export]] -double augmentAndTakeExpectation(NumericVector &x, NumericMatrix &dataset, NumericMatrix &tree, NumericVector &to_explain, bool is_ranger) -{ - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset); - return is_ranger ? augmentExpectationRanger(x, tree, to_explain, leaf_data) : augmentExpectationXgboost(x, tree, to_explain, leaf_data); -} - -// [[Rcpp::export]] -XPtr augmentTree(NumericMatrix &tree, NumericMatrix &dataset, bool is_ranger) -{ - LeafData *leaf_data = new LeafData(is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset)); // Dynamically allocate - XPtr ptr(leaf_data, true); // true enables automatic memory management - return ptr; -} - -// [[Rcpp::export]] -double augmentExpectation(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, SEXP leaf_data_ptr, bool is_ranger) -{ - const Rcpp::XPtr leaf_data(leaf_data_ptr); - return is_ranger ? augmentExpectationRanger(x, tree, to_explain, *leaf_data) : augmentExpectationXgboost(x, tree, to_explain, *leaf_data); -} - Rcpp::NumericMatrix recurseMarginalizeURanger( Rcpp::NumericMatrix &x, NumericMatrix &tree, std::vector> &U, unsigned int node, @@ -468,17 +342,6 @@ Rcpp::NumericMatrix recurseMarginalizeUXgboost( return mat; } -// [[Rcpp::export]] -Rcpp::NumericMatrix marginalizeAllSplittedSubsetsinTree( - Rcpp::NumericMatrix &x, - NumericMatrix &tree, - bool is_ranger) -{ - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); - std::vector> U = get_all_subsets_(leaf_data.all_encountered); - return is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); -} - void contributeFastPD( Rcpp::NumericMatrix &mat, Rcpp::NumericMatrix &m_all, diff --git a/src/fastpd_expectation.cpp b/src/fastpd_expectation.cpp new file mode 100644 index 0000000..9b1df0d --- /dev/null +++ b/src/fastpd_expectation.cpp @@ -0,0 +1,157 @@ + +#include +#include +#include +#include "../inst/include/glex.h" + +using namespace Rcpp; +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset); +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset); +std::vector> get_all_subsets_(std::set &set); + +Rcpp::NumericMatrix recurseMarginalizeURanger( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data); + +Rcpp::NumericMatrix recurseMarginalizeUXgboost( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data); + +// [[Rcpp::export]] +double augmentAndTakeExpectation(NumericVector &x, NumericMatrix &dataset, NumericMatrix &tree, NumericVector &to_explain, bool is_ranger) +{ + LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset); + return is_ranger ? augmentExpectationRanger(x, tree, to_explain, leaf_data) : augmentExpectationXgboost(x, tree, to_explain, leaf_data); +} + +// [[Rcpp::export]] +XPtr augmentTree(NumericMatrix &tree, NumericMatrix &dataset, bool is_ranger) +{ + LeafData *leaf_data = new LeafData(is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset)); // Dynamically allocate + XPtr ptr(leaf_data, true); // true enables automatic memory management + return ptr; +} + +// [[Rcpp::export]] +double augmentExpectation(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, SEXP leaf_data_ptr, bool is_ranger) +{ + const Rcpp::XPtr leaf_data(leaf_data_ptr); + return is_ranger ? augmentExpectationRanger(x, tree, to_explain, *leaf_data) : augmentExpectationXgboost(x, tree, to_explain, *leaf_data); +} + +// [[Rcpp::export]] +Rcpp::NumericMatrix marginalizeAllSplittedSubsetsinTree( + Rcpp::NumericMatrix &x, + NumericMatrix &tree, + bool is_ranger) +{ + LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); + std::vector> U = get_all_subsets_(leaf_data.all_encountered); + return is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); +} + +double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) +{ + std::stack to_process; + to_process.push(0); + + double result = 0; + std::set to_explain_set(to_explain.begin(), to_explain.end()); + + while (!to_process.empty()) + { + int node_idx = to_process.top(); + to_process.pop(); + + NumericMatrix::Row current_node = tree(node_idx, _); + int current_feature = current_node[Index::FEATURE]; + double split = current_node[Index::SPLIT]; + + if (current_feature == -1) + { + std::set to_marginalize = {}; + if (!to_explain_set.empty()) + std::set_intersection( + to_explain_set.begin(), to_explain_set.end(), + leaf_data.encountered[node_idx].begin(), leaf_data.encountered[node_idx].end(), + std::inserter(to_marginalize, to_marginalize.begin())); + + double p = leaf_data.leafProbs[node_idx][to_marginalize]; + result += tree(node_idx, Index::QUALITY) * p; + continue; + } + + if (std::find(to_explain.begin(), to_explain.end(), current_feature) != to_explain.end()) + { + if (x[current_feature] <= split) + { + to_process.push(current_node[Index::YES]); + } + else + { + to_process.push(current_node[Index::NO]); + } + } + else + { + to_process.push(current_node[Index::YES]); + to_process.push(current_node[Index::NO]); + } + } + + return result; +} + +double augmentExpectationXgboost(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) +{ + std::stack to_process; + to_process.push(0); + + double result = 0; + std::set to_explain_set(to_explain.begin(), to_explain.end()); + + while (!to_process.empty()) + { + int node_idx = to_process.top(); + to_process.pop(); + + NumericMatrix::Row current_node = tree(node_idx, _); + int current_feature = current_node[Index::FEATURE]; + double split = current_node[Index::SPLIT]; + + if (current_feature == -1) + { + std::set to_marginalize = {}; + if (!to_explain_set.empty()) + std::set_intersection( + to_explain_set.begin(), to_explain_set.end(), + leaf_data.encountered[node_idx].begin(), leaf_data.encountered[node_idx].end(), + std::inserter(to_marginalize, to_marginalize.begin())); + + double p = leaf_data.leafProbs[node_idx][to_marginalize]; + result += tree(node_idx, Index::QUALITY) * p; + continue; + } + + if (std::find(to_explain.begin(), to_explain.end(), current_feature) != to_explain.end()) + { + if (x[current_feature] < split) + { + to_process.push(current_node[Index::YES]); + } + else + { + to_process.push(current_node[Index::NO]); + } + } + else + { + to_process.push(current_node[Index::YES]); + to_process.push(current_node[Index::NO]); + } + } + + return result; +} From 2b9387ea3a43e41f2ae46a4ecece634cb493c2be Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 11:37:35 +0100 Subject: [PATCH 02/17] Split rest into multiple files --- src/fastpd.cpp | 420 ------------------------------------ src/fastpd_augment.cpp | 208 ++++++++++++++++++ src/fastpd_contribute.cpp | 36 ++++ src/fastpd_explain_tree.cpp | 70 ++++++ src/fastpd_recurse.cpp | 138 ++++++++++++ 5 files changed, 452 insertions(+), 420 deletions(-) delete mode 100644 src/fastpd.cpp create mode 100644 src/fastpd_augment.cpp create mode 100644 src/fastpd_contribute.cpp create mode 100644 src/fastpd_explain_tree.cpp create mode 100644 src/fastpd_recurse.cpp diff --git a/src/fastpd.cpp b/src/fastpd.cpp deleted file mode 100644 index dd4fd44..0000000 --- a/src/fastpd.cpp +++ /dev/null @@ -1,420 +0,0 @@ -#include -#include -#include -#include "../inst/include/glex.h" - -using namespace Rcpp; -std::vector> get_all_subsets_(std::set &set); - -void augmentTreeRecurseStepRanger( - AugmentedData passed_down, - LeafData &leaf_data, - NumericMatrix &tree, - NumericMatrix &dataset, - unsigned int node) -{ - NumericMatrix::Row current_node = tree(node, _); - - int current_feature = current_node[Index::FEATURE]; - double split = current_node[Index::SPLIT]; - - if (current_feature == -1) - { - // Leaf node - leaf_data.encountered[node] = passed_down.encountered; - leaf_data.leafProbs[node] = ProbsMap(); - for (const auto &[subset, path_dat] : passed_down.pathData) - { - leaf_data.leafProbs[node][subset] = (double)path_dat.size() / dataset.nrow(); - } - return; - } - else if (leaf_data.all_encountered.find(current_feature) == leaf_data.all_encountered.end()) - { - leaf_data.all_encountered.insert(current_feature); - } - - AugmentedData passed_down_yes = { - .encountered = passed_down.encountered, - .pathData = PathData(), - }; - AugmentedData passed_down_no = { - .encountered = passed_down.encountered, - .pathData = PathData(), - }; - - for (const auto &[subset, path_dat] : passed_down.pathData) - { - if (subset.find(current_feature) != subset.end()) - { - // Feature is in the path - passed_down_yes.pathData[subset] = path_dat; - passed_down_no.pathData[subset] = path_dat; - } - else - { - // Feature is not in the path - passed_down_yes.pathData[subset] = std::vector(); - passed_down_no.pathData[subset] = std::vector(); - for (int i : path_dat) - { - if (dataset(i, current_feature) <= split) - { - passed_down_yes.pathData[subset].push_back(i); - } - else - { - passed_down_no.pathData[subset].push_back(i); - } - } - } - } - - if (passed_down.encountered.find(current_feature) == passed_down.encountered.end()) - { - passed_down_yes.encountered.insert(current_feature); - passed_down_no.encountered.insert(current_feature); - - for (const auto &[subset, path_dat] : passed_down.pathData) - { - std::set to_add = subset; - to_add.insert(current_feature); - - passed_down_yes.pathData[to_add] = path_dat; - passed_down_no.pathData[to_add] = path_dat; - } - } - - unsigned int yes = current_node[Index::YES]; - unsigned int no = current_node[Index::NO]; - - augmentTreeRecurseStepRanger(passed_down_yes, leaf_data, tree, dataset, yes); - augmentTreeRecurseStepRanger(passed_down_no, leaf_data, tree, dataset, no); -} - -void augmentTreeRecurseStepXgboost( - AugmentedData passed_down, - LeafData &leaf_data, - NumericMatrix &tree, - NumericMatrix &dataset, - unsigned int node) -{ - NumericMatrix::Row current_node = tree(node, _); - - int current_feature = current_node[Index::FEATURE]; - double split = current_node[Index::SPLIT]; - - if (current_feature == -1) - { - // Leaf node - leaf_data.encountered[node] = passed_down.encountered; - leaf_data.leafProbs[node] = ProbsMap(); - for (const auto &[subset, path_dat] : passed_down.pathData) - { - leaf_data.leafProbs[node][subset] = (double)path_dat.size() / dataset.nrow(); - } - return; - } - else if (leaf_data.all_encountered.find(current_feature) == leaf_data.all_encountered.end()) - { - leaf_data.all_encountered.insert(current_feature); - } - - AugmentedData passed_down_yes = { - .encountered = passed_down.encountered, - .pathData = PathData(), - }; - AugmentedData passed_down_no = { - .encountered = passed_down.encountered, - .pathData = PathData(), - }; - - for (const auto &[subset, path_dat] : passed_down.pathData) - { - if (subset.find(current_feature) != subset.end()) - { - // Feature is in the path - passed_down_yes.pathData[subset] = path_dat; - passed_down_no.pathData[subset] = path_dat; - } - else - { - // Feature is not in the path - passed_down_yes.pathData[subset] = std::vector(); - passed_down_no.pathData[subset] = std::vector(); - for (int i : path_dat) - { - if (dataset(i, current_feature) < split) - { - passed_down_yes.pathData[subset].push_back(i); - } - else - { - passed_down_no.pathData[subset].push_back(i); - } - } - } - } - - if (passed_down.encountered.find(current_feature) == passed_down.encountered.end()) - { - passed_down_yes.encountered.insert(current_feature); - passed_down_no.encountered.insert(current_feature); - - for (const auto &[subset, path_dat] : passed_down.pathData) - { - std::set to_add = subset; - to_add.insert(current_feature); - - passed_down_yes.pathData[to_add] = path_dat; - passed_down_no.pathData[to_add] = path_dat; - } - } - - unsigned int yes = current_node[Index::YES]; - unsigned int no = current_node[Index::NO]; - - augmentTreeRecurseStepXgboost(passed_down_yes, leaf_data, tree, dataset, yes); - augmentTreeRecurseStepXgboost(passed_down_no, leaf_data, tree, dataset, no); -} - -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) -{ - AugmentedData result; - - AugmentedData to_pass_down = { - .encountered = std::set(), - .pathData = PathData(), - }; - to_pass_down.pathData[{}] = std::vector(dataset.nrow()); - std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); - LeafData leaf_data = LeafData(); - - augmentTreeRecurseStepRanger(to_pass_down, leaf_data, tree, dataset, 0); - return leaf_data; -} - -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) -{ - AugmentedData result; - - AugmentedData to_pass_down = { - .encountered = std::set(), - .pathData = PathData(), - }; - to_pass_down.pathData[{}] = std::vector(dataset.nrow()); - std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); - LeafData leaf_data = LeafData(); - - augmentTreeRecurseStepXgboost(to_pass_down, leaf_data, tree, dataset, 0); - return leaf_data; -} - -Rcpp::NumericMatrix recurseMarginalizeURanger( - Rcpp::NumericMatrix &x, NumericMatrix &tree, - std::vector> &U, unsigned int node, - LeafData &leaf_data) -{ - NumericMatrix::Row current_node = tree(node, _); - int current_feature = current_node[Index::FEATURE]; - - // Start with all 0 - unsigned int n = x.nrow(); - Rcpp::NumericMatrix mat(n, U.size()); - - // If leaf, just return value - if (current_feature == -1) - { - for (unsigned int j = 0; j < U.size(); ++j) - { - std::set to_explain = {}; - std::set_difference( - leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), - U[j].begin(), U[j].end(), - std::inserter(to_explain, to_explain.begin())); - - double p = leaf_data.leafProbs[node][to_explain]; - double quality = tree(node, Index::QUALITY); - double expected_value = quality * p; - - Rcpp::NumericMatrix::Column to_fill = mat(Rcpp::_, j); - std::fill(to_fill.begin(), to_fill.end(), expected_value); - } - } - else - { - unsigned int yes = current_node[Index::YES]; - unsigned int no = current_node[Index::NO]; - - // Call both children, they give a matrix each of all obs and subsets - Rcpp::NumericMatrix mat_yes = recurseMarginalizeURanger(x, tree, U, yes, leaf_data); - Rcpp::NumericMatrix mat_no = recurseMarginalizeURanger(x, tree, U, no, leaf_data); - - for (unsigned int j = 0; j < U.size(); ++j) - { - // Is splitting feature out in this subset? - - if (U[j].find(current_feature) != U[j].end()) - { - // For subsets where feature is out, weighted average of left/right - for (unsigned int i = 0; i < n; ++i) - mat(i, j) += mat_yes(i, j) + mat_no(i, j); - } - else - { - double split = current_node[Index::SPLIT]; - // For subsets where feature is in, split to left/right - for (unsigned int i = 0; i < n; ++i) - { - mat(i, j) += (x(i, current_feature) <= split) ? mat_yes(i, j) : mat_no(i, j); - } - } - } - } - - // Return combined matrix - return mat; -} - -Rcpp::NumericMatrix recurseMarginalizeUXgboost( - Rcpp::NumericMatrix &x, NumericMatrix &tree, - std::vector> &U, unsigned int node, - LeafData &leaf_data) -{ - NumericMatrix::Row current_node = tree(node, _); - int current_feature = current_node[Index::FEATURE]; - - // Start with all 0 - unsigned int n = x.nrow(); - Rcpp::NumericMatrix mat(n, U.size()); - - // If leaf, just return value - if (current_feature == -1) - { - for (unsigned int j = 0; j < U.size(); ++j) - { - std::set to_explain = {}; - std::set_difference( - leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), - U[j].begin(), U[j].end(), - std::inserter(to_explain, to_explain.begin())); - - double p = leaf_data.leafProbs[node][to_explain]; - double quality = tree(node, Index::QUALITY); - double expected_value = quality * p; - - Rcpp::NumericMatrix::Column to_fill = mat(Rcpp::_, j); - std::fill(to_fill.begin(), to_fill.end(), expected_value); - } - } - else - { - unsigned int yes = current_node[Index::YES]; - unsigned int no = current_node[Index::NO]; - - // Call both children, they give a matrix each of all obs and subsets - Rcpp::NumericMatrix mat_yes = recurseMarginalizeUXgboost(x, tree, U, yes, leaf_data); - Rcpp::NumericMatrix mat_no = recurseMarginalizeUXgboost(x, tree, U, no, leaf_data); - - for (unsigned int j = 0; j < U.size(); ++j) - { - // Is splitting feature out in this subset? - - if (U[j].find(current_feature) != U[j].end()) - { - // For subsets where feature is out, weighted average of left/right - for (unsigned int i = 0; i < n; ++i) - mat(i, j) += mat_yes(i, j) + mat_no(i, j); - } - else - { - double split = current_node[Index::SPLIT]; - // For subsets where feature is in, split to left/right - for (unsigned int i = 0; i < n; ++i) - { - mat(i, j) += (x(i, current_feature) < split) ? mat_yes(i, j) : mat_no(i, j); - } - } - } - } - - // Return combined matrix - return mat; -} - -void contributeFastPD( - Rcpp::NumericMatrix &mat, - Rcpp::NumericMatrix &m_all, - std::set &S, - std::set &T, - std::vector> &T_subsets, - unsigned int colnum) -{ - std::set sTS; - std::set_difference(T.begin(), T.end(), S.begin(), S.end(), std::inserter(sTS, sTS.begin())); - - for (unsigned int i = 0; i < T_subsets.size(); ++i) - { - std::set U = T_subsets[i]; - if (sTS.size() != 0) - { - std::set ssTSU; - std::set_difference(sTS.begin(), sTS.end(), U.begin(), U.end(), std::inserter(ssTSU, ssTSU.begin())); - if (ssTSU.size() != 0) - continue; - } - - std::set sTU; - std::set_difference(T.begin(), T.end(), U.begin(), U.end(), std::inserter(sTU, sTU.begin())); - - if (((S.size() - sTU.size()) % 2) == 0) - m_all(Rcpp::_, colnum) = m_all(Rcpp::_, colnum) + mat(Rcpp::_, i); - else - m_all(Rcpp::_, colnum) = m_all(Rcpp::_, colnum) - mat(Rcpp::_, i); - } -} - -// [[Rcpp::export]] -Rcpp::NumericMatrix explainTreeFastPD( - Rcpp::NumericMatrix &x, - NumericMatrix &tree, - Rcpp::List &to_explain_list, - bool is_ranger) -{ - std::vector> to_explain; - for (int i = 0; i < to_explain_list.size(); i++) - { - std::set to_explain_set = std::set(as(to_explain_list[i]).begin(), as(to_explain_list[i]).end()); - to_explain.push_back(to_explain_set); - } - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); - std::vector> U = get_all_subsets_(leaf_data.all_encountered); - NumericMatrix mat = is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); - - unsigned int to_explain_size = to_explain.size(); - NumericMatrix m_all = NumericMatrix(x.nrow(), to_explain_size); - CharacterVector m_all_col_names = CharacterVector(to_explain_size); - CharacterVector x_col_names = colnames(x); - - for (int S_idx = 0; S_idx < to_explain_size; S_idx++) - { - std::set S = to_explain[S_idx]; - - if (int k = S.size(); k != 0) - { - auto it = S.begin(); - std::ostringstream oss; - for (int i = 0; i < k - 1; i++, it++) - oss << x_col_names[*it] << ":"; - oss << x_col_names[*it]; - m_all_col_names[S_idx] = oss.str(); - } - - if (std::find(U.begin(), U.end(), S) == U.end()) - continue; - // std::set S_set = std::set(S.begin(), S.end()); - contributeFastPD(mat, m_all, S, leaf_data.all_encountered, U, S_idx); - } - colnames(m_all) = m_all_col_names; - return m_all; -} diff --git a/src/fastpd_augment.cpp b/src/fastpd_augment.cpp new file mode 100644 index 0000000..866d5b5 --- /dev/null +++ b/src/fastpd_augment.cpp @@ -0,0 +1,208 @@ +#include +#include "../inst/include/glex.h" + +using namespace Rcpp; + +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) +{ + AugmentedData result; + + AugmentedData to_pass_down = { + .encountered = std::set(), + .pathData = PathData(), + }; + to_pass_down.pathData[{}] = std::vector(dataset.nrow()); + std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); + LeafData leaf_data = LeafData(); + + augmentTreeRecurseStepRanger(to_pass_down, leaf_data, tree, dataset, 0); + return leaf_data; +} + +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) +{ + AugmentedData result; + + AugmentedData to_pass_down = { + .encountered = std::set(), + .pathData = PathData(), + }; + to_pass_down.pathData[{}] = std::vector(dataset.nrow()); + std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); + LeafData leaf_data = LeafData(); + + augmentTreeRecurseStepXgboost(to_pass_down, leaf_data, tree, dataset, 0); + return leaf_data; +} + +void augmentTreeRecurseStepRanger( + AugmentedData passed_down, + LeafData &leaf_data, + NumericMatrix &tree, + NumericMatrix &dataset, + unsigned int node) +{ + NumericMatrix::Row current_node = tree(node, _); + + int current_feature = current_node[Index::FEATURE]; + double split = current_node[Index::SPLIT]; + + if (current_feature == -1) + { + // Leaf node + leaf_data.encountered[node] = passed_down.encountered; + leaf_data.leafProbs[node] = ProbsMap(); + for (const auto &[subset, path_dat] : passed_down.pathData) + { + leaf_data.leafProbs[node][subset] = (double)path_dat.size() / dataset.nrow(); + } + return; + } + else if (leaf_data.all_encountered.find(current_feature) == leaf_data.all_encountered.end()) + { + leaf_data.all_encountered.insert(current_feature); + } + + AugmentedData passed_down_yes = { + .encountered = passed_down.encountered, + .pathData = PathData(), + }; + AugmentedData passed_down_no = { + .encountered = passed_down.encountered, + .pathData = PathData(), + }; + + for (const auto &[subset, path_dat] : passed_down.pathData) + { + if (subset.find(current_feature) != subset.end()) + { + // Feature is in the path + passed_down_yes.pathData[subset] = path_dat; + passed_down_no.pathData[subset] = path_dat; + } + else + { + // Feature is not in the path + passed_down_yes.pathData[subset] = std::vector(); + passed_down_no.pathData[subset] = std::vector(); + for (int i : path_dat) + { + if (dataset(i, current_feature) <= split) + { + passed_down_yes.pathData[subset].push_back(i); + } + else + { + passed_down_no.pathData[subset].push_back(i); + } + } + } + } + + if (passed_down.encountered.find(current_feature) == passed_down.encountered.end()) + { + passed_down_yes.encountered.insert(current_feature); + passed_down_no.encountered.insert(current_feature); + + for (const auto &[subset, path_dat] : passed_down.pathData) + { + std::set to_add = subset; + to_add.insert(current_feature); + + passed_down_yes.pathData[to_add] = path_dat; + passed_down_no.pathData[to_add] = path_dat; + } + } + + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + augmentTreeRecurseStepRanger(passed_down_yes, leaf_data, tree, dataset, yes); + augmentTreeRecurseStepRanger(passed_down_no, leaf_data, tree, dataset, no); +} + +void augmentTreeRecurseStepXgboost( + AugmentedData passed_down, + LeafData &leaf_data, + NumericMatrix &tree, + NumericMatrix &dataset, + unsigned int node) +{ + NumericMatrix::Row current_node = tree(node, _); + + int current_feature = current_node[Index::FEATURE]; + double split = current_node[Index::SPLIT]; + + if (current_feature == -1) + { + // Leaf node + leaf_data.encountered[node] = passed_down.encountered; + leaf_data.leafProbs[node] = ProbsMap(); + for (const auto &[subset, path_dat] : passed_down.pathData) + { + leaf_data.leafProbs[node][subset] = (double)path_dat.size() / dataset.nrow(); + } + return; + } + else if (leaf_data.all_encountered.find(current_feature) == leaf_data.all_encountered.end()) + { + leaf_data.all_encountered.insert(current_feature); + } + + AugmentedData passed_down_yes = { + .encountered = passed_down.encountered, + .pathData = PathData(), + }; + AugmentedData passed_down_no = { + .encountered = passed_down.encountered, + .pathData = PathData(), + }; + + for (const auto &[subset, path_dat] : passed_down.pathData) + { + if (subset.find(current_feature) != subset.end()) + { + // Feature is in the path + passed_down_yes.pathData[subset] = path_dat; + passed_down_no.pathData[subset] = path_dat; + } + else + { + // Feature is not in the path + passed_down_yes.pathData[subset] = std::vector(); + passed_down_no.pathData[subset] = std::vector(); + for (int i : path_dat) + { + if (dataset(i, current_feature) < split) + { + passed_down_yes.pathData[subset].push_back(i); + } + else + { + passed_down_no.pathData[subset].push_back(i); + } + } + } + } + + if (passed_down.encountered.find(current_feature) == passed_down.encountered.end()) + { + passed_down_yes.encountered.insert(current_feature); + passed_down_no.encountered.insert(current_feature); + + for (const auto &[subset, path_dat] : passed_down.pathData) + { + std::set to_add = subset; + to_add.insert(current_feature); + + passed_down_yes.pathData[to_add] = path_dat; + passed_down_no.pathData[to_add] = path_dat; + } + } + + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + augmentTreeRecurseStepXgboost(passed_down_yes, leaf_data, tree, dataset, yes); + augmentTreeRecurseStepXgboost(passed_down_no, leaf_data, tree, dataset, no); +} diff --git a/src/fastpd_contribute.cpp b/src/fastpd_contribute.cpp new file mode 100644 index 0000000..6fc4289 --- /dev/null +++ b/src/fastpd_contribute.cpp @@ -0,0 +1,36 @@ +#include +#include "../inst/include/glex.h" + +using namespace Rcpp; + +void contributeFastPD( + Rcpp::NumericMatrix &mat, + Rcpp::NumericMatrix &m_all, + std::set &S, + std::set &T, + std::vector> &T_subsets, + unsigned int colnum) +{ + std::set sTS; + std::set_difference(T.begin(), T.end(), S.begin(), S.end(), std::inserter(sTS, sTS.begin())); + + for (unsigned int i = 0; i < T_subsets.size(); ++i) + { + std::set U = T_subsets[i]; + if (sTS.size() != 0) + { + std::set ssTSU; + std::set_difference(sTS.begin(), sTS.end(), U.begin(), U.end(), std::inserter(ssTSU, ssTSU.begin())); + if (ssTSU.size() != 0) + continue; + } + + std::set sTU; + std::set_difference(T.begin(), T.end(), U.begin(), U.end(), std::inserter(sTU, sTU.begin())); + + if (((S.size() - sTU.size()) % 2) == 0) + m_all(Rcpp::_, colnum) = m_all(Rcpp::_, colnum) + mat(Rcpp::_, i); + else + m_all(Rcpp::_, colnum) = m_all(Rcpp::_, colnum) - mat(Rcpp::_, i); + } +} diff --git a/src/fastpd_explain_tree.cpp b/src/fastpd_explain_tree.cpp new file mode 100644 index 0000000..3054a44 --- /dev/null +++ b/src/fastpd_explain_tree.cpp @@ -0,0 +1,70 @@ +#include +#include "../inst/include/glex.h" + +using namespace Rcpp; + +std::vector> get_all_subsets_(std::set &set); + +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset); +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset); +Rcpp::NumericMatrix recurseMarginalizeURanger( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data); +Rcpp::NumericMatrix recurseMarginalizeUXgboost( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data); + +void contributeFastPD( + Rcpp::NumericMatrix &mat, + Rcpp::NumericMatrix &m_all, + std::set &S, + std::set &T, + std::vector> &T_subsets, + unsigned int colnum); + +// [[Rcpp::export]] +Rcpp::NumericMatrix explainTreeFastPD( + Rcpp::NumericMatrix &x, + NumericMatrix &tree, + Rcpp::List &to_explain_list, + bool is_ranger) +{ + std::vector> to_explain; + for (int i = 0; i < to_explain_list.size(); i++) + { + std::set to_explain_set = std::set(as(to_explain_list[i]).begin(), as(to_explain_list[i]).end()); + to_explain.push_back(to_explain_set); + } + LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); + std::vector> U = get_all_subsets_(leaf_data.all_encountered); + NumericMatrix mat = is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); + + unsigned int to_explain_size = to_explain.size(); + NumericMatrix m_all = NumericMatrix(x.nrow(), to_explain_size); + CharacterVector m_all_col_names = CharacterVector(to_explain_size); + CharacterVector x_col_names = colnames(x); + + for (int S_idx = 0; S_idx < to_explain_size; S_idx++) + { + std::set S = to_explain[S_idx]; + + if (int k = S.size(); k != 0) + { + auto it = S.begin(); + std::ostringstream oss; + for (int i = 0; i < k - 1; i++, it++) + oss << x_col_names[*it] << ":"; + oss << x_col_names[*it]; + m_all_col_names[S_idx] = oss.str(); + } + + if (std::find(U.begin(), U.end(), S) == U.end()) + continue; + // std::set S_set = std::set(S.begin(), S.end()); + contributeFastPD(mat, m_all, S, leaf_data.all_encountered, U, S_idx); + } + colnames(m_all) = m_all_col_names; + return m_all; +} diff --git a/src/fastpd_recurse.cpp b/src/fastpd_recurse.cpp new file mode 100644 index 0000000..2f2d09b --- /dev/null +++ b/src/fastpd_recurse.cpp @@ -0,0 +1,138 @@ +#include +#include +#include "../inst/include/glex.h" + +using namespace Rcpp; +std::vector> get_all_subsets_(std::set &set); + +Rcpp::NumericMatrix recurseMarginalizeURanger( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data) +{ + NumericMatrix::Row current_node = tree(node, _); + int current_feature = current_node[Index::FEATURE]; + + // Start with all 0 + unsigned int n = x.nrow(); + Rcpp::NumericMatrix mat(n, U.size()); + + // If leaf, just return value + if (current_feature == -1) + { + for (unsigned int j = 0; j < U.size(); ++j) + { + std::set to_explain = {}; + std::set_difference( + leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), + U[j].begin(), U[j].end(), + std::inserter(to_explain, to_explain.begin())); + + double p = leaf_data.leafProbs[node][to_explain]; + double quality = tree(node, Index::QUALITY); + double expected_value = quality * p; + + Rcpp::NumericMatrix::Column to_fill = mat(Rcpp::_, j); + std::fill(to_fill.begin(), to_fill.end(), expected_value); + } + } + else + { + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + // Call both children, they give a matrix each of all obs and subsets + Rcpp::NumericMatrix mat_yes = recurseMarginalizeURanger(x, tree, U, yes, leaf_data); + Rcpp::NumericMatrix mat_no = recurseMarginalizeURanger(x, tree, U, no, leaf_data); + + for (unsigned int j = 0; j < U.size(); ++j) + { + // Is splitting feature out in this subset? + + if (U[j].find(current_feature) != U[j].end()) + { + // For subsets where feature is out, weighted average of left/right + for (unsigned int i = 0; i < n; ++i) + mat(i, j) += mat_yes(i, j) + mat_no(i, j); + } + else + { + double split = current_node[Index::SPLIT]; + // For subsets where feature is in, split to left/right + for (unsigned int i = 0; i < n; ++i) + { + mat(i, j) += (x(i, current_feature) <= split) ? mat_yes(i, j) : mat_no(i, j); + } + } + } + } + + // Return combined matrix + return mat; +} + +Rcpp::NumericMatrix recurseMarginalizeUXgboost( + Rcpp::NumericMatrix &x, NumericMatrix &tree, + std::vector> &U, unsigned int node, + LeafData &leaf_data) +{ + NumericMatrix::Row current_node = tree(node, _); + int current_feature = current_node[Index::FEATURE]; + + // Start with all 0 + unsigned int n = x.nrow(); + Rcpp::NumericMatrix mat(n, U.size()); + + // If leaf, just return value + if (current_feature == -1) + { + for (unsigned int j = 0; j < U.size(); ++j) + { + std::set to_explain = {}; + std::set_difference( + leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), + U[j].begin(), U[j].end(), + std::inserter(to_explain, to_explain.begin())); + + double p = leaf_data.leafProbs[node][to_explain]; + double quality = tree(node, Index::QUALITY); + double expected_value = quality * p; + + Rcpp::NumericMatrix::Column to_fill = mat(Rcpp::_, j); + std::fill(to_fill.begin(), to_fill.end(), expected_value); + } + } + else + { + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + // Call both children, they give a matrix each of all obs and subsets + Rcpp::NumericMatrix mat_yes = recurseMarginalizeUXgboost(x, tree, U, yes, leaf_data); + Rcpp::NumericMatrix mat_no = recurseMarginalizeUXgboost(x, tree, U, no, leaf_data); + + for (unsigned int j = 0; j < U.size(); ++j) + { + // Is splitting feature out in this subset? + + if (U[j].find(current_feature) != U[j].end()) + { + // For subsets where feature is out, weighted average of left/right + for (unsigned int i = 0; i < n; ++i) + mat(i, j) += mat_yes(i, j) + mat_no(i, j); + } + else + { + double split = current_node[Index::SPLIT]; + // For subsets where feature is in, split to left/right + for (unsigned int i = 0; i < n; ++i) + { + mat(i, j) += (x(i, current_feature) < split) ? mat_yes(i, j) : mat_no(i, j); + } + } + } + } + + // Return combined matrix + return mat; +} From b4a5ca84f94652fa8a89ec42a70736b67f0b57f1 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 11:39:36 +0100 Subject: [PATCH 03/17] Add helper.cpp --- src/glex.cpp | 36 ++---------------------------------- src/helper.cpp | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 src/helper.cpp diff --git a/src/glex.cpp b/src/glex.cpp index 2877f40..7e009a9 100644 --- a/src/glex.cpp +++ b/src/glex.cpp @@ -1,40 +1,8 @@ #include #include -bool containsNumber(Rcpp::IntegerVector &vec, int target) -{ - // Use std::find to check if the target is in the vector - return std::find(vec.begin(), vec.end(), target) != vec.end(); -} - -Rcpp::IntegerVector removeValue(Rcpp::IntegerVector &vec, int valueToRemove) -{ - // Use std::remove to move the value to the end - vec.erase(std::remove(vec.begin(), vec.end(), valueToRemove), vec.end()); - - return vec; -} - -std::vector> get_all_subsets_(std::set &set) -{ - std::vector> result; - unsigned int n = set.size(); - for (unsigned int i = 0; i < (1 << n); ++i) - { - std::set subset; - auto it = set.begin(); - for (unsigned int j = 0; j < n; ++j) - { - if (i & (1 << j)) - { - subset.insert(*it); - } - it++; - } - result.push_back(subset); - } - return result; -} +bool containsNumber(Rcpp::IntegerVector &vec, int target); +Rcpp::IntegerVector removeValue(Rcpp::IntegerVector &vec, int valueToRemove); // [[Rcpp::export]] double empProbFunction(Rcpp::NumericMatrix &x, Rcpp::IntegerVector &coords, Rcpp::NumericVector &lb, Rcpp::NumericVector &ub) diff --git a/src/helper.cpp b/src/helper.cpp new file mode 100644 index 0000000..b764e81 --- /dev/null +++ b/src/helper.cpp @@ -0,0 +1,36 @@ +#include + +bool containsNumber(Rcpp::IntegerVector &vec, int target) +{ + // Use std::find to check if the target is in the vector + return std::find(vec.begin(), vec.end(), target) != vec.end(); +} + +Rcpp::IntegerVector removeValue(Rcpp::IntegerVector &vec, int valueToRemove) +{ + // Use std::remove to move the value to the end + vec.erase(std::remove(vec.begin(), vec.end(), valueToRemove), vec.end()); + + return vec; +} + +std::vector> get_all_subsets_(std::set &set) +{ + std::vector> result; + unsigned int n = set.size(); + for (unsigned int i = 0; i < (1 << n); ++i) + { + std::set subset; + auto it = set.begin(); + for (unsigned int j = 0; j < n; ++j) + { + if (i & (1 << j)) + { + subset.insert(*it); + } + it++; + } + result.push_back(subset); + } + return result; +} From b079e3b4bb347b04b9cd399d3e2b6764bcfc53e2 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 11:51:27 +0100 Subject: [PATCH 04/17] Augment up to max_interaction --- src/fastpd_augment.cpp | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/fastpd_augment.cpp b/src/fastpd_augment.cpp index 866d5b5..f9320c2 100644 --- a/src/fastpd_augment.cpp +++ b/src/fastpd_augment.cpp @@ -3,7 +3,7 @@ using namespace Rcpp; -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) { AugmentedData result; @@ -15,11 +15,11 @@ LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); LeafData leaf_data = LeafData(); - augmentTreeRecurseStepRanger(to_pass_down, leaf_data, tree, dataset, 0); + augmentTreeRecurseStepRanger(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); return leaf_data; } -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) { AugmentedData result; @@ -31,16 +31,27 @@ LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); LeafData leaf_data = LeafData(); - augmentTreeRecurseStepXgboost(to_pass_down, leaf_data, tree, dataset, 0); + augmentTreeRecurseStepXgboost(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); return leaf_data; } +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) +{ + return augmentTreeXgboost(tree, dataset, dataset.ncol()); +} + +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) +{ + return augmentTreeRanger(tree, dataset, dataset.ncol()); +} + void augmentTreeRecurseStepRanger( AugmentedData passed_down, LeafData &leaf_data, NumericMatrix &tree, NumericMatrix &dataset, - unsigned int node) + unsigned int node, + unsigned int max_interaction) { NumericMatrix::Row current_node = tree(node, _); @@ -106,6 +117,9 @@ void augmentTreeRecurseStepRanger( for (const auto &[subset, path_dat] : passed_down.pathData) { + if (subset.size() >= max_interaction) + continue; + std::set to_add = subset; to_add.insert(current_feature); @@ -117,8 +131,8 @@ void augmentTreeRecurseStepRanger( unsigned int yes = current_node[Index::YES]; unsigned int no = current_node[Index::NO]; - augmentTreeRecurseStepRanger(passed_down_yes, leaf_data, tree, dataset, yes); - augmentTreeRecurseStepRanger(passed_down_no, leaf_data, tree, dataset, no); + augmentTreeRecurseStepRanger(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); + augmentTreeRecurseStepRanger(passed_down_no, leaf_data, tree, dataset, no, max_interaction); } void augmentTreeRecurseStepXgboost( @@ -126,7 +140,8 @@ void augmentTreeRecurseStepXgboost( LeafData &leaf_data, NumericMatrix &tree, NumericMatrix &dataset, - unsigned int node) + unsigned int node, + unsigned int max_interaction) { NumericMatrix::Row current_node = tree(node, _); @@ -192,6 +207,8 @@ void augmentTreeRecurseStepXgboost( for (const auto &[subset, path_dat] : passed_down.pathData) { + if (subset.size() >= max_interaction) + continue; std::set to_add = subset; to_add.insert(current_feature); @@ -203,6 +220,6 @@ void augmentTreeRecurseStepXgboost( unsigned int yes = current_node[Index::YES]; unsigned int no = current_node[Index::NO]; - augmentTreeRecurseStepXgboost(passed_down_yes, leaf_data, tree, dataset, yes); - augmentTreeRecurseStepXgboost(passed_down_no, leaf_data, tree, dataset, no); + augmentTreeRecurseStepXgboost(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); + augmentTreeRecurseStepXgboost(passed_down_no, leaf_data, tree, dataset, no, max_interaction); } From 6d577ed2417d90553a142e2fcb1e31c3d282baba Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 11:58:01 +0100 Subject: [PATCH 05/17] Marginalize by looking at S - features to explain --- src/fastpd_contribute.cpp | 26 ++++++++ src/fastpd_recurse.cpp | 133 +++++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/src/fastpd_contribute.cpp b/src/fastpd_contribute.cpp index 6fc4289..2ef91f5 100644 --- a/src/fastpd_contribute.cpp +++ b/src/fastpd_contribute.cpp @@ -2,6 +2,32 @@ #include "../inst/include/glex.h" using namespace Rcpp; +std::vector> get_all_subsets_(std::set &set); + +void contributeFastPD2( + NumericMatrix &mat, + NumericMatrix &m_all, + std::set &S, + std::vector> &T_subsets, + unsigned int colnum) +{ + std::vector> Vs = get_all_subsets_(S); + for (unsigned int i = 0; i < Vs.size(); ++i) + { + std::set V = Vs[i]; + auto it = std::find(T_subsets.begin(), T_subsets.end(), V); + unsigned int idx = std::distance(T_subsets.begin(), it); + + if ((S.size() - V.size()) % 2 == 0) + { + m_all(_, colnum) = m_all(_, colnum) + mat(_, idx); + } + else + { + m_all(_, colnum) = m_all(_, colnum) - mat(_, idx); + } + } +} void contributeFastPD( Rcpp::NumericMatrix &mat, diff --git a/src/fastpd_recurse.cpp b/src/fastpd_recurse.cpp index 2f2d09b..fc6aa18 100644 --- a/src/fastpd_recurse.cpp +++ b/src/fastpd_recurse.cpp @@ -3,7 +3,138 @@ #include "../inst/include/glex.h" using namespace Rcpp; -std::vector> get_all_subsets_(std::set &set); + +NumericMatrix recurseMarginalizeSRanger( + NumericMatrix &x, NumericMatrix &tree, + std::vector> &Ss, unsigned int node, + LeafData &leaf_data) +{ + NumericMatrix::Row current_node = tree(node, _); + int current_feature = current_node[Index::FEATURE]; + + // Start with all 0 + unsigned int n = x.nrow(); + NumericMatrix mat(n, Ss.size()); + + // If leaf, just return value + if (current_feature == -1) + { + for (unsigned int j = 0; j < Ss.size(); ++j) + { + std::set to_explain = {}; + std::set_intersection( + Ss[j].begin(), Ss[j].end(), + leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), + std::inserter(to_explain, to_explain.begin())); + + double p = leaf_data.leafProbs[node][to_explain]; + double quality = tree(node, Index::QUALITY); + double expected_value = quality * p; + + NumericMatrix::Column to_fill = mat(_, j); + std::fill(to_fill.begin(), to_fill.end(), expected_value); + } + } + else + { + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + // Call both children, they give a matrix each of all obs and subsets + NumericMatrix mat_yes = recurseMarginalizeSRanger(x, tree, Ss, yes, leaf_data); + NumericMatrix mat_no = recurseMarginalizeSRanger(x, tree, Ss, no, leaf_data); + + for (unsigned int j = 0; j < Ss.size(); ++j) + { + // Is splitting feature out in this subset? + + if (Ss[j].find(current_feature) == Ss[j].end()) + { + // For subsets where feature is out, weighted average of left/right + for (unsigned int i = 0; i < n; ++i) + mat(i, j) += mat_yes(i, j) + mat_no(i, j); + } + else + { + double split = current_node[Index::SPLIT]; + // For subsets where feature is in, split to left/right + for (unsigned int i = 0; i < n; ++i) + { + mat(i, j) += (x(i, current_feature) <= split) ? mat_yes(i, j) : mat_no(i, j); + } + } + } + } + + // Return combined matrix + return mat; +} + +NumericMatrix recurseMarginalizeSXgboost( + NumericMatrix &x, NumericMatrix &tree, + std::vector> &Ss, unsigned int node, + LeafData &leaf_data) +{ + NumericMatrix::Row current_node = tree(node, _); + int current_feature = current_node[Index::FEATURE]; + + // Start with all 0 + unsigned int n = x.nrow(); + NumericMatrix mat(n, Ss.size()); + + // If leaf, just return value + if (current_feature == -1) + { + for (unsigned int j = 0; j < Ss.size(); ++j) + { + std::set to_explain = {}; + std::set_intersection( + Ss[j].begin(), Ss[j].end(), + leaf_data.encountered[node].begin(), leaf_data.encountered[node].end(), + std::inserter(to_explain, to_explain.begin())); + + double p = leaf_data.leafProbs[node][to_explain]; + double quality = tree(node, Index::QUALITY); + double expected_value = quality * p; + + NumericMatrix::Column to_fill = mat(_, j); + std::fill(to_fill.begin(), to_fill.end(), expected_value); + } + } + else + { + unsigned int yes = current_node[Index::YES]; + unsigned int no = current_node[Index::NO]; + + // Call both children, they give a matrix each of all obs and subsets + NumericMatrix mat_yes = recurseMarginalizeSRanger(x, tree, Ss, yes, leaf_data); + NumericMatrix mat_no = recurseMarginalizeSRanger(x, tree, Ss, no, leaf_data); + + for (unsigned int j = 0; j < Ss.size(); ++j) + { + // Is splitting feature out in this subset? + + if (Ss[j].find(current_feature) == Ss[j].end()) + { + // For subsets where feature is out, weighted average of left/right + for (unsigned int i = 0; i < n; ++i) + mat(i, j) += mat_yes(i, j) + mat_no(i, j); + } + else + { + double split = current_node[Index::SPLIT]; + // For subsets where feature is in, split to left/right + for (unsigned int i = 0; i < n; ++i) + { + mat(i, j) += (x(i, current_feature) < split) ? mat_yes(i, j) : mat_no(i, j); + } + } + } + } + + // Return combined matrix + return mat; +} Rcpp::NumericMatrix recurseMarginalizeURanger( Rcpp::NumericMatrix &x, NumericMatrix &tree, From 24d8d4914c339680a7b6f6f801b778d1df026e47 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 12:05:28 +0100 Subject: [PATCH 06/17] Add get_all_subsets up to max size --- src/helper.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/helper.cpp b/src/helper.cpp index b764e81..cb60ce1 100644 --- a/src/helper.cpp +++ b/src/helper.cpp @@ -34,3 +34,41 @@ std::vector> get_all_subsets_(std::set &set } return result; } + +void recurse_subset( + const std::set &arr, + std::set ¤tSubset, + std::vector> &subsets, + int maxSize, + int index) +{ + // Add the current subset to the result if it's within the maxSize limit + if (currentSubset.size() <= maxSize) + { + subsets.push_back(currentSubset); + } + + // Stop recursion if maxSize is reached + if (currentSubset.size() == maxSize || index >= arr.size()) + return; + + // Recursively generate subsets starting from the current index + auto it = std::next(arr.begin(), index); + do + { + currentSubset.insert(*it); + recurse_subset(arr, currentSubset, subsets, maxSize, ++index); + currentSubset.erase(*it); + ++it; + } while (index < arr.size()); +} + +std::vector> get_all_subsets( + std::set &arr, + unsigned int maxSize = UINT_MAX) +{ + std::vector> subsets = {}; + std::set currentSubset = {}; + recurse_subset(arr, currentSubset, subsets, maxSize, 0); + return subsets; +} From 0f6a7fd9b4d1022b6761f4aa6817db3d3af3282c Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 12:06:44 +0100 Subject: [PATCH 07/17] Declare functions --- src/fastpd_augment.cpp | 14 ++++++++++++++ src/fastpd_expectation.cpp | 3 +++ 2 files changed, 17 insertions(+) diff --git a/src/fastpd_augment.cpp b/src/fastpd_augment.cpp index f9320c2..2d1e586 100644 --- a/src/fastpd_augment.cpp +++ b/src/fastpd_augment.cpp @@ -2,6 +2,20 @@ #include "../inst/include/glex.h" using namespace Rcpp; +void augmentTreeRecurseStepRanger( + AugmentedData passed_down, + LeafData &leaf_data, + NumericMatrix &tree, + NumericMatrix &dataset, + unsigned int node, + unsigned int max_interaction); +void augmentTreeRecurseStepXgboost( + AugmentedData passed_down, + LeafData &leaf_data, + NumericMatrix &tree, + NumericMatrix &dataset, + unsigned int node, + unsigned int max_interaction); LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) { diff --git a/src/fastpd_expectation.cpp b/src/fastpd_expectation.cpp index 9b1df0d..06968c1 100644 --- a/src/fastpd_expectation.cpp +++ b/src/fastpd_expectation.cpp @@ -19,6 +19,9 @@ Rcpp::NumericMatrix recurseMarginalizeUXgboost( std::vector> &U, unsigned int node, LeafData &leaf_data); +double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); +double augmentExpectationXgboost(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); + // [[Rcpp::export]] double augmentAndTakeExpectation(NumericVector &x, NumericMatrix &dataset, NumericMatrix &tree, NumericVector &to_explain, bool is_ranger) { From f578b95c08e1648553cc280531349cf2d6d1f5d7 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Fri, 15 Nov 2024 12:18:29 +0100 Subject: [PATCH 08/17] Marginalize up to max_interaction --- R/RcppExports.R | 4 +-- R/glex.R | 10 +++---- src/RcppExports.cpp | 9 +++--- src/fastpd_explain_tree.cpp | 59 ++++++++++++++++++++++--------------- src/fastpd_recurse.cpp | 4 +-- 5 files changed, 49 insertions(+), 37 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index d7b406a..1971b22 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -17,8 +17,8 @@ marginalizeAllSplittedSubsetsinTree <- function(x, tree, is_ranger) { .Call(`_glex_marginalizeAllSplittedSubsetsinTree`, x, tree, is_ranger) } -explainTreeFastPD <- function(x, tree, to_explain_list, is_ranger) { - .Call(`_glex_explainTreeFastPD`, x, tree, to_explain_list, is_ranger) +explainTreeFastPD <- function(x, tree, to_explain_list, max_interaction, is_ranger) { + .Call(`_glex_explainTreeFastPD`, x, tree, to_explain_list, max_interaction, is_ranger) } find_term_matches <- function(main_term, terms) { diff --git a/R/glex.R b/R/glex.R index d7b6802..191995c 100644 --- a/R/glex.R +++ b/R/glex.R @@ -316,7 +316,7 @@ tree_fun_emp <- function(tree, trees, x, all_S, probFunction = NULL) { m_all } -tree_fun_emp_fastPD <- function(tree, trees, x, all_S) { +tree_fun_emp_fastPD <- function(tree, trees, x, all_S, max_interaction) { # Calculate matrix tree_info <- trees[get("Tree") == tree, ] tree_info[, "Feature" := get("Feature_num") - 1L] @@ -326,7 +326,7 @@ tree_fun_emp_fastPD <- function(tree, trees, x, all_S) { tree_mat <- as.matrix(tree_mat) is_ranger <- tree_info$Type[1] == "Ranger" - m_all <- explainTreeFastPD(x, tree_mat, lapply(all_S, function(S) S - 1L), is_ranger) + m_all <- explainTreeFastPD(x, tree_mat, lapply(all_S, function(S) S - 1L), max_interaction, is_ranger) m_all } @@ -338,7 +338,7 @@ tree_fun_emp_fastPD <- function(tree, trees, x, all_S) { #' @param probFunction probFunction that was supplied to \code{glex} #' @keywords internal #' @noRd -tree_fun_wrapper <- function(trees, x, all_S, probFunction) { +tree_fun_wrapper <- function(trees, x, all_S, probFunction, max_interaction) { if (is.character(probFunction)) { if (probFunction == "path-dependent") { return(function(tree) tree_fun_path_dependent(tree, trees, x, all_S)) @@ -348,7 +348,7 @@ tree_fun_wrapper <- function(trees, x, all_S, probFunction) { stop("The probability function can either be 'path-dependent' or 'empirical' when specified as a string") } } else if (is.function(probFunction) || is.null(probFunction)) { - return(function(tree) tree_fun_emp_fastPD(tree, trees, x, all_S)) + return(function(tree) tree_fun_emp_fastPD(tree, trees, x, all_S, max_interaction)) } else { stop("The probability function can either be a string ('path-dependent', 'empirical'), NULL, or a function(coords, lb, ub) type function") } @@ -393,7 +393,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction = j <- NULL idx <- 0:max(trees$Tree) - tree_fun <- tree_fun_wrapper(trees, x, all_S, probFunction) + tree_fun <- tree_fun_wrapper(trees, x, all_S, probFunction, max_interaction) if (foreach::getDoParRegistered()) { m_all <- foreach(j = idx, .combine = "+") %dopar% tree_fun(j) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 1b3cc6a..90b7269 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -68,16 +68,17 @@ BEGIN_RCPP END_RCPP } // explainTreeFastPD -Rcpp::NumericMatrix explainTreeFastPD(Rcpp::NumericMatrix& x, NumericMatrix& tree, Rcpp::List& to_explain_list, bool is_ranger); -RcppExport SEXP _glex_explainTreeFastPD(SEXP xSEXP, SEXP treeSEXP, SEXP to_explain_listSEXP, SEXP is_rangerSEXP) { +Rcpp::NumericMatrix explainTreeFastPD(Rcpp::NumericMatrix& x, NumericMatrix& tree, Rcpp::List& to_explain_list, unsigned int max_interaction, bool is_ranger); +RcppExport SEXP _glex_explainTreeFastPD(SEXP xSEXP, SEXP treeSEXP, SEXP to_explain_listSEXP, SEXP max_interactionSEXP, SEXP is_rangerSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type x(xSEXP); Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); Rcpp::traits::input_parameter< Rcpp::List& >::type to_explain_list(to_explain_listSEXP); + Rcpp::traits::input_parameter< unsigned int >::type max_interaction(max_interactionSEXP); Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(explainTreeFastPD(x, tree, to_explain_list, is_ranger)); + rcpp_result_gen = Rcpp::wrap(explainTreeFastPD(x, tree, to_explain_list, max_interaction, is_ranger)); return rcpp_result_gen; END_RCPP } @@ -190,7 +191,7 @@ static const R_CallMethodDef CallEntries[] = { {"_glex_augmentTree", (DL_FUNC) &_glex_augmentTree, 3}, {"_glex_augmentExpectation", (DL_FUNC) &_glex_augmentExpectation, 5}, {"_glex_marginalizeAllSplittedSubsetsinTree", (DL_FUNC) &_glex_marginalizeAllSplittedSubsetsinTree, 3}, - {"_glex_explainTreeFastPD", (DL_FUNC) &_glex_explainTreeFastPD, 4}, + {"_glex_explainTreeFastPD", (DL_FUNC) &_glex_explainTreeFastPD, 5}, {"_glex_find_term_matches", (DL_FUNC) &_glex_find_term_matches, 2}, {"_glex_empProbFunction", (DL_FUNC) &_glex_empProbFunction, 4}, {"_glex_recurseRcppEmpProbfunction", (DL_FUNC) &_glex_recurseRcppEmpProbfunction, 11}, diff --git a/src/fastpd_explain_tree.cpp b/src/fastpd_explain_tree.cpp index 3054a44..37311ad 100644 --- a/src/fastpd_explain_tree.cpp +++ b/src/fastpd_explain_tree.cpp @@ -3,24 +3,24 @@ using namespace Rcpp; -std::vector> get_all_subsets_(std::set &set); +std::vector> get_all_subsets(std::set &set, unsigned int maxSize); -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset); -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset); -Rcpp::NumericMatrix recurseMarginalizeURanger( - Rcpp::NumericMatrix &x, NumericMatrix &tree, - std::vector> &U, unsigned int node, +LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); +LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); + +NumericMatrix recurseMarginalizeSRanger( + NumericMatrix &x, NumericMatrix &tree, + std::vector> &Ss, unsigned int node, LeafData &leaf_data); -Rcpp::NumericMatrix recurseMarginalizeUXgboost( - Rcpp::NumericMatrix &x, NumericMatrix &tree, - std::vector> &U, unsigned int node, +NumericMatrix recurseMarginalizeSXgboost( + NumericMatrix &x, NumericMatrix &tree, + std::vector> &Ss, unsigned int node, LeafData &leaf_data); -void contributeFastPD( +void contributeFastPD2( Rcpp::NumericMatrix &mat, Rcpp::NumericMatrix &m_all, std::set &S, - std::set &T, std::vector> &T_subsets, unsigned int colnum); @@ -29,26 +29,27 @@ Rcpp::NumericMatrix explainTreeFastPD( Rcpp::NumericMatrix &x, NumericMatrix &tree, Rcpp::List &to_explain_list, + unsigned int max_interaction, bool is_ranger) { - std::vector> to_explain; - for (int i = 0; i < to_explain_list.size(); i++) - { - std::set to_explain_set = std::set(as(to_explain_list[i]).begin(), as(to_explain_list[i]).end()); - to_explain.push_back(to_explain_set); - } - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); - std::vector> U = get_all_subsets_(leaf_data.all_encountered); - NumericMatrix mat = is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); + // Augment step + LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x, max_interaction) : augmentTreeXgboost(tree, x, max_interaction); + std::vector> U = get_all_subsets(leaf_data.all_encountered, max_interaction); - unsigned int to_explain_size = to_explain.size(); + // Explain/expectation/marginalization step + std::vector> to_explain; // List of S'es to explain + unsigned int to_explain_size = to_explain_list.size(); NumericMatrix m_all = NumericMatrix(x.nrow(), to_explain_size); CharacterVector m_all_col_names = CharacterVector(to_explain_size); CharacterVector x_col_names = colnames(x); + std::set needToComputePDfunctionsFor; // The set of coords we need to compute PD functions of for (int S_idx = 0; S_idx < to_explain_size; S_idx++) { - std::set S = to_explain[S_idx]; + std::set S = std::set( + as(to_explain_list[S_idx]).begin(), + as(to_explain_list[S_idx]).end()); + to_explain.push_back(S); if (int k = S.size(); k != 0) { @@ -60,10 +61,20 @@ Rcpp::NumericMatrix explainTreeFastPD( m_all_col_names[S_idx] = oss.str(); } + // Check if S \subset T if (std::find(U.begin(), U.end(), S) == U.end()) continue; - // std::set S_set = std::set(S.begin(), S.end()); - contributeFastPD(mat, m_all, S, leaf_data.all_encountered, U, S_idx); + + needToComputePDfunctionsFor.insert(S_idx); + } + + // Compute expectation of all necessary subsets + NumericMatrix mat = is_ranger ? recurseMarginalizeSRanger(x, tree, U, 0, leaf_data) : recurseMarginalizeSXgboost(x, tree, U, 0, leaf_data); + + for (int S_idx : needToComputePDfunctionsFor) + { + std::set S = to_explain[S_idx]; + contributeFastPD2(mat, m_all, S, U, S_idx); } colnames(m_all) = m_all_col_names; return m_all; diff --git a/src/fastpd_recurse.cpp b/src/fastpd_recurse.cpp index fc6aa18..2bb30ee 100644 --- a/src/fastpd_recurse.cpp +++ b/src/fastpd_recurse.cpp @@ -107,8 +107,8 @@ NumericMatrix recurseMarginalizeSXgboost( unsigned int no = current_node[Index::NO]; // Call both children, they give a matrix each of all obs and subsets - NumericMatrix mat_yes = recurseMarginalizeSRanger(x, tree, Ss, yes, leaf_data); - NumericMatrix mat_no = recurseMarginalizeSRanger(x, tree, Ss, no, leaf_data); + NumericMatrix mat_yes = recurseMarginalizeSXgboost(x, tree, Ss, yes, leaf_data); + NumericMatrix mat_no = recurseMarginalizeSXgboost(x, tree, Ss, no, leaf_data); for (unsigned int j = 0; j < Ss.size(); ++j) { From 0889f9548101fb1b99d78d938290ba5026e1d697 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Sat, 16 Nov 2024 11:26:48 +0100 Subject: [PATCH 09/17] Add test to ensure correctness of interactions --- tests/testthat/test-fastpd-equals-empirical.R | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-fastpd-equals-empirical.R b/tests/testthat/test-fastpd-equals-empirical.R index 43d7793..0ef473a 100644 --- a/tests/testthat/test-fastpd-equals-empirical.R +++ b/tests/testthat/test-fastpd-equals-empirical.R @@ -4,7 +4,7 @@ test_that("FastPD equals empirical leaf weighting", { p <- 2 x <- matrix(rnorm(n * p), ncol = p) colnames(x) <- paste0("x", 1:p) - x[, 2] <- 0.3 * x[, 1] + sqrt(1 - 0.3^2) * x[, 2] # Add covariance + x[, 2] <- 0.3 * x[, 1] + sqrt(1 - 0.3^2) * x[, 2] # Add covariance y <- x[, 1] + x[, 2] + 2 * x[, 1] * x[, 2] + rnorm(n) dtrain <- xgb.DMatrix(data = x, label = y) @@ -14,5 +14,19 @@ test_that("FastPD equals empirical leaf weighting", { empirical_leaf_weighting <- glex(xg, x, probFunction = "empirical") expect_equal(fastpd$m, empirical_leaf_weighting$m) - expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept + expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept +}) + +test_that("FastPD equals empirical leaf weighting for lower interactions", { + x <- as.matrix(mtcars[, -1]) + xg <- xgboost(x, mtcars$mpg, nrounds = 15, verbose = 0) + + fastpd <- glex(xg, x, max_interaction = 2) + empirical_leaf_weighting <- glex(xg, x, probFunction = "empirical", max_interaction = 2) + + expect_equal( + fastpd$m, + empirical_leaf_weighting$m, + tolerance = 1e-5 + ) }) From 1ab207400187d510f358d78fa633b74a8b7f7d8e Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Sat, 16 Nov 2024 12:19:43 +0100 Subject: [PATCH 10/17] Fix empirical leaf-weighting for ranger only --- R/glex.R | 3 +++ src/glex.cpp | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/R/glex.R b/R/glex.R index 191995c..9c8a1e1 100644 --- a/R/glex.R +++ b/R/glex.R @@ -343,6 +343,9 @@ tree_fun_wrapper <- function(trees, x, all_S, probFunction, max_interaction) { if (probFunction == "path-dependent") { return(function(tree) tree_fun_path_dependent(tree, trees, x, all_S)) } else if (probFunction == "empirical") { + if (trees$Type[1] != "Ranger") { + warning("Using 'empirical' for XGBoost models can result in inaccuracies. Use the default (probFunction = NULL) instead.") + } return(function(tree) tree_fun_emp(tree, trees, x, all_S, NULL)) } else { stop("The probability function can either be 'path-dependent' or 'empirical' when specified as a string") diff --git a/src/glex.cpp b/src/glex.cpp index 7e009a9..bc3589d 100644 --- a/src/glex.cpp +++ b/src/glex.cpp @@ -17,7 +17,7 @@ double empProbFunction(Rcpp::NumericMatrix &x, Rcpp::IntegerVector &coords, Rcpp for (int j = 0; j < m; ++j) { // Loop over selected columns int col = coords[j] - 1; // Adjust for 0-based indexing in C++ - if (x(i, col) <= lb[j] || x(i, col) >= ub[j]) + if (x(i, col) <= lb[j] || x(i, col) > ub[j]) { withinBounds = false; break; // Exit loop if any variable is out of bounds From 2fead1d9d55f793914be3c42d461d8e893e84a9d Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Sat, 16 Nov 2024 12:25:31 +0100 Subject: [PATCH 11/17] Fix test --- tests/testthat/test-fastpd-equals-empirical.R | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/testthat/test-fastpd-equals-empirical.R b/tests/testthat/test-fastpd-equals-empirical.R index 0ef473a..d708130 100644 --- a/tests/testthat/test-fastpd-equals-empirical.R +++ b/tests/testthat/test-fastpd-equals-empirical.R @@ -1,3 +1,5 @@ +library(ranger) + test_that("FastPD equals empirical leaf weighting", { set.seed(1) n <- 5e2 @@ -7,11 +9,10 @@ test_that("FastPD equals empirical leaf weighting", { x[, 2] <- 0.3 * x[, 1] + sqrt(1 - 0.3^2) * x[, 2] # Add covariance y <- x[, 1] + x[, 2] + 2 * x[, 1] * x[, 2] + rnorm(n) - dtrain <- xgb.DMatrix(data = x, label = y) - xg <- xgboost(data = dtrain, max_depth = 4, eta = 1, nrounds = 15, objective = "reg:squarederror") + rf <- ranger(x = x, y = y, num.trees = 5, max.depth = 4, node.stats = TRUE) - fastpd <- glex(xg, x) - empirical_leaf_weighting <- glex(xg, x, probFunction = "empirical") + fastpd <- glex(rf, x) + empirical_leaf_weighting <- glex(rf, x, probFunction = "empirical") expect_equal(fastpd$m, empirical_leaf_weighting$m) expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept @@ -19,14 +20,18 @@ test_that("FastPD equals empirical leaf weighting", { test_that("FastPD equals empirical leaf weighting for lower interactions", { x <- as.matrix(mtcars[, -1]) - xg <- xgboost(x, mtcars$mpg, nrounds = 15, verbose = 0) + rf <- ranger( + x = x, y = mtcars$mpg, + node.stats = TRUE, + num.trees = 5, max.depth = 4 + ) - fastpd <- glex(xg, x, max_interaction = 2) - empirical_leaf_weighting <- glex(xg, x, probFunction = "empirical", max_interaction = 2) + fastpd <- glex(rf, x, max_interaction = 2) + empirical_leaf_weighting <- glex(rf, x, probFunction = "empirical", max_interaction = 2) expect_equal( fastpd$m, empirical_leaf_weighting$m, - tolerance = 1e-5 + ignore_attr = TRUE ) }) From 62261a68c3e9bc2f7e4794157b009ea3c09e59e1 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 16:28:55 +0100 Subject: [PATCH 12/17] Use more generalizeable naming --- R/RcppExports.R | 20 +++++++++---------- R/glex.R | 12 +++++------ src/RcppExports.cpp | 40 ++++++++++++++++++------------------- src/fastpd_augment.cpp | 32 ++++++++++++++--------------- src/fastpd_expectation.cpp | 38 +++++++++++++++++------------------ src/fastpd_explain_tree.cpp | 14 ++++++------- src/fastpd_recurse.cpp | 24 +++++++++++----------- 7 files changed, 90 insertions(+), 90 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 1971b22..2c9b2db 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,24 +1,24 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -augmentAndTakeExpectation <- function(x, dataset, tree, to_explain, is_ranger) { - .Call(`_glex_augmentAndTakeExpectation`, x, dataset, tree, to_explain, is_ranger) +augmentAndTakeExpectation <- function(x, dataset, tree, to_explain, is_weak_inequality) { + .Call(`_glex_augmentAndTakeExpectation`, x, dataset, tree, to_explain, is_weak_inequality) } -augmentTree <- function(tree, dataset, is_ranger) { - .Call(`_glex_augmentTree`, tree, dataset, is_ranger) +augmentTree <- function(tree, dataset, is_weak_inequality) { + .Call(`_glex_augmentTree`, tree, dataset, is_weak_inequality) } -augmentExpectation <- function(x, tree, to_explain, leaf_data_ptr, is_ranger) { - .Call(`_glex_augmentExpectation`, x, tree, to_explain, leaf_data_ptr, is_ranger) +augmentExpectation <- function(x, tree, to_explain, leaf_data_ptr, is_weak_inequality) { + .Call(`_glex_augmentExpectation`, x, tree, to_explain, leaf_data_ptr, is_weak_inequality) } -marginalizeAllSplittedSubsetsinTree <- function(x, tree, is_ranger) { - .Call(`_glex_marginalizeAllSplittedSubsetsinTree`, x, tree, is_ranger) +marginalizeAllSplittedSubsetsinTree <- function(x, tree, is_weak_inequality) { + .Call(`_glex_marginalizeAllSplittedSubsetsinTree`, x, tree, is_weak_inequality) } -explainTreeFastPD <- function(x, tree, to_explain_list, max_interaction, is_ranger) { - .Call(`_glex_explainTreeFastPD`, x, tree, to_explain_list, max_interaction, is_ranger) +explainTreeFastPD <- function(x, tree, to_explain_list, max_interaction, is_weak_inequality) { + .Call(`_glex_explainTreeFastPD`, x, tree, to_explain_list, max_interaction, is_weak_inequality) } find_term_matches <- function(main_term, terms) { diff --git a/R/glex.R b/R/glex.R index 9c8a1e1..734ecbc 100644 --- a/R/glex.R +++ b/R/glex.R @@ -110,7 +110,7 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, # Convert model trees <- xgboost::xgb.model.dt.tree(model = object, use_int_id = TRUE) - trees$Type <- "XGB" + trees$Type <- "<" # Calculate components res <- calc_components(trees, x, max_interaction, features, probFunction) @@ -184,7 +184,7 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob trees[, terminal := NULL] trees[, prediction := NULL] colnames(trees) <- c("Node", "Yes", "No", "Feature", "Split", "Cover", "Quality", "Tree") - trees$Type <- "Ranger" + trees$Type <- "<=" # Calculate components res <- calc_components(trees, x, max_interaction, features, probFunction) @@ -325,8 +325,8 @@ tree_fun_emp_fastPD <- function(tree, trees, x, all_S, max_interaction) { tree_mat[is.na(tree_mat)] <- -1L tree_mat <- as.matrix(tree_mat) - is_ranger <- tree_info$Type[1] == "Ranger" - m_all <- explainTreeFastPD(x, tree_mat, lapply(all_S, function(S) S - 1L), max_interaction, is_ranger) + is_weak_inequality <- tree_info$Type[1] == "<=" + m_all <- explainTreeFastPD(x, tree_mat, lapply(all_S, function(S) S - 1L), max_interaction, is_weak_inequality) m_all } @@ -343,8 +343,8 @@ tree_fun_wrapper <- function(trees, x, all_S, probFunction, max_interaction) { if (probFunction == "path-dependent") { return(function(tree) tree_fun_path_dependent(tree, trees, x, all_S)) } else if (probFunction == "empirical") { - if (trees$Type[1] != "Ranger") { - warning("Using 'empirical' for XGBoost models can result in inaccuracies. Use the default (probFunction = NULL) instead.") + if (trees$Type[1] != "<=") { + warning("Using `probFunction = 'empirical'` with models that apply strict inequality (<) in the splitting rule may lead to inaccuracies. It is recommended to use the default setting (`probFunction = NULL`) instead.") } return(function(tree) tree_fun_emp(tree, trees, x, all_S, NULL)) } else { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 90b7269..2b95163 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -12,8 +12,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // augmentAndTakeExpectation -double augmentAndTakeExpectation(NumericVector& x, NumericMatrix& dataset, NumericMatrix& tree, NumericVector& to_explain, bool is_ranger); -RcppExport SEXP _glex_augmentAndTakeExpectation(SEXP xSEXP, SEXP datasetSEXP, SEXP treeSEXP, SEXP to_explainSEXP, SEXP is_rangerSEXP) { +double augmentAndTakeExpectation(NumericVector& x, NumericMatrix& dataset, NumericMatrix& tree, NumericVector& to_explain, bool is_weak_inequality); +RcppExport SEXP _glex_augmentAndTakeExpectation(SEXP xSEXP, SEXP datasetSEXP, SEXP treeSEXP, SEXP to_explainSEXP, SEXP is_weak_inequalitySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -21,27 +21,27 @@ BEGIN_RCPP Rcpp::traits::input_parameter< NumericMatrix& >::type dataset(datasetSEXP); Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); Rcpp::traits::input_parameter< NumericVector& >::type to_explain(to_explainSEXP); - Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(augmentAndTakeExpectation(x, dataset, tree, to_explain, is_ranger)); + Rcpp::traits::input_parameter< bool >::type is_weak_inequality(is_weak_inequalitySEXP); + rcpp_result_gen = Rcpp::wrap(augmentAndTakeExpectation(x, dataset, tree, to_explain, is_weak_inequality)); return rcpp_result_gen; END_RCPP } // augmentTree -XPtr augmentTree(NumericMatrix& tree, NumericMatrix& dataset, bool is_ranger); -RcppExport SEXP _glex_augmentTree(SEXP treeSEXP, SEXP datasetSEXP, SEXP is_rangerSEXP) { +XPtr augmentTree(NumericMatrix& tree, NumericMatrix& dataset, bool is_weak_inequality); +RcppExport SEXP _glex_augmentTree(SEXP treeSEXP, SEXP datasetSEXP, SEXP is_weak_inequalitySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); Rcpp::traits::input_parameter< NumericMatrix& >::type dataset(datasetSEXP); - Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(augmentTree(tree, dataset, is_ranger)); + Rcpp::traits::input_parameter< bool >::type is_weak_inequality(is_weak_inequalitySEXP); + rcpp_result_gen = Rcpp::wrap(augmentTree(tree, dataset, is_weak_inequality)); return rcpp_result_gen; END_RCPP } // augmentExpectation -double augmentExpectation(NumericVector& x, NumericMatrix& tree, NumericVector& to_explain, SEXP leaf_data_ptr, bool is_ranger); -RcppExport SEXP _glex_augmentExpectation(SEXP xSEXP, SEXP treeSEXP, SEXP to_explainSEXP, SEXP leaf_data_ptrSEXP, SEXP is_rangerSEXP) { +double augmentExpectation(NumericVector& x, NumericMatrix& tree, NumericVector& to_explain, SEXP leaf_data_ptr, bool is_weak_inequality); +RcppExport SEXP _glex_augmentExpectation(SEXP xSEXP, SEXP treeSEXP, SEXP to_explainSEXP, SEXP leaf_data_ptrSEXP, SEXP is_weak_inequalitySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -49,27 +49,27 @@ BEGIN_RCPP Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); Rcpp::traits::input_parameter< NumericVector& >::type to_explain(to_explainSEXP); Rcpp::traits::input_parameter< SEXP >::type leaf_data_ptr(leaf_data_ptrSEXP); - Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(augmentExpectation(x, tree, to_explain, leaf_data_ptr, is_ranger)); + Rcpp::traits::input_parameter< bool >::type is_weak_inequality(is_weak_inequalitySEXP); + rcpp_result_gen = Rcpp::wrap(augmentExpectation(x, tree, to_explain, leaf_data_ptr, is_weak_inequality)); return rcpp_result_gen; END_RCPP } // marginalizeAllSplittedSubsetsinTree -Rcpp::NumericMatrix marginalizeAllSplittedSubsetsinTree(Rcpp::NumericMatrix& x, NumericMatrix& tree, bool is_ranger); -RcppExport SEXP _glex_marginalizeAllSplittedSubsetsinTree(SEXP xSEXP, SEXP treeSEXP, SEXP is_rangerSEXP) { +Rcpp::NumericMatrix marginalizeAllSplittedSubsetsinTree(Rcpp::NumericMatrix& x, NumericMatrix& tree, bool is_weak_inequality); +RcppExport SEXP _glex_marginalizeAllSplittedSubsetsinTree(SEXP xSEXP, SEXP treeSEXP, SEXP is_weak_inequalitySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type x(xSEXP); Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); - Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(marginalizeAllSplittedSubsetsinTree(x, tree, is_ranger)); + Rcpp::traits::input_parameter< bool >::type is_weak_inequality(is_weak_inequalitySEXP); + rcpp_result_gen = Rcpp::wrap(marginalizeAllSplittedSubsetsinTree(x, tree, is_weak_inequality)); return rcpp_result_gen; END_RCPP } // explainTreeFastPD -Rcpp::NumericMatrix explainTreeFastPD(Rcpp::NumericMatrix& x, NumericMatrix& tree, Rcpp::List& to_explain_list, unsigned int max_interaction, bool is_ranger); -RcppExport SEXP _glex_explainTreeFastPD(SEXP xSEXP, SEXP treeSEXP, SEXP to_explain_listSEXP, SEXP max_interactionSEXP, SEXP is_rangerSEXP) { +Rcpp::NumericMatrix explainTreeFastPD(Rcpp::NumericMatrix& x, NumericMatrix& tree, Rcpp::List& to_explain_list, unsigned int max_interaction, bool is_weak_inequality); +RcppExport SEXP _glex_explainTreeFastPD(SEXP xSEXP, SEXP treeSEXP, SEXP to_explain_listSEXP, SEXP max_interactionSEXP, SEXP is_weak_inequalitySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -77,8 +77,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< NumericMatrix& >::type tree(treeSEXP); Rcpp::traits::input_parameter< Rcpp::List& >::type to_explain_list(to_explain_listSEXP); Rcpp::traits::input_parameter< unsigned int >::type max_interaction(max_interactionSEXP); - Rcpp::traits::input_parameter< bool >::type is_ranger(is_rangerSEXP); - rcpp_result_gen = Rcpp::wrap(explainTreeFastPD(x, tree, to_explain_list, max_interaction, is_ranger)); + Rcpp::traits::input_parameter< bool >::type is_weak_inequality(is_weak_inequalitySEXP); + rcpp_result_gen = Rcpp::wrap(explainTreeFastPD(x, tree, to_explain_list, max_interaction, is_weak_inequality)); return rcpp_result_gen; END_RCPP } diff --git a/src/fastpd_augment.cpp b/src/fastpd_augment.cpp index 2d1e586..77da340 100644 --- a/src/fastpd_augment.cpp +++ b/src/fastpd_augment.cpp @@ -2,14 +2,14 @@ #include "../inst/include/glex.h" using namespace Rcpp; -void augmentTreeRecurseStepRanger( +void augmentTreeRecurseStepWeakComparison( AugmentedData passed_down, LeafData &leaf_data, NumericMatrix &tree, NumericMatrix &dataset, unsigned int node, unsigned int max_interaction); -void augmentTreeRecurseStepXgboost( +void augmentTreeRecurseStrictComparison( AugmentedData passed_down, LeafData &leaf_data, NumericMatrix &tree, @@ -17,7 +17,7 @@ void augmentTreeRecurseStepXgboost( unsigned int node, unsigned int max_interaction); -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) +LeafData augmentTreeWeakComparison(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) { AugmentedData result; @@ -29,11 +29,11 @@ LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); LeafData leaf_data = LeafData(); - augmentTreeRecurseStepRanger(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); + augmentTreeRecurseStepWeakComparison(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); return leaf_data; } -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) +LeafData augmentTreeStrictComparison(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction) { AugmentedData result; @@ -45,21 +45,21 @@ LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset, unsigne std::iota(to_pass_down.pathData[{}].begin(), to_pass_down.pathData[{}].end(), 0); LeafData leaf_data = LeafData(); - augmentTreeRecurseStepXgboost(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); + augmentTreeRecurseStrictComparison(to_pass_down, leaf_data, tree, dataset, 0, max_interaction); return leaf_data; } -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset) +LeafData augmentTreeStrictComparison(NumericMatrix &tree, NumericMatrix &dataset) { - return augmentTreeXgboost(tree, dataset, dataset.ncol()); + return augmentTreeStrictComparison(tree, dataset, dataset.ncol()); } -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset) +LeafData augmentTreeWeakComparison(NumericMatrix &tree, NumericMatrix &dataset) { - return augmentTreeRanger(tree, dataset, dataset.ncol()); + return augmentTreeWeakComparison(tree, dataset, dataset.ncol()); } -void augmentTreeRecurseStepRanger( +void augmentTreeRecurseStepWeakComparison( AugmentedData passed_down, LeafData &leaf_data, NumericMatrix &tree, @@ -145,11 +145,11 @@ void augmentTreeRecurseStepRanger( unsigned int yes = current_node[Index::YES]; unsigned int no = current_node[Index::NO]; - augmentTreeRecurseStepRanger(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); - augmentTreeRecurseStepRanger(passed_down_no, leaf_data, tree, dataset, no, max_interaction); + augmentTreeRecurseStepWeakComparison(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); + augmentTreeRecurseStepWeakComparison(passed_down_no, leaf_data, tree, dataset, no, max_interaction); } -void augmentTreeRecurseStepXgboost( +void augmentTreeRecurseStrictComparison( AugmentedData passed_down, LeafData &leaf_data, NumericMatrix &tree, @@ -234,6 +234,6 @@ void augmentTreeRecurseStepXgboost( unsigned int yes = current_node[Index::YES]; unsigned int no = current_node[Index::NO]; - augmentTreeRecurseStepXgboost(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); - augmentTreeRecurseStepXgboost(passed_down_no, leaf_data, tree, dataset, no, max_interaction); + augmentTreeRecurseStrictComparison(passed_down_yes, leaf_data, tree, dataset, yes, max_interaction); + augmentTreeRecurseStrictComparison(passed_down_no, leaf_data, tree, dataset, no, max_interaction); } diff --git a/src/fastpd_expectation.cpp b/src/fastpd_expectation.cpp index 06968c1..a46f6a3 100644 --- a/src/fastpd_expectation.cpp +++ b/src/fastpd_expectation.cpp @@ -5,57 +5,57 @@ #include "../inst/include/glex.h" using namespace Rcpp; -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset); -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset); +LeafData augmentTreeWeakComparison(NumericMatrix &tree, NumericMatrix &dataset); +LeafData augmentTreeStrictComparison(NumericMatrix &tree, NumericMatrix &dataset); std::vector> get_all_subsets_(std::set &set); -Rcpp::NumericMatrix recurseMarginalizeURanger( +Rcpp::NumericMatrix recurseMarginalizeUWeakComparison( Rcpp::NumericMatrix &x, NumericMatrix &tree, std::vector> &U, unsigned int node, LeafData &leaf_data); -Rcpp::NumericMatrix recurseMarginalizeUXgboost( +Rcpp::NumericMatrix recurseMarginalizeUStrictComparison( Rcpp::NumericMatrix &x, NumericMatrix &tree, std::vector> &U, unsigned int node, LeafData &leaf_data); -double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); -double augmentExpectationXgboost(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); +double augmentExpectationWeakComparison(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); +double augmentExpectationStrictComparison(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data); // [[Rcpp::export]] -double augmentAndTakeExpectation(NumericVector &x, NumericMatrix &dataset, NumericMatrix &tree, NumericVector &to_explain, bool is_ranger) +double augmentAndTakeExpectation(NumericVector &x, NumericMatrix &dataset, NumericMatrix &tree, NumericVector &to_explain, bool is_weak_inequality) { - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset); - return is_ranger ? augmentExpectationRanger(x, tree, to_explain, leaf_data) : augmentExpectationXgboost(x, tree, to_explain, leaf_data); + LeafData leaf_data = is_weak_inequality ? augmentTreeWeakComparison(tree, dataset) : augmentTreeStrictComparison(tree, dataset); + return is_weak_inequality ? augmentExpectationWeakComparison(x, tree, to_explain, leaf_data) : augmentExpectationStrictComparison(x, tree, to_explain, leaf_data); } // [[Rcpp::export]] -XPtr augmentTree(NumericMatrix &tree, NumericMatrix &dataset, bool is_ranger) +XPtr augmentTree(NumericMatrix &tree, NumericMatrix &dataset, bool is_weak_inequality) { - LeafData *leaf_data = new LeafData(is_ranger ? augmentTreeRanger(tree, dataset) : augmentTreeXgboost(tree, dataset)); // Dynamically allocate - XPtr ptr(leaf_data, true); // true enables automatic memory management + LeafData *leaf_data = new LeafData(is_weak_inequality ? augmentTreeWeakComparison(tree, dataset) : augmentTreeStrictComparison(tree, dataset)); // Dynamically allocate + XPtr ptr(leaf_data, true); // true enables automatic memory management return ptr; } // [[Rcpp::export]] -double augmentExpectation(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, SEXP leaf_data_ptr, bool is_ranger) +double augmentExpectation(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, SEXP leaf_data_ptr, bool is_weak_inequality) { const Rcpp::XPtr leaf_data(leaf_data_ptr); - return is_ranger ? augmentExpectationRanger(x, tree, to_explain, *leaf_data) : augmentExpectationXgboost(x, tree, to_explain, *leaf_data); + return is_weak_inequality ? augmentExpectationWeakComparison(x, tree, to_explain, *leaf_data) : augmentExpectationStrictComparison(x, tree, to_explain, *leaf_data); } // [[Rcpp::export]] Rcpp::NumericMatrix marginalizeAllSplittedSubsetsinTree( Rcpp::NumericMatrix &x, NumericMatrix &tree, - bool is_ranger) + bool is_weak_inequality) { - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x) : augmentTreeXgboost(tree, x); + LeafData leaf_data = is_weak_inequality ? augmentTreeWeakComparison(tree, x) : augmentTreeStrictComparison(tree, x); std::vector> U = get_all_subsets_(leaf_data.all_encountered); - return is_ranger ? recurseMarginalizeURanger(x, tree, U, 0, leaf_data) : recurseMarginalizeUXgboost(x, tree, U, 0, leaf_data); + return is_weak_inequality ? recurseMarginalizeUWeakComparison(x, tree, U, 0, leaf_data) : recurseMarginalizeUStrictComparison(x, tree, U, 0, leaf_data); } -double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) +double augmentExpectationWeakComparison(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) { std::stack to_process; to_process.push(0); @@ -107,7 +107,7 @@ double augmentExpectationRanger(NumericVector &x, NumericMatrix &tree, NumericVe return result; } -double augmentExpectationXgboost(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) +double augmentExpectationStrictComparison(NumericVector &x, NumericMatrix &tree, NumericVector &to_explain, LeafData &leaf_data) { std::stack to_process; to_process.push(0); diff --git a/src/fastpd_explain_tree.cpp b/src/fastpd_explain_tree.cpp index 37311ad..ba3ac8b 100644 --- a/src/fastpd_explain_tree.cpp +++ b/src/fastpd_explain_tree.cpp @@ -5,14 +5,14 @@ using namespace Rcpp; std::vector> get_all_subsets(std::set &set, unsigned int maxSize); -LeafData augmentTreeRanger(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); -LeafData augmentTreeXgboost(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); +LeafData augmentTreeWeakComparison(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); +LeafData augmentTreeStrictComparison(NumericMatrix &tree, NumericMatrix &dataset, unsigned int max_interaction); -NumericMatrix recurseMarginalizeSRanger( +NumericMatrix recurseMarginalizeSWeakComparison( NumericMatrix &x, NumericMatrix &tree, std::vector> &Ss, unsigned int node, LeafData &leaf_data); -NumericMatrix recurseMarginalizeSXgboost( +NumericMatrix recurseMarginalizeSStrictComparison( NumericMatrix &x, NumericMatrix &tree, std::vector> &Ss, unsigned int node, LeafData &leaf_data); @@ -30,10 +30,10 @@ Rcpp::NumericMatrix explainTreeFastPD( NumericMatrix &tree, Rcpp::List &to_explain_list, unsigned int max_interaction, - bool is_ranger) + bool is_weak_inequality) { // Augment step - LeafData leaf_data = is_ranger ? augmentTreeRanger(tree, x, max_interaction) : augmentTreeXgboost(tree, x, max_interaction); + LeafData leaf_data = is_weak_inequality ? augmentTreeWeakComparison(tree, x, max_interaction) : augmentTreeStrictComparison(tree, x, max_interaction); std::vector> U = get_all_subsets(leaf_data.all_encountered, max_interaction); // Explain/expectation/marginalization step @@ -69,7 +69,7 @@ Rcpp::NumericMatrix explainTreeFastPD( } // Compute expectation of all necessary subsets - NumericMatrix mat = is_ranger ? recurseMarginalizeSRanger(x, tree, U, 0, leaf_data) : recurseMarginalizeSXgboost(x, tree, U, 0, leaf_data); + NumericMatrix mat = is_weak_inequality ? recurseMarginalizeSWeakComparison(x, tree, U, 0, leaf_data) : recurseMarginalizeSStrictComparison(x, tree, U, 0, leaf_data); for (int S_idx : needToComputePDfunctionsFor) { diff --git a/src/fastpd_recurse.cpp b/src/fastpd_recurse.cpp index 2bb30ee..3e3a0b1 100644 --- a/src/fastpd_recurse.cpp +++ b/src/fastpd_recurse.cpp @@ -4,7 +4,7 @@ using namespace Rcpp; -NumericMatrix recurseMarginalizeSRanger( +NumericMatrix recurseMarginalizeSWeakComparison( NumericMatrix &x, NumericMatrix &tree, std::vector> &Ss, unsigned int node, LeafData &leaf_data) @@ -41,8 +41,8 @@ NumericMatrix recurseMarginalizeSRanger( unsigned int no = current_node[Index::NO]; // Call both children, they give a matrix each of all obs and subsets - NumericMatrix mat_yes = recurseMarginalizeSRanger(x, tree, Ss, yes, leaf_data); - NumericMatrix mat_no = recurseMarginalizeSRanger(x, tree, Ss, no, leaf_data); + NumericMatrix mat_yes = recurseMarginalizeSWeakComparison(x, tree, Ss, yes, leaf_data); + NumericMatrix mat_no = recurseMarginalizeSWeakComparison(x, tree, Ss, no, leaf_data); for (unsigned int j = 0; j < Ss.size(); ++j) { @@ -70,7 +70,7 @@ NumericMatrix recurseMarginalizeSRanger( return mat; } -NumericMatrix recurseMarginalizeSXgboost( +NumericMatrix recurseMarginalizeSStrictComparison( NumericMatrix &x, NumericMatrix &tree, std::vector> &Ss, unsigned int node, LeafData &leaf_data) @@ -107,8 +107,8 @@ NumericMatrix recurseMarginalizeSXgboost( unsigned int no = current_node[Index::NO]; // Call both children, they give a matrix each of all obs and subsets - NumericMatrix mat_yes = recurseMarginalizeSXgboost(x, tree, Ss, yes, leaf_data); - NumericMatrix mat_no = recurseMarginalizeSXgboost(x, tree, Ss, no, leaf_data); + NumericMatrix mat_yes = recurseMarginalizeSStrictComparison(x, tree, Ss, yes, leaf_data); + NumericMatrix mat_no = recurseMarginalizeSStrictComparison(x, tree, Ss, no, leaf_data); for (unsigned int j = 0; j < Ss.size(); ++j) { @@ -136,7 +136,7 @@ NumericMatrix recurseMarginalizeSXgboost( return mat; } -Rcpp::NumericMatrix recurseMarginalizeURanger( +Rcpp::NumericMatrix recurseMarginalizeUWeakComparison( Rcpp::NumericMatrix &x, NumericMatrix &tree, std::vector> &U, unsigned int node, LeafData &leaf_data) @@ -173,8 +173,8 @@ Rcpp::NumericMatrix recurseMarginalizeURanger( unsigned int no = current_node[Index::NO]; // Call both children, they give a matrix each of all obs and subsets - Rcpp::NumericMatrix mat_yes = recurseMarginalizeURanger(x, tree, U, yes, leaf_data); - Rcpp::NumericMatrix mat_no = recurseMarginalizeURanger(x, tree, U, no, leaf_data); + Rcpp::NumericMatrix mat_yes = recurseMarginalizeUWeakComparison(x, tree, U, yes, leaf_data); + Rcpp::NumericMatrix mat_no = recurseMarginalizeUWeakComparison(x, tree, U, no, leaf_data); for (unsigned int j = 0; j < U.size(); ++j) { @@ -202,7 +202,7 @@ Rcpp::NumericMatrix recurseMarginalizeURanger( return mat; } -Rcpp::NumericMatrix recurseMarginalizeUXgboost( +Rcpp::NumericMatrix recurseMarginalizeUStrictComparison( Rcpp::NumericMatrix &x, NumericMatrix &tree, std::vector> &U, unsigned int node, LeafData &leaf_data) @@ -239,8 +239,8 @@ Rcpp::NumericMatrix recurseMarginalizeUXgboost( unsigned int no = current_node[Index::NO]; // Call both children, they give a matrix each of all obs and subsets - Rcpp::NumericMatrix mat_yes = recurseMarginalizeUXgboost(x, tree, U, yes, leaf_data); - Rcpp::NumericMatrix mat_no = recurseMarginalizeUXgboost(x, tree, U, no, leaf_data); + Rcpp::NumericMatrix mat_yes = recurseMarginalizeUStrictComparison(x, tree, U, yes, leaf_data); + Rcpp::NumericMatrix mat_no = recurseMarginalizeUStrictComparison(x, tree, U, no, leaf_data); for (unsigned int j = 0; j < U.size(); ++j) { From 508a56de90603ada0b72a4984e59cf1d8841e77d Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 12:52:25 +0100 Subject: [PATCH 13/17] Don't remove intercept column --- R/glex.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/glex.R b/R/glex.R index 734ecbc..afed798 100644 --- a/R/glex.R +++ b/R/glex.R @@ -403,7 +403,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction = } else { m_all <- foreach(j = idx, .combine = "+") %do% tree_fun(j) } - + colnames(m_all)[1] <- "intercept" d <- get_degree(colnames(m_all)) # Overall feature effect is sum of all elements where feature is involved @@ -423,7 +423,7 @@ calc_components <- function(trees, x, max_interaction, features, probFunction = # Return shap values, decomposition and intercept ret <- list( shap = data.table::setDT(as.data.frame(shap)), - m = data.table::setDT(as.data.frame(m_all[, -1])), + m = data.table::setDT(as.data.frame(m_all)), intercept = unique(m_all[, 1]), x = data.table::setDT(as.data.frame(x)) ) From 47ecc7bc80fb15abdc76f4ab00a9e69b6c662d71 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 15:09:31 +0100 Subject: [PATCH 14/17] Add rest column for regression tasks --- R/glex.R | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/R/glex.R b/R/glex.R index afed798..4444d0c 100644 --- a/R/glex.R +++ b/R/glex.R @@ -64,6 +64,17 @@ glex.rpf <- function(object, x, max_interaction = NULL, features = NULL, ...) { predictors = features ) # class(ret) <- c("glex", "rpf_components", class(ret)) + ret$m <- cbind(ret$intercept, ret$m) + colnames(ret$m)[1] <- "intercept" + if ( + object$mode == "regression" && + is.numeric(max_interaction) && + max_interaction < object$params$max_interaction + ) { + preds <- predict(object, x) + ret$m[, "rest"] <- preds$.pred - rowSums(ret$m) + } + ret } @@ -115,7 +126,17 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, # Calculate components res <- calc_components(trees, x, max_interaction, features, probFunction) res$intercept <- res$intercept + 0.5 - + res$m[, "intercept"] <- res$intercept + + if ( + "objective" %in% names(object$params) && + startsWith(object$params$objective, "reg:") && + is.numeric(max_interaction) && + max_interaction < xgb_max_depth + ) { + preds <- predict(object, x) + res$m[, "rest"] <- preds - rowSums(res$m) + } # Return components res } @@ -194,6 +215,14 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob res$m <- res$m / object$num.trees res$intercept <- res$intercept / object$num.trees + if ( + object$treetype == "Regression" && + is.numeric(max_interaction) && + max_interaction < rf_max_depth + ) { + preds <- predict(object, x) + res$m[, "rest"] <- preds$predictions - rowSums(res$m) + } # Return components res } From b8f2f90bd24afdb7b72e8520a3cbc7a306194594 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 15:09:42 +0100 Subject: [PATCH 15/17] Convert to matrix if not so --- R/glex.R | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/R/glex.R b/R/glex.R index 4444d0c..f9bc961 100644 --- a/R/glex.R +++ b/R/glex.R @@ -108,7 +108,9 @@ glex.xgb.Booster <- function(object, x, max_interaction = NULL, features = NULL, if (!requireNamespace("xgboost", quietly = TRUE)) { stop("xgboost needs to be installed: install.packages(\"xgboost\")") } - + if (!is.matrix(x)) { + x <- as.matrix(x) + } # If max_interaction is not specified, we set it to the max_depth param of the xgb model. # If max_depth is not defined in xgb, we assume its default of 6. xgb_max_depth <- ifelse(is.null(object$params$max_depth), 6L, object$params$max_depth) @@ -184,7 +186,9 @@ glex.ranger <- function(object, x, max_interaction = NULL, features = NULL, prob if (is.null(object$forest$num.samples.nodes)) { stop("ranger needs to be called with node.stats=TRUE for glex.") } - + if (!is.matrix(x)) { + x <- as.matrix(x) + } # If max_interaction is not specified, we set it to the max.depth param of the ranger model. # If max.depth is not defined in ranger, we assume 6 as in xgboost. rf_max_depth <- ifelse((is.null(object$max.depth) || object$max.depth == 0), 6L, object$max.depth) From 30470b77235eb6c1fb9096d3e911c16050d38b3b Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 16:05:57 +0100 Subject: [PATCH 16/17] Update tests to accomodate for intercept --- R/glex_vi.R | 50 +++++++------- R/print.R | 20 +++--- R/utils-components-reshaping.R | 9 +-- tests/testthat/_snaps/print-glex.md | 25 +++---- tests/testthat/test-fastpd-equals-empirical.R | 1 - tests/testthat/test-glex-xgboost.R | 6 +- tests/testthat/test-mtcars-ranger.R | 32 +++++---- tests/testthat/test-mtcars.R | 24 ++++--- tests/testthat/test-sim.R | 65 +++++++++++-------- 9 files changed, 130 insertions(+), 102 deletions(-) diff --git a/R/glex_vi.R b/R/glex_vi.R index aabfc83..53b87f5 100644 --- a/R/glex_vi.R +++ b/R/glex_vi.R @@ -28,38 +28,40 @@ #' set.seed(1) #' # Random Planted Forest ----- #' if (requireNamespace("randomPlantedForest", quietly = TRUE)) { -#' library(randomPlantedForest) +#' library(randomPlantedForest) #' -#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) +#' rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) #' -#' glex_rpf <- glex(rp, mtcars[27:32, ]) +#' glex_rpf <- glex(rp, mtcars[27:32, ]) #' -#' # All terms -#' vi_rpf <- glex_vi(glex_rpf) +#' # All terms +#' vi_rpf <- glex_vi(glex_rpf) #' -#' library(ggplot2) -#' # Filter to contributions greater 0.05 on the scale of the target -#' autoplot(vi_rpf, threshold = 0.05) -#' # Summarize by degree of interaction -#' autoplot(vi_rpf, by_degree = TRUE) -#' # Filter by relative contributions greater 0.1% -#' autoplot(vi_rpf, scale = "relative", threshold = 0.001) +#' library(ggplot2) +#' # Filter to contributions greater 0.05 on the scale of the target +#' autoplot(vi_rpf, threshold = 0.05) +#' # Summarize by degree of interaction +#' autoplot(vi_rpf, by_degree = TRUE) +#' # Filter by relative contributions greater 0.1% +#' autoplot(vi_rpf, scale = "relative", threshold = 0.001) #' } #' #' # xgboost ----- #' if (requireNamespace("xgboost", quietly = TRUE)) { -#' library(xgboost) -#' x <- as.matrix(mtcars[, -1]) -#' y <- mtcars$mpg -#' xg <- xgboost(data = x[1:26, ], label = y[1:26], -#' params = list(max_depth = 4, eta = .1), -#' nrounds = 10, verbose = 0) -#' glex_xgb <- glex(xg, x[27:32, ]) -#' vi_xgb <- glex_vi(glex_xgb) +#' library(xgboost) +#' x <- as.matrix(mtcars[, -1]) +#' y <- mtcars$mpg +#' xg <- xgboost( +#' data = x[1:26, ], label = y[1:26], +#' params = list(max_depth = 4, eta = .1), +#' nrounds = 10, verbose = 0 +#' ) +#' glex_xgb <- glex(xg, x[27:32, ]) +#' vi_xgb <- glex_vi(glex_xgb) #' -#' library(ggplot2) -#' autoplot(vi_xgb) -#' autoplot(vi_xgb, by_degree = TRUE) +#' library(ggplot2) +#' autoplot(vi_xgb) +#' autoplot(vi_xgb, by_degree = TRUE) #' } glex_vi <- function(object, ...) { checkmate::assert_class(object, classes = "glex") @@ -67,7 +69,7 @@ glex_vi <- function(object, ...) { # FIXME: data.table NSE warnings term <- degree <- m <- m_rel <- NULL - m_long <- melt_m(object$m, object$target_levels) + m_long <- melt_m(object$m[, -c("intercept")], object$target_levels) if (!is.null(object$target_levels)) { vars_by <- c("term", "class") diff --git a/R/print.R b/R/print.R index fc95907..2a336cc 100644 --- a/R/print.R +++ b/R/print.R @@ -11,21 +11,25 @@ #' @examples #' # Random Planted Forest ----- #' if (requireNamespace("randomPlantedForest", quietly = TRUE)) { -#' library(randomPlantedForest) -#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) +#' library(randomPlantedForest) +#' rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) #' -#' glex(rp, mtcars[27:32, ]) +#' glex(rp, mtcars[27:32, ]) #' } print.glex <- function(x, ...) { - n <- nrow(x$x) - n_m <- ncol(x$m) + n_m <- ncol(x$m) - 1 max_deg <- max(get_degree(names(x$m))) - max_deg_lab <- switch(as.character(max_deg), "1" = "degree", "degrees") + max_deg_lab <- switch(as.character(max_deg), + "1" = "degree", + "degrees" + ) cat("glex object of subclass", class(x)[[2]], "\n") - cat("Explaining predictions of", n, "observations with", n_m, - "terms of up to", max_deg, max_deg_lab) + cat( + "Explaining predictions of", n, "observations with", n_m, + "terms of up to", max_deg, max_deg_lab + ) cat("\n\n") str(x, list.len = 5) } diff --git a/R/utils-components-reshaping.R b/R/utils-components-reshaping.R index 8d52fef..8995ac1 100644 --- a/R/utils-components-reshaping.R +++ b/R/utils-components-reshaping.R @@ -27,8 +27,10 @@ melt_m <- function(m, levels = NULL) { term <- NULL # We need an id.var for melt to avoid a warning, but don't want to modify m permanently m[, ".id" := .I] - m_long <- data.table::melt(m, id.vars = ".id", value.name = "m", - variable.name = "term", variable.factor = FALSE) + m_long <- data.table::melt(m, + id.vars = ".id", value.name = "m", + variable.name = "term", variable.factor = FALSE + ) # clean up that temporary id column again while modifying by reference m[, ".id" := NULL] @@ -45,7 +47,7 @@ reshape_m_multiclass <- function(object) { checkmate::assert_character(object$target_levels, min.len = 2) mlong <- melt_m(object$m, object$target_levels) - data.table::dcast(mlong, .id + class ~ term, value.var = "m") + data.table::dcast(mlong, .id + class ~ term, value.var = "m") } #' Helper to split multiclass terms of format __class: @@ -62,4 +64,3 @@ split_names <- function(mn, split_string = "__class:", target_index = 2) { unlist(strsplit(x, split = split_string, fixed = TRUE))[target_index] }, character(1), USE.NAMES = FALSE) } - diff --git a/tests/testthat/_snaps/print-glex.md b/tests/testthat/_snaps/print-glex.md index 1f00a77..65e1a22 100644 --- a/tests/testthat/_snaps/print-glex.md +++ b/tests/testthat/_snaps/print-glex.md @@ -3,16 +3,17 @@ Code out Output - [1] "glex object of subclass rpf_components " - [2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree" - [3] "" - [4] "List of 3" - [5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" - [6] " ..$ cyl: num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..." - [7] " ..- attr(*, \".internal.selfref\")= " - [8] " $ intercept: num 20.2" - [9] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" - [10] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..." - [11] " ..- attr(*, \".internal.selfref\")= " - [12] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\"" + [1] "glex object of subclass rpf_components " + [2] "Explaining predictions of 32 observations with 1 terms of up to 1 degree" + [3] "" + [4] "List of 3" + [5] " $ m :Classes 'data.table' and 'data.frame':\t32 obs. of 2 variables:" + [6] " ..$ intercept: num [1:32] 20.2 20.2 20.2 20.2 20.2 ..." + [7] " ..$ cyl : num [1:32] -0.527 -0.527 7.084 -0.527 -5.302 ..." + [8] " ..- attr(*, \".internal.selfref\")= " + [9] " $ intercept: num 20.2" + [10] " $ x :Classes 'data.table' and 'data.frame':\t32 obs. of 1 variable:" + [11] " ..$ cyl: num [1:32] 6 6 4 6 8 6 8 4 4 6 ..." + [12] " ..- attr(*, \".internal.selfref\")= " + [13] " - attr(*, \"class\")= chr [1:3] \"glex\" \"rpf_components\" \"list\"" diff --git a/tests/testthat/test-fastpd-equals-empirical.R b/tests/testthat/test-fastpd-equals-empirical.R index d708130..7adff17 100644 --- a/tests/testthat/test-fastpd-equals-empirical.R +++ b/tests/testthat/test-fastpd-equals-empirical.R @@ -15,7 +15,6 @@ test_that("FastPD equals empirical leaf weighting", { empirical_leaf_weighting <- glex(rf, x, probFunction = "empirical") expect_equal(fastpd$m, empirical_leaf_weighting$m) - expect_equal(fastpd$intercept, empirical_leaf_weighting$intercept) # Check intercept }) test_that("FastPD equals empirical leaf weighting for lower interactions", { diff --git a/tests/testthat/test-glex-xgboost.R b/tests/testthat/test-glex-xgboost.R index 1392e5e..3c7307a 100644 --- a/tests/testthat/test-glex-xgboost.R +++ b/tests/testthat/test-glex-xgboost.R @@ -34,7 +34,7 @@ test_that("features argument only calculates for given feautures", { x <- as.matrix(mtcars[, -1]) xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) glexb <- glex(xg, x, features = c("cyl", "disp")) - expect_equal(colnames(glexb$m), c("cyl", "disp", "cyl:disp")) + expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp", "cyl:disp")) }) test_that("features argument results in same values as without", { @@ -50,7 +50,7 @@ test_that("features argument works together with max_interaction", { x <- as.matrix(mtcars[, -1]) xg <- xgboost(x, mtcars$mpg, nrounds = 10, verbose = 0) glexb <- glex(xg, x, features = c("cyl", "disp"), max_interaction = 1) - expect_equal(colnames(glexb$m), c("cyl", "disp")) + expect_equal(colnames(glexb$m), c("intercept", "cyl", "disp")) }) test_that("Prediction is approx. same as sum of decomposition + intercept, xgboost", { @@ -59,7 +59,7 @@ test_that("Prediction is approx. same as sum of decomposition + intercept, xgboo pred_train <- predict(xg, x) res_train <- glex(xg, x) - expect_equal(res_train$intercept + rowSums(res_train$m), + expect_equal(rowSums(res_train$m), pred_train, tolerance = 1e-5 ) diff --git a/tests/testthat/test-mtcars-ranger.R b/tests/testthat/test-mtcars-ranger.R index b6dc3d4..7bc7536 100644 --- a/tests/testthat/test-mtcars-ranger.R +++ b/tests/testthat/test-mtcars-ranger.R @@ -5,9 +5,11 @@ x_test <- as.matrix(mtcars[27:32, -1]) y_train <- mtcars$mpg[1:26] y_test <- mtcars$mpg[27:32] -# xgboost -rf <- ranger(x = x_train, y = y_train, node.stats = TRUE, - num.trees = 5, max.depth = 4) +# ranger +rf <- ranger( + x = x_train, y = y_train, node.stats = TRUE, + num.trees = 5, max.depth = 4 +) pred_train <- predict(rf, x_train)$predictions pred_test <- predict(rf, x_test)$predictions @@ -17,24 +19,28 @@ res_test <- glex(rf, x_test, max_interaction = 4) test_that("Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) diff --git a/tests/testthat/test-mtcars.R b/tests/testthat/test-mtcars.R index 948d48d..d0c3db2 100644 --- a/tests/testthat/test-mtcars.R +++ b/tests/testthat/test-mtcars.R @@ -14,24 +14,28 @@ res_test <- glex(xg, x_test) test_that("Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) diff --git a/tests/testthat/test-sim.R b/tests/testthat/test-sim.R index de110fd..a7ee4c7 100644 --- a/tests/testthat/test-sim.R +++ b/tests/testthat/test-sim.R @@ -3,9 +3,11 @@ n <- 100 p <- 2 beta <- c(1, 1) beta0 <- 0 -x <- matrix(rnorm(n = n * p), ncol = p, - dimnames = list(NULL, paste0('x', seq_len(p)))) -lp <- x %*% beta + beta0 + 2*x[, 1] * x[, 2] +x <- matrix(rnorm(n = n * p), + ncol = p, + dimnames = list(NULL, paste0("x", seq_len(p))) +) +lp <- x %*% beta + beta0 + 2 * x[, 1] * x[, 2] y <- lp + rnorm(n) x_train <- x[1:50, ] @@ -37,50 +39,58 @@ resbin_test <- glex(xgbin, x_test) test_that("regr: Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(res_train$intercept + rowSums(res_train$shap), - pred_train, - tolerance = 1e-5) + pred_train, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of shap + intercept, training data", { expect_equal(resbin_train$intercept + rowSums(resbin_train$shap), - predbin_train, - tolerance = 1e-5) + predbin_train, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(res_test$intercept + rowSums(res_test$shap), - pred_test, - tolerance = 1e-5) + pred_test, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of shap + intercept, test data", { expect_equal(resbin_test$intercept + rowSums(resbin_test$shap), - predbin_test, - tolerance = 1e-5) + predbin_test, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(res_train$intercept + rowSums(res_train$m), - pred_train, - tolerance = 1e-5) + expect_equal(rowSums(res_train$m), + pred_train, + tolerance = 1e-5 + ) }) test_that("classif: Prediction is approx. same as sum of decomposition + intercept, training data", { - expect_equal(resbin_train$intercept + rowSums(resbin_train$m), - predbin_train, - tolerance = 1e-5) + expect_equal(rowSums(resbin_train$m), + predbin_train, + tolerance = 1e-5 + ) }) test_that("regr: Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(res_test$intercept + rowSums(res_test$m), - pred_test, - tolerance = 1e-5) + expect_equal(rowSums(res_test$m), + pred_test, + tolerance = 1e-5 + ) }) test_that("binary: Prediction is approx. same as sum of decomposition + intercept, test data", { - expect_equal(resbin_test$intercept + rowSums(resbin_test$m), - predbin_test, - tolerance = 1e-5) + expect_equal(rowSums(resbin_test$m), + predbin_test, + tolerance = 1e-5 + ) }) test_that("Shap is computed correctly with overlapping colnames", { @@ -88,10 +98,12 @@ test_that("Shap is computed correctly with overlapping colnames", { p <- 2 beta <- c(1, 1) beta0 <- 0 - x <- matrix(rnorm(n = n * p), ncol = p, - dimnames = list(NULL, paste0('x', seq_len(p)))) + x <- matrix(rnorm(n = n * p), + ncol = p, + dimnames = list(NULL, paste0("x", seq_len(p))) + ) - lp <- x %*% beta + beta0 + 2*x[, 1] * x[, 2] + lp <- x %*% beta + beta0 + 2 * x[, 1] * x[, 2] y <- lp + rnorm(n) # xgboost @@ -113,5 +125,4 @@ test_that("Shap is computed correctly with overlapping colnames", { # We only check values for equality, colnames will of course differ expect_equal(unname(overlapping_names), unname(unique_names)) - }) From 3c596c7a1757a1dce843a08cab8a7f79c872d9e0 Mon Sep 17 00:00:00 2001 From: Jinyang Liu Date: Mon, 18 Nov 2024 16:06:11 +0100 Subject: [PATCH 17/17] Update README --- README.Rmd | 31 ++++++++++++++++--------------- README.md | 4 ++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/README.Rmd b/README.Rmd index ee044a9..7fd28c0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -85,10 +85,11 @@ rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg -xg <- xgboost(data = x[1:26, ], label = y[1:26], - params = list(max_depth = 3, eta = .1), - nrounds = 30, verbose = 0) - +xg <- xgboost( + data = x[1:26, ], label = y[1:26], + params = list(max_depth = 3, eta = .1), + nrounds = 30, verbose = 0 +) ``` Using the model objects and a dataset to explain (such as a test set in this case), we can create `glex` objects @@ -106,8 +107,8 @@ Both `m` and `shap` satisfy the property that their sums (per observation) toget ```{r comp-sum} # Calculating sum of components and sum of SHAP values -sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept -sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept +sum_m_rpf <- rowSums(glex_rpf$m) +sum_m_xgb <- rowSums(glex_xgb$m) sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept # Model predictions @@ -137,10 +138,10 @@ vi_xgb[1:5, c("degree", "term", "m")] The output additionally contains the degree of interaction, which can be used for filtering and aggregating. Here we filter for terms with contributions above a `threshold` of `0.05` to get a more compact plot, with terms below the threshold aggregated into one labelled "Remaining terms": ```{r glex_vi-plot, fig.width=12} -p_vi1 <- autoplot(vi_rpf, threshold = .05) + +p_vi1 <- autoplot(vi_rpf, threshold = .05) + labs(title = NULL, subtitle = "RPF") -p_vi2 <- autoplot(vi_xgb, threshold = .05) + +p_vi2 <- autoplot(vi_xgb, threshold = .05) + labs(title = NULL, subtitle = "XGBoost") p_vi1 + p_vi2 + @@ -150,14 +151,14 @@ p_vi1 + p_vi2 + We can also sum values within each degree of interaction for a more aggregated view, which can be useful as it allows us to judge interactions above a certain degree to not be particularly relevant for a given model. ```{r glex_vi-plot-by-degree, fig.height=5} -p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) + +p_vi1 <- autoplot(vi_rpf, by_degree = TRUE) + labs(title = NULL, subtitle = "RPF") -p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) + +p_vi2 <- autoplot(vi_xgb, by_degree = TRUE) + labs(title = NULL, subtitle = "XGBoost") p_vi1 + p_vi2 + - plot_annotation(title = "Variable importance scores by degree") + plot_annotation(title = "Variable importance scores by degree") ``` ### Feature Effects @@ -168,13 +169,13 @@ We can also plot prediction components against observed feature values, which ad p1 <- autoplot(glex_rpf, "hp") + labs(subtitle = "RPF") p2 <- autoplot(glex_xgb, "hp") + labs(subtitle = "XGBoost") -p1 + p2 + +p1 + p2 + plot_annotation(title = "Main effect for 'hp'") p1 <- autoplot(glex_rpf, c("hp", "wt")) + labs(subtitle = "RPF") p2 <- autoplot(glex_xgb, c("hp", "wt")) + labs(subtitle = "XGBoost") -p1 + p2 + +p1 + p2 + plot_annotation(title = "Two-way effects for 'hp' and 'wt'") ``` @@ -192,9 +193,9 @@ Finally, we can explore the prediction for a single observation by displaying it For compactness, we only plot one feature and collapse all interaction terms above the second degree into one as their combined effect is very small. ```{r glex_explain} -p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) + +p1 <- glex_explain(glex_rpf, id = 2, predictors = "hp", max_interaction = 2) + labs(tag = "RPF") -p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) + +p2 <- glex_explain(glex_xgb, id = 2, predictors = "hp", max_interaction = 2) + labs(tag = "XGBoost") p1 + p2 & theme(plot.tag.position = "bottom") diff --git a/README.md b/README.md index 1a39b5d..688a80d 100644 --- a/README.md +++ b/README.md @@ -105,8 +105,8 @@ prediction for each observation: ``` r # Calculating sum of components and sum of SHAP values -sum_m_rpf <- rowSums(glex_rpf$m) + glex_rpf$intercept -sum_m_xgb <- rowSums(glex_xgb$m) + glex_xgb$intercept +sum_m_rpf <- rowSums(glex_rpf$m) +sum_m_xgb <- rowSums(glex_xgb$m) sum_shap_xgb <- rowSums(glex_xgb$shap) + glex_xgb$intercept # Model predictions