diff --git a/.Rbuildignore b/.Rbuildignore index 4117da74..15030367 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -14,3 +14,5 @@ ^CRAN-SUBMISSION$ ^revdep$ ^vignettes/*_files$ +^\.claude$ +^\.positai$ diff --git a/.gitignore b/.gitignore index 4b3d4c83..985fdf11 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ tabnet_*.tar.gz tabnet.Rproj po/glossary.csv inst/IMPORTLIST +.positai diff --git a/DESCRIPTION b/DESCRIPTION index 95044dcc..93bc2b5a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,5 +68,5 @@ Config/testthat/parallel: false Config/testthat/start-first: interface, explain, params Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.3 Language: en-US +Config/roxygen2/version: 8.0.0 diff --git a/NAMESPACE b/NAMESPACE index 8f99e229..7cc1064a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -72,5 +72,8 @@ importFrom(stats,predict) importFrom(stats,update) importFrom(tidyr,replace_na) importFrom(torch,nn_prune_head) +importFrom(torch,torch_int64) +importFrom(torch,torch_ones) +importFrom(torch,torch_sparse_coo_tensor) importFrom(tune,min_grid) importFrom(zeallot,"%<-%") diff --git a/R/hardhat.R b/R/hardhat.R index 1d57a1b4..5cdaf4b9 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -162,16 +162,12 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), .. # get tree leaves and extract attributes into data.frames xy_df <- node_to_df(x) processed <- hardhat::mold(xy_df$x, xy_df$y) - # Given n classes, M is an (n x n) matrix where M_ij = 1 if class i is descendant of class j - ancestor <- data.tree::ToDataFrameNetwork(x) %>% - mutate_if(is.character, ~.x %>% as.factor %>% as.numeric) - # TODO check correctness - # embed the M matrix in the config$ancestor variable - dims <- c(max(ancestor), max(ancestor)) - ancestor_m <- Matrix::sparseMatrix(ancestor$from, ancestor$to, dims = dims, x = 1) check_type(processed$outcomes) + ancestor_tt <- build_ancestor_matrix(x) + config <- merge_config_and_dots(config, ...) + config$ancestor <- ancestor_tt tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised") } @@ -272,7 +268,7 @@ tabnet_pretrain.default <- function(x, ...) { #' @export #' @rdname tabnet_pretrain -tabnet_pretrain.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { +tabnet_pretrain.data.frame <- function(x, y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { processed <- hardhat::mold(x, y) config <- merge_config_and_dots(config, ...) @@ -309,8 +305,7 @@ tabnet_pretrain.Node <- function(x, tabnet_model = NULL, config = tabnet_config( check_compliant_node(x) # get tree leaves and extract attributes into data.frames xy_df <- node_to_df(x) - tabnet_pretrain(xy_df$x, xy_df$y, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch) - + tabnet_pretrain(xy_df$x, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch) } new_tabnet_pretrain <- function(pretrain, blueprint) { @@ -436,8 +431,8 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) { type <- check_type(object$blueprint$ptypes$outcomes, type) is_multi_outcome <- ncol(object$blueprint$ptypes$outcomes) > 1 outcome_nlevels <- NULL - if (is_multi_outcome & type != "numeric") { - outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~length(levels(.x))) + if (is_multi_outcome && type != "numeric") { + outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~nlevels(.x)) } if (!is.null(epoch)) { @@ -604,4 +599,78 @@ nn_prune_head.tabnet_pretrain <- function(x, head_size) { nn_prune_head(x$fit$network, head_size=head_size) } -} \ No newline at end of file +} + +#' Build a sparse ancestor-descendant matrix from a hierarchy edge list +#' +#' Given a directed graph where edges point from descendant to ancestor, +#' computes the full transitive closure via BFS, then transposes so that +#' the resulting sparse matrix R satisfies R\[i, j\] = 1 whenever class j +#' is a descendant of class i (including i itself). This is the +#' orientation expected by \code{get_constr_output} and the +#' max-constraint-margin (MCM) loss. +#' +#' @param x a Node object. +#' @return A \code{torch_sparse_coo_tensor} of shape +#' \code{(n_classes, n_classes)} and dtype \code{torch_double()}. +#' Entry \code{R[i, j] = 1} means class \code{j} is a descendant of +#' class \code{i}. Indices follow torch's 0-based convention. +#' +#' +#' @importFrom torch torch_ones torch_int64 torch_sparse_coo_tensor +#' @noRd +build_ancestor_matrix <- function(x) { + # 1. Extract edges + edges <- data.tree::ToDataFrameNetwork(x) + # 2. prune tree from root and from leafs + non_root_edges <- edges$from != x$path + non_leaf_targets <- edges$to %in% unique(edges$from) + + edges <- edges[non_root_edges & non_leaf_targets, ] + + # 3. Map node names to integer indices + all_nodes <- unique(c(edges$from, edges$to)) + n <- length(all_nodes) + # Handle case where no edges match the filter + if (n == 0) { + return(matrix(nrow = 0, ncol = 2)) + } + + # Create a lookup map: name -> index + node_map <- setNames(seq_along(all_nodes), all_nodes) + + # Conversion of edges to integer indices + from_idx <- node_map[edges$from] + to_idx <- node_map[edges$to] + + # 4. Build Adjacency Matrix + # adj_mat[i, j] = 1 means i is a direct parent of j + adj_mat <- matrix(0L, nrow = n, ncol = n) + adj_mat[cbind(from_idx, to_idx)] <- 1L + + # 5. Compute Transitive Closure (Ancestors) + # Initialize reachability matrix with self-loops (Identity) + direct connections + reachability <- adj_mat + diag(n) + + # Use Boolean Matrix Multiplication to find all reachable nodes + # (i, j) = 1 if j is reachable from i (i is ancestor of j) + repeat { + # reachability %*% reachability finds paths of length 2*k + # Multiplying the matrix by itself effectively extends the reachable frontier + next_reachability <- (reachability %*% reachability) > 0 + + # Check for convergence + if (identical(next_reachability, reachability)) { + break + } + + # Convert back to integer/numeric for next iteration + reachability <- next_reachability * 1L + } + + # 6. Extract indices (COO format) + # which(arr.ind = TRUE) returns a matrix where col 1 is row (Ancestor) and col 2 is column (Descendant) + idx_mat <- which(reachability == 1L, arr.ind = TRUE) + + torch::torch_sparse_coo_tensor(t(idx_mat), rep(TRUE, nrow(idx_mat)), c(n_classes, n_classes)) +} diff --git a/R/explain.R b/R/model_explain.R similarity index 100% rename from R/explain.R rename to R/model_explain.R diff --git a/R/pretraining.R b/R/model_pretraining.R similarity index 99% rename from R/pretraining.R rename to R/model_pretraining.R index 2ec315d6..0f48482e 100644 --- a/R/pretraining.R +++ b/R/model_pretraining.R @@ -178,9 +178,9 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift = metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss } - if (config$verbose & !has_valid) + if (config$verbose && !has_valid) message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train))) - if (config$verbose & has_valid) + if (config$verbose && has_valid) message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid))) # Early-stopping checks diff --git a/R/model.R b/R/model_training.R similarity index 96% rename from R/model.R rename to R/model_training.R index 30084761..f5fe82dd 100644 --- a/R/model.R +++ b/R/model_training.R @@ -175,7 +175,8 @@ tabnet_config <- function(batch_size = 1024^2, early_stopping_tolerance = 0, early_stopping_patience = 0L, num_workers=0L, - skip_importance = FALSE) { + skip_importance = FALSE + ) { if (is.null(decision_width) && is.null(attention_width)) { decision_width <- 8 # default is 8 } @@ -228,8 +229,8 @@ tabnet_config <- function(batch_size = 1024^2, get_constr_output <- function(x, R) { # MCM of the prediction given the hierarchy constraint expressed in the matrix R """ - c_out <- x$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2])) - R_batch <- R$expand(c(x$shape[1], R$shape[2], R$shape[2])) + c_out <- x$to(dtype = torch::torch_double())$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2])) + R_batch <- R$unsqueeze(1)$expand(c(x$shape[1], R$shape[2], R$shape[2])) final_out <- torch::torch_max(R_batch * c_out, dim = 3) final_out[[1]] } @@ -237,7 +238,7 @@ get_constr_output <- function(x, R) { max_constraint_output <- function(output, labels, ancestor) { constr_output <- get_constr_output(output, ancestor) train_output <- get_constr_output(labels * output, ancestor) - labels$bitwise_not() * constr_output + labels * train_output + torch::torch_logical_not(labels) * constr_output + labels * train_output } resolve_loss <- function(config, dtype) { @@ -249,7 +250,7 @@ resolve_loss <- function(config, dtype) { loss_fn <- loss else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long()) loss_fn <- torch::nn_mse_loss() - else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor_tt)) + else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor)) # cross entropy loss is required loss_fn <- torch::nn_cross_entropy_loss() else @@ -270,7 +271,7 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) { } train_batch <- function(network, optimizer, batch, config) { - # NULLing values to avoid a R-CMD Check Note "No visible binding for global variable" + # NULL-ing values to avoid a R-CMD Check Note "No visible binding for global variable" out <- M_loss <- NULL # forward pass c(out, M_loss) %<-% network(batch$x, batch$x_na_mask) @@ -278,14 +279,14 @@ train_batch <- function(network, optimizer, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$ancestor_tt)) { + if (!is.null(config$ancestor)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y, config$ancestor), .y$squeeze(2)) )), dim = 1) } else { @@ -332,14 +333,14 @@ valid_batch <- function(network, batch, config) { if (max(batch$output_dim$shape) > 1) { # multi-outcome outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) - if (!is.null(config$ancestor_tt)) { + if (!is.null(config$ancestor)) { # hierarchical mandates use of `max_constraint_output` loss <- torch::torch_sum(torch::torch_stack(purrr::pmap( list( torch::torch_split(out, outcome_nlevels, dim = 2), torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2) ), - ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt)) + ~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor)) )), dim = 1) } else { @@ -513,7 +514,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s # provide ancestor to torch tensor in case of hierarchical classification if (!is.null(config$ancestor)) { - config$ancestor_tt <- torch::torch_tensor(config$ancestor)$to(torch::torch_bool(), device = device) + if (!config$ancestor$is_sparse()) { + # config is expected to carry the sparse tensor + runtime_error("ancestor was configured. Expecting a sparse tensor but got {.cls {class(config$ancestor)}}") + } } # instantiate optimizer @@ -579,9 +583,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss } - if (config$verbose & !has_valid) + if (config$verbose && !has_valid) message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train))) - if (config$verbose & has_valid) + if (config$verbose && has_valid) message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid))) @@ -690,7 +694,7 @@ predict_impl_numeric <- function(obj, x, batch_size) { predict_impl_numeric_multiple <- function(obj, x, batch_size) { p <- as.matrix(predict_impl(obj, x, batch_size)) # TODO use a cleaner function to turn matrix into vectors - hardhat::spruce_numeric_multiple(!!!purrr::map(1:ncol(p), ~p[,.x])) + hardhat::spruce_numeric_multiple(!!!purrr::map(seq_len(ncol(p)), ~p[,.x])) } #' single-outcome level blueprint diff --git a/R/parsnip.R b/R/parsnip.R index 00627a59..33d16fd6 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -539,7 +539,7 @@ multi_predict._tabnet_fit <- function(object, new_data, type = NULL, epochs = NU pred <- predict(object$fit, new_data, type = type, epoch = epoch) nms <- names(pred) pred[["epochs"]] <- epoch - pred[[".row"]] <- 1:nrow(new_data) + pred[[".row"]] <- seq_len(nrow(new_data)) pred[, c(".row", "epochs", nms)] }) diff --git a/R/plot.R b/R/plot.R index f84ea638..4d3985d5 100644 --- a/R/plot.R +++ b/R/plot.R @@ -41,7 +41,7 @@ autoplot.tabnet_fit <- function(object, ...) { if ("checkpoint" %in% names(collect_metrics)) { checkpoints <- collect_metrics %>% - dplyr::filter(checkpoint == TRUE, dataset == "train") %>% + dplyr::filter(checkpoint, dataset == "train") %>% dplyr::select(-checkpoint) %>% dplyr::mutate(size = 2) p + diff --git a/R/tab-network.R b/R/tabnet_network.R similarity index 100% rename from R/tab-network.R rename to R/tabnet_network.R diff --git a/man/tabnet_config.Rd b/man/tabnet_config.Rd index d20ecd87..7295409d 100644 --- a/man/tabnet_config.Rd +++ b/man/tabnet_config.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/model.R +% Please edit documentation in R/model_training.R \name{tabnet_config} \alias{tabnet_config} \title{Configuration for TabNet models} diff --git a/man/tabnet_explain.Rd b/man/tabnet_explain.Rd index f750c5f5..1327c039 100644 --- a/man/tabnet_explain.Rd +++ b/man/tabnet_explain.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/explain.R +% Please edit documentation in R/model_explain.R \name{tabnet_explain} \alias{tabnet_explain} \alias{tabnet_explain.default} diff --git a/man/tabnet_nn.Rd b/man/tabnet_nn.Rd index ae5e8f76..3fa483a1 100644 --- a/man/tabnet_nn.Rd +++ b/man/tabnet_nn.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/tab-network.R +% Please edit documentation in R/tabnet_network.R \name{tabnet_nn} \alias{tabnet_nn} \title{TabNet Model Architecture} diff --git a/man/tabnet_pretrain.Rd b/man/tabnet_pretrain.Rd index 5c1e42e7..b3777ae5 100644 --- a/man/tabnet_pretrain.Rd +++ b/man/tabnet_pretrain.Rd @@ -15,7 +15,7 @@ tabnet_pretrain(x, ...) \method{tabnet_pretrain}{data.frame}( x, - y, + y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., diff --git a/po/R-fr.po b/po/R-fr.po index 14b4b8a1..efa406ad 100644 --- a/po/R-fr.po +++ b/po/R-fr.po @@ -225,7 +225,7 @@ msgid "" " Please change those names as they will lead to unexpected " "tabnet behavior." msgstr "" -"Les attributs ou noms de colonnes dans l’objet hiérarchique fournit utilise " +"Les `attributs` (noms de colonne) dans l’objet hiérarchique fournit utilisent " "les noms réservés suivants : {.vars {actual_names[actual_names %in% " "reserved_names]}}. Veuillez changer ces noms pour éviter un comportement " "imprévisible de TabNet." diff --git a/tests/testthat/_snaps/pretraining.md b/tests/testthat/_snaps/model_pretraining.md similarity index 100% rename from tests/testthat/_snaps/pretraining.md rename to tests/testthat/_snaps/model_pretraining.md diff --git a/tests/testthat/helper-tensor.R b/tests/testthat/helper-tensor.R index 31b5c9bd..534bff14 100644 --- a/tests/testthat/helper-tensor.R +++ b/tests/testthat/helper-tensor.R @@ -38,7 +38,7 @@ expect_no_error <- function(object, ...) { expect_tensor <- function(object) { expect_true(torch:::is_torch_tensor(object)) - expect_no_error(torch::as_array(object$to(device = "cpu"))) + expect_no_error(torch::as_array(object$to_dense()$to(device = "cpu"))) } expect_equal_to_r <- function(object, expected, ...) { @@ -50,6 +50,13 @@ expect_tensor_shape <- function(object, expected) { expect_equal(object$shape, expected) } + +expect_tensor_dtype <- function(object, expected_dtype) { + expect_tensor(object) + expect_true(object$dtype == expected_dtype) +} + + expect_undefined_tensor <- function(object) { # TODO } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4b78d736..abb33507 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -13,7 +13,7 @@ y <- ames[ids,]$Sale_Price # ames common models ames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1) -ames_pretrain_vsplit <- tabnet_pretrain(x, y, epochs = 3, valid_split=.2, +ames_pretrain_vsplit <- tabnet_pretrain(x, y, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) ames_fit <- tabnet_fit(x, y, epochs = 5 , checkpoint_epochs = 2) ames_fit_vsplit <- tabnet_fit(x, y, tabnet_model=ames_pretrain_vsplit, epochs = 3, @@ -38,6 +38,8 @@ attr_fitted_vsplit <- tabnet_fit(attrix, attriy, epochs = 12, valid_split=0.3) utils::data("acme", package = "data.tree") acme_df <- data.tree::ToDataFrameTypeCol(acme, acme$attributesAll) %>% select(-starts_with("level_")) +# acme2 <- acme$clone() +# acme2$RemoveAttribute("level_3") attrition_tree <- attrition %>% tibble::rowid_to_column() %>% @@ -45,5 +47,20 @@ attrition_tree <- attrition %>% select(-Department, -JobRole, -rowid) %>% data.tree::as.Node() +# --- Helper function to create test trees easily --- +create_test_tree <- function(structure) { + # structure: list of path strings or a nested list + # Simple parser for "Root/A/B" style paths + paths <- structure + root_name <- unique(sapply(strsplit(paths, "/"), `[`, 1)) + + tree <- Node$new(root_name) + for (p in paths) { + if (p == root_name) next + tree$AddChild(p) + } + return(tree) +} + # Run after all tests withr::defer(testthat::teardown_env()) diff --git a/tests/testthat/test-hardhat_hierarchical.R b/tests/testthat/test-hardhat_hierarchical.R index a6fa0983..1b3d9ef8 100644 --- a/tests/testthat/test-hardhat_hierarchical.R +++ b/tests/testthat/test-hardhat_hierarchical.R @@ -1,87 +1,4 @@ -test_that("C-HMCNN get_constr_output works ", { - x <- torch::torch_rand(c(2,4)) - R <- torch::torch_tril(torch::torch_zeros(c(4,4))$bernoulli(p = 0.2) + torch::torch_diag(rep(1,4)))$to(dtype = torch::torch_bool()) - expect_no_error( - constr_output <- get_constr_output(x, R) - ) - expect_tensor_shape( - constr_output, x$shape - ) - # expect_equal( - # constr_output$dtype, torch_tensor(0.1)$dtype - # ) - - R <- torch::torch_zeros(c(4,4))$to(dtype = torch::torch_bool()) - expect_equal_to_tensor( - get_constr_output(x, R), torch::torch_zeros_like(x) - ) -}) - -test_that("C-HMCNN max_constraint_output works ", { - output <- torch::torch_rand(c(3, 5)) - labels <- torch::torch_diag(rep(1,5))[1:3, ]$to(dtype = torch::torch_bool()) - ancestor <- torch::torch_tril(torch::torch_zeros(c(5, 5))$bernoulli(p = 0.2) )$to(dtype = torch::torch_bool()) - - expect_no_error( - MC_output <- max_constraint_output(output, labels, ancestor) - ) - expect_tensor_shape( - MC_output, output$shape - ) - # max_constraint_output is not identity - expect_not_equal_to_tensor( - MC_output, output - ) - # max_constraint_output provides more than 35% null values - expect_gte( - as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), .30 * output$shape[1] * output$shape[2] - ) -}) - -test_that("node_to_df works ", { - expect_no_error( - node_to_df(acme) - ) - expect_no_error( - attrition_df <- node_to_df(attrition_tree) - ) - # node_to_df removes first and last level of the hierarchy - outcome_levels <- paste0("level_", seq(2, attrition_tree$height - 1)) - expect_equal(names(attrition_df$y), outcome_levels) - - # node_to_df do not shuffle outcome rows - df <- tibble(pred_1 = seq(1,26), pred_2 = seq(26,1), - level_2 = factor(LETTERS[1:26]), level_3 = factor(letters[26:1])) - df_node_df <- df %>% - mutate(pathString = paste("synth", level_2, level_3, level_3, sep = "/")) %>% - select(-level_2, -level_3) %>% - as.Node() %>% - node_to_df() - - expect_equal(df_node_df$y %>% as_tibble(), df %>% select(starts_with("level_"))) - expect_equal(df_node_df$x %>% as_tibble(), df %>% select(starts_with("pred_"))) - -}) - - -test_that("Training hierarchical classification for {data.tree} Node", { - - expect_no_error( - fit <- tabnet_fit(acme, epochs = 1) - ) - expect_no_error( - result <- predict(fit, acme_df, type = "prob") - ) - - expect_equal(ncol(result), 3) - outcome_levels <-levels(fit$blueprint$ptypes$outcomes[[1]]) - # we get back outcomes vars with a `.pred_` prefix - expect_equal(stringr::str_remove(names(result), ".pred_"), outcome_levels) - expect_no_error( - result <- predict(fit, acme_df) - ) - expect_equal(ncol(result), 1) - +test_that("Training hierarchical classification for {data.tree} Node attrition_tree", { expect_no_error( fit <- tabnet_fit(attrition_tree, epochs = 1) ) @@ -91,7 +8,7 @@ test_that("Training hierarchical classification for {data.tree} Node", { expect_equal(ncol(result), 2) # 2 outcomes levels_ - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -106,14 +23,16 @@ test_that("Training hierarchical classification for {data.tree} Node with valida expect_no_error( fit <- tabnet_fit(attrition_tree, valid_split = 0.2, epochs = 1) ) - + expect_named(fit$fit$config, "ancestor") + expect_true(fit$fit$config$ancestor$is_sparse()) + expect_no_error( result <- predict(fit, attrition_tree, type = "prob") ) expect_equal(ncol(result), 2) # 2 outcomes levels_ - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) diff --git a/tests/testthat/test-hardhat_multi-outcome.R b/tests/testthat/test-hardhat_multi-outcome.R index fa3eafd2..d1a79d43 100644 --- a/tests/testthat/test-hardhat_multi-outcome.R +++ b/tests/testthat/test-hardhat_multi-outcome.R @@ -54,7 +54,7 @@ test_that("Training multilabel classification from data.frame", { ) expect_equal(ncol(result), 3) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -82,7 +82,7 @@ test_that("Training multilabel classification from formula", { ) expect_equal(ncol(result), 2) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) @@ -108,7 +108,7 @@ test_that("Training multilabel classification from recipe", { ) expect_equal(ncol(result), 2) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(outcome_nlevels)) }) @@ -126,7 +126,7 @@ test_that("Training multilabel classification from data.frame with validation sp expect_equal(ncol(result), 3) - outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) + outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~nlevels(.x)) # we get back outcomes vars with a `.pred_` prefix expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) diff --git a/tests/testthat/test-hierarchical_utils.R b/tests/testthat/test-hierarchical_utils.R new file mode 100644 index 00000000..de54009b --- /dev/null +++ b/tests/testthat/test-hierarchical_utils.R @@ -0,0 +1,399 @@ +test_that("get_constr_output handles basic 2D input with identity constraint", { + x <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2), dtype = torch_float32()) + R <- torch_eye(2, dtype = torch_float32()) + result <- get_constr_output(x, R) + expect_tensor(result) + expect_tensor_shape(result, c(2, 2)) + expect_equal_to_r(result, matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2)) +}) + +test_that("get_constr_output applies hierarchy constraint correctly", { + x <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_float64()) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 2)) + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) +}) + +test_that("get_constr_output preserves input dtype", { + x_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + x_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + R <- torch_eye(2) + expect_tensor_dtype(get_constr_output(x_f32, R), torch_float64()) + expect_tensor_dtype(get_constr_output(x_f64, R), torch_float64()) +}) + +test_that("get_constr_output handles batch dimension correctly", { + x <- torch_tensor(matrix(1:12, nrow = 3, ncol = 4)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(3, 4)) + + for (i in 1:3) { + row_result <- as_array(result[i, ]) + max_grp1 <- max(as_array(x[i, 1:2])) + max_grp2 <- max(as_array(x[i, 3:4])) + + expect_equal(row_result[1:2], rep(max_grp1, 2), tolerance = 1e-6) + expect_equal(row_result[3:4], rep(max_grp2, 2), tolerance = 1e-6) + } +}) +test_that("get_constr_output works with single sample", { + x <- torch_tensor(matrix(c(2, 1, 4, 3), nrow = 1, ncol = 4, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1), nrow = 4, ncol = 4, byrow = TRUE)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(1, 4)) + expected <- matrix(c(2, 2, 4, 4), nrow = 1, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("get_constr_output handles all-zeros constraint matrix", { + x <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + R <- torch_zeros(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + expect_equal_to_r(result, matrix(0, nrow = 2, ncol = 3)) +}) + +test_that("get_constr_output handles all-ones constraint matrix", { + x <- torch_tensor(matrix(c(1, 5, 3, 2, 4, 6), nrow = 2, ncol = 3, byrow = TRUE)) + R <- torch_ones(c(3, 3)) + result <- get_constr_output(x, R) + expect_tensor_shape(result, c(2, 3)) + # Each row is filled with its own row-wise maximum + expected <- matrix(c(5, 5, 5, 6, 6, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected, tolerance = 1e-6) +}) + +test_that("get_constr_output throws error for dimension mismatch", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_eye(3) + expect_error(get_constr_output(x, R), "must match the existing size") +}) + +test_that("get_constr_output throws error for non-2D R", { + x <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + R <- torch_tensor(array(1:8, dim = c(2, 2, 2))) + expect_error(get_constr_output(x, R), "dimension") +}) + +test_that("max_constraint_output returns original output when ancestor is identity", { + output <- torch_tensor(matrix(1:6, nrow = 2, ncol = 3)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 3), dtype = torch_bool()) + ancestor <- torch_eye(3) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # With an identity ancestor matrix, constraint propagation is neutral. + # The formula simplifies to: (~labels * output) + (labels * output) == output + expect_equal_to_r(result, matrix(1:6, nrow = 2, ncol = 3)) +}) + +test_that("max_constraint_output applies constraint to positive labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 1, 0), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # Unlabelled positions get propagated raw max, labelled get propagated masked max + expected <- matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles all-zero labels", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_zeros(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + result <- max_constraint_output(output, labels, ancestor) + # With all false labels, result equals constr_output. With identity ancestor, constr_output == output + expect_equal_to_r(result, matrix(1:4, nrow = 2, ncol = 2)) +}) + +test_that("max_constraint_output handles all-one labels", { + output <- torch_tensor(matrix(c(1, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) + # When all labels are TRUE, (~labels) is 0, so result = train_output. + expected <- matrix(c(5, 5, 3, 2), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output preserves output dtype", { + output_f32 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float32()) + output_f64 <- torch_tensor(matrix(1:4, nrow = 2), dtype = torch_float64()) + labels <- torch_ones(c(2, 2), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_tensor_dtype(max_constraint_output(output_f32, labels, ancestor), torch_float64()) + expect_tensor_dtype(max_constraint_output(output_f64, labels, ancestor), torch_float64()) +}) + + +test_that("max_constraint_output works with complex hierarchy", { + output <- torch_tensor(matrix(c(1, 2, 3, 4, 5, 6), nrow = 2, ncol = 3, byrow = TRUE)) + labels <- torch_tensor(matrix(c(1, 0, 0, 0, 1, 0), nrow = 2, ncol = 3, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_triu(torch_ones(c(3,3))) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 3)) + # Row 1: label on col 1 -> train_output[1,1]=1, others get constr_output=3 + # Row 2: label on col 2 -> train_output[2,2]=5, others get constr_output=6 + expected <- matrix(c(1, 3, 3, + 6, 5, 6), nrow = 2, ncol = 3, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles single element tensors", { + output <- torch_tensor(matrix(5, nrow = 1, ncol = 1)) + labels <- torch_tensor(matrix(TRUE, nrow = 1, ncol = 1), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(1, nrow = 1, ncol = 1)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(1, 1)) + # Compare against 1x1 matrix instead of scalar to match torch array output + expect_equal_to_r(result, matrix(5, nrow = 1, ncol = 1)) +}) + +test_that("max_constraint_output throws error for dimension mismatch", { + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 3), dtype = torch_bool()) + ancestor <- torch_eye(2) + expect_error(max_constraint_output(output, labels, ancestor), "dimension") +}) + +test_that("max_constraint_output handles float labels without error", { + # torch_logical_not works on float tensors (0.0 -> TRUE, others -> FALSE) + # No explicit type check exists in the function, so it should run successfully + output <- torch_tensor(matrix(1:4, nrow = 2, ncol = 2)) + labels <- torch_ones(c(2, 2), dtype = torch_float32()) + ancestor <- torch_eye(2) + expect_silent(max_constraint_output(output, labels, ancestor)) + result <- max_constraint_output(output, labels, ancestor) + expect_tensor_shape(result, c(2, 2)) +}) + +test_that("get_constr_output and max_constraint_output compose correctly", { + output <- torch_tensor(matrix(c(1, 4, 2, 3), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, FALSE, TRUE, FALSE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + direct <- max_constraint_output(output, labels, ancestor) + constr_out <- get_constr_output(output, ancestor) + train_out <- get_constr_output(labels * output, ancestor) + manual <- torch_logical_not(labels) * constr_out + labels * train_out + expect_equal_to_r(direct, as_array(manual)) +}) + +test_that("get_constr_output handles negative values correctly", { + x <- torch_tensor(matrix(c(-5, -1, -3, -2), nrow = 2, ncol = 2, byrow = TRUE)) + R <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- get_constr_output(x, R) + expected <- matrix(c(-1, 0, -2, 0), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("max_constraint_output handles mixed positive-negative with constraints", { + output <- torch_tensor(matrix(c(-5, -3, -1, 4), nrow = 2, ncol = 2, byrow = TRUE)) + labels <- torch_tensor(matrix(c(TRUE, TRUE, FALSE, TRUE), nrow = 2, ncol = 2, byrow = TRUE), dtype = torch_bool()) + ancestor <- torch_tensor(matrix(c(1, 1, 0, 1), nrow = 2, ncol = 2, byrow = TRUE)) + result <- max_constraint_output(output, labels, ancestor) + expected <- matrix(c(-3, 0, 4, 4), nrow = 2, ncol = 2, byrow = TRUE) + expect_equal_to_r(result, expected) +}) + +test_that("build_ancestor_matrix handles basic hierarchy", { + # Tree: Root -> A -> C + # Root -> B -> D + # Edges: R->A, A->C, R->B, B->D + # Pruning Logic: + # 1. Remove Root: Keeps A->C, B->D + # 2. Remove leaves (C, D are not in 'from'): Keeps A->C? No. C is not a parent. + # Keeps B->D? No. D is not a parent. + # Result: No edges match criteria. Empty matrix. + + paths <- c("Root/A", "Root/A/C", "Root/B", "Root/B/D") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # Expectation: No internal nodes exist that are also children (excluding Root) + # A and B are children of Root, but their children (C, D) are leaves. + # Thus A and B are effectively leaves in the "internal structure". + expect_equal(nrow(result), 0) +}) + +test_that("build_ancestor_matrix handles linear chain of internal nodes", { + # Tree: Root -> A -> B -> C + # Edges: R->A, A->B, B->C + # Pruning Logic: + # 1. Remove Root: Keeps A->B, B->C + # 2. Keep only if target is a parent: + # - A->B: B is a parent (of C). Keep. + # - B->C: C is a leaf. Drop. + # Remaining Edges: A -> B + # Nodes: A(1), B(2) + # Matrix: A->A, A->B, B->B + + paths <- c("Root/A", "Root/A/B", "Root/A/B/C") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + expected <- matrix(c( + 1, 1, # A -> A + 1, 2, # A -> B + 2, 2 # B -> B + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) +}) + +test_that("build_ancestor_matrix calculates transitive closure correctly", { + # Tree: Root -> A -> B -> C -> D + # Edges: R->A, A->B, B->C, C->D + # Pruning Logic: + # 1. Remove Root: A->B, B->C, C->D + # 2. Keep if target is parent: + # - A->B: B is parent (of C). Keep. + # - B->C: C is parent (of D). Keep. + # - C->D: D is leaf. Drop. + # Remaining Edges: A -> B, B -> C + # Nodes: A(1), B(2), C(3) + + paths <- c("Root/A", "Root/A/B", "Root/A/B/C", "Root/A/B/C/D") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # Expected Relations: + # A -> A, A -> B, A -> C + # B -> B, B -> C + # C -> C + + # Sorted by column then row (default behavior of which(arr.ind=TRUE)) + expected <- matrix(c( + 1, 1, # A->A + 1, 2, # A->B + 2, 2, # B->B + 1, 3, # A->C (transitive) + 2, 3, # B->C + 3, 3 # C->C + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) +}) + +test_that("build_ancestor_matrix handles branching internal nodes", { + # Tree: R -> A -> C + # R -> B -> C (Diamond shape, merging back to C) + # *Note: data.tree allows this structure (multiple parents)? + # Actually standard trees are single parent. Let's stick to standard tree. + + # Tree: R -> A -> C -> E + # R -> B -> D -> E + # Edges: R->A, A->C, C->E, R->B, B->D, D->E + # Pruning: + # 1. Remove R: A->C, C->E, B->D, D->E + # 2. Keep target if parent: + # - A->C (C is parent of E). Keep. + # - C->E (E is leaf). Drop. + # - B->D (D is parent of E). Keep. + # - D->E (E is leaf). Drop. + # Nodes: A, C, B, D + # Edges: A->C, B->D + + paths <- c("Root/A", "Root/A/C", "Root/A/C/E", + "Root/B", "Root/B/D", "Root/B/D/E") + tree <- create_test_tree(paths) + + result <- build_ancestor_matrix(tree) + + # We have two disconnected components in the adjacency matrix: (A,C) and (B,D) + # A(1), C(2), B(3), D(4) (Order depends on unique(c(edges$from, edges$to))) + # Edges order: A->C, B->D. + # Unique nodes: A, C, B, D. + + # Expected: Self loops + A->C, B->D + # (1,1), (1,2), (2,2), (3,3), (3,4), (4,4) + + expected <- matrix(c( + 1, 1, # A->A + 1, 2, # A->C + 2, 2, # C->C + 3, 3, # B->B + 3, 4, # B->D + 4, 4 # D->D + ), ncol = 2, byrow = TRUE) + + expect_equal(result, expected) +}) + +test_that("build_ancestor_matrix returns empty for Root-only tree", { + tree <- Node$new("Root") + result <- build_ancestor_matrix(tree) + expect_equal(nrow(result), 0) +}) + +test_that("build_ancestor_matrix returns empty for Root + Leaf", { + # Tree: Root -> A + # Edges: R->A + # 1. Remove R (from!=Root): Result empty. + tree <- Node$new("Root") + tree$AddChild("A") + result <- build_ancestor_matrix(tree) + expect_equal(nrow(result), 0) +}) + +test_that("build_ancestor_matrix handles deep wide tree (Stress Test)", { + # Create a binary tree of depth 4 + tree <- Node$new("R") + add_children <- function(node, depth) { + if (depth == 0) return() + node$AddChild(paste0(node.name, "L")) + node$AddChild(paste0(node.name, "R")) + add_children(node$children[[1]], depth - 1) + add_children(node$children[[2]], depth - 1) + } + add_children(tree, 4) + + result <- build_ancestor_matrix(tree) + + # Validate structure without checking exact numbers (too complex for hardcoded) + # 1. Must be integer matrix + expect_type(result, "integer") + # 2. Must have 2 columns + expect_equal(ncol(result), 2) + # 3. First column (Ancestor) <= Second column (Descendant) implies topological sort check? + # Actually indices are arbitrary based on sorting, but relationships are directional. + # Just ensure no NA or Inf + expect_true(!any(is.na(result))) +}) + +test_that("build_ancestor_matrix preserves integer type", { + paths <- c("Root/A", "Root/A/B") + tree <- create_test_tree(paths) + result <- build_ancestor_matrix(tree) + expect_type(result, "integer") +}) +test_that("node_to_df works ", { + expect_no_error( + node_to_df(acme) + ) + expect_no_error( + attrition_df <- node_to_df(attrition_tree) + ) + # node_to_df removes first and last level of the hierarchy + outcome_levels <- paste0("level_", seq(2, attrition_tree$height - 1)) + expect_equal(names(attrition_df$y), outcome_levels) + + # node_to_df do not shuffle outcome rows + df <- tibble(pred_1 = seq(1,26), pred_2 = seq(26,1), + level_2 = factor(LETTERS[1:26]), level_3 = factor(letters[26:1])) + df_node_df <- df %>% + mutate(pathString = paste("synth", level_2, level_3, level_3, sep = "/")) %>% + select(-level_2, -level_3) %>% + as.Node() %>% + node_to_df() + + expect_equal(df_node_df$y %>% as_tibble(), df %>% select(starts_with("level_"))) + expect_equal(df_node_df$x %>% as_tibble(), df %>% select(starts_with("pred_"))) + +}) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-model_explain.R similarity index 98% rename from tests/testthat/test-explain.R rename to tests/testthat/test-model_explain.R index 549afed4..db0ace3b 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-model_explain.R @@ -52,7 +52,7 @@ test_that("explain works for dataframe, formula and recipe", { # formula - tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=.2, + tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) expect_no_error( tabnet_explain(tabnet_pretrain, new_data=small_ames) @@ -69,7 +69,7 @@ test_that("explain works for dataframe, formula and recipe", { step_zv(all_predictors()) %>% step_normalize(all_numeric_predictors()) - tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=.2, + tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=0.2, num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) expect_no_error( tabnet_explain(tabnet_pretrain, new_data=small_ames) diff --git a/tests/testthat/test-pretraining.R b/tests/testthat/test-model_pretraining.R similarity index 100% rename from tests/testthat/test-pretraining.R rename to tests/testthat/test-model_pretraining.R diff --git a/tests/testthat/test-model.R b/tests/testthat/test-model_training.R similarity index 100% rename from tests/testthat/test-model.R rename to tests/testthat/test-model_training.R diff --git a/tests/testthat/test-parsnip.R b/tests/testthat/test-parsnip.R index 742281e8..f322bcbe 100644 --- a/tests/testthat/test-parsnip.R +++ b/tests/testthat/test-parsnip.R @@ -98,6 +98,7 @@ test_that("Check we can finalize a workflow from a tune_grid", { model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>% parsnip::set_mode("regression") %>% + parsnip::set_args(epochs = 2) %>% parsnip::set_engine("torch") wf <- workflows::workflow() %>% @@ -134,7 +135,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$epochs, rep(3, 2)) expect_equal(reg_grid_smol$penalty, 1:2) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(epochs = 1:2)) } @@ -155,7 +156,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_extra_smol$epochs, rep(3, 6)) expect_equal(reg_grid_extra_smol$penalty, rep(1:2, each = 3)) expect_equal(reg_grid_extra_smol$batch_size, rep(10:12, 2)) - for (i in 1:nrow(reg_grid_extra_smol)) { + for (i in seq_len(nrow(reg_grid_extra_smol))) { expect_equal(reg_grid_extra_smol$.submodels[[i]], list(epochs = 1:2)) } @@ -172,7 +173,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(no_sub_smol$epochs, rep(1, 2)) expect_equal(no_sub_smol$penalty, 1:2) - for (i in 1:nrow(no_sub_smol)) { + for (i in seq_len(nrow(no_sub_smol))) { expect_length(no_sub_smol$.submodels[[i]], 0) } @@ -184,7 +185,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$Amos, rep(3, 2)) expect_equal(reg_grid_smol$penalty, 1:2) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(Amos = 1:2)) } @@ -202,7 +203,7 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$`Ade Tukunbo`, rep(3, 4)) expect_equal(reg_grid_smol$penalty, rep(1:2, each = 2)) expect_equal(reg_grid_smol$` \t123`, rep(10:11, 2)) - for (i in 1:nrow(reg_grid_smol)) { + for (i in seq_len(nrow(reg_grid_smol))) { expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2)) } }) diff --git a/tests/testthat/test_translations.R b/tests/testthat/test-translations.R similarity index 100% rename from tests/testthat/test_translations.R rename to tests/testthat/test-translations.R diff --git a/vignettes/Hierarchical_classification.Rmd b/vignettes/Hierarchical_classification.Rmd index ab7bee93..e4f804d9 100644 --- a/vignettes/Hierarchical_classification.Rmd +++ b/vignettes/Hierarchical_classification.Rmd @@ -25,20 +25,32 @@ library(tibble) set.seed(202307) ``` -## Data preparation +## Data format The supported data format for hierarchical classification is the `Node` object format from package `{data.tree}`. This is a general purpose format that fits generic hierarchical tree encoding needs. Each node of the tree is associated with predictor values through the `attributes` in the data `Node` object. - - A very basic example is the `acme` dataset to show you how the two predictors values `cost` and `p` are associates attributes of each node in the hierarchy : +|{tabnet} concept| {data.tree} concept |see command| +|---|---|---| +|dataset predictor| Node `attributesAll` | acme example | +|dataset multi-label target| Node hierarchy | print(acme) | + + +A very basic example is the `acme` dataset to show you how the two predictors values `cost` and `p` are associates attributes of each node in the hierarchy : ```{r} data(acme, package = "data.tree") acme$attributesAll print(acme, "cost", "p" , limit = 8) ``` +So printing Node objects reverse the usual ordering, as target is printed first in column `levelName`, and predictors printed right of it. + +As you can see, only leaf nodes of the tree gets predictors value. {tabnet} will take this into account via an `ancestor` square sparse tensor registering all possible parent-child relation among the target labels. + + +## Data preparation -- Multiple manual or programmatic methods are available to create or update predictors. They are detailled in the `vignette("data.tree", package = "data.tree")`. +Multiple manual or programmatic methods are available to create or update predictors. They are detailled in the `vignette("data.tree", package = "data.tree")`. - a lot of native hierarchical data-format conversion from files to `Node` are covered by the`{data.tree}` package. You can find them in the "Create tree from a file" section of the same vignette. If needed, the `{ape}` package covers a lot of conversion format to the `philo` format. Thus you can reach the `Node` format in maybe two transformation steps... @@ -74,13 +86,13 @@ As `as.Node()` will only consider the as.numeric() values of a factor(), you sho Your dataset hierarchy will be turn internally into multi-outcomes named `level_1` to `level_n`, n beeing the depth of your tree. Thus column names starting with `level_` should be avoided. -### Ensure the last hierarchy of the tree is the observation id +### Ensure the last hierarchy of the tree is the **observation id** The tree only keeps a single row of attributes per tree leaf. Thus in order to transfer your complete predictors dataset into the Node object, you must keep the last level of the hierarchy to be a unique observation identifier (last resort beeing `rowid_to_column()` to achieve it). The classification will be done **removing the last level of hierarchy** in any case. -### Ensure there is a root level in the hierarchy +### Ensure there is a **root level** in the hierarchy The tree should have a single root for all nodes to be consistent. Thus you have to use a constant prefix to all `pathString`.