From fa49b452c6a4ab4ac46b1c0767187c2dbe445ddc Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 29 Jan 2026 17:51:44 +0000 Subject: [PATCH] update to air v0.8.1 --- R/initial_validation_split.R | 8 ++++++-- R/make_groups.R | 4 +++- R/mc.R | 8 ++++++-- R/tidy.R | 4 +++- tests/testthat/test-reshuffle_rset.R | 6 ++++-- tests/testthat/test-rolling_origin.R | 18 ++++++++++++------ 6 files changed, 34 insertions(+), 14 deletions(-) diff --git a/R/initial_validation_split.R b/R/initial_validation_split.R index 6c24579b..e912d667 100644 --- a/R/initial_validation_split.R +++ b/R/initial_validation_split.R @@ -107,7 +107,9 @@ initial_validation_split <- function( ) # include those so that they can be attached to the `rset` later in `validation_set()` - if (!is.null(strata)) names(strata) <- NULL + if (!is.null(strata)) { + names(strata) <- NULL + } val_att <- list( prop = prop, strata = strata, @@ -264,7 +266,9 @@ group_initial_validation_split <- function( ) # include those so that they can be attached to the `rset` later in `validation_set()` - if (!is.null(strata)) names(strata) <- NULL + if (!is.null(strata)) { + names(strata) <- NULL + } val_att <- list( group = group, prop = prop, diff --git a/R/make_groups.R b/R/make_groups.R index 8ebeaa7d..54e36ad0 100644 --- a/R/make_groups.R +++ b/R/make_groups.R @@ -274,7 +274,9 @@ balance_prop_helper <- function(prop, data_ind, v, replace) { # if we somehow got the smallest group every time. # If sampling without replacement, just reshuffle all the groups. n <- nrow(freq_table) - if (replace) n <- n * prop * sum(freq_table$count) / min(freq_table$count) + if (replace) { + n <- n * prop * sum(freq_table$count) / min(freq_table$count) + } n <- ceiling(n) purrr::map( diff --git a/R/mc.R b/R/mc.R index b8b2c223..7971be8c 100644 --- a/R/mc.R +++ b/R/mc.R @@ -83,7 +83,9 @@ mc_cv <- function( split_objs$splits <- map(split_objs$splits, rm_out) - if (!is.null(strata)) names(strata) <- NULL + if (!is.null(strata)) { + names(strata) <- NULL + } mc_att <- list( prop = prop, times = times, @@ -206,7 +208,9 @@ group_mc_cv <- function( ) # This is needed for printing checks; strata can't be missing for mc_cv - if (is.null(strata)) strata <- FALSE + if (is.null(strata)) { + strata <- FALSE + } ## We remove the holdout indices since it will save space and we can ## derive them later when they are needed. diff --git a/R/tidy.R b/R/tidy.R index aefa9f9b..fc1bffc9 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -59,7 +59,9 @@ #' @export tidy.rsplit <- function(x, unique_ind = TRUE, ...) { check_dots_empty() - if (unique_ind) x$in_id <- unique(x$in_id) + if (unique_ind) { + x$in_id <- unique(x$in_id) + } out <- tibble( Row = c(x$in_id, complement(x)), Data = rep( diff --git a/tests/testthat/test-reshuffle_rset.R b/tests/testthat/test-reshuffle_rset.R index b54e3008..282949d4 100644 --- a/tests/testthat/test-reshuffle_rset.R +++ b/tests/testthat/test-reshuffle_rset.R @@ -31,8 +31,9 @@ test_that("reshuffle_rset is working", { # Select any non-grouped function in rset_subclasses with a strata argument: supports_strata <- purrr::map_lgl( names(supported_subclasses), - \(.x) + \(.x) { any(names(formals(.x)) == "strata") && !any(names(formals(.x)) == "group") + } ) supports_strata <- names(supported_subclasses)[supports_strata] @@ -57,8 +58,9 @@ test_that("reshuffle_rset is working", { # Select any grouped function in rset_subclasses with a strata argument: grouped_strata <- purrr::map_lgl( names(supported_subclasses), - \(.x) + \(.x) { any(names(formals(.x)) == "strata") && any(names(formals(.x)) == "group") + } ) grouped_strata <- names(supported_subclasses)[grouped_strata] diff --git a/tests/testthat/test-rolling_origin.R b/tests/testthat/test-rolling_origin.R index 3e847787..6a1ca963 100644 --- a/tests/testthat/test-rolling_origin.R +++ b/tests/testthat/test-rolling_origin.R @@ -36,8 +36,10 @@ test_that("larger holdout", { ) expect_equal( rs2$splits[[i]]$out_id, - (i + attr(rs2, "initial")): # fmt: skip - (i + attr(rs2, "initial") + attr(rs2, "assess") - 1) # fmt: skip + # fmt: skip + (i + attr(rs2, "initial")): + (i + attr(rs2, "initial") + attr(rs2, "assess") - 1) + # fmt: skip ) } }) @@ -72,8 +74,10 @@ test_that("skipping", { for (i in 1:nrow(rs4)) { expect_equal( rs4$splits[[i]]$in_id, - (i + attr(rs4, "skip") * (i - 1)): # fmt: skip - (i + attr(rs4, "skip") * (i - 1) + attr(rs4, "initial") - 1) # fmt: skip + # fmt: skip + (i + attr(rs4, "skip") * (i - 1)): + (i + attr(rs4, "skip") * (i - 1) + attr(rs4, "initial") - 1) + # fmt: skip ) expect_equal( rs4$splits[[i]]$out_id, @@ -103,8 +107,10 @@ test_that("lag", { ) expect_equal( rs5$splits[[i]]$out_id, - (i + attr(rs5, "initial") - attr(rs5, "lag")): # fmt: skip - (i + attr(rs5, "initial") + attr(rs5, "assess") - 1) # fmt: skip + # fmt: skip + (i + attr(rs5, "initial") - attr(rs5, "lag")): + (i + attr(rs5, "initial") + attr(rs5, "assess") - 1) + # fmt: skip ) }