Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
^CRAN-SUBMISSION$
^revdep$
^vignettes/*_files$
^\.claude$
^\.positai$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ tabnet_*.tar.gz
tabnet.Rproj
po/glossary.csv
inst/IMPORTLIST
.positai
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ Config/testthat/parallel: false
Config/testthat/start-first: interface, explain, params
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.3
Language: en-US
Config/roxygen2/version: 8.0.0
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,8 @@ importFrom(stats,predict)
importFrom(stats,update)
importFrom(tidyr,replace_na)
importFrom(torch,nn_prune_head)
importFrom(torch,torch_int64)
importFrom(torch,torch_ones)
importFrom(torch,torch_sparse_coo_tensor)
importFrom(tune,min_grid)
importFrom(zeallot,"%<-%")
95 changes: 82 additions & 13 deletions R/hardhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,12 @@ tabnet_fit.Node <- function(x, tabnet_model = NULL, config = tabnet_config(), ..
# get tree leaves and extract attributes into data.frames
xy_df <- node_to_df(x)
processed <- hardhat::mold(xy_df$x, xy_df$y)
# Given n classes, M is an (n x n) matrix where M_ij = 1 if class i is descendant of class j
ancestor <- data.tree::ToDataFrameNetwork(x) %>%
mutate_if(is.character, ~.x %>% as.factor %>% as.numeric)
# TODO check correctness
# embed the M matrix in the config$ancestor variable
dims <- c(max(ancestor), max(ancestor))
ancestor_m <- Matrix::sparseMatrix(ancestor$from, ancestor$to, dims = dims, x = 1)
check_type(processed$outcomes)

ancestor_tt <- build_ancestor_matrix(x)

config <- merge_config_and_dots(config, ...)
config$ancestor <- ancestor_tt
tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised")
}

Expand Down Expand Up @@ -272,7 +268,7 @@ tabnet_pretrain.default <- function(x, ...) {

#' @export
#' @rdname tabnet_pretrain
tabnet_pretrain.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) {
tabnet_pretrain.data.frame <- function(x, y = NULL, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) {
processed <- hardhat::mold(x, y)

config <- merge_config_and_dots(config, ...)
Expand Down Expand Up @@ -309,8 +305,7 @@ tabnet_pretrain.Node <- function(x, tabnet_model = NULL, config = tabnet_config(
check_compliant_node(x)
# get tree leaves and extract attributes into data.frames
xy_df <- node_to_df(x)
tabnet_pretrain(xy_df$x, xy_df$y, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch)

tabnet_pretrain(xy_df$x, tabnet_model = tabnet_model, config = config, ..., from_epoch = from_epoch)
}

new_tabnet_pretrain <- function(pretrain, blueprint) {
Expand Down Expand Up @@ -436,8 +431,8 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) {
type <- check_type(object$blueprint$ptypes$outcomes, type)
is_multi_outcome <- ncol(object$blueprint$ptypes$outcomes) > 1
outcome_nlevels <- NULL
if (is_multi_outcome & type != "numeric") {
outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~length(levels(.x)))
if (is_multi_outcome && type != "numeric") {
outcome_nlevels <- purrr::map_dbl(object$blueprint$ptypes$outcomes, ~nlevels(.x))
}

if (!is.null(epoch)) {
Expand Down Expand Up @@ -604,4 +599,78 @@ nn_prune_head.tabnet_pretrain <- function(x, head_size) {
nn_prune_head(x$fit$network, head_size=head_size)
}

}
}

#' Build a sparse ancestor-descendant matrix from a hierarchy edge list
#'
#' Given a directed graph where edges point from descendant to ancestor,
#' computes the full transitive closure via BFS, then transposes so that
#' the resulting sparse matrix R satisfies R\[i, j\] = 1 whenever class j
#' is a descendant of class i (including i itself). This is the
#' orientation expected by \code{get_constr_output} and the
#' max-constraint-margin (MCM) loss.
#'
#' @param x a Node object.
#' @return A \code{torch_sparse_coo_tensor} of shape
#' \code{(n_classes, n_classes)} and dtype \code{torch_double()}.
#' Entry \code{R[i, j] = 1} means class \code{j} is a descendant of
#' class \code{i}. Indices follow torch's 0-based convention.
#'
#'
#' @importFrom torch torch_ones torch_int64 torch_sparse_coo_tensor
#' @noRd
build_ancestor_matrix <- function(x) {
# 1. Extract edges
edges <- data.tree::ToDataFrameNetwork(x)
# 2. prune tree from root and from leafs
non_root_edges <- edges$from != x$path
non_leaf_targets <- edges$to %in% unique(edges$from)

edges <- edges[non_root_edges & non_leaf_targets, ]

# 3. Map node names to integer indices
all_nodes <- unique(c(edges$from, edges$to))
n <- length(all_nodes)
# Handle case where no edges match the filter
if (n == 0) {
return(matrix(nrow = 0, ncol = 2))
}

# Create a lookup map: name -> index
node_map <- setNames(seq_along(all_nodes), all_nodes)

# Conversion of edges to integer indices
from_idx <- node_map[edges$from]
to_idx <- node_map[edges$to]

# 4. Build Adjacency Matrix
# adj_mat[i, j] = 1 means i is a direct parent of j
adj_mat <- matrix(0L, nrow = n, ncol = n)
adj_mat[cbind(from_idx, to_idx)] <- 1L

# 5. Compute Transitive Closure (Ancestors)
# Initialize reachability matrix with self-loops (Identity) + direct connections
reachability <- adj_mat + diag(n)

# Use Boolean Matrix Multiplication to find all reachable nodes
# (i, j) = 1 if j is reachable from i (i is ancestor of j)
repeat {
# reachability %*% reachability finds paths of length 2*k
# Multiplying the matrix by itself effectively extends the reachable frontier
next_reachability <- (reachability %*% reachability) > 0

# Check for convergence
if (identical(next_reachability, reachability)) {
break
}

# Convert back to integer/numeric for next iteration
reachability <- next_reachability * 1L
}

# 6. Extract indices (COO format)
# which(arr.ind = TRUE) returns a matrix where col 1 is row (Ancestor) and col 2 is column (Descendant)
idx_mat <- which(reachability == 1L, arr.ind = TRUE)

torch::torch_sparse_coo_tensor(t(idx_mat), rep(TRUE, nrow(idx_mat)), c(n_classes, n_classes))
}
File renamed without changes.
4 changes: 2 additions & 2 deletions R/pretraining.R → R/model_pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss
}

if (config$verbose & !has_valid)
if (config$verbose && !has_valid)
message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train)))
if (config$verbose & has_valid)
if (config$verbose && has_valid)
message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid)))

# Early-stopping checks
Expand Down
32 changes: 18 additions & 14 deletions R/model.R → R/model_training.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ tabnet_config <- function(batch_size = 1024^2,
early_stopping_tolerance = 0,
early_stopping_patience = 0L,
num_workers=0L,
skip_importance = FALSE) {
skip_importance = FALSE
) {
if (is.null(decision_width) && is.null(attention_width)) {
decision_width <- 8 # default is 8
}
Expand Down Expand Up @@ -228,16 +229,16 @@ tabnet_config <- function(batch_size = 1024^2,

get_constr_output <- function(x, R) {
# MCM of the prediction given the hierarchy constraint expressed in the matrix R """
c_out <- x$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2]))
R_batch <- R$expand(c(x$shape[1], R$shape[2], R$shape[2]))
c_out <- x$to(dtype = torch::torch_double())$unsqueeze(2)$expand(c(x$shape[1], R$shape[2], R$shape[2]))
R_batch <- R$unsqueeze(1)$expand(c(x$shape[1], R$shape[2], R$shape[2]))
final_out <- torch::torch_max(R_batch * c_out, dim = 3)
final_out[[1]]
}

max_constraint_output <- function(output, labels, ancestor) {
constr_output <- get_constr_output(output, ancestor)
train_output <- get_constr_output(labels * output, ancestor)
labels$bitwise_not() * constr_output + labels * train_output
torch::torch_logical_not(labels) * constr_output + labels * train_output
}

resolve_loss <- function(config, dtype) {
Expand All @@ -249,7 +250,7 @@ resolve_loss <- function(config, dtype) {
loss_fn <- loss
else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long())
loss_fn <- torch::nn_mse_loss()
else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor_tt))
else if ((loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) || !is.null(config$ancestor))
# cross entropy loss is required
loss_fn <- torch::nn_cross_entropy_loss()
else
Expand All @@ -270,22 +271,22 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) {
}

train_batch <- function(network, optimizer, batch, config) {
# NULLing values to avoid a R-CMD Check Note "No visible binding for global variable"
# NULL-ing values to avoid a R-CMD Check Note "No visible binding for global variable"
out <- M_loss <- NULL
# forward pass
c(out, M_loss) %<-% network(batch$x, batch$x_na_mask)
# if target is multi-outcome, loss has to be applied to each label-group
if (max(batch$output_dim$shape) > 1) {
# multi-outcome
outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu"))
if (!is.null(config$ancestor_tt)) {
if (!is.null(config$ancestor)) {
# hierarchical mandates use of `max_constraint_output`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt))
~config$loss_fn(max_constraint_output(.x, .y, config$ancestor), .y$squeeze(2))
)),
dim = 1)
} else {
Expand Down Expand Up @@ -332,14 +333,14 @@ valid_batch <- function(network, batch, config) {
if (max(batch$output_dim$shape) > 1) {
# multi-outcome
outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu"))
if (!is.null(config$ancestor_tt)) {
if (!is.null(config$ancestor)) {
# hierarchical mandates use of `max_constraint_output`
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(out, outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor_tt))
~config$loss_fn(max_constraint_output(.x, .y$squeeze(2), config$ancestor))
)),
dim = 1)
} else {
Expand Down Expand Up @@ -513,7 +514,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s

# provide ancestor to torch tensor in case of hierarchical classification
if (!is.null(config$ancestor)) {
config$ancestor_tt <- torch::torch_tensor(config$ancestor)$to(torch::torch_bool(), device = device)
if (!config$ancestor$is_sparse()) {
# config is expected to carry the sparse tensor
runtime_error("ancestor was configured. Expecting a sparse tensor but got {.cls {class(config$ancestor)}}")
}
}

# instantiate optimizer
Expand Down Expand Up @@ -579,9 +583,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)$loss
}

if (config$verbose & !has_valid)
if (config$verbose && !has_valid)
message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train)))
if (config$verbose & has_valid)
if (config$verbose && has_valid)
message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train), mean(metrics[[epoch]]$valid)))


Expand Down Expand Up @@ -690,7 +694,7 @@ predict_impl_numeric <- function(obj, x, batch_size) {
predict_impl_numeric_multiple <- function(obj, x, batch_size) {
p <- as.matrix(predict_impl(obj, x, batch_size))
# TODO use a cleaner function to turn matrix into vectors
hardhat::spruce_numeric_multiple(!!!purrr::map(1:ncol(p), ~p[,.x]))
hardhat::spruce_numeric_multiple(!!!purrr::map(seq_len(ncol(p)), ~p[,.x]))
}

#' single-outcome level blueprint
Expand Down
2 changes: 1 addition & 1 deletion R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ multi_predict._tabnet_fit <- function(object, new_data, type = NULL, epochs = NU
pred <- predict(object$fit, new_data, type = type, epoch = epoch)
nms <- names(pred)
pred[["epochs"]] <- epoch
pred[[".row"]] <- 1:nrow(new_data)
pred[[".row"]] <- seq_len(nrow(new_data))
pred[, c(".row", "epochs", nms)]
})

Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ autoplot.tabnet_fit <- function(object, ...) {

if ("checkpoint" %in% names(collect_metrics)) {
checkpoints <- collect_metrics %>%
dplyr::filter(checkpoint == TRUE, dataset == "train") %>%
dplyr::filter(checkpoint, dataset == "train") %>%
dplyr::select(-checkpoint) %>%
dplyr::mutate(size = 2)
p +
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/tabnet_explain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/tabnet_nn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/tabnet_pretrain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion po/R-fr.po
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ msgid ""
" Please change those names as they will lead to unexpected "
"tabnet behavior."
msgstr ""
"Les attributs ou noms de colonnes dans l’objet hiérarchique fournit utilise "
"Les `attributs` (noms de colonne) dans l’objet hiérarchique fournit utilisent "
"les noms réservés suivants : {.vars {actual_names[actual_names %in% "
"reserved_names]}}. Veuillez changer ces noms pour éviter un comportement "
"imprévisible de TabNet."
Expand Down
9 changes: 8 additions & 1 deletion tests/testthat/helper-tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ expect_no_error <- function(object, ...) {

expect_tensor <- function(object) {
expect_true(torch:::is_torch_tensor(object))
expect_no_error(torch::as_array(object$to(device = "cpu")))
expect_no_error(torch::as_array(object$to_dense()$to(device = "cpu")))
}

expect_equal_to_r <- function(object, expected, ...) {
Expand All @@ -50,6 +50,13 @@ expect_tensor_shape <- function(object, expected) {
expect_equal(object$shape, expected)
}


expect_tensor_dtype <- function(object, expected_dtype) {
expect_tensor(object)
expect_true(object$dtype == expected_dtype)
}


expect_undefined_tensor <- function(object) {
# TODO
}
Expand Down
Loading
Loading