From a97771c3bbcf69a5915084436daaf74b68eb0700 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Fri, 1 May 2026 19:14:03 +0200 Subject: [PATCH 1/9] rename files for clearer organization add a epoch to test remove processing of y in front of pretraining --- .Rbuildignore | 1 + R/hardhat.R | 7 ++----- R/{explain.R => model_explain.R} | 0 R/{pretraining.R => model_pretraining.R} | 0 R/{model.R => model_training.R} | 0 R/{tab-network.R => tabnet_network.R} | 0 man/tabnet_config.Rd | 2 +- man/tabnet_explain.Rd | 2 +- man/tabnet_nn.Rd | 2 +- man/tabnet_pretrain.Rd | 2 +- .../_snaps/{pretraining.md => model_pretraining.md} | 0 tests/testthat/{test-explain.R => test-model_explain.R} | 0 .../{test-pretraining.R => test-model_pretraining.R} | 0 tests/testthat/{test-model.R => test-model_training.R} | 0 tests/testthat/test-parsnip.R | 1 + .../testthat/{test_translations.R => test-translations.R} | 0 16 files changed, 8 insertions(+), 9 deletions(-) rename R/{explain.R => model_explain.R} (100%) rename R/{pretraining.R => model_pretraining.R} (100%) rename R/{model.R => model_training.R} (100%) rename R/{tab-network.R => tabnet_network.R} (100%) rename tests/testthat/_snaps/{pretraining.md => model_pretraining.md} (100%) rename tests/testthat/{test-explain.R => test-model_explain.R} (100%) rename tests/testthat/{test-pretraining.R => test-model_pretraining.R} (100%) rename tests/testthat/{test-model.R => test-model_training.R} (100%) rename tests/testthat/{test_translations.R => test-translations.R} (100%) diff --git a/.Rbuildignore b/.Rbuildignore index 4117da74..94ba712b 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -14,3 +14,4 @@ ^CRAN-SUBMISSION$ ^revdep$ ^vignettes/*_files$ +^\.claude$ diff --git a/R/hardhat.R b/R/hardhat.R index 1d57a1b4..48cb076a 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -272,9 +272,7 @@ tabnet_pretrain.default <- function(x, ...) { #' @export #' @rdname tabnet_pretrain -tabnet_pretrain.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { - processed <- hardhat::mold(x, y) - +tabnet_pretrain.data.frame <- function(x, y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { config <- merge_config_and_dots(config, ...) tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "unsupervised") } @@ -309,8 +307,7 @@ tabnet_pretrain.Node <- function(x, tabnet_model = NULL, config = tabnet_config( check_compliant_node(x) # get tree leaves and extract attributes into data.frames xy_df <- node_to_df(x) - tabnet_pretrain(xy_df$x, xy_df$y, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch) - + tabnet_pretrain(xy_df$x, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch) } new_tabnet_pretrain <- function(pretrain, blueprint) { diff --git a/R/explain.R b/R/model_explain.R similarity index 100% rename from R/explain.R rename to R/model_explain.R diff --git a/R/pretraining.R b/R/model_pretraining.R similarity index 100% rename from R/pretraining.R rename to R/model_pretraining.R diff --git a/R/model.R b/R/model_training.R similarity index 100% rename from R/model.R rename to R/model_training.R diff --git a/R/tab-network.R b/R/tabnet_network.R similarity index 100% rename from R/tab-network.R rename to R/tabnet_network.R diff --git a/man/tabnet_config.Rd b/man/tabnet_config.Rd index d20ecd87..7295409d 100644 --- a/man/tabnet_config.Rd +++ b/man/tabnet_config.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/model.R +% Please edit documentation in R/model_training.R \name{tabnet_config} \alias{tabnet_config} \title{Configuration for TabNet models} diff --git a/man/tabnet_explain.Rd b/man/tabnet_explain.Rd index f750c5f5..1327c039 100644 --- a/man/tabnet_explain.Rd +++ b/man/tabnet_explain.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/explain.R +% Please edit documentation in R/model_explain.R \name{tabnet_explain} \alias{tabnet_explain} \alias{tabnet_explain.default} diff --git a/man/tabnet_nn.Rd b/man/tabnet_nn.Rd index ae5e8f76..3fa483a1 100644 --- a/man/tabnet_nn.Rd +++ b/man/tabnet_nn.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/tab-network.R +% Please edit documentation in R/tabnet_network.R \name{tabnet_nn} \alias{tabnet_nn} \title{TabNet Model Architecture} diff --git a/man/tabnet_pretrain.Rd b/man/tabnet_pretrain.Rd index 5c1e42e7..b3777ae5 100644 --- a/man/tabnet_pretrain.Rd +++ b/man/tabnet_pretrain.Rd @@ -15,7 +15,7 @@ tabnet_pretrain(x, ...) \method{tabnet_pretrain}{data.frame}( x, - y, + y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., diff --git a/tests/testthat/_snaps/pretraining.md b/tests/testthat/_snaps/model_pretraining.md similarity index 100% rename from tests/testthat/_snaps/pretraining.md rename to tests/testthat/_snaps/model_pretraining.md diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-model_explain.R similarity index 100% rename from tests/testthat/test-explain.R rename to tests/testthat/test-model_explain.R diff --git a/tests/testthat/test-pretraining.R b/tests/testthat/test-model_pretraining.R similarity index 100% rename from tests/testthat/test-pretraining.R rename to tests/testthat/test-model_pretraining.R diff --git a/tests/testthat/test-model.R b/tests/testthat/test-model_training.R similarity index 100% rename from tests/testthat/test-model.R rename to tests/testthat/test-model_training.R diff --git a/tests/testthat/test-parsnip.R b/tests/testthat/test-parsnip.R index 742281e8..6373e6ed 100644 --- a/tests/testthat/test-parsnip.R +++ b/tests/testthat/test-parsnip.R @@ -98,6 +98,7 @@ test_that("Check we can finalize a workflow from a tune_grid", { model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>% parsnip::set_mode("regression") %>% + parsnip::set_args(epochs = 2) %>% parsnip::set_engine("torch") wf <- workflows::workflow() %>% diff --git a/tests/testthat/test_translations.R b/tests/testthat/test-translations.R similarity index 100% rename from tests/testthat/test_translations.R rename to tests/testthat/test-translations.R From 2f5b71a77d8028f97fa09bdc9989f69f4cfbea2d Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sat, 2 May 2026 15:30:47 +0200 Subject: [PATCH 2/9] fix #187 --- .Rbuildignore | 1 + .gitignore | 1 + DESCRIPTION | 2 +- R/hardhat.R | 15 ++++++++++----- R/model_training.R | 20 +++++++++++++------- po/R-fr.po | 2 +- tests/testthat/test-hardhat_hierarchical.R | 10 ++++++---- 7 files changed, 33 insertions(+), 18 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index 94ba712b..15030367 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -15,3 +15,4 @@ ^revdep$ ^vignettes/*_files$ ^\.claude$ +^\.positai$ diff --git a/.gitignore b/.gitignore index 4b3d4c83..985fdf11 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ tabnet_*.tar.gz tabnet.Rproj po/glossary.csv inst/IMPORTLIST +.positai diff --git a/DESCRIPTION b/DESCRIPTION index 95044dcc..93bc2b5a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,5 +68,5 @@ Config/testthat/parallel: false Config/testthat/start-first: interface, explain, params Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.3 Language: en-US +Config/roxygen2/version: 8.0.0 diff --git a/R/hardhat.R b/R/hardhat.R index 48cb076a..af85eaba 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -164,14 +164,17 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), .. processed <- hardhat::mold(xy_df$x, xy_df$y) # Given n classes, M is an (n x n) matrix where M_ij = 1 if class i is descendant of class j ancestor <- data.tree::ToDataFrameNetwork(x) %>% - mutate_if(is.character, ~.x %>% as.factor %>% as.numeric) - # TODO check correctness - # embed the M matrix in the config$ancestor variable - dims <- c(max(ancestor), max(ancestor)) - ancestor_m <- Matrix::sparseMatrix(ancestor$from, ancestor$to, dims = dims, x = 1) + mutate_if(is.character, ~.x %>% as.factor %>% as.integer) + + # embed the M matrix in the config$ancestor_tt variable + ancestor_tt <- torch::torch_sparse_coo_tensor( + matrix(c(ancestor$from, ancestor$to), nrow = 2), + rep(TRUE, length(ancestor$from))) + check_type(processed$outcomes) config <- merge_config_and_dots(config, ...) + config$ancestor <- ancestor_tt tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised") } @@ -273,6 +276,8 @@ tabnet_pretrain.default <- function(x, ...) { #' @export #' @rdname tabnet_pretrain tabnet_pretrain.data.frame <- function(x, y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { + processed <- hardhat::mold(x, y) + config <- merge_config_and_dots(config, ...) tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "unsupervised") } diff --git a/R/model_training.R b/R/model_training.R index 30084761..50553ef6 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -175,7 +175,8 @@ tabnet_config <- function(batch_size = 1024^2, early_stopping_tolerance = 0, early_stopping_patience = 0L, num_workers=0L, - skip_importance = FALSE) { + skip_importance = FALSE + ) { if (is.null(decision_width) && is.null(attention_width)) { decision_width <- 8 # default is 8 } @@ -249,7 +250,7 @@ resolve_loss <- function(config, dtype) { loss_fn <- loss else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long()) loss_fn <- torch::nn_mse_loss() - else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor_tt)) + else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$.ancestor_tt)) # cross entropy loss is required loss_fn <- torch::nn_cross_entropy_loss() else @@ -278,14 +279,14 @@ train_batch <- function(network, optimizer, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$ancestor_tt)) { + if (!is.null(config$.ancestor_tt)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$.ancestor_tt)) )), dim = 1) } else { @@ -332,14 +333,14 @@ valid_batch <- function(network, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$ancestor_tt)) { + if (!is.null(config$.ancestor_tt)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$.ancestor_tt)) )), dim = 1) } else { @@ -513,7 +514,12 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s # provide ancestor to torch tensor in case of hierarchical classification if (!is.null(config$ancestor)) { - config$ancestor_tt <- torch::torch_tensor(config$ancestor)$to(torch::torch_bool(), device = device) + if (config$ancestor$is_spase()) { + # config is expected to carry the sparse tensor + config$.ancestor_tt <- config$ancestor + } else { + config$.ancestor_tt <- NULL + } } # instantiate optimizer diff --git a/po/R-fr.po b/po/R-fr.po index 14b4b8a1..efa406ad 100644 --- a/po/R-fr.po +++ b/po/R-fr.po @@ -225,7 +225,7 @@ msgid "" " Please change those names as they will lead to unexpected " "tabnet behavior." msgstr "" -"Les attributs ou noms de colonnes dans l’objet hiérarchique fournit utilise " +"Les `attributs` (noms de colonne) dans l’objet hiérarchique fournit utilisent " "les noms réservés suivants : {.vars {actual_names[actual_names %in% " "reserved_names]}}. Veuillez changer ces noms pour éviter un comportement " "imprévisible de TabNet." diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index a6fa0983..14f54595 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -20,7 +20,7 @@ test_that("C-HMCNN get_constr_output works ", { test_that("C-HMCNN max_constraint_output works ", { output <- torch::torch_rand(c(3, 5)) labels <- torch::torch_diag(rep(1,5))[1:3, ]$to(dtype = torch::torch_bool()) - ancestor <- torch::torch_tril(torch::torch_zeros(c(5, 5))$bernoulli(p = 0.2) )$to(dtype = torch::torch_bool()) + ancestor <- torch::torch_triu(torch::torch_zeros(c(5, 5))$bernoulli(p = 0.2) )$to(dtype = torch::torch_bool()) expect_no_error( MC_output <- max_constraint_output(output, labels, ancestor) @@ -34,7 +34,7 @@ test_that("C-HMCNN max_constraint_output works ", { ) # max_constraint_output provides more than 35% null values expect_gte( - as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), .30 * output$shape[1] * output$shape[2] + as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), .30 * prod(output$shape) ) }) @@ -69,12 +69,13 @@ test_that("Training hierarchical classification for {data.tree} Node", { expect_no_error( fit <- tabnet_fit(acme, epochs = 1) ) + expect_named(fit$fit$config, "ancestor") expect_no_error( result <- predict(fit, acme_df, type = "prob") ) expect_equal(ncol(result), 3) - outcome_levels <-levels(fit$blueprint$ptypes$outcomes[[1]]) + outcome_levels <- levels(fit$blueprint$ptypes$outcomes[[1]]) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), outcome_levels) expect_no_error( @@ -106,7 +107,8 @@ test_that("Training hierarchical classification for {data.tree} Node with valida expect_no_error( fit <- tabnet_fit(attrition_tree, valid_split = 0.2, epochs = 1) ) - + expect_named(fit$fit$config, "ancestor") + expect_no_error( result <- predict(fit, attrition_tree, type = "prob") ) From 81b1d5a7695ef3340a9c02b61df8098cdeaad62d Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sat, 2 May 2026 16:09:26 +0200 Subject: [PATCH 3/9] lint with jarl --- R/hardhat.R | 2 +- R/model_training.R | 2 +- R/parsnip.R | 2 +- R/plot.R | 2 +- tests/testthat/setup.R | 2 +- tests/testthat/test-hardhat_hierarchical.R | 9 ++++++--- tests/testthat/test-hardhat_multi-outcome.R | 8 ++++---- tests/testthat/test-model_explain.R | 4 ++-- tests/testthat/test-parsnip.R | 10 +++++----- vignettes/Hierarchical_classification.Rmd | 22 ++++++++++++++++----- 10 files changed, 39 insertions(+), 24 deletions(-) diff --git a/R/hardhat.R b/R/hardhat.R index af85eaba..77cafd7c 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -439,7 +439,7 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) { is_multi_outcome <- ncol(object$blueprint$ptypes$outcomes) > 1 outcome_nlevels <- NULL if (is_multi_outcome & type != "numeric") { - outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~nlevels(.x)) } if (!is.null(epoch)) { diff --git a/R/model_training.R b/R/model_training.R index 50553ef6..11bdd0d0 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -696,7 +696,7 @@ predict_impl_numeric <- function(obj, x, batch_size) { predict_impl_numeric_multiple <- function(obj, x, batch_size) { p <- as.matrix(predict_impl(obj, x, batch_size)) # TODO use a cleaner function to turn matrix into vectors - hardhat::spruce_numeric_multiple(!!!purrr::map(1:ncol(p), ~p[,.x])) + hardhat::spruce_numeric_multiple(!!!purrr::map(seq_len(ncol(p)), ~p[,.x])) } #' single-outcome level blueprint diff --git a/R/parsnip.R b/R/parsnip.R index 00627a59..33d16fd6 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -539,7 +539,7 @@ multi_predict._tabnet_fit <- function(object, new_data, type = NULL, epochs = NU pred <- predict(object$fit, new_data, type = type, epoch = epoch) nms <- names(pred) pred[["epochs"]] <- epoch - pred[[".row"]] <- 1:nrow(new_data) + pred[[".row"]] <- seq_len(nrow(new_data)) pred[, c(".row", "epochs", nms)] }) diff --git a/R/plot.R b/R/plot.R index f84ea638..4d3985d5 100644 --- a/R/plot.R +++ b/R/plot.R @@ -41,7 +41,7 @@ autoplot.tabnet_fit <- function(object, ...) { if ("checkpoint" %in% names(collect_metrics)) { checkpoints <- collect_metrics %>% - dplyr::filter(checkpoint == TRUE, dataset == "train") %>% + dplyr::filter(checkpoint, dataset == "train") %>% dplyr::select(-checkpoint) %>% dplyr::mutate(size = 2) p + diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4b78d736..589ba492 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -13,7 +13,7 @@ y <- ames[ids,]$Sale_Price # ames common models ames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1) -ames_pretrain_vsplit <- tabnet_pretrain(x, y, epochs = 3, valid_split=.2, +ames_pretrain_vsplit <- tabnet_pretrain(x, y, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) ames_fit <- tabnet_fit(x, y, epochs = 5 , checkpoint_epochs = 2) ames_fit_vsplit <- tabnet_fit(x, y, tabnet_model=ames_pretrain_vsplit, epochs = 3, diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index 14f54595..c8acdf06 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -34,7 +34,7 @@ test_that("C-HMCNN max_constraint_output works ", { ) # max_constraint_output provides more than 35% null values expect_gte( - as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), .30 * prod(output$shape) + as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), 0.30 * prod(output$shape) ) }) @@ -70,6 +70,8 @@ test_that("Training hierarchical classification for {data.tree} Node", { fit <- tabnet_fit(acme, epochs = 1) ) expect_named(fit$fit$config, "ancestor") + expect_true(fit$fit$config$ancestor$is_sparse()) + expect_no_error( result <- predict(fit, acme_df, type = "prob") ) @@ -92,7 +94,7 @@ test_that("Training hierarchical classification for {data.tree} Node", { expect_equal(ncol(result), 2) # 2 outcomes levels_ - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -108,6 +110,7 @@ test_that("Training hierarchical classification for {data.tree} Node with valida fit <- tabnet_fit(attrition_tree, valid_split = 0.2, epochs = 1) ) expect_named(fit$fit$config, "ancestor") + expect_true(fit$fit$config$ancestor$is_sparse()) expect_no_error( result <- predict(fit, attrition_tree, type = "prob") @@ -115,7 +118,7 @@ test_that("Training hierarchical classification for {data.tree} Node with valida expect_equal(ncol(result), 2) # 2 outcomes levels_ - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) diff --git a/tests/testthat/test-hardhat_multi-outcome.R b/tests/testthat/test-hardhat_multi-outcome.R index fa3eafd2..d1a79d43 100644 --- a/tests/testthat/test-hardhat_multi-outcome.R +++ b/tests/testthat/test-hardhat_multi-outcome.R @@ -54,7 +54,7 @@ test_that("Training multilabel classification from data.frame", { ) expect_equal(ncol(result), 3) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -82,7 +82,7 @@ test_that("Training multilabel classification from formula", { ) expect_equal(ncol(result), 2) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -108,7 +108,7 @@ test_that("Training multilabel classification from recipe", { ) expect_equal(ncol(result), 2) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(outcome_nlevels)) }) @@ -126,7 +126,7 @@ test_that("Training multilabel classification from data.frame with validation sp expect_equal(ncol(result), 3) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) diff --git a/tests/testthat/test-model_explain.R b/tests/testthat/test-model_explain.R index 549afed4..db0ace3b 100644 --- a/tests/testthat/test-model_explain.R +++ b/tests/testthat/test-model_explain.R @@ -52,7 +52,7 @@ test_that("explain works for dataframe, formula and recipe", { # formula - tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=.2, + tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) expect_no_error( tabnet_explain(tabnet_pretrain, new_data=small_ames) @@ -69,7 +69,7 @@ test_that("explain works for dataframe, formula and recipe", { step_zv(all_predictors()) %>% step_normalize(all_numeric_predictors()) - tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=.2, + tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) expect_no_error( tabnet_explain(tabnet_pretrain, new_data=small_ames) diff --git a/tests/testthat/test-parsnip.R b/tests/testthat/test-parsnip.R index 6373e6ed..f322bcbe 100644 --- a/tests/testthat/test-parsnip.R +++ b/tests/testthat/test-parsnip.R @@ -135,7 +135,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$epochs, rep(3, 2)) expect_equal(reg_grid_smol$penalty, 1:2) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(epochs = 1:2)) } @@ -156,7 +156,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_extra_smol$epochs, rep(3, 6)) expect_equal(reg_grid_extra_smol$penalty, rep(1:2, each = 3)) expect_equal(reg_grid_extra_smol$batch_size, rep(10:12, 2)) - for (i in 1:nrow(reg_grid_extra_smol)) { + for (i in seq_len(nrow(reg_grid_extra_smol))) { expect_equal(reg_grid_extra_smol$.submodels[[i]], list(epochs = 1:2)) } @@ -173,7 +173,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(no_sub_smol$epochs, rep(1, 2)) expect_equal(no_sub_smol$penalty, 1:2) - for (i in 1:nrow(no_sub_smol)) { + for (i in seq_len(nrow(no_sub_smol))) { expect_length(no_sub_smol$.submodels[[i]], 0) } @@ -185,7 +185,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$Amos, rep(3, 2)) expect_equal(reg_grid_smol$penalty, 1:2) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(Amos = 1:2)) } @@ -203,7 +203,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$`Ade Tukunbo`, rep(3, 4)) expect_equal(reg_grid_smol$penalty, rep(1:2, each = 2)) expect_equal(reg_grid_smol$` \t123`, rep(10:11, 2)) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2)) } }) diff --git a/vignettes/Hierarchical_classification.Rmd b/vignettes/Hierarchical_classification.Rmd index ab7bee93..e4f804d9 100644 --- a/vignettes/Hierarchical_classification.Rmd +++ b/vignettes/Hierarchical_classification.Rmd @@ -25,20 +25,32 @@ library(tibble) set.seed(202307) ``` -## Data preparation +## Data format The supported data format for hierarchical classification is the `Node` object format from package `{data.tree}`. This is a general purpose format that fits generic hierarchical tree encoding needs. Each node of the tree is associated with predictor values through the `attributes` in the data `Node` object. - - A very basic example is the `acme` dataset to show you how the two predictors values `cost` and `p` are associates attributes of each node in the hierarchy : +|{tabnet} concept| {data.tree} concept |see command| +|---|---|---| +|dataset predictor| Node `attributesAll` | acme example | +|dataset multi-label target| Node hierarchy | print(acme) | + + +A very basic example is the `acme` dataset to show you how the two predictors values `cost` and `p` are associates attributes of each node in the hierarchy : ```{r} data(acme, package = "data.tree") acme$attributesAll print(acme, "cost", "p" , limit = 8) ``` +So printing Node objects reverse the usual ordering, as target is printed first in column `levelName`, and predictors printed right of it. + +As you can see, only leaf nodes of the tree gets predictors value. {tabnet} will take this into account via an `ancestor` square sparse tensor registering all possible parent-child relation among the target labels. + + +## Data preparation -- Multiple manual or programmatic methods are available to create or update predictors. They are detailled in the `vignette("data.tree", package = "data.tree")`. +Multiple manual or programmatic methods are available to create or update predictors. They are detailled in the `vignette("data.tree", package = "data.tree")`. - a lot of native hierarchical data-format conversion from files to `Node` are covered by the`{data.tree}` package. You can find them in the "Create tree from a file" section of the same vignette. If needed, the `{ape}` package covers a lot of conversion format to the `philo` format. Thus you can reach the `Node` format in maybe two transformation steps... @@ -74,13 +86,13 @@ As `as.Node()` will only consider the as.numeric() values of a factor(), you sho Your dataset hierarchy will be turn internally into multi-outcomes named `level_1` to `level_n`, n beeing the depth of your tree. Thus column names starting with `level_` should be avoided. -### Ensure the last hierarchy of the tree is the observation id +### Ensure the last hierarchy of the tree is the **observation id** The tree only keeps a single row of attributes per tree leaf. Thus in order to transfer your complete predictors dataset into the Node object, you must keep the last level of the hierarchy to be a unique observation identifier (last resort beeing `rowid_to_column()` to achieve it). The classification will be done **removing the last level of hierarchy** in any case. -### Ensure there is a root level in the hierarchy +### Ensure there is a **root level** in the hierarchy The tree should have a single root for all nodes to be consistent. Thus you have to use a constant prefix to all `pathString`. From 57295f7ee78e8d88eb2877df9f90961f58c5bdd0 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sat, 2 May 2026 16:13:00 +0200 Subject: [PATCH 4/9] lint && fix --- R/hardhat.R | 2 +- R/model_pretraining.R | 4 ++-- R/model_training.R | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/hardhat.R b/R/hardhat.R index 77cafd7c..fc056347 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -438,7 +438,7 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) { type <- check_type(object$blueprint$ptypes$outcomes, type) is_multi_outcome <- ncol(object$blueprint$ptypes$outcomes) > 1 outcome_nlevels <- NULL - if (is_multi_outcome & type != "numeric") { + if (is_multi_outcome && type != "numeric") { outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~nlevels(.x)) } diff --git a/R/model_pretraining.R b/R/model_pretraining.R index 2ec315d6..0f48482e 100644 --- a/R/model_pretraining.R +++ b/R/model_pretraining.R @@ -178,9 +178,9 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift = metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss } - if (config$verbose & !has_valid) + if (config$verbose && !has_valid) message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train))) - if (config$verbose & has_valid) + if (config$verbose && has_valid) message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid))) # Early-stopping checks diff --git a/R/model_training.R b/R/model_training.R index 11bdd0d0..23547b9d 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -585,9 +585,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss } - if (config$verbose & !has_valid) + if (config$verbose && !has_valid) message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train))) - if (config$verbose & has_valid) + if (config$verbose && has_valid) message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid))) From b1600245ce5d3e7923ae47f59f5332490c3e0c6f Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sat, 2 May 2026 20:11:02 +0200 Subject: [PATCH 5/9] do not duplicate ancestor transport --- R/hardhat.R | 8 ++++---- R/model_training.R | 16 +++++++--------- tests/testthat/test-hardhat_hierarchical.R | 10 +++++----- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/R/hardhat.R b/R/hardhat.R index fc056347..51a37888 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -166,15 +166,15 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), .. ancestor <- data.tree::ToDataFrameNetwork(x) %>% mutate_if(is.character, ~.x %>% as.factor %>% as.integer) - # embed the M matrix in the config$ancestor_tt variable - ancestor_tt <- torch::torch_sparse_coo_tensor( + # embed the M matrix in the config$ancestor variable + ancestor <- torch::torch_sparse_coo_tensor( matrix(c(ancestor$from, ancestor$to), nrow = 2), rep(TRUE, length(ancestor$from))) check_type(processed$outcomes) config <- merge_config_and_dots(config, ...) - config$ancestor <- ancestor_tt + config$ancestor <- ancestor tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised") } @@ -606,4 +606,4 @@ nn_prune_head.tabnet_pretrain <- function(x, head_size) { nn_prune_head(x$fit$network, head_size=head_size) } -} \ No newline at end of file +} diff --git a/R/model_training.R b/R/model_training.R index 23547b9d..977ac550 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -250,7 +250,7 @@ resolve_loss <- function(config, dtype) { loss_fn <- loss else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long()) loss_fn <- torch::nn_mse_loss() - else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$.ancestor_tt)) + else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor)) # cross entropy loss is required loss_fn <- torch::nn_cross_entropy_loss() else @@ -279,14 +279,14 @@ train_batch <- function(network, optimizer, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$.ancestor_tt)) { + if (!is.null(config$ancestor)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$.ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor)) )), dim = 1) } else { @@ -333,14 +333,14 @@ valid_batch <- function(network, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$.ancestor_tt)) { + if (!is.null(config$ancestor)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$.ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor)) )), dim = 1) } else { @@ -514,11 +514,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s # provide ancestor to torch tensor in case of hierarchical classification if (!is.null(config$ancestor)) { - if (config$ancestor$is_spase()) { + if (!config$ancestor$is_sparse()) { # config is expected to carry the sparse tensor - config$.ancestor_tt <- config$ancestor - } else { - config$.ancestor_tt <- NULL + runtime_error("ancestor was configured. Expecting a sparse tensor but got {.class {class(config$ancestor)}}") } } diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index c8acdf06..9d5e4ef3 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -18,9 +18,9 @@ test_that("C-HMCNN get_constr_output works ", { }) test_that("C-HMCNN max_constraint_output works ", { - output <- torch::torch_rand(c(3, 5)) - labels <- torch::torch_diag(rep(1,5))[1:3, ]$to(dtype = torch::torch_bool()) - ancestor <- torch::torch_triu(torch::torch_zeros(c(5, 5))$bernoulli(p = 0.2) )$to(dtype = torch::torch_bool()) + output <- torch::torch_rand(c(5, 7)) + labels <- torch::torch_diag(rep(1,7))[1:5, ]$to(dtype = torch::torch_bool()) + ancestor <- torch::torch_triu(torch::torch_zeros(c(7, 7))$bernoulli(p = 0.1) )$to(dtype = torch::torch_bool()) expect_no_error( MC_output <- max_constraint_output(output, labels, ancestor) @@ -32,9 +32,9 @@ test_that("C-HMCNN max_constraint_output works ", { expect_not_equal_to_tensor( MC_output, output ) - # max_constraint_output provides more than 35% null values + # max_constraint_output provides more than 50% null values expect_gte( - as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), 0.30 * prod(output$shape) + as.numeric((MC_output == 0)$sum()), 0.50 * prod(output$shape) ) }) From 9acdeda52eab189bda836c177348643c9a79341c Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Sun, 3 May 2026 14:52:13 +0200 Subject: [PATCH 6/9] augment and fix get_constr_output add a proper `build_ancestor_matrix` [ FAIL 12 | WARN 0 | SKIP 0 | PASS 126 ] --- NAMESPACE | 3 + R/hardhat.R | 120 +++++++- R/model_training.R | 10 +- tests/testthat/helper-tensor.R | 9 +- tests/testthat/setup.R | 2 + tests/testthat/test-hardhat_hierarchical.R | 341 +++++++++++++++++++-- 6 files changed, 439 insertions(+), 46 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 8f99e229..7cc1064a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -72,5 +72,8 @@ importFrom(stats,predict) importFrom(stats,update) importFrom(tidyr,replace_na) importFrom(torch,nn_prune_head) +importFrom(torch,torch_int64) +importFrom(torch,torch_ones) +importFrom(torch,torch_sparse_coo_tensor) importFrom(tune,min_grid) importFrom(zeallot,"%<-%") diff --git a/R/hardhat.R b/R/hardhat.R index 51a37888..d1111ef1 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -162,19 +162,17 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), .. # get tree leaves and extract attributes into data.frames xy_df <- node_to_df(x) processed <- hardhat::mold(xy_df$x, xy_df$y) + check_type(processed$outcomes) + # Given n classes, M is an (n x n) matrix where M_ij = 1 if class i is descendant of class j - ancestor <- data.tree::ToDataFrameNetwork(x) %>% + edges <- data.tree::ToDataFrameNetwork(x) %>% mutate_if(is.character, ~.x %>% as.factor %>% as.integer) # embed the M matrix in the config$ancestor variable - ancestor <- torch::torch_sparse_coo_tensor( - matrix(c(ancestor$from, ancestor$to), nrow = 2), - rep(TRUE, length(ancestor$from))) + ancestor_tt <- build_ancestor_matrix(edges) - check_type(processed$outcomes) - config <- merge_config_and_dots(config, ...) - config$ancestor <- ancestor + config$ancestor <- ancestor_tt tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised") } @@ -607,3 +605,111 @@ nn_prune_head.tabnet_pretrain <- function(x, head_size) { } } + +#' Build a sparse ancestor-descendant matrix from a hierarchy edge list +#' +#' Given a directed graph where edges point from descendant to ancestor, +#' computes the full transitive closure via BFS, then transposes so that +#' the resulting sparse matrix R satisfies R\[i, j\] = 1 whenever class j +#' is a descendant of class i (including i itself). This is the +#' orientation expected by \code{get_constr_output} and the +#' max-constraint-margin (MCM) loss. +#' +#' @param edges A \code{data.frame} with exactly two integer columns +#' named \code{"from"} and \code{"to"}. Each row represents a +#' directed edge from a descendant node to one of its ancestors. +#' Node IDs must be positive integers. Self-loops (e.g. \code{1 -> 1}) +#' are allowed but not required; the diagonal is always set to 1 +#' for every node in \code{1:n_classes}. +#' @param n_classes `integer(1)` or `code{NULL}`. Total number of +#' classes. When \code{NULL} (the default), it is computed as +#' \code{max(edges$from, edges$to)} so that every node appearing in +#' the edge list is represented. Supply an explicit value when there +#' are classes with no edges at all that must still appear in the +#' matrix. +#' +#' @return A \code{torch_sparse_coo_tensor} of shape +#' \code{(n_classes, n_classes)} and dtype \code{torch_double()}. +#' Entry \code{R[i, j] = 1} means class \code{j} is a descendant of +#' class \code{i}. Indices follow torch's 0-based convention. +#' +#' @details +#' The algorithm proceeds in three stages: +#' \enumerate{ +#' \item Build an adjacency list from the edge \code{data.frame} using +#' \code{split()}, grouping by the \code{from} column. Each entry +#' \code{adj[[i]]} contains the direct ancestors reachable from node +#' \code{i} in one hop. +#' \item Run a breadth-first search from every node \code{i = 1, ..., +#' n_classes}, following outgoing edges to discover the full set of +#' ancestors (the transitive closure). A logical \code{visited} +#' vector provides O(1) membership tests and prevents infinite loops +#' when cycles are present. +#' \item Transpose the collected COO index pairs so that the final +#' matrix is oriented for MCM: \code{R[i, j] = 1} means "j is a +#' descendant of i". +#' } +#' Because the diagonal is always filled, every class is its own +#' descendant, ensuring that the MCM constraint is at least as +#' permissive as the unconstrained prediction. +#' +#' @examples +#' \dontrun{ +#' edges <- data.frame( +#' from = c(1L, 1L, 2L, 2L, 3L), +#' to = c(2L, 3L, 4L, 5L, 5L) +#' ) +#' R <- build_ancestor_matrix(edges, n_classes = 5L) +#' } +#' +#' @importFrom torch torch_ones torch_int64 torch_sparse_coo_tensor +#' @noRd +build_ancestor_matrix <- function(edges, n_classes = NULL) { + + if (is.null(n_classes)) { + n_classes <- max(c(edges$from, edges$to)) + } + + # Build adjacency list efficiently: adj[[i]] = direct ancestors of i + adj <- vector("list", n_classes) + by_from <- split(edges$to, edges$from) + for (nm in names(by_from)) { + adj[[as.integer(nm)]] <- as.integer(by_from[[nm]]) + } + + # BFS from each node to find all reachable nodes via outgoing edges + # (i.e., all ancestors). Collect COO indices. + idx_list <- lapply(seq_len(n_classes), function(i) { + visited <- rep(FALSE, n_classes) + visited[i] <- TRUE + frontier <- adj[[i]] + + while (length(frontier) > 0L) { + next_frontier <- integer(0) + for (node in frontier) { + if (!visited[node]) { + visited[node] <- TRUE + next_frontier <- c(next_frontier, adj[[node]]) + } + } + frontier <- next_frontier + } + + reached <- which(visited) + cbind(rep(i, length(reached)), reached) + }) + + # Combine all pairs: before transpose, (i, j) means j is ancestor of i + idx_mat <- do.call(rbind, idx_list) + + # Transpose: swap columns so that (i, j) means j is descendant of i + idx_mat <- cbind(idx_mat[, 2L], idx_mat[, 1L]) + + # idx <- torch::torch_tensor( + # matrix(idx_mat - 1L, nrow = 2L), + # dtype = torch::torch_int64() + # ) + # vals <- torch::torch_ones(nrow(idx_mat), dtype = torch::torch_double()) + + torch::torch_sparse_coo_tensor(t(idx_mat), rep(TRUE, nrow(idx_mat)), c(n_classes, n_classes)) +} diff --git a/R/model_training.R b/R/model_training.R index 977ac550..739a0beb 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -229,8 +229,8 @@ tabnet_config <- function(batch_size = 1024^2, get_constr_output <- function(x, R) { # MCM of the prediction given the hierarchy constraint expressed in the matrix R """ - c_out <- x$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2])) - R_batch <- R$expand(c(x$shape[1], R$shape[2], R$shape[2])) + c_out <- x$to(dtype = torch::torch_double())$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2])) + R_batch <- R$unsqueeze(1)$expand(c(x$shape[1], R$shape[2], R$shape[2])) final_out <- torch::torch_max(R_batch * c_out, dim = 3) final_out[[1]] } @@ -238,7 +238,7 @@ get_constr_output <- function(x, R) { max_constraint_output <- function(output, labels, ancestor) { constr_output <- get_constr_output(output, ancestor) train_output <- get_constr_output(labels * output, ancestor) - labels$bitwise_not() * constr_output + labels * train_output + torch::torch_logical_not(labels) * constr_output + labels * train_output } resolve_loss <- function(config, dtype) { @@ -271,7 +271,7 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) { } train_batch <- function(network, optimizer, batch, config) { - # NULLing values to avoid a R-CMD Check Note "No visible binding for global variable" + # NULL-ing values to avoid a R-CMD Check Note "No visible binding for global variable" out <- M_loss <- NULL # forward pass c(out, M_loss) %<-% network(batch$x, batch$x_na_mask) @@ -516,7 +516,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s if (!is.null(config$ancestor)) { if (!config$ancestor$is_sparse()) { # config is expected to carry the sparse tensor - runtime_error("ancestor was configured. Expecting a sparse tensor but got {.class {class(config$ancestor)}}") + runtime_error("ancestor was configured. Expecting a sparse tensor but got {.cls {class(config$ancestor)}}") } } diff --git a/tests/testthat/helper-tensor.R b/tests/testthat/helper-tensor.R index 31b5c9bd..534bff14 100644 --- a/tests/testthat/helper-tensor.R +++ b/tests/testthat/helper-tensor.R @@ -38,7 +38,7 @@ expect_no_error <- function(object, ...) { expect_tensor <- function(object) { expect_true(torch:::is_torch_tensor(object)) - expect_no_error(torch::as_array(object$to(device = "cpu"))) + expect_no_error(torch::as_array(object$to_dense()$to(device = "cpu"))) } expect_equal_to_r <- function(object, expected, ...) { @@ -50,6 +50,13 @@ expect_tensor_shape <- function(object, expected) { expect_equal(object$shape, expected) } + +expect_tensor_dtype <- function(object, expected_dtype) { + expect_tensor(object) + expect_true(object$dtype == expected_dtype) +} + + expect_undefined_tensor <- function(object) { # TODO } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 589ba492..74c72ae9 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -38,6 +38,8 @@ attr_fitted_vsplit <- tabnet_fit(attrix, attriy, epochs = 12, valid_split=0.3) utils::data("acme", package = "data.tree") acme_df <- data.tree::ToDataFrameTypeCol(acme, acme$attributesAll) %>% select(-starts_with("level_")) +# acme2 <- acme$clone() +# acme2$RemoveAttribute("level_3") attrition_tree <- attrition %>% tibble::rowid_to_column() %>% diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index 9d5e4ef3..8ce70b4a 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -1,41 +1,316 @@ -test_that("C-HMCNN get_constr_output works ", { - x <- torch::torch_rand(c(2,4)) - R <- torch::torch_tril(torch::torch_zeros(c(4,4))$bernoulli(p = 0.2) + torch::torch_diag(rep(1,4)))$to(dtype = torch::torch_bool()) - expect_no_error( - constr_output <- get_constr_output(x, R) - ) - expect_tensor_shape( - constr_output, x$shape - ) - # expect_equal( - # constr_output$dtype, torch_tensor(0.1)$dtype - # ) +test_that("get_constr_output handles basic 2D input with identity constraint", { + x <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2), dtype = torch_float32()) + R <- torch_eye(2, dtype = torch_float32()) + result <- get_constr_output(x, R) + expect_tensor(result) + expect_tensor_shape(result, c(2, 2)) + expect_equal_to_r(result, matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2)) +}) - R <- torch::torch_zeros(c(4,4))$to(dtype = torch::torch_bool()) - expect_equal_to_tensor( - get_constr_output(x, R), torch::torch_zeros_like(x) - ) +test_that("get_constr_output applies hierarchy constraint correctly", { + x <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 2)) + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) }) -test_that("C-HMCNN max_constraint_output works ", { - output <- torch::torch_rand(c(5, 7)) - labels <- torch::torch_diag(rep(1,7))[1:5, ]$to(dtype = torch::torch_bool()) - ancestor <- torch::torch_triu(torch::torch_zeros(c(7, 7))$bernoulli(p = 0.1) )$to(dtype = torch::torch_bool()) +test_that("get_constr_output preserves input dtype", { + x_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + x_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + R <- torch_eye(2) + expect_tensor_dtype(get_constr_output(x_f32, R), torch_float64()) + expect_tensor_dtype(get_constr_output(x_f64, R), torch_float64()) +}) - expect_no_error( - MC_output <- max_constraint_output(output, labels, ancestor) - ) - expect_tensor_shape( - MC_output, output$shape - ) - # max_constraint_output is not identity - expect_not_equal_to_tensor( - MC_output, output +test_that("get_constr_output handles batch dimension correctly", { + x <- torch_tensor(matrix(1:12, nrow = 3, ncol = 4)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(3, 4)) + + for (i in 1:3) { + row_result <- as_array(result[i, ]) + max_grp1 <- max(as_array(x[i, 1:2])) + max_grp2 <- max(as_array(x[i, 3:4])) + + expect_equal(row_result[1:2], rep(max_grp1, 2), tolerance = 1e-6) + expect_equal(row_result[3:4], rep(max_grp2, 2), tolerance = 1e-6) + } +}) +test_that("get_constr_output works with single sample", { + x <- torch_tensor(matrix(c(2, 1, 4, 3), nrow = 1, ncol = 4, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(1, 4)) + expected <- matrix(c(2, 2, 4, 4), nrow = 1, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("get_constr_output handles all-zeros constraint matrix", { + x <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + R <- torch_zeros(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + expect_equal_to_r(result, matrix(0, nrow = 2, ncol = 3)) +}) + +test_that("get_constr_output handles all-ones constraint matrix", { + x <- torch_tensor(matrix(c(1, 5, 3, 2, 4, 6), nrow = 2, ncol = 3, byrow = TRUE)) + R <- torch_ones(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + # Each row is filled with its own row-wise maximum + expected <- matrix(c(5, 5, 5, 6, 6, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) +}) + +test_that("get_constr_output throws error for dimension mismatch", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_eye(3) + expect_error(get_constr_output(x, R), "must match the existing size") +}) + +test_that("get_constr_output throws error for non-2D R", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_tensor(array(1:8, dim = c(2, 2, 2))) + expect_error(get_constr_output(x, R), "dimension") +}) + +test_that("max_constraint_output returns original output when ancestor is identity", { + output <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 3), dtype = torch_bool()) + ancestor <- torch_eye(3) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # With an identity ancestor matrix, constraint propagation is neutral. + # The formula simplifies to: (~labels * output) + (labels * output) == output + expect_equal_to_r(result, matrix(1:6, nrow = 2, ncol = 3)) +}) + +test_that("max_constraint_output applies constraint to positive labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 1, 0), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # Unlabelled positions get propagated raw max, labelled get propagated masked max + expected <- matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles all-zero labels", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_zeros(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + result <- max_constraint_output(output, labels, ancestor) + # With all false labels, result equals constr_output. With identity ancestor, constr_output == output + expect_equal_to_r(result, matrix(1:4, nrow = 2, ncol = 2)) +}) + +test_that("max_constraint_output handles all-one labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # When all labels are TRUE, (~labels) is 0, so result = train_output. + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output preserves output dtype", { + output_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + output_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_tensor_dtype(max_constraint_output(output_f32, labels, ancestor), torch_float64()) + expect_tensor_dtype(max_constraint_output(output_f64, labels, ancestor), torch_float64()) +}) + + +test_that("max_constraint_output works with complex hierarchy", { + output <- torch_tensor(matrix(c(1, 2, 3, 4, 5, 6), nrow = 2, ncol = 3, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 0, 0, 1, 0), nrow = 2, ncol = 3, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 1, 0, 1, 1, 0, 0, 1), nrow = 3, ncol = 3, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # Row 1: label on col 1 -> train_output[1,1]=1, others get constr_output=3 + # Row 2: label on col 2 -> train_output[2,2]=5, others get constr_output=6 + row1_expected <- c(1, 3, 3) + row2_expected <- c(6, 5, 6) + expect_equal_to_r(result[1, ], row1_expected) + expect_equal_to_r(result[2, ], row2_expected) +}) + +test_that("max_constraint_output handles single element tensors", { + output <- torch_tensor(matrix(5, nrow = 1, ncol = 1)) + labels <- torch_tensor(matrix(TRUE, nrow = 1, ncol = 1), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(1, nrow = 1, ncol = 1)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(1, 1)) + # Compare against 1x1 matrix instead of scalar to match torch array output + expect_equal_to_r(result, matrix(5, nrow = 1, ncol = 1)) +}) + +test_that("max_constraint_output throws error for dimension mismatch", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 3), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_error(max_constraint_output(output, labels, ancestor), "dimension") +}) + +test_that("max_constraint_output handles float labels without error", { + # torch_logical_not works on float tensors (0.0 -> TRUE, others -> FALSE) + # No explicit type check exists in the function, so it should run successfully + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 2), dtype = torch_float32()) + ancestor <- torch_eye(2) + expect_silent(max_constraint_output(output, labels, ancestor)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) +}) + +test_that("get_constr_output and max_constraint_output compose correctly", { + output <- torch_tensor(matrix(c(1, 4, 2, 3), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + direct <- max_constraint_output(output, labels, ancestor) + constr_out <- get_constr_output(output, ancestor) + train_out <- get_constr_output(labels * output, ancestor) + manual <- torch_logical_not(labels) * constr_out + labels * train_out + expect_equal_to_r(direct, as_array(manual)) +}) + +test_that("get_constr_output handles negative values correctly", { + x <- torch_tensor(matrix(c(-5, -1, -3, -2), nrow = 2, ncol = 2, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- get_constr_output(x, R) + expected <- matrix(c(-1, 0, -2, 0), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles mixed positive-negative with constraints", { + output <- torch_tensor(matrix(c(-5, 3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expected <- matrix(c(-1, 0, 0, 4), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("build-ancestor-matrix diagonal is always 1 for every class", { + edges <- data.frame(from = c(1L, 2L, 3L), to = c(2L, 3L, 1L)) + R <- build_ancestor_matrix(edges, n_classes = 3L) + R_dense <- R$to_dense() + + expect_equal_to_r(R_dense[1, 1], TRUE) + expect_equal_to_r(R_dense[2, 2], TRUE) + expect_equal_to_r(R_dense[3, 3], TRUE) +}) + +test_that("build-ancestor-matrix: single edge produces correct transitive pair", { + # 1 -> 2 means "2 is ancestor of 1", so transposed: R[2, 1] = 1 + edges <- data.frame(from = c(1L, 2L), to = c(2L, 2L)) + R <- build_ancestor_matrix(edges, n_classes = 2L) + R_dense <- R$to_dense() + + # 2 is descendant of 2 (self) + expect_equal_to_r(R_dense[2, 2], TRUE) + # 1 is descendant of 1 (self) + expect_equal_to_r(R_dense[1, 1], TRUE) + # 1 is descendant of 2 (because 1 -> 2) + expect_equal_to_r(R_dense[2, 1], TRUE) + # 2 is NOT descendant of 1 + expect_equal_to_r(R_dense[1, 2], FALSE) +}) + +test_that("build-ancestor-matrix: multi-hop ancestor chain is fully resolved", { + # Chain: 1 -> 2 -> 3 -> 4 (each is ancestor of the previous) + # After transpose: 4 is descendant of 1, 2, 3, 4 + # 3 is descendant of 1, 2, 3 + # 2 is descendant of 1, 2 + # 1 is descendant of 1 + edges <- data.frame( + from = c(1L, 2L, 3L, 1L, 2L, 3L, 4L), + to = c(2L, 3L, 4L, 1L, 2L, 3L, 4L) ) - # max_constraint_output provides more than 50% null values - expect_gte( - as.numeric((MC_output == 0)$sum()), 0.50 * prod(output$shape) + R <- build_ancestor_matrix(edges, n_classes = 4L) + R_dense <- R$to_dense() + + # Row 1: only node 1 is its own descendant + expect_equal_to_r(R_dense[1, ], c(TRUE, FALSE, FALSE, FALSE)) + # Row 2: nodes 1 and 2 are descendants of 2 + expect_equal_to_r(R_dense[2, ], c(TRUE, TRUE, FALSE, FALSE)) + # Row 3: nodes 1, 2, 3 are descendants of 3 + expect_equal_to_r(R_dense[3, ], c(TRUE, TRUE, TRUE, FALSE)) + # Row 4: all nodes are descendants of 4 + expect_equal_to_r(R_dense[4, ], c(TRUE, TRUE, TRUE, TRUE)) +}) + +test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { + # Diamond: 1 -> 2 -> 4, 1 -> 3 -> 4 + # After transpose: 4 is descendant of all; 2 and 3 are descendants of + # 1 and themselves only + edges <- data.frame( + from = c(1L, 1L, 2L, 3L), + to = c(2L, 3L, 4L, 4L) ) + R <- build_ancestor_matrix(edges, n_classes = 4L) + R_dense <- R$to_dense() + + expect_equal_to_r(R_dense[4, 1], TRUE) + expect_equal_to_r(R_dense[4, 2], TRUE) + expect_equal_to_r(R_dense[4, 3], TRUE) + expect_equal_to_r(R_dense[4, 4], TRUE) + expect_equal_to_r(R_dense[2, 3], FALSE) + expect_equal_to_r(R_dense[3, 2], FALSE) +}) + +test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { + edges <- data.frame(from = c(1L, 1L), to = c(2L, 1L)) + R <- build_ancestor_matrix(edges, n_classes = 5L) + R_dense <- R$to_dense() + + # Nodes 3, 4, 5 have no edges + expect_equal_to_r(R_dense[3, 3], TRUE) + expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) + expect_equal_to_r(R_dense[4, 4], TRUE) + expect_equal_to_r(R_dense[5, 5], TRUE) +}) + +test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { + edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) + R <- build_ancestor_matrix(edges) + + # n_classes should be max(1, 5, TRUE) = 5 + expect_tensor_shape(R, c(5, 5)) +}) + +test_that("build-ancestor-matrix: output has correct shape and dtype", { + edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) + R <- build_ancestor_matrix(edges, n_classes = 3L) + + expect_tensor_shape(R, c(3L, 3L)) + expect_tensor_dtype(R, torch::torch_bool()) + expect_true(R$is_sparse()) +}) + +test_that("build-ancestor-matrix: output uses 0-based indices internally", { + # Verify that torch sees correct values when converted to dense + edges <- data.frame(from = c(1L, 2L, 3L), to = c(2L, 3L, 1L)) + R <- build_ancestor_matrix(edges, n_classes = 3L) + # 3 -> 1 -> 2 -> 3 is a cycle; after transpose every node is + # descendant of every other node + expect_equal_to_r(R$to_dense(), matrix(rep(TRUE, 9), nrow=3)) +}) + +test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { + edges <- data.frame(from = 1L, to = 1L) + R <- build_ancestor_matrix(edges, n_classes = 1L) + expect_tensor_shape(R$to_dense(), c(1L, 1L)) + expect_equal_to_r(R$to_dense()[1, 1], TRUE) }) test_that("node_to_df works ", { @@ -69,7 +344,7 @@ test_that("Training hierarchical classification for {data.tree} Node", { expect_no_error( fit <- tabnet_fit(acme, epochs = 1) ) - expect_named(fit$fit$config, "ancestor") + expect_true("ancestor" %in% fit$fit$config$names) expect_true(fit$fit$config$ancestor$is_sparse()) expect_no_error( From b6e52c0da62460a34916235c46da4a3173661979 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Mon, 4 May 2026 17:28:47 +0200 Subject: [PATCH 7/9] prune leafs in adjacency matrix extraction --- R/hardhat.R | 133 ++++++----------- R/model_training.R | 2 +- tests/testthat/test-hardhat_hierarchical.R | 165 +++++++++------------ 3 files changed, 116 insertions(+), 184 deletions(-) diff --git a/R/hardhat.R b/R/hardhat.R index d1111ef1..5cdaf4b9 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -164,12 +164,7 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), .. processed <- hardhat::mold(xy_df$x, xy_df$y) check_type(processed$outcomes) - # Given n classes, M is an (n x n) matrix where M_ij = 1 if class i is descendant of class j - edges <- data.tree::ToDataFrameNetwork(x) %>% - mutate_if(is.character, ~.x %>% as.factor %>% as.integer) - - # embed the M matrix in the config$ancestor variable - ancestor_tt <- build_ancestor_matrix(edges) + ancestor_tt <- build_ancestor_matrix(x) config <- merge_config_and_dots(config, ...) config$ancestor <- ancestor_tt @@ -615,101 +610,67 @@ nn_prune_head.tabnet_pretrain <- function(x, head_size) { #' orientation expected by \code{get_constr_output} and the #' max-constraint-margin (MCM) loss. #' -#' @param edges A \code{data.frame} with exactly two integer columns -#' named \code{"from"} and \code{"to"}. Each row represents a -#' directed edge from a descendant node to one of its ancestors. -#' Node IDs must be positive integers. Self-loops (e.g. \code{1 -> 1}) -#' are allowed but not required; the diagonal is always set to 1 -#' for every node in \code{1:n_classes}. -#' @param n_classes `integer(1)` or `code{NULL}`. Total number of -#' classes. When \code{NULL} (the default), it is computed as -#' \code{max(edges$from, edges$to)} so that every node appearing in -#' the edge list is represented. Supply an explicit value when there -#' are classes with no edges at all that must still appear in the -#' matrix. -#' +#' @param x a Node object. #' @return A \code{torch_sparse_coo_tensor} of shape #' \code{(n_classes, n_classes)} and dtype \code{torch_double()}. #' Entry \code{R[i, j] = 1} means class \code{j} is a descendant of #' class \code{i}. Indices follow torch's 0-based convention. #' -#' @details -#' The algorithm proceeds in three stages: -#' \enumerate{ -#' \item Build an adjacency list from the edge \code{data.frame} using -#' \code{split()}, grouping by the \code{from} column. Each entry -#' \code{adj[[i]]} contains the direct ancestors reachable from node -#' \code{i} in one hop. -#' \item Run a breadth-first search from every node \code{i = 1, ..., -#' n_classes}, following outgoing edges to discover the full set of -#' ancestors (the transitive closure). A logical \code{visited} -#' vector provides O(1) membership tests and prevents infinite loops -#' when cycles are present. -#' \item Transpose the collected COO index pairs so that the final -#' matrix is oriented for MCM: \code{R[i, j] = 1} means "j is a -#' descendant of i". -#' } -#' Because the diagonal is always filled, every class is its own -#' descendant, ensuring that the MCM constraint is at least as -#' permissive as the unconstrained prediction. -#' -#' @examples -#' \dontrun{ -#' edges <- data.frame( -#' from = c(1L, 1L, 2L, 2L, 3L), -#' to = c(2L, 3L, 4L, 5L, 5L) -#' ) -#' R <- build_ancestor_matrix(edges, n_classes = 5L) -#' } #' #' @importFrom torch torch_ones torch_int64 torch_sparse_coo_tensor #' @noRd -build_ancestor_matrix <- function(edges, n_classes = NULL) { +build_ancestor_matrix <- function(x) { + # 1. Extract edges + edges <- data.tree::ToDataFrameNetwork(x) + # 2. prune tree from root and from leafs + non_root_edges <- edges$from != x$path + non_leaf_targets <- edges$to %in% unique(edges$from) - if (is.null(n_classes)) { - n_classes <- max(c(edges$from, edges$to)) - } - - # Build adjacency list efficiently: adj[[i]] = direct ancestors of i - adj <- vector("list", n_classes) - by_from <- split(edges$to, edges$from) - for (nm in names(by_from)) { - adj[[as.integer(nm)]] <- as.integer(by_from[[nm]]) + edges <- edges[non_root_edges & non_leaf_targets, ] + + # 3. Map node names to integer indices + all_nodes <- unique(c(edges$from, edges$to)) + n <- length(all_nodes) + # Handle case where no edges match the filter + if (n == 0) { + return(matrix(nrow = 0, ncol = 2)) } - # BFS from each node to find all reachable nodes via outgoing edges - # (i.e., all ancestors). Collect COO indices. - idx_list <- lapply(seq_len(n_classes), function(i) { - visited <- rep(FALSE, n_classes) - visited[i] <- TRUE - frontier <- adj[[i]] + # Create a lookup map: name -> index + node_map <- setNames(seq_along(all_nodes), all_nodes) + + # Conversion of edges to integer indices + from_idx <- node_map[edges$from] + to_idx <- node_map[edges$to] + + # 4. Build Adjacency Matrix + # adj_mat[i, j] = 1 means i is a direct parent of j + adj_mat <- matrix(0L, nrow = n, ncol = n) + adj_mat[cbind(from_idx, to_idx)] <- 1L + + # 5. Compute Transitive Closure (Ancestors) + # Initialize reachability matrix with self-loops (Identity) + direct connections + reachability <- adj_mat + diag(n) + + # Use Boolean Matrix Multiplication to find all reachable nodes + # (i, j) = 1 if j is reachable from i (i is ancestor of j) + repeat { + # reachability %*% reachability finds paths of length 2*k + # Multiplying the matrix by itself effectively extends the reachable frontier + next_reachability <- (reachability %*% reachability) > 0 - while (length(frontier) > 0L) { - next_frontier <- integer(0) - for (node in frontier) { - if (!visited[node]) { - visited[node] <- TRUE - next_frontier <- c(next_frontier, adj[[node]]) - } - } - frontier <- next_frontier + # Check for convergence + if (identical(next_reachability, reachability)) { + break } - reached <- which(visited) - cbind(rep(i, length(reached)), reached) - }) - - # Combine all pairs: before transpose, (i, j) means j is ancestor of i - idx_mat <- do.call(rbind, idx_list) - - # Transpose: swap columns so that (i, j) means j is descendant of i - idx_mat <- cbind(idx_mat[, 2L], idx_mat[, 1L]) + # Convert back to integer/numeric for next iteration + reachability <- next_reachability * 1L + } - # idx <- torch::torch_tensor( - # matrix(idx_mat - 1L, nrow = 2L), - # dtype = torch::torch_int64() - # ) - # vals <- torch::torch_ones(nrow(idx_mat), dtype = torch::torch_double()) + # 6. Extract indices (COO format) + # which(arr.ind = TRUE) returns a matrix where col 1 is row (Ancestor) and col 2 is column (Descendant) + idx_mat <- which(reachability == 1L, arr.ind = TRUE) torch::torch_sparse_coo_tensor(t(idx_mat), rep(TRUE, nrow(idx_mat)), c(n_classes, n_classes)) } diff --git a/R/model_training.R b/R/model_training.R index 739a0beb..f5fe82dd 100644 --- a/R/model_training.R +++ b/R/model_training.R @@ -286,7 +286,7 @@ train_batch <- function(network, optimizer, batch, config) { torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor)) + ~config$loss_fn(max_constraint_output(.x, .y, config$ancestor), .y$squeeze(2)) )), dim = 1) } else { diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index 8ce70b4a..bddc2286 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -133,15 +133,14 @@ test_that("max_constraint_output preserves output dtype", { test_that("max_constraint_output works with complex hierarchy", { output <- torch_tensor(matrix(c(1, 2, 3, 4, 5, 6), nrow = 2, ncol = 3, byrow = TRUE)) labels <- torch_tensor(matrix(c(1, 0, 0, 0, 1, 0), nrow = 2, ncol = 3, byrow = TRUE), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(c(1, 1, 1, 0, 1, 1, 0, 0, 1), nrow = 3, ncol = 3, byrow = TRUE)) + ancestor <- torch_triu(torch_ones(c(3,3))) result <- max_constraint_output(output, labels, ancestor) expect_tensor_shape(result, c(2, 3)) # Row 1: label on col 1 -> train_output[1,1]=1, others get constr_output=3 # Row 2: label on col 2 -> train_output[2,2]=5, others get constr_output=6 - row1_expected <- c(1, 3, 3) - row2_expected <- c(6, 5, 6) - expect_equal_to_r(result[1, ], row1_expected) - expect_equal_to_r(result[2, ], row2_expected) + expected <- matrix(c(1, 3, 3, + 6, 5, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected) }) test_that("max_constraint_output handles single element tensors", { @@ -192,19 +191,21 @@ test_that("get_constr_output handles negative values correctly", { }) test_that("max_constraint_output handles mixed positive-negative with constraints", { - output <- torch_tensor(matrix(c(-5, 3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) - labels <- torch_tensor(matrix(c(TRUE, FALSE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + output <- torch_tensor(matrix(c(-5, -3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, TRUE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) result <- max_constraint_output(output, labels, ancestor) - expected <- matrix(c(-1, 0, 0, 4), nrow = 2, ncol = 2, byrow = TRUE) + expected <- matrix(c(-3, 0, 4, 4), nrow = 2, ncol = 2, byrow = TRUE) expect_equal_to_r(result, expected) }) +# need rework as FromDataFrameNetwork(edges) gives "cannot find root name" error test_that("build-ancestor-matrix diagonal is always 1 for every class", { - edges <- data.frame(from = c(1L, 2L, 3L), to = c(2L, 3L, 1L)) - R <- build_ancestor_matrix(edges, n_classes = 3L) + edges <- data.frame(from = c(1L, 2L, 2L), + to = c(2L, 3L, 4L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) R_dense <- R$to_dense() - + expect_equal_to_r(R_dense[1, 1], TRUE) expect_equal_to_r(R_dense[2, 2], TRUE) expect_equal_to_r(R_dense[3, 3], TRUE) @@ -212,10 +213,10 @@ test_that("build-ancestor-matrix diagonal is always 1 for every class", { test_that("build-ancestor-matrix: single edge produces correct transitive pair", { # 1 -> 2 means "2 is ancestor of 1", so transposed: R[2, 1] = 1 - edges <- data.frame(from = c(1L, 2L), to = c(2L, 2L)) - R <- build_ancestor_matrix(edges, n_classes = 2L) + edges <- data.frame(from = c(1L, 2L), to = c(2L, 3L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) R_dense <- R$to_dense() - + # 2 is descendant of 2 (self) expect_equal_to_r(R_dense[2, 2], TRUE) # 1 is descendant of 1 (self) @@ -227,26 +228,20 @@ test_that("build-ancestor-matrix: single edge produces correct transitive pair", }) test_that("build-ancestor-matrix: multi-hop ancestor chain is fully resolved", { - # Chain: 1 -> 2 -> 3 -> 4 (each is ancestor of the previous) + # Chain: 2 -> 3 -> 4 -> 5 (each is ancestor of the previous) # After transpose: 4 is descendant of 1, 2, 3, 4 # 3 is descendant of 1, 2, 3 # 2 is descendant of 1, 2 # 1 is descendant of 1 edges <- data.frame( - from = c(1L, 2L, 3L, 1L, 2L, 3L, 4L), - to = c(2L, 3L, 4L, 1L, 2L, 3L, 4L) + from = c(1L, 2L, 3L, 4L), + to = c(2L, 3L, 4L, 5L) ) - R <- build_ancestor_matrix(edges, n_classes = 4L) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) R_dense <- R$to_dense() - + # Row 1: only node 1 is its own descendant - expect_equal_to_r(R_dense[1, ], c(TRUE, FALSE, FALSE, FALSE)) - # Row 2: nodes 1 and 2 are descendants of 2 - expect_equal_to_r(R_dense[2, ], c(TRUE, TRUE, FALSE, FALSE)) - # Row 3: nodes 1, 2, 3 are descendants of 3 - expect_equal_to_r(R_dense[3, ], c(TRUE, TRUE, TRUE, FALSE)) - # Row 4: all nodes are descendants of 4 - expect_equal_to_r(R_dense[4, ], c(TRUE, TRUE, TRUE, TRUE)) + expect_equal_to_r(R_dense, lower.tri(diag(4), diag = TRUE)) }) test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { @@ -257,61 +252,58 @@ test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { from = c(1L, 1L, 2L, 3L), to = c(2L, 3L, 4L, 4L) ) - R <- build_ancestor_matrix(edges, n_classes = 4L) - R_dense <- R$to_dense() - - expect_equal_to_r(R_dense[4, 1], TRUE) - expect_equal_to_r(R_dense[4, 2], TRUE) - expect_equal_to_r(R_dense[4, 3], TRUE) - expect_equal_to_r(R_dense[4, 4], TRUE) - expect_equal_to_r(R_dense[2, 3], FALSE) - expect_equal_to_r(R_dense[3, 2], FALSE) -}) - -test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { - edges <- data.frame(from = c(1L, 1L), to = c(2L, 1L)) - R <- build_ancestor_matrix(edges, n_classes = 5L) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) R_dense <- R$to_dense() - - # Nodes 3, 4, 5 have no edges - expect_equal_to_r(R_dense[3, 3], TRUE) - expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) - expect_equal_to_r(R_dense[4, 4], TRUE) - expect_equal_to_r(R_dense[5, 5], TRUE) -}) - -test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { - edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) - R <- build_ancestor_matrix(edges) - - # n_classes should be max(1, 5, TRUE) = 5 - expect_tensor_shape(R, c(5, 5)) -}) - -test_that("build-ancestor-matrix: output has correct shape and dtype", { - edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) - R <- build_ancestor_matrix(edges, n_classes = 3L) - - expect_tensor_shape(R, c(3L, 3L)) - expect_tensor_dtype(R, torch::torch_bool()) - expect_true(R$is_sparse()) -}) + expect_equal_to_r(R_dense[3, 1], TRUE) + expect_equal_to_r(R_dense[4, 2], TRUE) + expect_equal_to_r(R_dense[4, 3], FALSE) + expect_equal_to_r(R_dense[2, 4], FALSE) +}) + +# test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { +# edges <- data.frame(from = c(1L, 1L), +# to = c(2L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# R_dense <- R$to_dense() +# +# # Nodes 3, 4, 5 have no edges +# expect_equal_to_r(R_dense[3, 3], TRUE) +# expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) +# expect_equal_to_r(R_dense[4, 4], TRUE) +# expect_equal_to_r(R_dense[5, 5], TRUE) +# }) +# +# test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { +# edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# +# # n_classes should be max(1, 5, TRUE) = 5 +# expect_tensor_shape(R, c(5, 5)) +# }) + +# test_that("build-ancestor-matrix: output has correct shape and dtype", { +# edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# +# expect_tensor_shape(R, c(3L, 3L)) +# expect_tensor_dtype(R, torch::torch_bool()) +# expect_true(R$is_sparse()) +# }) +# test_that("build-ancestor-matrix: output uses 0-based indices internally", { # Verify that torch sees correct values when converted to dense - edges <- data.frame(from = c(1L, 2L, 3L), to = c(2L, 3L, 1L)) - R <- build_ancestor_matrix(edges, n_classes = 3L) - # 3 -> 1 -> 2 -> 3 is a cycle; after transpose every node is - # descendant of every other node - expect_equal_to_r(R$to_dense(), matrix(rep(TRUE, 9), nrow=3)) + edges <- data.frame(from = c(1L, 2L, 1L), to = c(2L, 3L, 3L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + expect_equal_to_r(R$to_dense(), matrix(c(TRUE, TRUE, FALSE, TRUE), nrow=2)) }) -test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { - edges <- data.frame(from = 1L, to = 1L) - R <- build_ancestor_matrix(edges, n_classes = 1L) - expect_tensor_shape(R$to_dense(), c(1L, 1L)) - expect_equal_to_r(R$to_dense()[1, 1], TRUE) -}) +# test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { +# edges <- data.frame(from = 1L, to = 1L) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# expect_tensor_shape(R$to_dense(), c(1L, 1L)) +# expect_equal_to_r(R$to_dense()[1, 1], TRUE) +# }) test_that("node_to_df works ", { expect_no_error( @@ -338,28 +330,7 @@ test_that("node_to_df works ", { }) - -test_that("Training hierarchical classification for {data.tree} Node", { - - expect_no_error( - fit <- tabnet_fit(acme, epochs = 1) - ) - expect_true("ancestor" %in% fit$fit$config$names) - expect_true(fit$fit$config$ancestor$is_sparse()) - - expect_no_error( - result <- predict(fit, acme_df, type = "prob") - ) - - expect_equal(ncol(result), 3) - outcome_levels <- levels(fit$blueprint$ptypes$outcomes[[1]]) - # we get back outcomes vars with a `.pred_` prefix - expect_equal(stringr::str_remove(names(result), ".pred_"), outcome_levels) - expect_no_error( - result <- predict(fit, acme_df) - ) - expect_equal(ncol(result), 1) - +test_that("Training hierarchical classification for {data.tree} Node attrition_tree", { expect_no_error( fit <- tabnet_fit(attrition_tree, epochs = 1) ) From 3d702d8296880678ccefbd4031caac7c44ac1c65 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Mon, 4 May 2026 17:29:14 +0200 Subject: [PATCH 8/9] split tests bw utils and modeling --- tests/testthat/test-hardhat_hierarchical.R | 332 --------------------- tests/testthat/test-hierarchical_utils.R | 331 ++++++++++++++++++++ 2 files changed, 331 insertions(+), 332 deletions(-) create mode 100644 tests/testthat/test-hierarchical_utils.R diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index bddc2286..1b3d9ef8 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -1,335 +1,3 @@ -test_that("get_constr_output handles basic 2D input with identity constraint", { - x <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2), dtype = torch_float32()) - R <- torch_eye(2, dtype = torch_float32()) - result <- get_constr_output(x, R) - expect_tensor(result) - expect_tensor_shape(result, c(2, 2)) - expect_equal_to_r(result, matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2)) -}) - -test_that("get_constr_output applies hierarchy constraint correctly", { - x <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) - R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) - result <- get_constr_output(x, R) - expect_tensor_shape(result, c(2, 2)) - expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) - expect_equal_to_r(result, expected, tolerance = 1e-6) -}) - -test_that("get_constr_output preserves input dtype", { - x_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) - x_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) - R <- torch_eye(2) - expect_tensor_dtype(get_constr_output(x_f32, R), torch_float64()) - expect_tensor_dtype(get_constr_output(x_f64, R), torch_float64()) -}) - -test_that("get_constr_output handles batch dimension correctly", { - x <- torch_tensor(matrix(1:12, nrow = 3, ncol = 4)) - R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) - result <- get_constr_output(x, R) - expect_tensor_shape(result, c(3, 4)) - - for (i in 1:3) { - row_result <- as_array(result[i, ]) - max_grp1 <- max(as_array(x[i, 1:2])) - max_grp2 <- max(as_array(x[i, 3:4])) - - expect_equal(row_result[1:2], rep(max_grp1, 2), tolerance = 1e-6) - expect_equal(row_result[3:4], rep(max_grp2, 2), tolerance = 1e-6) - } -}) -test_that("get_constr_output works with single sample", { - x <- torch_tensor(matrix(c(2, 1, 4, 3), nrow = 1, ncol = 4, byrow = TRUE)) - R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) - result <- get_constr_output(x, R) - expect_tensor_shape(result, c(1, 4)) - expected <- matrix(c(2, 2, 4, 4), nrow = 1, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -test_that("get_constr_output handles all-zeros constraint matrix", { - x <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) - R <- torch_zeros(c(3, 3)) - result <- get_constr_output(x, R) - expect_tensor_shape(result, c(2, 3)) - expect_equal_to_r(result, matrix(0, nrow = 2, ncol = 3)) -}) - -test_that("get_constr_output handles all-ones constraint matrix", { - x <- torch_tensor(matrix(c(1, 5, 3, 2, 4, 6), nrow = 2, ncol = 3, byrow = TRUE)) - R <- torch_ones(c(3, 3)) - result <- get_constr_output(x, R) - expect_tensor_shape(result, c(2, 3)) - # Each row is filled with its own row-wise maximum - expected <- matrix(c(5, 5, 5, 6, 6, 6), nrow = 2, ncol = 3, byrow = TRUE) - expect_equal_to_r(result, expected, tolerance = 1e-6) -}) - -test_that("get_constr_output throws error for dimension mismatch", { - x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) - R <- torch_eye(3) - expect_error(get_constr_output(x, R), "must match the existing size") -}) - -test_that("get_constr_output throws error for non-2D R", { - x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) - R <- torch_tensor(array(1:8, dim = c(2, 2, 2))) - expect_error(get_constr_output(x, R), "dimension") -}) - -test_that("max_constraint_output returns original output when ancestor is identity", { - output <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) - labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 3), dtype = torch_bool()) - ancestor <- torch_eye(3) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(2, 3)) - # With an identity ancestor matrix, constraint propagation is neutral. - # The formula simplifies to: (~labels * output) + (labels * output) == output - expect_equal_to_r(result, matrix(1:6, nrow = 2, ncol = 3)) -}) - -test_that("max_constraint_output applies constraint to positive labels", { - output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) - labels <- torch_tensor(matrix(c(1, 0, 1, 0), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(2, 2)) - # Unlabelled positions get propagated raw max, labelled get propagated masked max - expected <- matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -test_that("max_constraint_output handles all-zero labels", { - output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) - labels <- torch_zeros(c(2, 2), dtype = torch_bool()) - ancestor <- torch_eye(2) - result <- max_constraint_output(output, labels, ancestor) - # With all false labels, result equals constr_output. With identity ancestor, constr_output == output - expect_equal_to_r(result, matrix(1:4, nrow = 2, ncol = 2)) -}) - -test_that("max_constraint_output handles all-one labels", { - output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) - labels <- torch_ones(c(2, 2), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(2, 2)) - # When all labels are TRUE, (~labels) is 0, so result = train_output. - expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -test_that("max_constraint_output preserves output dtype", { - output_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) - output_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) - labels <- torch_ones(c(2, 2), dtype = torch_bool()) - ancestor <- torch_eye(2) - expect_tensor_dtype(max_constraint_output(output_f32, labels, ancestor), torch_float64()) - expect_tensor_dtype(max_constraint_output(output_f64, labels, ancestor), torch_float64()) -}) - - -test_that("max_constraint_output works with complex hierarchy", { - output <- torch_tensor(matrix(c(1, 2, 3, 4, 5, 6), nrow = 2, ncol = 3, byrow = TRUE)) - labels <- torch_tensor(matrix(c(1, 0, 0, 0, 1, 0), nrow = 2, ncol = 3, byrow = TRUE), dtype = torch_bool()) - ancestor <- torch_triu(torch_ones(c(3,3))) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(2, 3)) - # Row 1: label on col 1 -> train_output[1,1]=1, others get constr_output=3 - # Row 2: label on col 2 -> train_output[2,2]=5, others get constr_output=6 - expected <- matrix(c(1, 3, 3, - 6, 5, 6), nrow = 2, ncol = 3, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -test_that("max_constraint_output handles single element tensors", { - output <- torch_tensor(matrix(5, nrow = 1, ncol = 1)) - labels <- torch_tensor(matrix(TRUE, nrow = 1, ncol = 1), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(1, nrow = 1, ncol = 1)) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(1, 1)) - # Compare against 1x1 matrix instead of scalar to match torch array output - expect_equal_to_r(result, matrix(5, nrow = 1, ncol = 1)) -}) - -test_that("max_constraint_output throws error for dimension mismatch", { - output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) - labels <- torch_ones(c(2, 3), dtype = torch_bool()) - ancestor <- torch_eye(2) - expect_error(max_constraint_output(output, labels, ancestor), "dimension") -}) - -test_that("max_constraint_output handles float labels without error", { - # torch_logical_not works on float tensors (0.0 -> TRUE, others -> FALSE) - # No explicit type check exists in the function, so it should run successfully - output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) - labels <- torch_ones(c(2, 2), dtype = torch_float32()) - ancestor <- torch_eye(2) - expect_silent(max_constraint_output(output, labels, ancestor)) - result <- max_constraint_output(output, labels, ancestor) - expect_tensor_shape(result, c(2, 2)) -}) - -test_that("get_constr_output and max_constraint_output compose correctly", { - output <- torch_tensor(matrix(c(1, 4, 2, 3), nrow = 2, ncol = 2, byrow = TRUE)) - labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) - direct <- max_constraint_output(output, labels, ancestor) - constr_out <- get_constr_output(output, ancestor) - train_out <- get_constr_output(labels * output, ancestor) - manual <- torch_logical_not(labels) * constr_out + labels * train_out - expect_equal_to_r(direct, as_array(manual)) -}) - -test_that("get_constr_output handles negative values correctly", { - x <- torch_tensor(matrix(c(-5, -1, -3, -2), nrow = 2, ncol = 2, byrow = TRUE)) - R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) - result <- get_constr_output(x, R) - expected <- matrix(c(-1, 0, -2, 0), nrow = 2, ncol = 2, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -test_that("max_constraint_output handles mixed positive-negative with constraints", { - output <- torch_tensor(matrix(c(-5, -3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) - labels <- torch_tensor(matrix(c(TRUE, TRUE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) - ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) - result <- max_constraint_output(output, labels, ancestor) - expected <- matrix(c(-3, 0, 4, 4), nrow = 2, ncol = 2, byrow = TRUE) - expect_equal_to_r(result, expected) -}) - -# need rework as FromDataFrameNetwork(edges) gives "cannot find root name" error -test_that("build-ancestor-matrix diagonal is always 1 for every class", { - edges <- data.frame(from = c(1L, 2L, 2L), - to = c(2L, 3L, 4L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() - - expect_equal_to_r(R_dense[1, 1], TRUE) - expect_equal_to_r(R_dense[2, 2], TRUE) - expect_equal_to_r(R_dense[3, 3], TRUE) -}) - -test_that("build-ancestor-matrix: single edge produces correct transitive pair", { - # 1 -> 2 means "2 is ancestor of 1", so transposed: R[2, 1] = 1 - edges <- data.frame(from = c(1L, 2L), to = c(2L, 3L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() - - # 2 is descendant of 2 (self) - expect_equal_to_r(R_dense[2, 2], TRUE) - # 1 is descendant of 1 (self) - expect_equal_to_r(R_dense[1, 1], TRUE) - # 1 is descendant of 2 (because 1 -> 2) - expect_equal_to_r(R_dense[2, 1], TRUE) - # 2 is NOT descendant of 1 - expect_equal_to_r(R_dense[1, 2], FALSE) -}) - -test_that("build-ancestor-matrix: multi-hop ancestor chain is fully resolved", { - # Chain: 2 -> 3 -> 4 -> 5 (each is ancestor of the previous) - # After transpose: 4 is descendant of 1, 2, 3, 4 - # 3 is descendant of 1, 2, 3 - # 2 is descendant of 1, 2 - # 1 is descendant of 1 - edges <- data.frame( - from = c(1L, 2L, 3L, 4L), - to = c(2L, 3L, 4L, 5L) - ) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() - - # Row 1: only node 1 is its own descendant - expect_equal_to_r(R_dense, lower.tri(diag(4), diag = TRUE)) -}) - -test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { - # Diamond: 1 -> 2 -> 4, 1 -> 3 -> 4 - # After transpose: 4 is descendant of all; 2 and 3 are descendants of - # 1 and themselves only - edges <- data.frame( - from = c(1L, 1L, 2L, 3L), - to = c(2L, 3L, 4L, 4L) - ) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() - - expect_equal_to_r(R_dense[3, 1], TRUE) - expect_equal_to_r(R_dense[4, 2], TRUE) - expect_equal_to_r(R_dense[4, 3], FALSE) - expect_equal_to_r(R_dense[2, 4], FALSE) -}) - -# test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { -# edges <- data.frame(from = c(1L, 1L), -# to = c(2L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# R_dense <- R$to_dense() -# -# # Nodes 3, 4, 5 have no edges -# expect_equal_to_r(R_dense[3, 3], TRUE) -# expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) -# expect_equal_to_r(R_dense[4, 4], TRUE) -# expect_equal_to_r(R_dense[5, 5], TRUE) -# }) -# -# test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { -# edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# -# # n_classes should be max(1, 5, TRUE) = 5 -# expect_tensor_shape(R, c(5, 5)) -# }) - -# test_that("build-ancestor-matrix: output has correct shape and dtype", { -# edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# -# expect_tensor_shape(R, c(3L, 3L)) -# expect_tensor_dtype(R, torch::torch_bool()) -# expect_true(R$is_sparse()) -# }) -# -test_that("build-ancestor-matrix: output uses 0-based indices internally", { - # Verify that torch sees correct values when converted to dense - edges <- data.frame(from = c(1L, 2L, 1L), to = c(2L, 3L, 3L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - expect_equal_to_r(R$to_dense(), matrix(c(TRUE, TRUE, FALSE, TRUE), nrow=2)) -}) - -# test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { -# edges <- data.frame(from = 1L, to = 1L) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# expect_tensor_shape(R$to_dense(), c(1L, 1L)) -# expect_equal_to_r(R$to_dense()[1, 1], TRUE) -# }) - -test_that("node_to_df works ", { - expect_no_error( - node_to_df(acme) - ) - expect_no_error( - attrition_df <- node_to_df(attrition_tree) - ) - # node_to_df removes first and last level of the hierarchy - outcome_levels <- paste0("level_", seq(2, attrition_tree$height - 1)) - expect_equal(names(attrition_df$y), outcome_levels) - - # node_to_df do not shuffle outcome rows - df <- tibble(pred_1 = seq(1,26), pred_2 = seq(26,1), - level_2 = factor(LETTERS[1:26]), level_3 = factor(letters[26:1])) - df_node_df <- df %>% - mutate(pathString = paste("synth", level_2, level_3, level_3, sep = "/")) %>% - select(-level_2, -level_3) %>% - as.Node() %>% - node_to_df() - - expect_equal(df_node_df$y %>% as_tibble(), df %>% select(starts_with("level_"))) - expect_equal(df_node_df$x %>% as_tibble(), df %>% select(starts_with("pred_"))) - -}) - test_that("Training hierarchical classification for {data.tree} Node attrition_tree", { expect_no_error( fit <- tabnet_fit(attrition_tree, epochs = 1) diff --git a/tests/testthat/test-hierarchical_utils.R b/tests/testthat/test-hierarchical_utils.R new file mode 100644 index 00000000..905d94b4 --- /dev/null +++ b/tests/testthat/test-hierarchical_utils.R @@ -0,0 +1,331 @@ +test_that("get_constr_output handles basic 2D input with identity constraint", { + x <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2), dtype = torch_float32()) + R <- torch_eye(2, dtype = torch_float32()) + result <- get_constr_output(x, R) + expect_tensor(result) + expect_tensor_shape(result, c(2, 2)) + expect_equal_to_r(result, matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2)) +}) + +test_that("get_constr_output applies hierarchy constraint correctly", { + x <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 2)) + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) +}) + +test_that("get_constr_output preserves input dtype", { + x_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + x_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + R <- torch_eye(2) + expect_tensor_dtype(get_constr_output(x_f32, R), torch_float64()) + expect_tensor_dtype(get_constr_output(x_f64, R), torch_float64()) +}) + +test_that("get_constr_output handles batch dimension correctly", { + x <- torch_tensor(matrix(1:12, nrow = 3, ncol = 4)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(3, 4)) + + for (i in 1:3) { + row_result <- as_array(result[i, ]) + max_grp1 <- max(as_array(x[i, 1:2])) + max_grp2 <- max(as_array(x[i, 3:4])) + + expect_equal(row_result[1:2], rep(max_grp1, 2), tolerance = 1e-6) + expect_equal(row_result[3:4], rep(max_grp2, 2), tolerance = 1e-6) + } +}) +test_that("get_constr_output works with single sample", { + x <- torch_tensor(matrix(c(2, 1, 4, 3), nrow = 1, ncol = 4, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(1, 4)) + expected <- matrix(c(2, 2, 4, 4), nrow = 1, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("get_constr_output handles all-zeros constraint matrix", { + x <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + R <- torch_zeros(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + expect_equal_to_r(result, matrix(0, nrow = 2, ncol = 3)) +}) + +test_that("get_constr_output handles all-ones constraint matrix", { + x <- torch_tensor(matrix(c(1, 5, 3, 2, 4, 6), nrow = 2, ncol = 3, byrow = TRUE)) + R <- torch_ones(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + # Each row is filled with its own row-wise maximum + expected <- matrix(c(5, 5, 5, 6, 6, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) +}) + +test_that("get_constr_output throws error for dimension mismatch", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_eye(3) + expect_error(get_constr_output(x, R), "must match the existing size") +}) + +test_that("get_constr_output throws error for non-2D R", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_tensor(array(1:8, dim = c(2, 2, 2))) + expect_error(get_constr_output(x, R), "dimension") +}) + +test_that("max_constraint_output returns original output when ancestor is identity", { + output <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 3), dtype = torch_bool()) + ancestor <- torch_eye(3) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # With an identity ancestor matrix, constraint propagation is neutral. + # The formula simplifies to: (~labels * output) + (labels * output) == output + expect_equal_to_r(result, matrix(1:6, nrow = 2, ncol = 3)) +}) + +test_that("max_constraint_output applies constraint to positive labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 1, 0), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # Unlabelled positions get propagated raw max, labelled get propagated masked max + expected <- matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles all-zero labels", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_zeros(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + result <- max_constraint_output(output, labels, ancestor) + # With all false labels, result equals constr_output. With identity ancestor, constr_output == output + expect_equal_to_r(result, matrix(1:4, nrow = 2, ncol = 2)) +}) + +test_that("max_constraint_output handles all-one labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # When all labels are TRUE, (~labels) is 0, so result = train_output. + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output preserves output dtype", { + output_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + output_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_tensor_dtype(max_constraint_output(output_f32, labels, ancestor), torch_float64()) + expect_tensor_dtype(max_constraint_output(output_f64, labels, ancestor), torch_float64()) +}) + + +test_that("max_constraint_output works with complex hierarchy", { + output <- torch_tensor(matrix(c(1, 2, 3, 4, 5, 6), nrow = 2, ncol = 3, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 0, 0, 1, 0), nrow = 2, ncol = 3, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_triu(torch_ones(c(3,3))) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # Row 1: label on col 1 -> train_output[1,1]=1, others get constr_output=3 + # Row 2: label on col 2 -> train_output[2,2]=5, others get constr_output=6 + expected <- matrix(c(1, 3, 3, + 6, 5, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles single element tensors", { + output <- torch_tensor(matrix(5, nrow = 1, ncol = 1)) + labels <- torch_tensor(matrix(TRUE, nrow = 1, ncol = 1), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(1, nrow = 1, ncol = 1)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(1, 1)) + # Compare against 1x1 matrix instead of scalar to match torch array output + expect_equal_to_r(result, matrix(5, nrow = 1, ncol = 1)) +}) + +test_that("max_constraint_output throws error for dimension mismatch", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 3), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_error(max_constraint_output(output, labels, ancestor), "dimension") +}) + +test_that("max_constraint_output handles float labels without error", { + # torch_logical_not works on float tensors (0.0 -> TRUE, others -> FALSE) + # No explicit type check exists in the function, so it should run successfully + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 2), dtype = torch_float32()) + ancestor <- torch_eye(2) + expect_silent(max_constraint_output(output, labels, ancestor)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) +}) + +test_that("get_constr_output and max_constraint_output compose correctly", { + output <- torch_tensor(matrix(c(1, 4, 2, 3), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + direct <- max_constraint_output(output, labels, ancestor) + constr_out <- get_constr_output(output, ancestor) + train_out <- get_constr_output(labels * output, ancestor) + manual <- torch_logical_not(labels) * constr_out + labels * train_out + expect_equal_to_r(direct, as_array(manual)) +}) + +test_that("get_constr_output handles negative values correctly", { + x <- torch_tensor(matrix(c(-5, -1, -3, -2), nrow = 2, ncol = 2, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- get_constr_output(x, R) + expected <- matrix(c(-1, 0, -2, 0), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles mixed positive-negative with constraints", { + output <- torch_tensor(matrix(c(-5, -3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, TRUE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expected <- matrix(c(-3, 0, 4, 4), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +# need rework as FromDataFrameNetwork(edges) gives "cannot find root name" error +test_that("build-ancestor-matrix diagonal is always 1 for every class", { + edges <- data.frame(from = c(1L, 2L, 2L), + to = c(2L, 3L, 4L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + R_dense <- R$to_dense() + + expect_equal_to_r(R_dense[1, 1], TRUE) + expect_equal_to_r(R_dense[2, 2], TRUE) + expect_equal_to_r(R_dense[3, 3], TRUE) +}) + +test_that("build-ancestor-matrix: single edge produces correct transitive pair", { + # 1 -> 2 means "2 is ancestor of 1", so transposed: R[2, 1] = 1 + edges <- data.frame(from = c(1L, 2L), to = c(2L, 3L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + R_dense <- R$to_dense() + + # 2 is descendant of 2 (self) + expect_equal_to_r(R_dense[2, 2], TRUE) + # 1 is descendant of 1 (self) + expect_equal_to_r(R_dense[1, 1], TRUE) + # 1 is descendant of 2 (because 1 -> 2) + expect_equal_to_r(R_dense[2, 1], TRUE) + # 2 is NOT descendant of 1 + expect_equal_to_r(R_dense[1, 2], FALSE) +}) + +test_that("build-ancestor-matrix: multi-hop ancestor chain is fully resolved", { + # Chain: 2 -> 3 -> 4 -> 5 (each is ancestor of the previous) + # After transpose: 4 is descendant of 1, 2, 3, 4 + # 3 is descendant of 1, 2, 3 + # 2 is descendant of 1, 2 + # 1 is descendant of 1 + edges <- data.frame( + from = c(1L, 2L, 3L, 4L), + to = c(2L, 3L, 4L, 5L) + ) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + R_dense <- R$to_dense() + + # Row 1: only node 1 is its own descendant + expect_equal_to_r(R_dense, lower.tri(diag(4), diag = TRUE)) +}) + +test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { + # Diamond: 1 -> 2 -> 4, 1 -> 3 -> 4 + # After transpose: 4 is descendant of all; 2 and 3 are descendants of + # 1 and themselves only + edges <- data.frame( + from = c(1L, 1L, 2L, 3L), + to = c(2L, 3L, 4L, 4L) + ) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + R_dense <- R$to_dense() + + expect_equal_to_r(R_dense[3, 1], TRUE) + expect_equal_to_r(R_dense[4, 2], TRUE) + expect_equal_to_r(R_dense[4, 3], FALSE) + expect_equal_to_r(R_dense[2, 4], FALSE) +}) + +# test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { +# edges <- data.frame(from = c(1L, 1L), +# to = c(2L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# R_dense <- R$to_dense() +# +# # Nodes 3, 4, 5 have no edges +# expect_equal_to_r(R_dense[3, 3], TRUE) +# expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) +# expect_equal_to_r(R_dense[4, 4], TRUE) +# expect_equal_to_r(R_dense[5, 5], TRUE) +# }) +# +# test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { +# edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# +# # n_classes should be max(1, 5, TRUE) = 5 +# expect_tensor_shape(R, c(5, 5)) +# }) + +# test_that("build-ancestor-matrix: output has correct shape and dtype", { +# edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# +# expect_tensor_shape(R, c(3L, 3L)) +# expect_tensor_dtype(R, torch::torch_bool()) +# expect_true(R$is_sparse()) +# }) +# +test_that("build-ancestor-matrix: output uses 0-based indices internally", { + # Verify that torch sees correct values when converted to dense + edges <- data.frame(from = c(1L, 2L, 1L), to = c(2L, 3L, 3L)) + R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) + expect_equal_to_r(R$to_dense(), matrix(c(TRUE, TRUE, FALSE, TRUE), nrow=2)) +}) + +# test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { +# edges <- data.frame(from = 1L, to = 1L) +# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) +# expect_tensor_shape(R$to_dense(), c(1L, 1L)) +# expect_equal_to_r(R$to_dense()[1, 1], TRUE) +# }) + +test_that("node_to_df works ", { + expect_no_error( + node_to_df(acme) + ) + expect_no_error( + attrition_df <- node_to_df(attrition_tree) + ) + # node_to_df removes first and last level of the hierarchy + outcome_levels <- paste0("level_", seq(2, attrition_tree$height - 1)) + expect_equal(names(attrition_df$y), outcome_levels) + + # node_to_df do not shuffle outcome rows + df <- tibble(pred_1 = seq(1,26), pred_2 = seq(26,1), + level_2 = factor(LETTERS[1:26]), level_3 = factor(letters[26:1])) + df_node_df <- df %>% + mutate(pathString = paste("synth", level_2, level_3, level_3, sep = "/")) %>% + select(-level_2, -level_3) %>% + as.Node() %>% + node_to_df() + + expect_equal(df_node_df$y %>% as_tibble(), df %>% select(starts_with("level_"))) + expect_equal(df_node_df$x %>% as_tibble(), df %>% select(starts_with("pred_"))) + +}) From 07a639cd7d9fabc1158d736b9c42ee823c5a46e8 Mon Sep 17 00:00:00 2001 From: "C. Regouby" Date: Mon, 4 May 2026 23:00:54 +0200 Subject: [PATCH 9/9] adapt build_ancestor_matrix test to the new pruning of leaf nodes --- tests/testthat/setup.R | 15 ++ tests/testthat/test-hierarchical_utils.R | 248 +++++++++++++++-------- 2 files changed, 173 insertions(+), 90 deletions(-) diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 74c72ae9..abb33507 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -47,5 +47,20 @@ attrition_tree <- attrition %>% select(-Department, -JobRole, -rowid) %>% data.tree::as.Node() +# --- Helper function to create test trees easily --- +create_test_tree <- function(structure) { + # structure: list of path strings or a nested list + # Simple parser for "Root/A/B" style paths + paths <- structure + root_name <- unique(sapply(strsplit(paths, "/"), `[`, 1)) + + tree <- Node$new(root_name) + for (p in paths) { + if (p == root_name) next + tree$AddChild(p) + } + return(tree) +} + # Run after all tests withr::defer(testthat::teardown_env()) diff --git a/tests/testthat/test-hierarchical_utils.R b/tests/testthat/test-hierarchical_utils.R index 905d94b4..de54009b 100644 --- a/tests/testthat/test-hierarchical_utils.R +++ b/tests/testthat/test-hierarchical_utils.R @@ -199,112 +199,180 @@ test_that("max_constraint_output handles mixed positive-negative with constraint expect_equal_to_r(result, expected) }) -# need rework as FromDataFrameNetwork(edges) gives "cannot find root name" error -test_that("build-ancestor-matrix diagonal is always 1 for every class", { - edges <- data.frame(from = c(1L, 2L, 2L), - to = c(2L, 3L, 4L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() +test_that("build_ancestor_matrix handles basic hierarchy", { + # Tree: Root -> A -> C + # Root -> B -> D + # Edges: R->A, A->C, R->B, B->D + # Pruning Logic: + # 1. Remove Root: Keeps A->C, B->D + # 2. Remove leaves (C, D are not in 'from'): Keeps A->C? No. C is not a parent. + # Keeps B->D? No. D is not a parent. + # Result: No edges match criteria. Empty matrix. - expect_equal_to_r(R_dense[1, 1], TRUE) - expect_equal_to_r(R_dense[2, 2], TRUE) - expect_equal_to_r(R_dense[3, 3], TRUE) + paths <- c("Root/A", "Root/A/C", "Root/B", "Root/B/D") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # Expectation: No internal nodes exist that are also children (excluding Root) + # A and B are children of Root, but their children (C, D) are leaves. + # Thus A and B are effectively leaves in the "internal structure". + expect_equal(nrow(result), 0) }) -test_that("build-ancestor-matrix: single edge produces correct transitive pair", { - # 1 -> 2 means "2 is ancestor of 1", so transposed: R[2, 1] = 1 - edges <- data.frame(from = c(1L, 2L), to = c(2L, 3L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() +test_that("build_ancestor_matrix handles linear chain of internal nodes", { + # Tree: Root -> A -> B -> C + # Edges: R->A, A->B, B->C + # Pruning Logic: + # 1. Remove Root: Keeps A->B, B->C + # 2. Keep only if target is a parent: + # - A->B: B is a parent (of C). Keep. + # - B->C: C is a leaf. Drop. + # Remaining Edges: A -> B + # Nodes: A(1), B(2) + # Matrix: A->A, A->B, B->B + + paths <- c("Root/A", "Root/A/B", "Root/A/B/C") + tree <- create_test_tree(paths) - # 2 is descendant of 2 (self) - expect_equal_to_r(R_dense[2, 2], TRUE) - # 1 is descendant of 1 (self) - expect_equal_to_r(R_dense[1, 1], TRUE) - # 1 is descendant of 2 (because 1 -> 2) - expect_equal_to_r(R_dense[2, 1], TRUE) - # 2 is NOT descendant of 1 - expect_equal_to_r(R_dense[1, 2], FALSE) + result <- build_ancestor_matrix(tree) + + expected <- matrix(c( + 1, 1, # A -> A + 1, 2, # A -> B + 2, 2 # B -> B + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) }) -test_that("build-ancestor-matrix: multi-hop ancestor chain is fully resolved", { - # Chain: 2 -> 3 -> 4 -> 5 (each is ancestor of the previous) - # After transpose: 4 is descendant of 1, 2, 3, 4 - # 3 is descendant of 1, 2, 3 - # 2 is descendant of 1, 2 - # 1 is descendant of 1 - edges <- data.frame( - from = c(1L, 2L, 3L, 4L), - to = c(2L, 3L, 4L, 5L) - ) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() +test_that("build_ancestor_matrix calculates transitive closure correctly", { + # Tree: Root -> A -> B -> C -> D + # Edges: R->A, A->B, B->C, C->D + # Pruning Logic: + # 1. Remove Root: A->B, B->C, C->D + # 2. Keep if target is parent: + # - A->B: B is parent (of C). Keep. + # - B->C: C is parent (of D). Keep. + # - C->D: D is leaf. Drop. + # Remaining Edges: A -> B, B -> C + # Nodes: A(1), B(2), C(3) - # Row 1: only node 1 is its own descendant - expect_equal_to_r(R_dense, lower.tri(diag(4), diag = TRUE)) + paths <- c("Root/A", "Root/A/B", "Root/A/B/C", "Root/A/B/C/D") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # Expected Relations: + # A -> A, A -> B, A -> C + # B -> B, B -> C + # C -> C + + # Sorted by column then row (default behavior of which(arr.ind=TRUE)) + expected <- matrix(c( + 1, 1, # A->A + 1, 2, # A->B + 2, 2, # B->B + 1, 3, # A->C (transitive) + 2, 3, # B->C + 3, 3 # C->C + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) }) -test_that("build-ancestor-matrix: diamond hierarchy merges both paths", { - # Diamond: 1 -> 2 -> 4, 1 -> 3 -> 4 - # After transpose: 4 is descendant of all; 2 and 3 are descendants of - # 1 and themselves only - edges <- data.frame( - from = c(1L, 1L, 2L, 3L), - to = c(2L, 3L, 4L, 4L) - ) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - R_dense <- R$to_dense() +test_that("build_ancestor_matrix handles branching internal nodes", { + # Tree: R -> A -> C + # R -> B -> C (Diamond shape, merging back to C) + # *Note: data.tree allows this structure (multiple parents)? + # Actually standard trees are single parent. Let's stick to standard tree. + + # Tree: R -> A -> C -> E + # R -> B -> D -> E + # Edges: R->A, A->C, C->E, R->B, B->D, D->E + # Pruning: + # 1. Remove R: A->C, C->E, B->D, D->E + # 2. Keep target if parent: + # - A->C (C is parent of E). Keep. + # - C->E (E is leaf). Drop. + # - B->D (D is parent of E). Keep. + # - D->E (E is leaf). Drop. + # Nodes: A, C, B, D + # Edges: A->C, B->D - expect_equal_to_r(R_dense[3, 1], TRUE) - expect_equal_to_r(R_dense[4, 2], TRUE) - expect_equal_to_r(R_dense[4, 3], FALSE) - expect_equal_to_r(R_dense[2, 4], FALSE) + paths <- c("Root/A", "Root/A/C", "Root/A/C/E", + "Root/B", "Root/B/D", "Root/B/D/E") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # We have two disconnected components in the adjacency matrix: (A,C) and (B,D) + # A(1), C(2), B(3), D(4) (Order depends on unique(c(edges$from, edges$to))) + # Edges order: A->C, B->D. + # Unique nodes: A, C, B, D. + + # Expected: Self loops + A->C, B->D + # (1,1), (1,2), (2,2), (3,3), (3,4), (4,4) + + expected <- matrix(c( + 1, 1, # A->A + 1, 2, # A->C + 2, 2, # C->C + 3, 3, # B->B + 3, 4, # B->D + 4, 4 # D->D + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) }) -# test_that("build-ancestor-matrix: isolated nodes have only a diagonal entry", { -# edges <- data.frame(from = c(1L, 1L), -# to = c(2L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# R_dense <- R$to_dense() -# -# # Nodes 3, 4, 5 have no edges -# expect_equal_to_r(R_dense[3, 3], TRUE) -# expect_equal_to_r(R_dense[3, ], c(FALSE, FALSE, TRUE, FALSE, FALSE)) -# expect_equal_to_r(R_dense[4, 4], TRUE) -# expect_equal_to_r(R_dense[5, 5], TRUE) -# }) -# -# test_that("build-ancestor-matrix: n_classes defaults to max node id when NULL", { -# edges <- data.frame(from = c(1L, 1L), to = c(5L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# -# # n_classes should be max(1, 5, TRUE) = 5 -# expect_tensor_shape(R, c(5, 5)) -# }) +test_that("build_ancestor_matrix returns empty for Root-only tree", { + tree <- Node$new("Root") + result <- build_ancestor_matrix(tree) + expect_equal(nrow(result), 0) +}) -# test_that("build-ancestor-matrix: output has correct shape and dtype", { -# edges <- data.frame(from = c(1L, 2L), to = c(2L, 1L)) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# -# expect_tensor_shape(R, c(3L, 3L)) -# expect_tensor_dtype(R, torch::torch_bool()) -# expect_true(R$is_sparse()) -# }) -# -test_that("build-ancestor-matrix: output uses 0-based indices internally", { - # Verify that torch sees correct values when converted to dense - edges <- data.frame(from = c(1L, 2L, 1L), to = c(2L, 3L, 3L)) - R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) - expect_equal_to_r(R$to_dense(), matrix(c(TRUE, TRUE, FALSE, TRUE), nrow=2)) +test_that("build_ancestor_matrix returns empty for Root + Leaf", { + # Tree: Root -> A + # Edges: R->A + # 1. Remove R (from!=Root): Result empty. + tree <- Node$new("Root") + tree$AddChild("A") + result <- build_ancestor_matrix(tree) + expect_equal(nrow(result), 0) }) -# test_that("build-ancestor-matrix: single-node graph produces identity-like matrix", { -# edges <- data.frame(from = 1L, to = 1L) -# R <- build_ancestor_matrix(FromDataFrameNetwork(mutate_all(edges, as.character))) -# expect_tensor_shape(R$to_dense(), c(1L, 1L)) -# expect_equal_to_r(R$to_dense()[1, 1], TRUE) -# }) +test_that("build_ancestor_matrix handles deep wide tree (Stress Test)", { + # Create a binary tree of depth 4 + tree <- Node$new("R") + add_children <- function(node, depth) { + if (depth == 0) return() + node$AddChild(paste0(node.name, "L")) + node$AddChild(paste0(node.name, "R")) + add_children(node$children[[1]], depth - 1) + add_children(node$children[[2]], depth - 1) + } + add_children(tree, 4) + + result <- build_ancestor_matrix(tree) + + # Validate structure without checking exact numbers (too complex for hardcoded) + # 1. Must be integer matrix + expect_type(result, "integer") + # 2. Must have 2 columns + expect_equal(ncol(result), 2) + # 3. First column (Ancestor) <= Second column (Descendant) implies topological sort check? + # Actually indices are arbitrary based on sorting, but relationships are directional. + # Just ensure no NA or Inf + expect_true(!any(is.na(result))) +}) +test_that("build_ancestor_matrix preserves integer type", { + paths <- c("Root/A", "Root/A/B") + tree <- create_test_tree(paths) + result <- build_ancestor_matrix(tree) + expect_type(result, "integer") +}) test_that("node_to_df works ", { expect_no_error( node_to_df(acme)