Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ function AbstractMCMC.step(
samplers = spl.samplers
vi = initial_varinfo(rng, model, spl, initial_params)

vi, states = gibbs_initialstep_recursive(
vi, states, stats = gibbs_initialstep_recursive(
rng,
model,
AbstractMCMC.step,
Expand All @@ -330,7 +330,7 @@ function AbstractMCMC.step(
initial_params=initial_params,
kwargs...,
)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats)
return transition, GibbsState(vi, states)
end

Expand All @@ -346,7 +346,7 @@ function AbstractMCMC.step_warmup(
samplers = spl.samplers
vi = initial_varinfo(rng, model, spl, initial_params)

vi, states = gibbs_initialstep_recursive(
vi, states, stats = gibbs_initialstep_recursive(
rng,
model,
AbstractMCMC.step_warmup,
Expand All @@ -356,7 +356,7 @@ function AbstractMCMC.step_warmup(
initial_params=initial_params,
kwargs...,
)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats)
return transition, GibbsState(vi, states)
end

Expand All @@ -375,13 +375,15 @@ function gibbs_initialstep_recursive(
varname_vecs,
samplers,
vi,
states=();
states=(),
stats=NamedTuple(),
sampler_number=1;
initial_params,
kwargs...,
)
# End recursion
if isempty(varname_vecs) && isempty(samplers)
return vi, states
return vi, states, stats
end

varnames, varname_vecs_tail... = varname_vecs
Expand All @@ -391,16 +393,21 @@ function gibbs_initialstep_recursive(
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step with the current sampler.
_, new_state = step_function(
transition, new_state = step_function(
rng,
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params,
kwargs...,
discard_sample=true,
discard_sample=false,
)
ts_stats = NamedTuple(
Symbol("spl$(sampler_number)_$k") => v for (k, v) in pairs(transition.stats)
)
stats = merge(stats, ts_stats)

new_vi_local = get_varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
Expand All @@ -416,7 +423,9 @@ function gibbs_initialstep_recursive(
varname_vecs_tail,
samplers_tail,
vi,
states;
states,
stats,
sampler_number + 1;
initial_params=initial_params,
kwargs...,
)
Expand All @@ -436,10 +445,10 @@ function AbstractMCMC.step(
states = state.states
@assert length(samplers) == length(state.states)

vi, states = gibbs_step_recursive(
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
vi, states, stats = gibbs_step_recursive(
rng, model, AbstractMCMC.step, varnames, samplers, states, (;), 1, vi; kwargs...
)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats)
return transition, GibbsState(vi, states)
end

Expand All @@ -457,10 +466,20 @@ function AbstractMCMC.step_warmup(
states = state.states
@assert length(samplers) == length(state.states)

vi, states = gibbs_step_recursive(
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
vi, states, stats = gibbs_step_recursive(
rng,
model,
AbstractMCMC.step_warmup,
varnames,
samplers,
states,
(;),
1,
vi,
();
kwargs...,
)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model, stats)
return transition, GibbsState(vi, states)
end

Expand Down Expand Up @@ -586,13 +605,15 @@ function gibbs_step_recursive(
varname_vecs,
samplers,
states,
stats,
sampler_number,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
return global_vi, new_states
return global_vi, new_states, stats
end

varnames, varname_vecs_tail... = varname_vecs
Expand All @@ -619,9 +640,13 @@ function gibbs_step_recursive(
# Note that we pass `discard_sample=true` after `kwargs...`, because AbstractMCMC will
# tell Gibbs that _this Gibbs sample_ should be kept, and so `kwargs` will actually
# contain `discard_sample=false`!
_, new_state = step_function(
rng, conditioned_model, sampler, state; kwargs..., discard_sample=true
transition, new_state = step_function(
rng, conditioned_model, sampler, state; kwargs..., discard_sample=false
)
ts_stats = NamedTuple(
Symbol("spl$(sampler_number)_$k") => v for (k, v) in pairs(transition.stats)
)
stats = merge(stats, ts_stats)

new_vi_local = get_varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
Expand All @@ -636,6 +661,8 @@ function gibbs_step_recursive(
varname_vecs_tail,
samplers_tail,
states_tail,
stats,
sampler_number + 1,
new_global_vi,
new_states;
kwargs...,
Expand Down
Loading