diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 314f818209..27120b7f2a 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -320,7 +320,7 @@ function AbstractMCMC.step( samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) - vi, states = gibbs_initialstep_recursive( + vi, states, stats = gibbs_initialstep_recursive( rng, model, AbstractMCMC.step, @@ -330,7 +330,7 @@ function AbstractMCMC.step( initial_params=initial_params, kwargs..., ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats) return transition, GibbsState(vi, states) end @@ -346,7 +346,7 @@ function AbstractMCMC.step_warmup( samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) - vi, states = gibbs_initialstep_recursive( + vi, states, stats = gibbs_initialstep_recursive( rng, model, AbstractMCMC.step_warmup, @@ -356,7 +356,7 @@ function AbstractMCMC.step_warmup( initial_params=initial_params, kwargs..., ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats) return transition, GibbsState(vi, states) end @@ -375,13 +375,15 @@ function gibbs_initialstep_recursive( varname_vecs, samplers, vi, - states=(); + states=(), + stats=NamedTuple(), + sampler_number=1; initial_params, kwargs..., ) # End recursion if isempty(varname_vecs) && isempty(samplers) - return vi, states + return vi, states, stats end varnames, varname_vecs_tail... = varname_vecs @@ -391,7 +393,7 @@ function gibbs_initialstep_recursive( conditioned_model, context = make_conditional(model, varnames, vi) # Take initial step with the current sampler. - _, new_state = step_function( + transition, new_state = step_function( rng, conditioned_model, sampler; @@ -399,8 +401,13 @@ function gibbs_initialstep_recursive( # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. initial_params=initial_params, kwargs..., - discard_sample=true, + discard_sample=false, ) + ts_stats = NamedTuple( + Symbol("spl$(sampler_number)_$k") => v for (k, v) in pairs(transition.stats) + ) + stats = merge(stats, ts_stats) + new_vi_local = get_varinfo(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. @@ -416,7 +423,9 @@ function gibbs_initialstep_recursive( varname_vecs_tail, samplers_tail, vi, - states; + states, + stats, + sampler_number + 1; initial_params=initial_params, kwargs..., ) @@ -436,10 +445,10 @@ function AbstractMCMC.step( states = state.states @assert length(samplers) == length(state.states) - vi, states = gibbs_step_recursive( - rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... + vi, states, stats = gibbs_step_recursive( + rng, model, AbstractMCMC.step, varnames, samplers, states, (;), 1, vi; kwargs... ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats) return transition, GibbsState(vi, states) end @@ -457,10 +466,20 @@ function AbstractMCMC.step_warmup( states = state.states @assert length(samplers) == length(state.states) - vi, states = gibbs_step_recursive( - rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... + vi, states, stats = gibbs_step_recursive( + rng, + model, + AbstractMCMC.step_warmup, + varnames, + samplers, + states, + (;), + 1, + vi, + (); + kwargs..., ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats) return transition, GibbsState(vi, states) end @@ -586,13 +605,15 @@ function gibbs_step_recursive( varname_vecs, samplers, states, + stats, + sampler_number, global_vi, new_states=(); kwargs..., ) # End recursion. if isempty(varname_vecs) && isempty(samplers) && isempty(states) - return global_vi, new_states + return global_vi, new_states, stats end varnames, varname_vecs_tail... = varname_vecs @@ -619,9 +640,13 @@ function gibbs_step_recursive( # Note that we pass `discard_sample=true` after `kwargs...`, because AbstractMCMC will # tell Gibbs that _this Gibbs sample_ should be kept, and so `kwargs` will actually # contain `discard_sample=false`! - _, new_state = step_function( - rng, conditioned_model, sampler, state; kwargs..., discard_sample=true + transition, new_state = step_function( + rng, conditioned_model, sampler, state; kwargs..., discard_sample=false + ) + ts_stats = NamedTuple( + Symbol("spl$(sampler_number)_$k") => v for (k, v) in pairs(transition.stats) ) + stats = merge(stats, ts_stats) new_vi_local = get_varinfo(new_state) # Merge the latest values for all the variables in the current sampler. @@ -636,6 +661,8 @@ function gibbs_step_recursive( varname_vecs_tail, samplers_tail, states_tail, + stats, + sampler_number + 1, new_global_vi, new_states; kwargs...,