Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ include("sample.jl")
include("stepper.jl")
include("logdensityproblems.jl")
include("callbacks.jl")
include("gibbs.jl")

if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Expand Down
196 changes: 196 additions & 0 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@

"""
AbstractMCMC.condition(model, target_varnames, global_values)

Return a new model conditioned on all variables *not* in `target_varnames`,
using the values in `global_values`.

`target_varnames` is a vector of variable identifiers. `global_values` is an
associative collection mapping identifiers to their current raw values.

Downstream packages implement this per model type. There is no default.
"""
function condition end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure condition belongs here? AbstractMCMC's design (https://turinglang.org/AbstractMCMC.jl/stable/design/) is deliberately minimal and PPL-agnostic. condition is a PPL concept and AbstractPPL.condition already exists.


"""
Gibbs(varnames => sampler, ...)

A composable Gibbs sampler built on the AbstractMCMC interface.

Each pair maps a group of variable names to the component sampler responsible
for updating those variables. All model variables must be covered by exactly
one group.

Component samplers must implement:
- `AbstractMCMC.step(rng, conditioned_model, sampler[, state])`
- `AbstractMCMC.getparams(conditioned_model, state) → Vector{<:Real}`
- `AbstractMCMC.setparams!!(conditioned_model, state, params) → state`

# Example

```julia
sampler = Gibbs(@varname(μ) => NUTS(), @varname(σ) => MH())
chain = sample(model, sampler, 1000)
Comment on lines +32 to +33
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varname, NUTS, MH, and model are not defined in AbstractMCMC. A reader cannot run this example using only this package. Either rewrite it with AbstractMCMC primitives, or move it to a "Downstream usage" section that names Turing explicitly.

```
"""
struct Gibbs{N,S} <: AbstractSampler
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructor validation is too thin. The docstring says "All model variables must be covered by exactly one group," but the only check is length(varnames) == length(samplers). Overlapping groups, empty groups, and zero pairs all silently produce undefined behaviour. Please validate that the groups are disjoint and non-empty.

varnames::N
samplers::S

function Gibbs(varnames::N, samplers::S) where {N<:Tuple,S<:Tuple}
if length(varnames) != length(samplers)
throw(ArgumentError("Number of varname groups and samplers must match."))
end
return new{N,S}(varnames, samplers)
end
end

function Gibbs(pairs::Pair...)
varnames = map(p -> _normalise_varnames(first(p)), pairs)
samplers = map(last, pairs)
return Gibbs(varnames, samplers)
end

_normalise_varnames(x::AbstractVector) = x
_normalise_varnames(x) = [x]

"""
GibbsState(global_values, sub_states)

State for the `Gibbs` sampler.

- `global_values`: current raw values of all model variables (type is
model-specific: `VarNamedTuple` for Turing, `NamedTuple` for JuliaBUGS).
- `sub_states`: tuple of component-sampler states, one per `Gibbs.samplers` entry.
"""
struct GibbsState{G,S<:Tuple}
global_values::G
sub_states::S
end

function step(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::Gibbs;
initial_params=nothing,
kwargs...,
)
global_values = initial_params
sub_states = _gibbs_initial_steps(
rng, model, sampler.varnames, sampler.samplers, global_values; kwargs...
)
global_values = _collect_global_values(
model, sampler.varnames, sampler.samplers, sub_states
)
return _build_gibbs_transition(global_values), GibbsState(global_values, sub_states)
end

function step(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::Gibbs,
state::GibbsState;
kwargs...,
)
global_values, sub_states = _gibbs_sweep(
rng,
model,
sampler.varnames,
sampler.samplers,
state.sub_states,
state.global_values;
kwargs...,
)
return _build_gibbs_transition(global_values), GibbsState(global_values, sub_states)
end

function _gibbs_sweep(
rng, model, varname_groups, samplers, sub_states, global_values; kwargs...
)
new_sub_states = ()
for i in eachindex(varname_groups)
target_vars = varname_groups[i]
spl = samplers[i]
sub_state = sub_states[i]

cond_model = condition(model, target_vars, global_values)
synced_state = setparams!!(cond_model, sub_state, getparams(cond_model, sub_state))
_, new_sub_state = step(rng, cond_model, spl, synced_state; kwargs...)

new_params = getparams(cond_model, new_sub_state)
global_values = _update_global_values(
model, global_values, target_vars, cond_model, new_params
)
new_sub_states = (new_sub_states..., new_sub_state)
end
return global_values, new_sub_states
end

function _gibbs_initial_steps(
rng, model, varname_groups, samplers, global_values; kwargs...
)
sub_states = ()
for i in eachindex(varname_groups)
target_vars = varname_groups[i]
spl = samplers[i]

cond_model = condition(model, target_vars, global_values)
_, sub_state = step(rng, cond_model, spl; kwargs..., discard_sample=true)

if global_values === nothing
global_values = _init_global_values(model, target_vars, cond_model, sub_state)
else
new_params = getparams(cond_model, sub_state)
global_values = _update_global_values(
model, global_values, target_vars, cond_model, new_params
)
end
sub_states = (sub_states..., sub_state)
end
return sub_states
end

"""
AbstractMCMC._init_global_values(model, target_vars, cond_model, sub_state)

Initialise the global parameter store from the first component's initial state.
Called once only. Downstream packages implement this per model type.
"""
function _init_global_values end

"""
AbstractMCMC._update_global_values(model, global_values, target_vars, cond_model, new_params)

Return an updated `global_values` with the target variables set to the values
encoded in `new_params`. Downstream packages implement this per model type.
"""
function _update_global_values end

function _collect_global_values(model, varname_groups, samplers, sub_states)
global_values = nothing
for i in eachindex(varname_groups)
target_vars = varname_groups[i]
sub_state = sub_states[i]
cond_model = condition(model, target_vars, global_values)
if global_values === nothing
global_values = _init_global_values(model, target_vars, cond_model, sub_state)
else
global_values = _update_global_values(
model,
global_values,
target_vars,
cond_model,
getparams(cond_model, sub_state),
)
end
end
return global_values
end

"""
AbstractMCMC._build_gibbs_transition(global_values)

Build the transition object for each `step` call. Defaults to `global_values`
itself. Override to return a richer transition type.
"""
_build_gibbs_transition(global_values) = global_values
57 changes: 57 additions & 0 deletions test/gibbs_test.jl
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not included in test/runtests.jl, so you only know this PR "works" because of whatever you ran locally.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Test
using Random
using AbstractMCMC
using Turing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Project.toml's [extras] / test target not updated

using Statistics

@testset "AbstractMCMC Gibbs combinator" begin
@model function normal_model(x)
μ ~ Normal(0.0, 10.0)
σ ~ truncated(Normal(0.0, 5.0), 0.0, Inf)
for i in eachindex(x)
x[i] ~ Normal(μ, σ)
end
end

rng = MersenneTwister(42)
true_μ = 3.0
true_σ = 1.5
x_obs = true_μ .+ true_σ .* randn(rng, 50)
model = normal_model(x_obs)

@testset "Gibbs(μ=>MH, σ=>MH) recovers posterior" begin
spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH())
chain = sample(rng, model, spl, 1000; progress=false)
@test abs(mean(chain[:μ]) - mean(x_obs)) < 0.5
@test abs(mean(chain[:σ]) - true_σ) < 0.5
end

@testset "Gibbs is AbstractSampler" begin
spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH())
@test spl isa AbstractMCMC.AbstractSampler
end

@testset "GibbsState has correct structure" begin
spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH())
_, state = AbstractMCMC.step(rng, model, spl)
@test state isa AbstractMCMC.GibbsState
@test length(state.sub_states) == 2
end

@testset "Mismatched varnames/samplers raises ArgumentError" begin
@test_throws ArgumentError AbstractMCMC.Gibbs(
([@varname(μ), @varname(σ)],), (MH(), MH())
)
end

@testset "Existing Turing.Inference.Gibbs still works (no regression)" begin
chain = sample(
rng,
model,
Turing.Inference.Gibbs(@varname(μ) => MH(), @varname(σ) => MH()),
500;
progress=false,
)
@test abs(mean(chain[:μ]) - mean(x_obs)) < 0.5
end
end