diff --git a/R/chains.R b/R/chains.R index 43d0f48..4c867e8 100644 --- a/R/chains.R +++ b/R/chains.R @@ -56,8 +56,9 @@ #' coerce the average acceptance rate to a target value using a dual-averaging #' algorithm, and adapting the shape to an estimate of the covariance of the #' target distribution. -#' @param show_progress_bar Whether to show progress bars during sampling. -#' Requires `progress` package to be installed to have an effect. +#' @param show_progress_bar Whether to show progress bars during sampling. If the +#' `progress` package is installed, displays a progress bar; otherwise prints +#' periodic progress messages to the console. #' @param trace_warm_up Whether to record chain traces and adaptation / #' transition statistics during (adaptive) warm-up iterations in addition to #' (non-adaptive) main chain iterations. @@ -103,8 +104,12 @@ sample_chain <- function( show_progress_bar = TRUE, trace_warm_up = FALSE ) { - progress_available <- requireNamespace("progress", quietly = TRUE) - use_progress_bar <- progress_available && show_progress_bar + progress_available <- is_progress_package_available() + if (show_progress_bar && !progress_available) { + message( + "progress package is not installed, so will print progress updates below." + ) + } initial_state <- check_and_process_initial_state(initial_state) target_distribution <- check_and_process_target_distribution( target_distribution @@ -118,7 +123,8 @@ sample_chain <- function( target_distribution = target_distribution, proposal = proposal, adapters = adapters, - use_progress_bar = use_progress_bar, + show_progress_bar = show_progress_bar, + progress_available = progress_available, record_traces_and_statistics = trace_warm_up, trace_function = trace_function, statistic_names = statistic_names @@ -130,7 +136,8 @@ sample_chain <- function( target_distribution = target_distribution, proposal = proposal, adapters = NULL, - use_progress_bar = use_progress_bar, + show_progress_bar = show_progress_bar, + progress_available = progress_available, record_traces_and_statistics = TRUE, trace_function = trace_function, statistic_names = statistic_names @@ -184,11 +191,28 @@ default_trace_function <- function(target_distribution) { } } -get_progress_bar <- function(use_progress_bar, n_iteration, label) { +is_package_available <- function(pkg) requireNamespace(pkg, quietly = TRUE) + +is_progress_package_available <- function() is_package_available("progress") + +print_fallback_progress <- function( + stage_name, chain_iteration, n_iteration, start_time +) { + elapsed <- proc.time()[["elapsed"]] - start_time + pct <- round(100 * chain_iteration / n_iteration) + message(sprintf( + "%s: %d%% done (%d/%d iterations) | elapsed: %.1fs", + stage_name, pct, chain_iteration, n_iteration, elapsed + )) +} + +get_progress_bar <- function( + show_progress_bar, progress_available, n_iteration, label +) { progress_bar_format <- ( "%s :percent |:bar| :current/:total [:elapsed<:eta] :tick_rate it/s" ) - if (use_progress_bar) { + if (show_progress_bar && progress_available) { progress::progress_bar$new( format = sprintf(progress_bar_format, label), total = n_iteration, @@ -237,19 +261,22 @@ finalize_adapters <- function(adapters, proposal) { invisible(adapters) } -chain_loop <- function( +chain_loop <- function( # nolint: cyclocomp_linter. styler: off stage_name, n_iteration, state, target_distribution, proposal, adapters, - use_progress_bar, + show_progress_bar, + progress_available, record_traces_and_statistics, trace_function, statistic_names ) { - progress_bar <- get_progress_bar(use_progress_bar, n_iteration, stage_name) + progress_bar <- get_progress_bar( + show_progress_bar, progress_available, n_iteration, stage_name + ) # Only show 10% increments in progress bar to avoid progress bar updates being # a bottleneck when chain iteration rate is high tick_amount <- max(n_iteration %/% 10, 1) @@ -261,6 +288,7 @@ chain_loop <- function( traces <- NULL statistics <- NULL } + start_time <- proc.time()[["elapsed"]] for (chain_iteration in seq_len(n_iteration)) { state_and_statistics <- sample_metropolis_hastings( state, target_distribution, proposal @@ -274,13 +302,23 @@ chain_loop <- function( c(state_and_statistics$statistics, adapter_states) ) } - if (!is.null(progress_bar) && (chain_iteration %% tick_amount == 0)) { + do_tick <- chain_iteration %% tick_amount == 0 + if (!is.null(progress_bar) && do_tick) { progress_bar$tick(tick_amount) + } else if (show_progress_bar && do_tick) { # fallback progress updates + print_fallback_progress( + stage_name, chain_iteration, n_iteration, start_time + ) } } # Ensure progress bar shows completed in cases tick_amount not a factor of # n_iteration - if (!is.null(progress_bar) && !progress_bar$finished) progress_bar$update(1) + progress_unfinished <- n_iteration > 0 && (n_iteration %% tick_amount != 0) + if (!is.null(progress_bar) && progress_unfinished) { + progress_bar$update(1) + } else if (show_progress_bar && progress_unfinished) { + print_fallback_progress(stage_name, n_iteration, n_iteration, start_time) + } finalize_adapters(adapters, proposal) list(final_state = state, traces = traces, statistics = statistics) } diff --git a/man/sample_chain.Rd b/man/sample_chain.Rd index 4e7c774..3a11f99 100644 --- a/man/sample_chain.Rd +++ b/man/sample_chain.Rd @@ -76,8 +76,9 @@ coerce the average acceptance rate to a target value using a dual-averaging algorithm, and adapting the shape to an estimate of the covariance of the target distribution.} -\item{show_progress_bar}{Whether to show progress bars during sampling. -Requires \code{progress} package to be installed to have an effect.} +\item{show_progress_bar}{Whether to show progress bars during sampling. If the +\code{progress} package is installed, displays a progress bar; otherwise prints +periodic progress messages to the console.} \item{trace_warm_up}{Whether to record chain traces and adaptation / transition statistics during (adaptive) warm-up iterations in addition to diff --git a/tests/testthat/test-chains.R b/tests/testthat/test-chains.R index 38c2809..94edc70 100644 --- a/tests/testthat/test-chains.R +++ b/tests/testthat/test-chains.R @@ -1,5 +1,5 @@ for (n_warm_up_iteration in c(0, 1, 10)) { - for (n_main_iteration in c(0, 1, 10)) { + for (n_main_iteration in c(0, 1, 10, 21)) { for (dimension in c(1, 2)) { for (trace_warm_up in c(TRUE, FALSE)) { for (show_progress_bar in c(TRUE, FALSE)) { @@ -113,3 +113,73 @@ test_that("Sample chains with invalid target_distribution raises error", { "target_distribution" ) }) + +make_fallback_test_inputs <- function() { + target_distribution <- standard_normal_target_distribution() + adapters <- list(scale_adapter("stochastic_approximation", initial_scale = 1.)) + withr::with_seed(default_seed(), { + position <- rnorm(2) + }) + list( + target_distribution = target_distribution, + adapters = adapters, + position = position + ) +} + +test_that("Manual progress fallback prints messages when progress unavailable", { + inputs <- make_fallback_test_inputs() + n_warm_up_iteration <- 10 + # use non-multiple of 10 to test finalisation of progress updates + n_main_iteration <- 21 + # Simulate progress package being unavailable by mocking + with_mocked_bindings( + is_progress_package_available = function() FALSE, + .package = "rmcmc", + { + msgs <- capture_messages( + sample_chain( + target_distribution = inputs$target_distribution, + initial_state = inputs$position, + n_warm_up_iteration = n_warm_up_iteration, + n_main_iteration = n_main_iteration, + adapters = inputs$adapters, + show_progress_bar = TRUE + ) + ) + } + ) + expected_n_progress_messages <- function(n_iteration) { + tick_amount <- max(n_iteration %/% 10, 1) + n_iteration %/% tick_amount + (n_iteration %% tick_amount != 0) + } + # 1 upfront warning + n_iteration dependent number of messages (warm-up + main) + expect_length( + msgs, + 1 + expected_n_progress_messages(n_warm_up_iteration) + + expected_n_progress_messages(n_main_iteration) + ) + expect_true(any(grepl("progress package is not installed", msgs))) + expect_true(any(grepl("10%", msgs))) + expect_true(any(grepl("100%", msgs))) +}) + +test_that("No manual progress output when show_progress_bar is FALSE", { + inputs <- make_fallback_test_inputs() + with_mocked_bindings( + is_progress_package_available = function() FALSE, + .package = "rmcmc", + { + expect_no_message( + sample_chain( + target_distribution = inputs$target_distribution, + initial_state = inputs$position, + n_warm_up_iteration = 10, + n_main_iteration = 10, + adapters = inputs$adapters, + show_progress_bar = FALSE + ) + ) + } + ) +})