-
Notifications
You must be signed in to change notification settings - Fork 20
Add AbstractMCMC-based Gibbs combinator #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,196 @@ | ||
|
|
||
| """ | ||
| AbstractMCMC.condition(model, target_varnames, global_values) | ||
|
|
||
| Return a new model conditioned on all variables *not* in `target_varnames`, | ||
| using the values in `global_values`. | ||
|
|
||
| `target_varnames` is a vector of variable identifiers. `global_values` is an | ||
| associative collection mapping identifiers to their current raw values. | ||
|
|
||
| Downstream packages implement this per model type. There is no default. | ||
| """ | ||
| function condition end | ||
|
|
||
| """ | ||
| Gibbs(varnames => sampler, ...) | ||
|
|
||
| A composable Gibbs sampler built on the AbstractMCMC interface. | ||
|
|
||
| Each pair maps a group of variable names to the component sampler responsible | ||
| for updating those variables. All model variables must be covered by exactly | ||
| one group. | ||
|
|
||
| Component samplers must implement: | ||
| - `AbstractMCMC.step(rng, conditioned_model, sampler[, state])` | ||
| - `AbstractMCMC.getparams(conditioned_model, state) → Vector{<:Real}` | ||
| - `AbstractMCMC.setparams!!(conditioned_model, state, params) → state` | ||
|
|
||
| # Example | ||
|
|
||
| ```julia | ||
| sampler = Gibbs(@varname(μ) => NUTS(), @varname(σ) => MH()) | ||
| chain = sample(model, sampler, 1000) | ||
|
Comment on lines
+32
to
+33
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ``` | ||
| """ | ||
| struct Gibbs{N,S} <: AbstractSampler | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Constructor validation is too thin. The docstring says "All model variables must be covered by exactly one group," but the only check is |
||
| varnames::N | ||
| samplers::S | ||
|
|
||
| function Gibbs(varnames::N, samplers::S) where {N<:Tuple,S<:Tuple} | ||
| if length(varnames) != length(samplers) | ||
| throw(ArgumentError("Number of varname groups and samplers must match.")) | ||
| end | ||
| return new{N,S}(varnames, samplers) | ||
| end | ||
| end | ||
|
|
||
| function Gibbs(pairs::Pair...) | ||
| varnames = map(p -> _normalise_varnames(first(p)), pairs) | ||
| samplers = map(last, pairs) | ||
| return Gibbs(varnames, samplers) | ||
| end | ||
|
|
||
| _normalise_varnames(x::AbstractVector) = x | ||
| _normalise_varnames(x) = [x] | ||
|
|
||
| """ | ||
| GibbsState(global_values, sub_states) | ||
|
|
||
| State for the `Gibbs` sampler. | ||
|
|
||
| - `global_values`: current raw values of all model variables (type is | ||
| model-specific: `VarNamedTuple` for Turing, `NamedTuple` for JuliaBUGS). | ||
| - `sub_states`: tuple of component-sampler states, one per `Gibbs.samplers` entry. | ||
| """ | ||
| struct GibbsState{G,S<:Tuple} | ||
| global_values::G | ||
| sub_states::S | ||
| end | ||
|
|
||
| function step( | ||
| rng::Random.AbstractRNG, | ||
| model::AbstractModel, | ||
| sampler::Gibbs; | ||
| initial_params=nothing, | ||
| kwargs..., | ||
| ) | ||
| global_values = initial_params | ||
| sub_states = _gibbs_initial_steps( | ||
| rng, model, sampler.varnames, sampler.samplers, global_values; kwargs... | ||
| ) | ||
| global_values = _collect_global_values( | ||
| model, sampler.varnames, sampler.samplers, sub_states | ||
| ) | ||
| return _build_gibbs_transition(global_values), GibbsState(global_values, sub_states) | ||
| end | ||
|
|
||
| function step( | ||
| rng::Random.AbstractRNG, | ||
| model::AbstractModel, | ||
| sampler::Gibbs, | ||
| state::GibbsState; | ||
| kwargs..., | ||
| ) | ||
| global_values, sub_states = _gibbs_sweep( | ||
| rng, | ||
| model, | ||
| sampler.varnames, | ||
| sampler.samplers, | ||
| state.sub_states, | ||
| state.global_values; | ||
| kwargs..., | ||
| ) | ||
| return _build_gibbs_transition(global_values), GibbsState(global_values, sub_states) | ||
| end | ||
|
|
||
| function _gibbs_sweep( | ||
| rng, model, varname_groups, samplers, sub_states, global_values; kwargs... | ||
| ) | ||
| new_sub_states = () | ||
| for i in eachindex(varname_groups) | ||
| target_vars = varname_groups[i] | ||
| spl = samplers[i] | ||
| sub_state = sub_states[i] | ||
|
|
||
| cond_model = condition(model, target_vars, global_values) | ||
| synced_state = setparams!!(cond_model, sub_state, getparams(cond_model, sub_state)) | ||
| _, new_sub_state = step(rng, cond_model, spl, synced_state; kwargs...) | ||
|
|
||
| new_params = getparams(cond_model, new_sub_state) | ||
| global_values = _update_global_values( | ||
| model, global_values, target_vars, cond_model, new_params | ||
| ) | ||
| new_sub_states = (new_sub_states..., new_sub_state) | ||
| end | ||
| return global_values, new_sub_states | ||
| end | ||
|
|
||
| function _gibbs_initial_steps( | ||
| rng, model, varname_groups, samplers, global_values; kwargs... | ||
| ) | ||
| sub_states = () | ||
| for i in eachindex(varname_groups) | ||
| target_vars = varname_groups[i] | ||
| spl = samplers[i] | ||
|
|
||
| cond_model = condition(model, target_vars, global_values) | ||
| _, sub_state = step(rng, cond_model, spl; kwargs..., discard_sample=true) | ||
|
|
||
| if global_values === nothing | ||
| global_values = _init_global_values(model, target_vars, cond_model, sub_state) | ||
| else | ||
| new_params = getparams(cond_model, sub_state) | ||
| global_values = _update_global_values( | ||
| model, global_values, target_vars, cond_model, new_params | ||
| ) | ||
| end | ||
| sub_states = (sub_states..., sub_state) | ||
| end | ||
| return sub_states | ||
| end | ||
|
|
||
| """ | ||
| AbstractMCMC._init_global_values(model, target_vars, cond_model, sub_state) | ||
|
|
||
| Initialise the global parameter store from the first component's initial state. | ||
| Called once only. Downstream packages implement this per model type. | ||
| """ | ||
| function _init_global_values end | ||
|
|
||
| """ | ||
| AbstractMCMC._update_global_values(model, global_values, target_vars, cond_model, new_params) | ||
|
|
||
| Return an updated `global_values` with the target variables set to the values | ||
| encoded in `new_params`. Downstream packages implement this per model type. | ||
| """ | ||
| function _update_global_values end | ||
|
|
||
| function _collect_global_values(model, varname_groups, samplers, sub_states) | ||
| global_values = nothing | ||
| for i in eachindex(varname_groups) | ||
| target_vars = varname_groups[i] | ||
| sub_state = sub_states[i] | ||
| cond_model = condition(model, target_vars, global_values) | ||
| if global_values === nothing | ||
| global_values = _init_global_values(model, target_vars, cond_model, sub_state) | ||
| else | ||
| global_values = _update_global_values( | ||
| model, | ||
| global_values, | ||
| target_vars, | ||
| cond_model, | ||
| getparams(cond_model, sub_state), | ||
| ) | ||
| end | ||
| end | ||
| return global_values | ||
| end | ||
|
|
||
| """ | ||
| AbstractMCMC._build_gibbs_transition(global_values) | ||
|
|
||
| Build the transition object for each `step` call. Defaults to `global_values` | ||
| itself. Override to return a richer transition type. | ||
| """ | ||
| _build_gibbs_transition(global_values) = global_values | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not included in test/runtests.jl, so you only know this PR "works" because of whatever you ran locally. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| using Test | ||
| using Random | ||
| using AbstractMCMC | ||
| using Turing | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Project.toml's [extras] / test target not updated |
||
| using Statistics | ||
|
|
||
| @testset "AbstractMCMC Gibbs combinator" begin | ||
| @model function normal_model(x) | ||
| μ ~ Normal(0.0, 10.0) | ||
| σ ~ truncated(Normal(0.0, 5.0), 0.0, Inf) | ||
| for i in eachindex(x) | ||
| x[i] ~ Normal(μ, σ) | ||
| end | ||
| end | ||
|
|
||
| rng = MersenneTwister(42) | ||
| true_μ = 3.0 | ||
| true_σ = 1.5 | ||
| x_obs = true_μ .+ true_σ .* randn(rng, 50) | ||
| model = normal_model(x_obs) | ||
|
|
||
| @testset "Gibbs(μ=>MH, σ=>MH) recovers posterior" begin | ||
| spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH()) | ||
| chain = sample(rng, model, spl, 1000; progress=false) | ||
| @test abs(mean(chain[:μ]) - mean(x_obs)) < 0.5 | ||
| @test abs(mean(chain[:σ]) - true_σ) < 0.5 | ||
| end | ||
|
|
||
| @testset "Gibbs is AbstractSampler" begin | ||
| spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH()) | ||
| @test spl isa AbstractMCMC.AbstractSampler | ||
| end | ||
|
|
||
| @testset "GibbsState has correct structure" begin | ||
| spl = AbstractMCMC.Gibbs(@varname(μ) => MH(), @varname(σ) => MH()) | ||
| _, state = AbstractMCMC.step(rng, model, spl) | ||
| @test state isa AbstractMCMC.GibbsState | ||
| @test length(state.sub_states) == 2 | ||
| end | ||
|
|
||
| @testset "Mismatched varnames/samplers raises ArgumentError" begin | ||
| @test_throws ArgumentError AbstractMCMC.Gibbs( | ||
| ([@varname(μ), @varname(σ)],), (MH(), MH()) | ||
| ) | ||
| end | ||
|
|
||
| @testset "Existing Turing.Inference.Gibbs still works (no regression)" begin | ||
| chain = sample( | ||
| rng, | ||
| model, | ||
| Turing.Inference.Gibbs(@varname(μ) => MH(), @varname(σ) => MH()), | ||
| 500; | ||
| progress=false, | ||
| ) | ||
| @test abs(mean(chain[:μ]) - mean(x_obs)) < 0.5 | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure
conditionbelongs here? AbstractMCMC's design (https://turinglang.org/AbstractMCMC.jl/stable/design/) is deliberately minimal and PPL-agnostic.conditionis a PPL concept andAbstractPPL.conditionalready exists.