Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@
#' coerce the average acceptance rate to a target value using a dual-averaging
#' algorithm, and adapting the shape to an estimate of the covariance of the
#' target distribution.
#' @param show_progress_bar Whether to show progress bars during sampling.
#' Requires `progress` package to be installed to have an effect.
#' @param show_progress_bar Whether to show progress bars during sampling. If the
#' `progress` package is installed, displays a progress bar; otherwise prints
#' periodic progress messages to the console.
#' @param trace_warm_up Whether to record chain traces and adaptation /
#' transition statistics during (adaptive) warm-up iterations in addition to
#' (non-adaptive) main chain iterations.
Expand Down Expand Up @@ -103,8 +104,12 @@ sample_chain <- function(
show_progress_bar = TRUE,
trace_warm_up = FALSE
) {
progress_available <- requireNamespace("progress", quietly = TRUE)
use_progress_bar <- progress_available && show_progress_bar
progress_available <- is_progress_package_available()
if (show_progress_bar && !progress_available) {
message(
"progress package is not installed, so will print progress updates below."
)
}
initial_state <- check_and_process_initial_state(initial_state)
target_distribution <- check_and_process_target_distribution(
target_distribution
Expand All @@ -118,7 +123,8 @@ sample_chain <- function(
target_distribution = target_distribution,
proposal = proposal,
adapters = adapters,
use_progress_bar = use_progress_bar,
show_progress_bar = show_progress_bar,
progress_available = progress_available,
record_traces_and_statistics = trace_warm_up,
trace_function = trace_function,
statistic_names = statistic_names
Expand All @@ -130,7 +136,8 @@ sample_chain <- function(
target_distribution = target_distribution,
proposal = proposal,
adapters = NULL,
use_progress_bar = use_progress_bar,
show_progress_bar = show_progress_bar,
progress_available = progress_available,
record_traces_and_statistics = TRUE,
trace_function = trace_function,
statistic_names = statistic_names
Expand Down Expand Up @@ -184,11 +191,28 @@ default_trace_function <- function(target_distribution) {
}
}

get_progress_bar <- function(use_progress_bar, n_iteration, label) {
is_package_available <- function(pkg) requireNamespace(pkg, quietly = TRUE)

is_progress_package_available <- function() is_package_available("progress")

print_fallback_progress <- function(
stage_name, chain_iteration, n_iteration, start_time
) {
elapsed <- proc.time()[["elapsed"]] - start_time
pct <- round(100 * chain_iteration / n_iteration)
message(sprintf(
"%s: %d%% done (%d/%d iterations) | elapsed: %.1fs",
stage_name, pct, chain_iteration, n_iteration, elapsed
))
}

get_progress_bar <- function(
show_progress_bar, progress_available, n_iteration, label
) {
progress_bar_format <- (
"%s :percent |:bar| :current/:total [:elapsed<:eta] :tick_rate it/s"
)
if (use_progress_bar) {
if (show_progress_bar && progress_available) {
progress::progress_bar$new(
format = sprintf(progress_bar_format, label),
total = n_iteration,
Expand Down Expand Up @@ -237,19 +261,22 @@ finalize_adapters <- function(adapters, proposal) {
invisible(adapters)
}

chain_loop <- function(
chain_loop <- function( # nolint: cyclocomp_linter. styler: off
stage_name,
n_iteration,
state,
target_distribution,
proposal,
adapters,
use_progress_bar,
show_progress_bar,
progress_available,
record_traces_and_statistics,
trace_function,
statistic_names
) {
progress_bar <- get_progress_bar(use_progress_bar, n_iteration, stage_name)
progress_bar <- get_progress_bar(
show_progress_bar, progress_available, n_iteration, stage_name
)
# Only show 10% increments in progress bar to avoid progress bar updates being
# a bottleneck when chain iteration rate is high
tick_amount <- max(n_iteration %/% 10, 1)
Expand All @@ -261,6 +288,7 @@ chain_loop <- function(
traces <- NULL
statistics <- NULL
}
start_time <- proc.time()[["elapsed"]]
for (chain_iteration in seq_len(n_iteration)) {
state_and_statistics <- sample_metropolis_hastings(
state, target_distribution, proposal
Expand All @@ -274,13 +302,23 @@ chain_loop <- function(
c(state_and_statistics$statistics, adapter_states)
)
}
if (!is.null(progress_bar) && (chain_iteration %% tick_amount == 0)) {
do_tick <- chain_iteration %% tick_amount == 0
if (!is.null(progress_bar) && do_tick) {
progress_bar$tick(tick_amount)
} else if (show_progress_bar && do_tick) { # fallback progress updates
print_fallback_progress(
stage_name, chain_iteration, n_iteration, start_time
)
}
}
# Ensure progress bar shows completed in cases tick_amount not a factor of
# n_iteration
if (!is.null(progress_bar) && !progress_bar$finished) progress_bar$update(1)
progress_unfinished <- n_iteration > 0 && (n_iteration %% tick_amount != 0)
if (!is.null(progress_bar) && progress_unfinished) {
progress_bar$update(1)
} else if (show_progress_bar && progress_unfinished) {
print_fallback_progress(stage_name, n_iteration, n_iteration, start_time)
}
finalize_adapters(adapters, proposal)
list(final_state = state, traces = traces, statistics = statistics)
}
Expand Down
5 changes: 3 additions & 2 deletions man/sample_chain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 71 additions & 1 deletion tests/testthat/test-chains.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (n_warm_up_iteration in c(0, 1, 10)) {
for (n_main_iteration in c(0, 1, 10)) {
for (n_main_iteration in c(0, 1, 10, 21)) {
for (dimension in c(1, 2)) {
for (trace_warm_up in c(TRUE, FALSE)) {
for (show_progress_bar in c(TRUE, FALSE)) {
Expand Down Expand Up @@ -113,3 +113,73 @@ test_that("Sample chains with invalid target_distribution raises error", {
"target_distribution"
)
})

make_fallback_test_inputs <- function() {
target_distribution <- standard_normal_target_distribution()
adapters <- list(scale_adapter("stochastic_approximation", initial_scale = 1.))
withr::with_seed(default_seed(), {
position <- rnorm(2)
})
list(
target_distribution = target_distribution,
adapters = adapters,
position = position
)
}

test_that("Manual progress fallback prints messages when progress unavailable", {
inputs <- make_fallback_test_inputs()
n_warm_up_iteration <- 10
# use non-multiple of 10 to test finalisation of progress updates
n_main_iteration <- 21
# Simulate progress package being unavailable by mocking
with_mocked_bindings(
is_progress_package_available = function() FALSE,
.package = "rmcmc",
{
msgs <- capture_messages(
sample_chain(
target_distribution = inputs$target_distribution,
initial_state = inputs$position,
n_warm_up_iteration = n_warm_up_iteration,
n_main_iteration = n_main_iteration,
adapters = inputs$adapters,
show_progress_bar = TRUE
)
)
}
)
expected_n_progress_messages <- function(n_iteration) {
tick_amount <- max(n_iteration %/% 10, 1)
n_iteration %/% tick_amount + (n_iteration %% tick_amount != 0)
}
# 1 upfront warning + n_iteration dependent number of messages (warm-up + main)
expect_length(
msgs,
1 + expected_n_progress_messages(n_warm_up_iteration)
+ expected_n_progress_messages(n_main_iteration)
)
expect_true(any(grepl("progress package is not installed", msgs)))
expect_true(any(grepl("10%", msgs)))
expect_true(any(grepl("100%", msgs)))
})

test_that("No manual progress output when show_progress_bar is FALSE", {
inputs <- make_fallback_test_inputs()
with_mocked_bindings(
is_progress_package_available = function() FALSE,
.package = "rmcmc",
{
expect_no_message(
sample_chain(
target_distribution = inputs$target_distribution,
initial_state = inputs$position,
n_warm_up_iteration = 10,
n_main_iteration = 10,
adapters = inputs$adapters,
show_progress_bar = FALSE
)
)
}
)
})
Loading