diff --git a/Project.toml b/Project.toml index 28bc2ef5..2e09bffd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.12.0" +version = "5.13.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index 60b31c59..d4d935a9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,7 +61,7 @@ function _bundle_samples( end """ - step(rng, model, sampler[, state; kwargs...]) + step(rng, model, sampler[, state]; kwargs...) Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. @@ -70,11 +70,23 @@ might include a vector of parameters sampled from a prior distribution. When sampling using [`sample`](@ref), every `step` call after the first has access to the current `state` of the sampler. + +## Keyword arguments + +If the step being taken is going to be discarded (e.g. during burn-in, or if thinning is +performed), this method will be called with a `discard_sample=true` keyword argument. +Conversely, if the step being taken is to be retained, this method will be called with +`discard_sample=false`. This allows implementations of `step` to customize their behavior +based on whether or not the sample will be kept. + +Other keyword arguments are passed through from the call to [`sample`](@ref). Because there +is no way of knowing in advance which keyword arguments will be passed, implementations of +`step` should include a `kwargs...` argument to capture any additional keyword arguments. """ function step end """ - step_warmup(rng, model, sampler[, state; kwargs...]) + step_warmup(rng, model, sampler[, state]; kwargs...) Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. @@ -83,11 +95,25 @@ When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.ste This is useful if the sampler has an initial "warmup"-stage that is different from the standard iteration. +By default, this defers to [`AbstractMCMC.step`](@ref), meaning that if a sampler does not +have special warmup behaviour, it only needs to implement `step`. + +## Keyword arguments + The total number of warmup steps requested in sampling will be passed to the `step_warmup` function as the `num_warmup` keyword argument. This allows implementations of `step_warmup` to customise their behavior based on this information. -By default, this simply calls [`AbstractMCMC.step`](@ref). +If the step being taken is going to be discarded (e.g. during burn-in, or if thinning is +performed), this method will be called with a `discard_sample=true` keyword argument. +Conversely, if the step being taken is to be retained, this method will be called with +`discard_sample=false`. This allows implementations of `step_warmup` to customize their +behavior based on whether or not the sample will be kept. + +Other keyword arguments are passed through from the call to [`sample`](@ref). Because there +is no way of knowing in advance which keyword arguments will be passed, implementations of +`step_warmup` should include a `kwargs...` argument to capture any additional keyword +arguments. """ step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...) function step_warmup(rng, model, sampler, state; kwargs...) diff --git a/src/sample.jl b/src/sample.jl index afc68bcd..d999adbe 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -119,6 +119,47 @@ function _filter_initial_params_kwarg(kwargs) return pairs((; (k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...)) end +# Dispatch to step or step_warmup based on the number of steps run so far. `nsteps` is the +# number of steps taken so far (prior to calling this function). +function _step_or_step_warmup(nsteps::Int, num_warmup::Int, args...; kwargs...) + sample, state = if nsteps <= num_warmup + step_warmup(args...; num_warmup=num_warmup, kwargs...) + else + step(args...; kwargs...) + end + return (nsteps + 1, sample, state) +end + +# Save and perform callback. Just like _step_or_step_warmup, this function is abstracted to +# ensure that callbacks and saving is always done together. `nsteps_kept` is the number of +# steps kept so far (prior to calling this function); i.e., the first time it's called, +# it's 0. +function _save!!_and_callback( + nsteps_kept::Integer, + samples, + sample, + state, + model, + sampler, + # `nothing` if no predefined number of samples + N::Union{Integer,Nothing}, + callback, + rng; + kwargs..., +) + nsteps_kept = nsteps_kept + 1 + if callback !== nothing + callback(rng, model, sampler, sample, state, nsteps_kept; kwargs...) + end + samples = if N === nothing + # no predefined number of samples -- e.g. in convergence sampling + save!!(samples, sample, nsteps_kept, model, sampler; kwargs...) + else + save!!(samples, sample, nsteps_kept, model, sampler, N; kwargs...) + end + return nsteps_kept, samples +end + # Default implementations of regular and parallel sampling. function mcmcsample( rng::Random.AbstractRNG, @@ -156,11 +197,6 @@ function mcmcsample( progress = NoLogging() end - # Determine how many samples to drop from `num_warmup` and the - # main sampling process before we start saving samples. - discard_from_warmup = min(num_warmup, discard_initial) - keep_from_warmup = num_warmup - discard_from_warmup - # Start the timer start = time() local state @@ -176,90 +212,134 @@ function mcmcsample( threshold = Ntotal / n_updates next_update = threshold + # Number of steps taken so far. This is incremented by every call to + # `_step_or_step_warmup`. + nsteps = 0 + # Number of steps that have been saved so far. + nsteps_kept = 0 + # Obtain the initial sample and state. - sample, state = if num_warmup > 0 - if initial_state === nothing - step_warmup(rng, model, sampler; num_warmup, kwargs...) - else - step_warmup(rng, model, sampler, initial_state; num_warmup, kwargs...) - end - else - if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end - end + initial_arg = initial_state === nothing ? () : (initial_state,) + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + initial_arg...; + # If discard_initial == 0 then this is the actual first sample that + # we will end up keeping + discard_sample=(discard_initial > 0), + kwargs..., + ) # Start the progress bar. - itotal = 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) + if nsteps >= next_update + update_progress!(progress, nsteps / Ntotal) next_update += threshold end - # Discard initial samples. + # Generate initial samples to be discarded. + # + # TODO(penelopeysm): This is actually not very pretty: what this does is + # generates the samples 2:(discard_initial + 1). If discard_initial is positive, + # then the last sample generated in this loop is the actual first one that is + # kept. It would be nice to refactor this, but it also gets a bit finicky with + # the initial_state keyword argument. + # This could be solved more nicely by having a separate function to generate + # the initial sample and state. see + # https://github.com/TuringLang/AbstractMCMC.jl/issues/135 for j in 1:discard_initial + discard_sample = j < discard_initial # Obtain the next sample and state. - sample, state = if j ≤ num_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + state; + discard_sample, + kwargs..., + ) # Update the progress bar. - itotal += 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) + if nsteps >= next_update + update_progress!(progress, nsteps / Ntotal) next_update += threshold end end - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, 1; kwargs...) - - # Save the sample. + # Of those (1 + discard_initial) samples, we're going to save one of them (the + # last one). + # This is the first time we're saving a sample: need to initialise the samples + # object. (Note that samples() returns an empty vector.) samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) - samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) + nsteps_kept, samples = _save!!_and_callback( + nsteps_kept, + samples, + sample, + state, + model, + sampler, + N, + callback, + rng; + kwargs..., + ) # Step through the sampler. - for i in 2:N + for _ in 2:N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + state; + discard_sample=true, + kwargs..., + ) # Update progress bar. - itotal += 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) + if nsteps >= next_update + update_progress!(progress, nsteps / Ntotal) next_update += threshold end end # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end - - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + state; + discard_sample=false, + kwargs..., + ) + + # Run callback and save + nsteps_kept, samples = _save!!_and_callback( + nsteps_kept, + samples, + sample, + state, + model, + sampler, + N, + callback, + rng; + kwargs..., + ) # Update the progress bar. - itotal += 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) + if nsteps >= next_update + update_progress!(progress, nsteps / Ntotal) next_update += threshold end end @@ -269,6 +349,9 @@ function mcmcsample( duration = stop - start stats = SamplingStats(start, stop, duration) + # Sanity check: ensure we have saved exactly N samples. + @assert nsteps_kept == N + return bundle_samples( samples, model, @@ -314,78 +397,107 @@ function mcmcsample( progress = NoLogging() end - # Determine how many samples to drop from `num_warmup` and the - # main sampling process before we start saving samples. - discard_from_warmup = min(num_warmup, discard_initial) - keep_from_warmup = num_warmup - discard_from_warmup - # Start the timer start = time() local state @maybewithricherlogger begin init_progress!(progress) + # Number of steps taken so far. This is incremented by every call to + # `_step_or_step_warmup`. + nsteps = 0 + # Number of steps that have been saved so far. + nsteps_kept = 0 + # Obtain the initial sample and state. - sample, state = if num_warmup > 0 - if initial_state === nothing - step_warmup(rng, model, sampler; num_warmup, kwargs...) - else - step_warmup(rng, model, sampler, initial_state; num_warmup, kwargs...) - end - else - if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end - end + initial_arg = initial_state === nothing ? () : (initial_state,) + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + initial_arg...; + # If discard_initial == 0 then this is the actual first sample that + # we will end up keeping + discard_sample=(discard_initial > 0), + kwargs..., + ) - # Discard initial samples. + # Discard initial samples. See method above for logic. for j in 1:discard_initial + discard_sample = j < discard_initial # Obtain the next sample and state. - sample, state = if j ≤ num_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + nsteps, sample, state = _step_or_step_warmup( + nsteps, num_warmup, rng, model, sampler, state; discard_sample, kwargs... + ) end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) - # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, 1, model, sampler; kwargs...) + nsteps_kept, samples = _save!!_and_callback( + nsteps_kept, + samples, + sample, + state, + model, + sampler, + nothing, + callback, + rng; + kwargs..., + ) # Step through the sampler until stopping. - i = 2 - while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) + while !isdone( + rng, + model, + sampler, + samples, + state, + (nsteps_kept + 1); + progress=progress, + kwargs..., + ) # Discard thinned samples. for _ in 1:(thinning - 1) - # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + state; + discard_sample=true, + kwargs..., + ) end # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; num_warmup, kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + nsteps, sample, state = _step_or_step_warmup( + nsteps, + num_warmup, + rng, + model, + sampler, + state; + discard_sample=false, + kwargs..., + ) # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) - - # Increment iteration counter. - i += 1 + nsteps_kept, samples = _save!!_and_callback( + nsteps_kept, + samples, + sample, + state, + model, + sampler, + nothing, + callback, + rng; + kwargs..., + ) end finish_progress!(progress) end diff --git a/test/sample.jl b/test/sample.jl index 0e15a456..e733035d 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -646,9 +646,11 @@ ) @test length(chain_warmup) == N @test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N) - # Check that the first `num_warmup - discard_initial` samples are warmup samples. + # Check that the first `num_warmup - discard_initial` steps (not including the + # initial step; hence the +1) are warmup samples. @test all( - chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N + chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial + 1) for + i in 1:N ) end @@ -670,6 +672,88 @@ @test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N) end + @testset "Interaction between thinning, warmup and discard" begin + struct M <: AbstractMCMC.AbstractModel end + struct S <: AbstractMCMC.AbstractSampler end + struct T + is_warmup::Bool + i::Int + end + function AbstractMCMC.step_warmup(rng, ::M, ::S; kwargs...) + return T(true, 1), 1 + end + function AbstractMCMC.step_warmup(rng, ::M, ::S, i::Int; kwargs...) + return T(true, i + 1), i + 1 + end + function AbstractMCMC.step(rng, ::M, ::S; kwargs...) + return T(false, 1), 1 + end + function AbstractMCMC.step(rng, ::M, ::S, i::Int; kwargs...) + return T(false, i + 1), i + 1 + end + + @testset "num_warmup + thinning" begin + N = 5 + num_warmup = 5 + thinning = 2 + chain = sample(M(), S(), N; thinning, num_warmup) + @test length(chain) == N + @test [chain[n].i for n in 1:N] == range(; start=num_warmup + 1, step=thinning, length=N) + # The first step is reached by warming up, but the others shouldn't + @test chain[1].is_warmup + @test all(chain[n].is_warmup == false for n in 2:N) + end + + @testset "num_warmup + discard_initial" begin + N = 5 + num_warmup = 5 + discard_initial = 2 + chain = sample(M(), S(), N; num_warmup, discard_initial) + @test length(chain) == N + @test [chain[n].i for n in 1:N] == range(; start=discard_initial + 1, step=1, length=N) + last_warmup_step = num_warmup - discard_initial + 1 + @test all([chain[n].is_warmup for n in 1:4]) + @test all([!chain[n].is_warmup for n in 5:5]) + end + + @testset "discard_initial + thinning" begin + N = 5 + thinning = 3 + discard_initial = 2 + chain = sample(M(), S(), N; discard_initial, thinning) + @test length(chain) == N + @test [chain[n].i for n in 1:N] == range(; start=discard_initial + 1, step=thinning, length=N) + @test all([!chain[n].is_warmup for n in 1:N]) + end + + @testset "(num_warmup < discard_initial) + thinning" begin + N = 5 + thinning = 3 + discard_initial = 10 + num_warmup = 5 + chain = sample(M(), S(), N; discard_initial, thinning, num_warmup) + @test length(chain) == N + @test [chain[n].i for n in 1:N] == range(; start=discard_initial + 1, step=thinning, length=N) + @test all([!chain[n].is_warmup for n in 1:N]) + end + + @testset "(num_warmup > discard_initial) + thinning" begin + N = 5 + thinning = 3 + discard_initial = 2 + num_warmup = 5 + chain = sample(M(), S(), N; discard_initial, thinning, num_warmup) + # Note that num_warmup=5 means that the sixth step should be + # obtained via step_warmup. Because we discard the first two + # steps, the steps in the returned chain are 3, 6, 9, 12, 15. + # That means that the first two returned steps are warmup steps. + @test length(chain) == N + @test [chain[n].i for n in 1:N] == range(; start=discard_initial + 1, step=thinning, length=N) + @test all([chain[n].is_warmup for n in 1:2]) + @test all([!chain[n].is_warmup for n in 3:N]) + end + end + @testset "Sample without predetermined N" begin Random.seed!(1234) chain = sample(MyModel(), MySampler()) @@ -734,6 +818,36 @@ end end + @testset "discard_sample keyword argument" begin + struct M2 <: AbstractMCMC.AbstractModel end + struct S2 <: AbstractMCMC.AbstractSampler end + # If any sample with discard_sample == true is returned, the test will fail. + function AbstractMCMC.step( + rng, ::M2, ::S2, state=nothing; discard_sample, kwargs... + ) + return discard_sample, nothing + end + function AbstractMCMC.step( + rng, ::M2, ::S2, state=nothing; discard_sample, kwargs... + ) + return discard_sample, nothing + end + N = 10 + for kwargs in [ + (; num_warmup=5), + (; discard_initial=5), + (; num_warmup=5, discard_initial=2), + (; num_warmup=5, discard_initial=10), + (; thinning=2), + ] + chain = sample(M2(), S2(), N; kwargs...) + @test all(!x for x in chain) + # test with thinning too + chain = sample(M2(), S2(), N; kwargs..., thinning=2) + @test all(!x for x in chain) + end + end + @testset "Sample vector of `NamedTuple`s" begin chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple}) # Check output type