Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@
^doc$
^Meta$
^.vscode$
^\.DS_Store$
^SomeFile\.diff$
^src/.*\.(o|so|a)$
^src/lib/.*\.(o|so|a)$
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ S3method(rpf,recipe)
S3method(str,rpf_forest)
export(is_purified)
export(predict_components)
export(preprocess_predictors_predict)
export(purify)
export(rpf)
import(checkmate)
Expand Down
7 changes: 6 additions & 1 deletion R/predict_components.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ predict_components <- function(object, new_data, max_interaction = NULL, predict
}

# Check if forest is purified, if not we do that now
if (!is_purified(object)) purify(object)
if (!is_purified(object)) {
# Purify using default policy: mode=2 (fast exact),
# maxp_interaction=0 (uncapped),
# nthreads defaults to min(training nthreads, available cores)
object$fit$purify_threads(0L, 0L, 2L)
}

# If max_interaction is greater than number of predictors requested we need to adjust that
max_interaction <- min(max_interaction, length(predictors))
Expand Down
23 changes: 13 additions & 10 deletions R/predict_rpf.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,27 @@ predict_rpf_prob <- function(object, new_data, ...) {
pred_prob <- 1 / (1 + exp(-pred_raw))
} else if (object$params$loss %in% c("L1", "L2")) {
# Truncate probabilities at [0,1] for L1/L2 loss
pred_prob <- apply(pred_raw, 2, function(col) pmax(0, pmin(1, col)))
pred_prob <- pmax(0, pmin(1, pred_raw))
}

# Binary classif yields n x 1 prediction matrix, append complementary class prob
# Ensure a plain numeric vector in binary case
pred_prob <- as.numeric(pred_prob)
# Binary classif yields two columns ordered by outcome levels
pred_prob <- cbind(1 - pred_prob, pred_prob)

} else { # Multiclass

if (object$params$loss %in% c("logit", "exponential")) {
# FIXME:
# softmax() defined in utils.R, should be identical to logit^-1 for
# binary case but not properly tested yet
# softmax for multi-class
pred_prob <- softmax(pred_raw)
} else if (object$params$loss %in% c("L1", "L2")) {
# Truncate probabilities at [0,1] for L1/L2 loss
pred_prob <- apply(pred_raw, 2, function(col) pmax(0, pmin(1, col)))
# Normalise such that sum of class probs is always 1
pred_prob <- pred_prob/rowSums(pred_prob)
# Clamp to [0,1] and renormalize rows
pred_prob <- pmin(1, pmax(0, pred_raw))
# pmin/pmax drop dimensions; restore matrix shape explicitly
dim(pred_prob) <- dim(pred_raw)
rs <- rowSums(pred_prob)
rs[!is.finite(rs) | rs <= 0] <- 1
pred_prob <- pred_prob / rs
}
}

Expand All @@ -140,7 +143,7 @@ predict_rpf_class <- function(object, new_data, ...) {
pred_prob <- predict_rpf_prob(object, new_data, 0, ...)

# For each instance, class with higher probability
pred_class <- factor(outcome_levels[max.col(pred_prob)], levels = outcome_levels)
pred_class <- factor(outcome_levels[max.col(as.matrix(pred_prob))], levels = outcome_levels)
out <- hardhat::spruce_class(pred_class)

out
Expand Down
25 changes: 22 additions & 3 deletions R/purify.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Purify a Random Planted Forest
#'
#' TODO: Explain what this does
#' Purifies an rpf object.
#'
#' Unless [`rpf()`] is called with `purify = TRUE`, the forest has to be purified after fit
#' to ensure the components extracted by [`predict_components()`] are valid.
Expand Down Expand Up @@ -28,11 +28,28 @@ purify.default <- function(x, ...) {
)
}

#' @param maxp_interaction integer or NULL: Only compute/store purified components
#' up to this interaction order. Higher-order purified trees are zeroed (not
#' computed) but still implicitly influence lower orders during purification.
#' If NULL, purify all orders (default behavior).
#' @param mode integer(1): Purification algorithm mode. 1 = legacy grid path
#' used by `fit$fit$purify()`; 2 = fast exact KD-tree based path. Defaults to 2.
#' @param nthreads integer or NULL: number of threads to use. If NULL, defaults
#' to min of the object's configured `nthreads` and available threads.
#' @export
#' @rdname purify
#' @importFrom utils capture.output
purify.rpf <- function(x, ...) {
x$fit$purify()
purify.rpf <- function(x, ..., maxp_interaction = NULL, mode = 2L, nthreads = NULL) {
checkmate::assert_class(x, "rpf")
checkmate::assert_int(mode, lower = 1, upper = 2)
if (!is.null(nthreads)) checkmate::assert_int(nthreads, lower = 1)
if (is.null(maxp_interaction)) {
# Default: exact cut points, full interaction order
x$fit$purify_threads(0L, as.integer(if (is.null(nthreads)) 0L else nthreads), as.integer(mode))
} else {
checkmate::assert_int(maxp_interaction, lower = 1)
x$fit$purify_threads(as.integer(maxp_interaction), as.integer(if (is.null(nthreads)) 0L else nthreads), as.integer(mode))
}
x
}

Expand All @@ -43,3 +60,5 @@ is_purified <- function(x) {
checkmate::assert_class(x, "rpf")
x$fit$is_purified()
}


Loading
Loading