From 35d1b18f8d0fa097554ecb2874bc53453a788faa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 14 Mar 2026 00:14:13 +0000 Subject: [PATCH 01/13] Bump minor version --- HISTORY.md | 4 ++++ Project.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index a6b5e3f1fe..30b4d5c7c9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.44.0 + +[...] + # 0.43.2 Throw an `ArgumentError` when a `Gibbs` sampler is missing component samplers for any variable in the model. diff --git a/Project.toml b/Project.toml index f9d396cc11..e04815c993 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.43.2" +version = "0.44.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 191116feaed02e504e1c681f184e8a345262e0c1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 16 Apr 2026 18:43:10 +0100 Subject: [PATCH 02/13] DynamicPPL@0.41 (#2800) Changes from new DPPL version + VI interface --- Project.toml | 5 +- docs/Project.toml | 3 + src/mcmc/mh.jl | 18 ++- src/optimisation/init.jl | 48 +++--- src/variational/Variational.jl | 209 +++++++++++++++------------ test/Project.toml | 5 +- test/floattypes/Project.toml | 1 + test/integration/enzyme/Project.toml | 1 + 8 files changed, 167 insertions(+), 123 deletions(-) diff --git a/Project.toml b/Project.toml index ac4273856c..6c0acf9411 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DifferentiationInterface = "0.7" Distributions = "0.25.77" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.40.20" +DynamicPPL = "0.41" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.14" @@ -84,3 +84,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/docs/Project.toml b/docs/Project.toml index 8cf000eca9..c9552ae361 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,3 +8,6 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 51a3455075..fb13804ebf 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -175,8 +175,8 @@ struct InitFromProposals{V<:DynamicPPL.VarNamedTuple} <: DynamicPPL.AbstractInit "A mapping of VarNames to Tuple{Bool,Distribution}s that they should be sampled from. If the VarName is not in this VarNamedTuple, then it will be sampled from the prior. The Bool indicates whether the proposal is in linked space (true, i.e., the strategy should - return a `LinkedVectorValue`); or in untransformed space (false, i.e., the strategy - should return an `UntransformedValue`)." + return a linked vector value); or in untransformed space (false, i.e., the strategy + should return an untransformed value)." proposals::V "Whether to print the proposals as they are being sampled" verbose::Bool @@ -195,12 +195,11 @@ function DynamicPPL.init( end end if is_linkedrw - transform = Bijectors.VectorBijectors.from_linked_vec(prior) linked_vec = rand(rng, dist) - return DynamicPPL.LinkedVectorValue(linked_vec, transform) + return DynamicPPL.TransformedValue(linked_vec, DynamicPPL.DynamicLink()) else # Static or conditional proposal in untransformed space. - return DynamicPPL.UntransformedValue(rand(rng, dist)) + return DynamicPPL.TransformedValue(rand(rng, dist), DynamicPPL.NoTransform()) end else strategy.verbose && @info "varname $vn: no proposal specified, drawing from prior" @@ -459,12 +458,15 @@ end # Accumulator to store linked values; but only the ones that have a LinkedRW proposal. Since # model evaluation should have happened with `s.transform_strategy`, any variables that are -# marked by `s.transform_strategy` as being linked should generate a LinkedVectorValue here. +# marked by `s.transform_strategy` as being linked should generate a +# TransformedValue{V,DynamicLink} here. const MH_ACC_NAME = :MHLinkedValues -function store_linked_values(val, tval::DynamicPPL.LinkedVectorValue, logjac, vn, dist) +function store_linked_values( + val, tval::DynamicPPL.TransformedValue{V,DynamicPPL.DynamicLink}, logjac, vn, dist +) where {V} return DynamicPPL.get_internal_value(tval) end -function store_linked_values(val, ::DynamicPPL.AbstractTransformedValue, logjac, vn, dist) +function store_linked_values(val, ::DynamicPPL.TransformedValue, logjac, vn, dist) return DynamicPPL.DoNotAccumulate() end function MHLinkedValuesAccumulator() diff --git a/src/optimisation/init.jl b/src/optimisation/init.jl index 06109c349e..eda0f41c56 100644 --- a/src/optimisation/init.jl +++ b/src/optimisation/init.jl @@ -133,10 +133,8 @@ function DynamicPPL.init( # The inner `init` might (for whatever reason) return linked or otherwise # transformed values. We need to transform them back into to unlinked space, # so that we can check the constraints properly. - maybe_transformed_val = DynamicPPL.init(rng, vn, dist, c.actual_strategy) - proposed_val = DynamicPPL.get_transform(maybe_transformed_val)( - DynamicPPL.get_internal_value(maybe_transformed_val) - ) + tv = DynamicPPL.init(rng, vn, dist, c.actual_strategy) + proposed_val = DynamicPPL.get_raw_value(tv, dist) attempts = 1 while !satisfies_constraints(lb, ub, proposed_val, dist) if attempts >= MAX_ATTEMPTS @@ -146,13 +144,11 @@ function DynamicPPL.init( ), ) end - maybe_transformed_val = DynamicPPL.init(rng, vn, dist, c.actual_strategy) - proposed_val = DynamicPPL.get_transform(maybe_transformed_val)( - DynamicPPL.get_internal_value(maybe_transformed_val) - ) + tv = DynamicPPL.init(rng, vn, dist, c.actual_strategy) + proposed_val = DynamicPPL.get_raw_value(tv, dist) attempts += 1 end - return DynamicPPL.UntransformedValue(proposed_val) + return tv end can_have_linked_constraints(::Distribution) = false @@ -231,29 +227,35 @@ function DynamicPPL.accumulate_assume!!( ), ) end - transform = - if DynamicPPL.target_transform(acc.transform_strategy, vn) isa - DynamicPPL.DynamicLink - Bijectors.VectorBijectors.to_linked_vec(dist) - elseif DynamicPPL.target_transform(acc.transform_strategy, vn) isa DynamicPPL.Unlink - Bijectors.VectorBijectors.to_vec(dist) - else - error( - "don't know how to handle transform strategy $(acc.transform_strategy) for variable $(vn)", - ) - end + target_tfm = DynamicPPL.target_transform(acc.transform_strategy, vn) + transform_fn = if target_tfm isa DynamicPPL.DynamicLink + Bijectors.VectorBijectors.to_linked_vec(dist) + elseif target_tfm isa DynamicPPL.Unlink + Bijectors.VectorBijectors.to_vec(dist) + elseif target_tfm isa DynamicPPL.FixedTransform + Bijectors.inverse(target_tfm.transform) + else + error( + "don't know how to handle transform strategy $(acc.transform_strategy) for variable $(vn)", + ) + end # Transform the value and store it. - vectorised_val = transform(val) + vectorised_val = transform_fn(val) + if !(vectorised_val isa AbstractVector) + error( + "The transform strategy used ($(acc.transform_strategy)) generated a value for variable $(vn) that is not a vector; in general this cannot be handled by Turing", + ) + end acc.init_vecs[vn] = vectorised_val nelems = length(vectorised_val) # Then generate the constraints using the same transform. if lb !== nothing - acc.lb_vecs[vn] = transform(lb) + acc.lb_vecs[vn] = transform_fn(lb) else acc.lb_vecs[vn] = fill(-Inf, nelems) end if ub !== nothing - acc.ub_vecs[vn] = transform(ub) + acc.ub_vecs[vn] = transform_fn(ub) else acc.ub_vecs[vn] = fill(Inf, nelems) end diff --git a/src/variational/Variational.jl b/src/variational/Variational.jl index 0ff18483f8..94d4e9a904 100644 --- a/src/variational/Variational.jl +++ b/src/variational/Variational.jl @@ -14,7 +14,7 @@ using AdvancedVI: using ADTypes using Bijectors: Bijectors using Distributions -using DynamicPPL: DynamicPPL +using DynamicPPL: DynamicPPL, LogDensityFunction using LinearAlgebra using LogDensityProblems: LogDensityProblems using Random @@ -43,8 +43,8 @@ requires_unconstrained_space(::AdvancedVI.FisherMinBatchMatch) = true """ q_initialize_scale( - [rng::Random.AbstractRNG,] - model::DynamicPPL.Model, + rng::Random.AbstractRNG, + ldf::DynamicPPL.LogDensityFunction, location::AbstractVector, scale::AbstractMatrix, basedist::Distributions.UnivariateDistribution; @@ -53,7 +53,7 @@ requires_unconstrained_space(::AdvancedVI.FisherMinBatchMatch) = true reduce_factor::Real = one(eltype(scale)) / 2 ) -Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `model` taken over `q` are finite. +Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `ldf` taken over `q` are finite. If the log-densities are not finite even after `num_max_trials`, throw an error. For reference, a location-scale distribution \$q\$ formed by `location`, `scale`, and `basedist` is a distribution where its sampling process \$z \\sim q\$ can be represented as @@ -63,7 +63,7 @@ z = scale * u + location ``` # Arguments -- `model`: The target `DynamicPPL.Model`. +- `ldf`: The target log-density function. - `location`: The location parameter of the initialization. - `scale`: The scale parameter of the initialization. - `basedist`: The base distribution of the location-scale family. @@ -78,7 +78,7 @@ z = scale * u + location """ function q_initialize_scale( rng::Random.AbstractRNG, - model::DynamicPPL.Model, + ldf::LogDensityFunction, location::AbstractVector, scale::AbstractMatrix, basedist::Distributions.UnivariateDistribution; @@ -86,15 +86,16 @@ function q_initialize_scale( num_max_trials::Int=10, reduce_factor::Real=one(eltype(scale)) / 2, ) - prob = DynamicPPL.LogDensityFunction(model) - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - + num_max_trials > 0 || error("num_max_trials must be a positive integer") n_trial = 0 while true q = AdvancedVI.MvLocationScale(location, scale, basedist) - b = Bijectors.bijector(model) - q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) - energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) + energy = mean( + map(1:num_samples) do _ + z = rand(rng, q) + LogDensityProblems.logdensity(ldf, z) + end, + ) if isfinite(energy) return scale @@ -109,15 +110,15 @@ end """ q_locationscale( - [rng::Random.AbstractRNG,] - model::DynamicPPL.Model; + rng::Random.AbstractRNG, + ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing, meanfield::Bool = true, basedist::Distributions.UnivariateDistribution = Normal() ) -Find a numerically non-degenerate variational distribution `q` for approximating the target `model` within the location-scale variational family formed by the type of `scale` and `basedist`. +Find a numerically non-degenerate variational distribution `q` for approximating the target `LogDensityFunction` within the location-scale variational family formed by the type of `scale` and `basedist`. The distribution can be manually specified by setting `location`, `scale`, and `basedist`. Otherwise, it chooses a Gaussian with zero-mean and scale `0.6*I` (covariance of `0.6^2*I`) by default. @@ -133,7 +134,7 @@ z = scale * u + location ``` # Arguments -- `model`: The target `DynamicPPL.Model`. +- `ldf`: The target log-density function. # Keyword Arguments - `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. @@ -144,22 +145,18 @@ z = scale * u + location The remaining keywords are passed to `q_initialize_scale`. # Returns -- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +- An `AdvancedVI.LocationScale` distribution matching the support of `ldf`. """ function q_locationscale( rng::Random.AbstractRNG, - model::DynamicPPL.Model; + ldf::LogDensityFunction; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:Diagonal,<:LowerTriangular}=nothing, meanfield::Bool=true, basedist::Distributions.UnivariateDistribution=Normal(), kwargs..., ) - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) + num_params = LogDensityProblems.dimension(ldf) μ = if isnothing(location) zeros(num_params) @@ -171,11 +168,11 @@ function q_locationscale( L = if isnothing(scale) if meanfield q_initialize_scale( - rng, model, μ, Diagonal(fill(0.6, num_params)), basedist; kwargs... + rng, ldf, μ, Diagonal(fill(0.6, num_params)), basedist; kwargs... ) else L0 = LowerTriangular(Matrix{Float64}(0.6 * I, num_params, num_params)) - q_initialize_scale(rng, model, μ, L0, basedist; kwargs...) + q_initialize_scale(rng, ldf, μ, L0, basedist; kwargs...) end else @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." @@ -185,32 +182,26 @@ function q_locationscale( LowerTriangular(Matrix(scale)) end end - q = AdvancedVI.MvLocationScale(μ, L, basedist) - b = Bijectors.bijector(model) - return Bijectors.transformed(q, Bijectors.inverse(b)) -end - -function q_locationscale(model::DynamicPPL.Model; kwargs...) - return q_locationscale(Random.default_rng(), model; kwargs...) + return AdvancedVI.MvLocationScale(μ, L, basedist) end """ q_meanfield_gaussian( - [rng::Random.AbstractRNG,] - model::DynamicPPL.Model; + rng::Random.AbstractRNG, + ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:Diagonal} = nothing, kwargs... ) -Find a numerically non-degenerate mean-field Gaussian `q` for approximating the target `model`. +Find a numerically non-degenerate mean-field Gaussian `q` for approximating the target `ldf::LogDensityFunction`. If the `scale` set as `nothing`, the default value will be a zero-mean Gaussian with a `Diagonal` scale matrix (the "mean-field" approximation) no larger than `0.6*I` (covariance of `0.6^2*I`). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the `Turing.InitFromUniform()` strategy. Whether the default choice is used or not, the `scale` may be adjusted via `q_initialize_scale` so that the log-densities of `model` are finite over the samples from `q`. # Arguments -- `model`: The target `DynamicPPL.Model`. +- `ldf`: The target log-density function. # Keyword Arguments - `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. @@ -219,41 +210,37 @@ Whether the default choice is used or not, the `scale` may be adjusted via `q_in The remaining keyword arguments are passed to `q_locationscale`. # Returns -- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +- An `AdvancedVI.LocationScale` distribution matching the support of `ldf`. """ function q_meanfield_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model; + ldf::LogDensityFunction, location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) return q_locationscale( - rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs... + rng, ldf; location, scale, meanfield=true, basedist=Normal(), kwargs... ) end -function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) - return q_meanfield_gaussian(Random.default_rng(), model; kwargs...) -end - """ q_fullrank_gaussian( - [rng::Random.AbstractRNG,] - model::DynamicPPL.Model; + rng::Random.AbstractRNG, + ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:LowerTriangular} = nothing, kwargs... ) -Find a numerically non-degenerate Gaussian `q` with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target `model`. +Find a numerically non-degenerate Gaussian `q` with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target `ldf::LogDensityFunction`. If the `scale` set as `nothing`, the default value will be a zero-mean Gaussian with a `LowerTriangular` scale matrix (resulting in a covariance with "full-rank" factors) no larger than `0.6*I` (covariance of `0.6^2*I`). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the `Turing.InitFromUniform()` strategy. Whether the default choice is used or not, the `scale` may be adjusted via `q_initialize_scale` so that the log-densities of `model` are finite over the samples from `q`. # Arguments -- `model`: The target `DynamicPPL.Model`. +- `ldf`: The target log-density function. # Keyword Arguments - `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. @@ -262,50 +249,109 @@ Whether the default choice is used or not, the `scale` may be adjusted via `q_in The remaining keyword arguments are passed to `q_locationscale`. # Returns -- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +- An `AdvancedVI.LocationScale` distribution matching the support of `ldf`. """ function q_fullrank_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model; + ldf::LogDensityFunction, location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) return q_locationscale( - rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs... + rng, ldf; location, scale, meanfield=false, basedist=Normal(), kwargs... ) end -function q_fullrank_gaussian(model::DynamicPPL.Model; kwargs...) - return q_fullrank_gaussian(Random.default_rng(), model; kwargs...) +""" + VIResult(ldf, q, info, state) + +- `ldf`: A [`DynamicPPL.LogDensityFunction`](@extref) corresponding to the target model (the + original model can be accessed as `ldf.model`). If the VI process was run in unconstrained + space, this LogDensityFunction will also be in unconstrained space. +- `q`: Output variational distribution of `algorithm`. Note that, as above, this will + typically also be in unconstrained space. +- `state`: Collection of states used by `algorithm`. This can be used to resume from a past + call to `vi`. +- `info`: Information generated while executing `algorithm`. +""" +struct VIResult{L<:LogDensityFunction,Q<:Distribution,I<:AbstractArray{<:NamedTuple},S} + ldf::L + q::Q + info::I + state::S +end + +function Base.show(io::IO, ::MIME"text/plain", r::VIResult) + printstyled(io, "VIResult\n"; bold=true) + println(io, " ├ q : $(nameof(typeof(r.q)))") + n_iters = length(r.info) + println(io, " ├ info : $(length(r.info))-element $(typeof(r.info))") + if n_iters > 0 + println(io, " │ final iteration:") + last_info = r.info[end] + for (i, (k, v)) in enumerate(pairs(last_info)) + tree_char = i == length(last_info) ? "└" : "├" + println(io, " │ $(tree_char) $k = $v") + end + else + end + print(io, " └ (2 more fields: state, ldf)") + return nothing +end + +""" + Base.rand(rng::Random.AbstractRNG, res::VIResult, sz...) + +Draw a sample, or array of samples, from the variational distribution `q` in `res`. Each +sample is a [`DynamicPPL.VarNamedTuple`](@ref) containing parameter values (in original, +untransformed space). +""" +function Base.rand(rng::Random.AbstractRNG, res::VIResult, sz::Integer...) + # TODO(penelopeysm): Should we expose a way to get colon_eq results as well -- maybe a + # kwarg? + function to_vnt(v::AbstractVector) + pws = DynamicPPL.ParamsWithStats( + v, res.ldf; include_colon_eq=false, include_log_probs=false + ) + return pws.params + end + if sz == () + return to_vnt(rand(rng, res.q)) + else + # re. stack: https://github.com/TuringLang/AdvancedVI.jl/issues/245 + x = stack(rand(rng, res.q, sz...)) + return map(to_vnt, eachslice(x; dims=ntuple(i -> i + 1, length(sz)))) + end end +Base.rand(res::VIResult, sz::Integer...) = Base.rand(Random.default_rng(), res, sz...) """ vi( [rng::Random.AbstractRNG,] model::DynamicPPL.Model, - q, + family, max_iter::Int; adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent( adtype; n_samples=10 ), + unconstrained::Bool=requires_unconstrained_space(algorithm), + fix_transforms::Bool=false, show_progress::Bool = Turing.PROGRESS[], kwargs... ) -Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`. +Approximate the target `model` via the variational inference algorithm `algorithm` using a variational family specified by `family`. This is a thin wrapper around `AdvancedVI.optimize`. -If the chosen variational inference algorithm operates in an unconstrained space, then the provided initial variational approximation `q` must be a `Bijectors.TransformedDistribution` of an unconstrained distribution. -For example, the initialization supplied by `q_meanfield_gaussian`,`q_fullrank_gaussian`, `q_locationscale`. - -The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`. +The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `family` returns a `AdvancedVI.MvLocationScale`, which is true if `family` is `q_fullrank_gaussian` or `q_meanfield_gaussian`. For other variational families, refer to the documentation of `AdvancedVI` to determine the best algorithm and other options. # Arguments - `model`: The target `DynamicPPL.Model`. -- `q`: The initial variational approximation. +- `family`: A function which is used to generate an initial variational approximation. + Existing choices in Turing are [`q_locationscale`](@ref), [`q_meanfield_gaussian`](@ref), and [`q_fullrank_gaussian`](@ref). - `max_iter`: Maximum number of steps. - Any additional arguments are passed on to `AdvancedVI.optimize`. @@ -314,20 +360,20 @@ For other variational families, refer to the documentation of `AdvancedVI` to de - `algorithm`: Variational inference algorithm. The default is `KLMinRepGradProxDescent`, please refer to [AdvancedVI docs](https://turinglang.org/AdvancedVI.jl/stable/) for all the options. - `show_progress`: Whether to show the progress bar. - `unconstrained`: Whether to transform the posterior to be unconstrained for running the variational inference algorithm. If `true`, then the output `q` will be wrapped into a `Bijectors.TransformedDistribution` with the transformation matching the support of the posterior. The default value depends on the chosen `algorithm`. -- Any additional keyword arguments are passed on to `AdvancedVI.optimize`. +- `fix_transforms`: Whether to precompute the transforms needed to convert model parameters to (possibly unconstrained) vectors. This can lead to performance improvements, but if any transforms depend on model parameters, setting `fix_transforms=true` can silently yield incorrect results. +- Any additional keyword arguments are passed on both to the function `initial_approx`, and also to `AdvancedVI.optimize`. See the docs of `AdvancedVI.optimize` for additional keyword arguments. # Returns -- `q`: Output variational distribution of `algorithm`. -- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`. -- `info`: Information generated while executing `algorithm`. + +A [`VIResult`](@ref) object: please see its docstring for information. """ function vi( rng::Random.AbstractRNG, model::DynamicPPL.Model, - q, + family, max_iter::Int, args...; adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, @@ -335,40 +381,23 @@ function vi( adtype; n_samples=10 ), unconstrained::Bool=requires_unconstrained_space(algorithm), + fix_transforms::Bool=false, show_progress::Bool=PROGRESS[], kwargs..., ) - prob, q, trans = if unconstrained - if !(q isa Bijectors.TransformedDistribution) - throw( - ArgumentError( - "The algorithm $(algorithm) operates in an unconstrained space. Therefore, the initial variational approximation is expected to be a Bijectors.TransformedDistribution of an unconstrained distribution.", - ), - ) - end - vi = DynamicPPL.VarInfo(model) - vi = DynamicPPL.link!!(vi, model) - prob = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype - ) - prob, q.dist, q.transform - else - prob = DynamicPPL.LogDensityFunction(model; adtype) - prob, q, nothing - end + transform_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll() + prob = LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, transform_strategy; adtype, fix_transforms + ) + q = family(rng, prob; kwargs...) q, info, state = AdvancedVI.optimize( rng, algorithm, max_iter, prob, q, args...; show_progress=show_progress, kwargs... ) - q = if unconstrained - Bijectors.TransformedDistribution(q, trans) - else - q - end - return q, info, state + return VIResult(prob, q, info, state) end -function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...) - return vi(Random.default_rng(), model, q, max_iter; kwargs...) +function vi(model::DynamicPPL.Model, family, max_iter::Int; kwargs...) + return vi(Random.default_rng(), model, family, max_iter; kwargs...) end end diff --git a/test/Project.toml b/test/Project.toml index 4456e58c9c..9cf5ee6dfc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -54,7 +54,7 @@ Combinatorics = "1" DifferentiationInterface = "0.7" Distributions = "0.25" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.40.20" +DynamicPPL = "0.41" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -78,3 +78,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/floattypes/Project.toml b/test/floattypes/Project.toml index eb8425d3a8..7b361a1629 100644 --- a/test/floattypes/Project.toml +++ b/test/floattypes/Project.toml @@ -5,3 +5,4 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [sources] Turing = {path = "../../"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/integration/enzyme/Project.toml b/test/integration/enzyme/Project.toml index 5069b35d89..2a7f3c47f3 100644 --- a/test/integration/enzyme/Project.toml +++ b/test/integration/enzyme/Project.toml @@ -9,3 +9,4 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [sources] Turing = {path = "../../../"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} From 1faa78da9ca5bb306f379bcddf401c32db560b17 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 16 Apr 2026 18:44:01 +0100 Subject: [PATCH 03/13] Remove sources --- Project.toml | 3 --- docs/Project.toml | 3 --- test/Project.toml | 3 --- test/floattypes/Project.toml | 1 - test/integration/enzyme/Project.toml | 1 - 5 files changed, 11 deletions(-) diff --git a/Project.toml b/Project.toml index 6c0acf9411..1bf1add086 100644 --- a/Project.toml +++ b/Project.toml @@ -84,6 +84,3 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" - -[sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/docs/Project.toml b/docs/Project.toml index c9552ae361..8cf000eca9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,6 +8,3 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" - -[sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/Project.toml b/test/Project.toml index 9cf5ee6dfc..6562c81a2f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -78,6 +78,3 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" - -[sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/floattypes/Project.toml b/test/floattypes/Project.toml index 7b361a1629..eb8425d3a8 100644 --- a/test/floattypes/Project.toml +++ b/test/floattypes/Project.toml @@ -5,4 +5,3 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [sources] Turing = {path = "../../"} -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/integration/enzyme/Project.toml b/test/integration/enzyme/Project.toml index 2a7f3c47f3..5069b35d89 100644 --- a/test/integration/enzyme/Project.toml +++ b/test/integration/enzyme/Project.toml @@ -9,4 +9,3 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [sources] Turing = {path = "../../../"} -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} From dbc4d397037921e0e186e78a3ff85d7e36ce4a4c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 13:30:30 +0100 Subject: [PATCH 04/13] Rework of Gibbs sampler interface (#2803) This PR does a minor overhaul of the Gibbs interface, in line with various plans that have been made over the months. Specifically, this PR makes the following changes: ## VNT over VarInfo Instead of `GibbsContext` holding a global **`VarInfo`**, which may contain linked or unlinked values, it now holds a global **`VarNamedTuple`** which always holds **raw** values. The motivation for this is twofold: 1. `VarInfo` holding linked values and cached transforms leads to correctness issues, such as #2801, which I discovered while trying to write this refactor. I can confirm that this PR fixes that issue. 2. `VarInfo` is slow because it does a lot of extra bookkeeping. The performance issues are in fact not limited to Gibbs: it leaks through to every sampler which Gibbs uses, because it forces every component sampler to carry around a VarInfo so that it can talk to Gibbs. This refactor therefore also means that some component samplers will be faster (see below). ## Sampler <=> Gibbs interface The old interface that Gibbs expected for samplers was as follows: - `get_varinfo(state)` Return a VarInfo that the Gibbs sampler will then merge into the global VarInfo. - `setparams_varinfo!!(state, varinfo, ...)` Update the sampler's state with the new global VarInfo. In addition to this interface Gibbs also had a function `match_linking!!`, which 'lined up' the transform status of the global VarInfo with each state's individual VarInfo. This PR directly replaces them with VNT equivalents - `gibbs_get_raw_values(state)` Return a VNT that the Gibbs sampler will then merge into the global VNT. - `gibbs_update_state!!(state, vnt, ...)` Update the sampler's state with the new global VNT. `match_linking!!` is no longer required as a separate function now: inside `gibbs_update_state!!` each sampler can just use its own transform strategy to reconstruct the correct state from the raw values. ## Component samplers Component samplers have been updated to stop using VarInfo. ## Performance Gibbs itself is slightly faster: ```julia using Turing @model function g() x ~ Normal() y ~ Normal() end @time sample(g(), Gibbs(:x => MH(), :y => MH()), 100_000; chain_type=Any, progress=false, verbose=false); ``` - main: 1.106086 seconds (23.50 M allocations: 1.398 GiB, 7.60% gc time) - This PR: 0.807256 seconds (17.60 M allocations: 896.448 MiB, 7.04% gc time) This speedup probably comes from just Gibbs, because performance of MH on these models is very similar (see below). On top of that, as promised above, this gives some nice speedups for other samplers, and specifically for ESS. These are trivial models. Naturally with larger models the overhead from samplers will be smaller. I also specify `chain_type=Any` to cut out the constant overhead from MCMCChains, which is actually quite significant. ```julia using Turing @model function f() x ~ Normal() 1.0 ~ Normal(x) end @time sample(f(), spl, 100_000; chain_type=Any, progress=false, verbose=false); ``` `spl = ESS()`: - main: 0.711739 seconds (10.20 M allocations: 514.196 MiB, 5.33% gc time) - this PR: 0.166084 seconds (6.00 M allocations: 259.379 MiB, 9.10% gc time) I thought `spl = MH()` would also show some speedup, but it's way less drastic. - main: 0.226019 seconds (5.53 M allocations: 271.970 MiB, 9.79% gc time) - this PR: 0.218487 seconds (5.33 M allocations: 253.618 MiB, 8.29% gc time) SMC and PG have also been switched over to using `OnlyAccsVarInfo`, but like MH, there's no significant difference in timings either for those (presumably Libtask overhead dominates everything else). `spl = HMC(0.1, 20)` or `spl = NUTS()` are basically the same, they were already optimised before this. ## Closes Closes #2801 Closes #2764 Closes #2762 Closes #2642 --- Project.toml | 2 +- src/mcmc/emcee.jl | 6 +- src/mcmc/ess.jl | 150 ++++++++---- src/mcmc/external_sampler.jl | 59 +++-- src/mcmc/gibbs.jl | 393 +++++++++++++++----------------- src/mcmc/gibbs_conditional.jl | 83 ++++--- src/mcmc/hmc.jl | 104 ++++++--- src/mcmc/mh.jl | 55 +++-- src/mcmc/particle_mcmc.jl | 97 +++++--- src/mcmc/repeat_sampler.jl | 8 +- src/optimisation/init.jl | 13 +- src/optimisation/stats.jl | 2 +- src/variational/Variational.jl | 19 +- test/ad.jl | 5 +- test/essential/container.jl | 14 +- test/integration/enzyme/main.jl | 4 +- test/mcmc/Inference.jl | 19 +- test/mcmc/abstractmcmc.jl | 57 +++-- test/mcmc/gibbs.jl | 125 +++++----- test/mcmc/particle_mcmc.jl | 9 +- test/variational/vi.jl | 87 ++++--- 21 files changed, 730 insertions(+), 581 deletions(-) diff --git a/Project.toml b/Project.toml index 1bf1add086..74a3f43547 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DifferentiationInterface = "0.7" Distributions = "0.25.77" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.41" +DynamicPPL = "0.41.2" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.14" diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index a776579fbe..38373436d3 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -73,7 +73,11 @@ function AbstractMCMC.step( transition = if discard_sample nothing else - [DynamicPPL.ParamsWithStats(vi, model) for vi in vis] + [ + DynamicPPL.ParamsWithStats( + DynamicPPL.InitFromParams(DynamicPPL.get_values(vi)), model + ) for vi in vis + ] end linked_vi = DynamicPPL.link!!(vis[1], model) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 0fdaf7219f..eae57427d6 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -22,11 +22,21 @@ Mean """ struct ESS <: AbstractSampler end -struct TuringESSState{V<:DynamicPPL.AbstractVarInfo,VNT<:DynamicPPL.VarNamedTuple} - vi::V - priors::VNT +struct TuringESSState{ + L<:DynamicPPL.LogDensityFunction, + P<:AbstractVector{<:Real}, + R<:Real, + Va<:DynamicPPL.VarNamedTuple, + Vb<:DynamicPPL.VarNamedTuple, +} + ldf::L + params::P + loglikelihood::R + priors::Va + # Minor optimisation: we cache a VNT storing vectorised values here to avoid having to + # reconstruct it each time in `gibbs_update_state!!`. + _vector_vnt::Vb end -get_varinfo(state::TuringESSState) = state.vi # always accept in the first step function AbstractMCMC.step( @@ -37,20 +47,33 @@ function AbstractMCMC.step( initial_params, kwargs..., ) - vi = DynamicPPL.VarInfo() - vi = DynamicPPL.setacc!!(vi, DynamicPPL.RawValueAccumulator(true)) - prior_acc = DynamicPPL.PriorDistributionAccumulator() - prior_accname = DynamicPPL.accumulator_name(prior_acc) - vi = DynamicPPL.setacc!!(vi, prior_acc) - _, vi = DynamicPPL.init!!(rng, model, vi, initial_params, DynamicPPL.UnlinkAll()) - priors = DynamicPPL.get_priors(vi) - + # Add some extra accumulators so that we can compute everything we need to in one pass. + oavi = DynamicPPL.OnlyAccsVarInfo() + oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.PriorDistributionAccumulator()) + oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.RawValueAccumulator(true)) + oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.VectorValueAccumulator()) + _, oavi = DynamicPPL.init!!(rng, model, oavi, initial_params, DynamicPPL.UnlinkAll()) + + # Check that priors are all Gaussian + priors = DynamicPPL.get_priors(oavi) for dist in values(priors) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, TuringESSState(vi, priors) + + # Set up a LogDensityFunction which evaluates the model's log-likelihood. + # TODO(penelopeysm): We could conceivably use fixed transforms here because every prior + # distribution is Gaussian, which by definition must have a fixed bijector. Need to + # benchmark to see if it's worth it. + loglike_ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, oavi) + # And the corresponding vector of params, and the likelihood at this point + vecvals = DynamicPPL.get_vector_values(oavi) + vector_params = DynamicPPL.internal_values_as_vector(vecvals) + loglike = DynamicPPL.getloglikelihood(oavi) + + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(oavi) + state = TuringESSState(loglike_ldf, vector_params, loglike, priors, vecvals) + return transition, state end function AbstractMCMC.step( @@ -61,54 +84,54 @@ function AbstractMCMC.step( discard_sample=false, kwargs..., ) - # obtain previous sample - vi = state.vi - f = vi[:] - # define previous sampler state # (do not use cache to avoid in-place sampling from prior) wrapped_state = EllipticalSliceSampling.ESSState( - f, DynamicPPL.getloglikelihood(vi), nothing + state.params, state.loglikelihood, nothing ) # compute next state sample, new_wrapped_state = AbstractMCMC.step( rng, EllipticalSliceSampling.ESSModel( - ESSPrior(model, vi, state.priors), ESSLikelihood(model, vi) + ESSPrior(state.ldf, state.priors), ESSLikelihood(state.ldf) ), EllipticalSliceSampling.ESS(), wrapped_state, ) - # update sample and log-likelihood - vi = DynamicPPL.unflatten!!(vi, sample) - vi = DynamicPPL.setloglikelihood!!(vi, new_wrapped_state.loglikelihood) - - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, TuringESSState(vi, state.priors) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(sample, state.ldf) + new_state = TuringESSState( + state.ldf, sample, new_wrapped_state.loglikelihood, state.priors, state._vector_vnt + ) + return transition, new_state end +# NOTE: This is a quick and easy definition but it assumes that _vec(x) is the same as +# Bijectors.VectorBijectors.from_vec(dist) for all distributions we care about in the +# priors. If that ever becomes untrue, then this could silently cause bugs. _vec(x::Real) = [x] _vec(x::AbstractArray) = vec(x) # Prior distribution of considered random variable -struct ESSPrior{M<:Model,V<:AbstractVarInfo,T} - model::M - varinfo::V - μ::T - - function ESSPrior( - model::Model, varinfo::AbstractVarInfo, priors::DynamicPPL.VarNamedTuple - ) - μ = mapreduce(vcat, priors; init=Float64[]) do pair - prior_dist = pair.second - EllipticalSliceSampling.isgaussian(typeof(prior_dist)) || error( - "[ESS] only supports Gaussian prior distributions, but found $(typeof(prior_dist))", +struct ESSPrior{L<:DynamicPPL.LogDensityFunction,T<:AbstractVector{<:Real}} + ldf::L + means::T + + function ESSPrior(ldf::DynamicPPL.LogDensityFunction, priors::DynamicPPL.VarNamedTuple) + # Calculate means from priors. + means = fill(NaN, LogDensityProblems.dimension(ldf)) + for (vn, dist) in pairs(priors) + range = DynamicPPL.get_range_and_transform(ldf, vn).range + this_mean = _vec(mean(dist)) + means[range] .= this_mean + end + if any(isnan, means) + error( + "Some means were not filled in when constructing ESSPrior. This is likely a bug in Turing.jl, please report it.", ) - _vec(mean(prior_dist)) end - return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ) + return new{typeof(ldf),typeof(means)}(ldf, means) end end @@ -117,14 +140,11 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - _, vi = DynamicPPL.init!!( - rng, p.model, p.varinfo, DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll() - ) - return vi[:] + return Base.rand(rng, p.ldf) end # Mean of prior distribution -Distributions.mean(p::ESSPrior) = p.μ +Distributions.mean(p::ESSPrior) = p.means # Evaluate log-likelihood of proposals. We need this struct because # EllipticalSliceSampling.jl expects a callable struct / a function as its @@ -133,8 +153,13 @@ struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} ldf::L # Force usage of `getloglikelihood` in inner constructor - function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) - ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) + function ESSLikelihood(ldf::DynamicPPL.LogDensityFunction) + logp_callable = DynamicPPL.get_logdensity_callable(ldf) + if logp_callable !== DynamicPPL.getloglikelihood + error( + "The log-density function passed to ESSLikelihood must use `getloglikelihood` as its log-density function, but found $(logp_callable). This is likely a bug in Turing.jl, please report it!", + ) + end return new{typeof(ldf)}(ldf) end end @@ -155,3 +180,34 @@ function AbstractMCMC.step( "This method is not implemented! If you want to use the ESS sampler in Turing.jl, please use `Turing.ESS()` instead. If you want the default behaviour in EllipticalSliceSampling.jl, wrap your model in a different subtype of `AbstractMCMC.AbstractModel`, and then implement the necessary EllipticalSliceSampling.jl methods on it.", ) end + +#### +#### Gibbs interface +#### + +function gibbs_get_raw_values(state::TuringESSState) + pws = DynamicPPL.ParamsWithStats( + state.params, state.ldf; include_log_probs=false, include_colon_eq=false + ) + return pws.params +end + +function gibbs_update_state!!( + ::ESS, + state::TuringESSState, + model::DynamicPPL.Model, + global_vals::DynamicPPL.VarNamedTuple, +) + # We need to update everything in `state` except for the priors (which are constant). We + # pass an extra LogLikelihoodAccumulator here so that we can calculate the new loglike in + # one pass. + new_ldf, new_params, accs = gibbs_recompute_ldf_and_params( + state.ldf, + model, + state._vector_vnt, + global_vals, + (DynamicPPL.LogLikelihoodAccumulator(),), + ) + new_loglike = DynamicPPL.getloglikelihood(accs) + return TuringESSState(new_ldf, new_params, new_loglike, state.priors, state._vector_vnt) +end diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index f654d1ea54..7509e26ddd 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -121,23 +121,17 @@ function externalsampler( return ExternalSampler(sampler, adtype, Val(unconstrained)) end -# TODO(penelopeysm): Can't we clean this up somehow? -struct TuringState{S,V,P<:AbstractVector,L<:DynamicPPL.LogDensityFunction} +struct TuringState{ + S,P<:AbstractVector,L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.VarNamedTuple +} state::S - # Note that this varinfo is used only for structure. Its parameters and other info do - # not need to be accurate - varinfo::V - # These are the actual parameters that this state is at params::P ldf::L + # Cached vector VNT, used to construct new LDFs in gibbs_update_state!! without + # reevaluating the model. Same role as HMCState._vector_vnt. + _vector_vnt::V end -# get_varinfo must return something from which the correct parameters can be obtained -function get_varinfo(state::TuringState) - return DynamicPPL.unflatten!!(state.varinfo, state.params) -end -get_varinfo(state::AbstractVarInfo) = state - function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -151,8 +145,11 @@ function AbstractMCMC.step( # Construct LogDensityFunction tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll() + oavi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.VectorValueAccumulator()) + _, oavi = DynamicPPL.init!!(model, oavi, DynamicPPL.InitFromPrior(), tfm_strategy) + vecvals = DynamicPPL.get_vector_values(oavi) f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=sampler_wrapper.adtype + model, DynamicPPL.getlogjoint_internal, vecvals; adtype=sampler_wrapper.adtype ) x = find_initial_params_ldf(rng, f, initial_params) @@ -181,12 +178,7 @@ function AbstractMCMC.step( DynamicPPL.ParamsWithStats(new_parameters, f, new_stats) end - # TODO(penelopeysm): this varinfo is only needed for Gibbs. The external sampler itself - # has no use for it. Get rid of this as soon as possible. - vi = DynamicPPL.link!!(VarInfo(model), model) - vi = DynamicPPL.unflatten!!(vi, x) - - return (new_transition, TuringState(state_inner, vi, new_parameters, f)) + return (new_transition, TuringState(state_inner, new_parameters, f, vecvals)) end function AbstractMCMC.step( @@ -212,5 +204,32 @@ function AbstractMCMC.step( new_stats = AbstractMCMC.getstats(state_inner) DynamicPPL.ParamsWithStats(new_parameters, f, new_stats) end - return (new_transition, TuringState(state_inner, state.varinfo, new_parameters, f)) + return (new_transition, TuringState(state_inner, new_parameters, f, state._vector_vnt)) +end + +#### +#### Gibbs interface +#### + +function gibbs_get_raw_values(state::TuringState) + pws = DynamicPPL.ParamsWithStats( + state.params, state.ldf; include_log_probs=false, include_colon_eq=false + ) + return pws.params +end + +function gibbs_update_state!!( + ::ExternalSampler, + state::TuringState, + model::DynamicPPL.Model, + global_vals::DynamicPPL.VarNamedTuple, +) + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params( + state.ldf, model, state._vector_vnt, global_vals + ) + # Update the inner sampler's state with the new parameters. + new_inner_state = AbstractMCMC.setparams!!( + AbstractMCMC.LogDensityModel(new_ldf), state.state, new_params + ) + return TuringState(new_inner_state, new_params, new_ldf, state._vector_vnt) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 710ed0460b..5f9992eeb2 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,3 +1,7 @@ +################################################### +# Interface for other samplers to work with Gibbs # +################################################### + """ isgibbscomponent(spl::AbstractSampler) @@ -6,16 +10,82 @@ Return a boolean indicating whether `spl` is a valid component for a Gibbs sampl Defaults to `true` if no method has been defined for a particular sampler. """ isgibbscomponent(::AbstractSampler) = true - isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler) isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) - isgibbscomponent(::Prior) = false isgibbscomponent(::Emcee) = false isgibbscomponent(::SGLD) = false isgibbscomponent(::SGHMC) = false isgibbscomponent(::SMC) = false +""" + Turing.Inference.gibbs_get_raw_values(state) + +Return a `VarNamedTuple` containing the raw values of all variables in the sampler state. +""" +function gibbs_get_raw_values(state::AbstractVarInfo) + return DynamicPPL.get_raw_values(state) +end + +""" + Turing.Inference.gibbs_update_state!!( + sampler::AbstractSampler, state, model::Model, global_vals::VarNamedTuple + ) + +Update the state of a Gibbs component sampler to be consistent with the new values in +`global_vals`. The exact meaning of this depends on what the sampler state contains. + +Each sampler should implement a method for its respective state type. +""" +function gibbs_update_state!! end + +""" + gibbs_recompute_ldf_and_params( + old_ldf, model, vector_vnt, global_vals, extra_accs + ) + +Shared helper that is used in `gibbs_update_state!!` for any sampler that uses a +LogDensityFunction. + +Creates a new `LogDensityFunction` from the newly conditioned `model` (using a cached +`vector_vnt` to avoid an extra model evaluation), then reevaluates the model to obtain the +correct vectorised parameters corresponding to the raw values in `global_vals`. + +If extra information is needed (e.g. log-probabilities), `extra_accs` can be used to pass in +other accumulators to be used in the same model evaluation, to avoid having to recompute +them later. + +Returns `(new_ldf, new_params, accs)` where `accs` is the accumulator VarInfo after +evaluation, from which extra accumulators (e.g. `LogLikelihoodAccumulator`) can be read. +""" +function gibbs_recompute_ldf_and_params( + old_ldf::DynamicPPL.LogDensityFunction, + model::DynamicPPL.Model, + vector_vnt::DynamicPPL.VarNamedTuple, + global_vals::DynamicPPL.VarNamedTuple, + extra_accs::NTuple{N,<:DynamicPPL.AbstractAccumulator}=(), +) where {N} + new_ldf = DynamicPPL.LogDensityFunction( + model, + DynamicPPL.get_logdensity_callable(old_ldf), + vector_vnt; + adtype=old_ldf.adtype, + ) + accs = DynamicPPL.OnlyAccsVarInfo( + DynamicPPL.VectorParamAccumulator(new_ldf), extra_accs... + ) + init_strategy = DynamicPPL.InitFromParams(global_vals, nothing) + _, accs = DynamicPPL.init!!( + new_ldf.model, accs, init_strategy, new_ldf.transform_strategy + ) + new_params = DynamicPPL.get_vector_params(accs) + return new_ldf, new_params, accs +end + +############################### +# Gibbs implementation itself # +############################### + can_be_wrapped(::DynamicPPL.AbstractContext) = true can_be_wrapped(::DynamicPPL.AbstractParentContext) = false can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childcontext(ctx)) @@ -33,70 +103,70 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childc # - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline # rather than the `observe` pipeline for the conditioned variables. """ - GibbsContext(target_varnames, global_varinfo, context) + GibbsContext(target_varnames, global_vnt, context) A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. -`target_varnames` is a a tuple of `VarName`s that the current component sampler -is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!` -calls to its child context. For other variables, their values will be fixed to -the values they have in `global_varinfo`. +`target_varnames` is a a tuple of `VarName`s that the current component sampler is sampling. +For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!` calls to its child +context. For other variables, their values will be fixed to the values they have in +`global_vnt`. # Fields $(FIELDS) """ struct GibbsContext{ - VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext + VNs<:Tuple{Vararg{VarName}}, + GV<:Ref{<:DynamicPPL.VarNamedTuple}, + Ctx<:DynamicPPL.AbstractContext, } <: DynamicPPL.AbstractParentContext """ the VarNames being sampled """ target_varnames::VNs """ - a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both - those fixed and those being sampled. We use a `Ref` because this field may need to be - updated if new variables are introduced. + a `Ref` to the global `VarNamedTuple` object that holds raw values for all variables, + both those fixed and those being sampled. We use a `Ref` because this field may need + to be updated if new variables are introduced. """ - global_varinfo::GVI + global_vnt::GV """ the child context that tilde calls will eventually be passed onto. """ context::Ctx - function GibbsContext(target_varnames, global_varinfo, context) + function GibbsContext(target_varnames, global_vnt, context) if !can_be_wrapped(context) error("GibbsContext can only wrap a leaf or prefix context, not a $(context).") end target_varnames = tuple(target_varnames...) # Allow vectors. - return new{typeof(target_varnames),typeof(global_varinfo),typeof(context)}( - target_varnames, global_varinfo, context + return new{typeof(target_varnames),typeof(global_vnt),typeof(context)}( + target_varnames, global_vnt, context ) end end -function GibbsContext(target_varnames, global_varinfo) - return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) +function GibbsContext(target_varnames, global_vnt) + return GibbsContext(target_varnames, global_vnt, DynamicPPL.DefaultContext()) end DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) - return GibbsContext( - context.target_varnames, Ref(context.global_varinfo[]), childcontext - ) + return GibbsContext(context.target_varnames, context.global_vnt, childcontext) end -get_global_varinfo(context::GibbsContext) = context.global_varinfo[] +get_global_vnt(context::GibbsContext) = context.global_vnt[] -function set_global_varinfo!(context::GibbsContext, new_global_varinfo) - context.global_varinfo[] = new_global_varinfo +function set_global_vnt!(context::GibbsContext, new_global_varinfo) + context.global_vnt[] = new_global_varinfo return nothing end # has and get function has_conditioned_gibbs(context::GibbsContext, vn::VarName) - return DynamicPPL.haskey(get_global_varinfo(context), vn) + return DynamicPPL.haskey(get_global_vnt(context), vn) end function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns)) @@ -110,7 +180,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa end function get_conditioned_gibbs(context::GibbsContext, vn::VarName) - return get_global_varinfo(context)[vn] + return get_global_vnt(context)[vn] end function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) return map(Base.Fix1(get_conditioned_gibbs, context), vns) @@ -189,7 +259,7 @@ function DynamicPPL.tilde_assume!!( # variable. From the perspective of this sampler, this variable is # conditioned on, so we can just treat it as an observation. # The only catch is that the value that we need is to be obtained from - # the global VarInfo (since the local VarInfo has no knowledge of it). + # the global VNT (since the local VarInfo has no knowledge of it). # Note that tilde_observe!! will trigger resampling in particle methods # for variables that are handled by other Gibbs component samplers. val = get_conditioned_gibbs(context, vn) @@ -199,50 +269,49 @@ function DynamicPPL.tilde_assume!!( # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, new_global_vi = DynamicPPL.tilde_assume!!( - # We assume that the new variable should just be sampled in unlinked space. - DynamicPPL.setleafcontext( - child_context, - DynamicPPL.InitContext(DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll()), - ), - right, - vn, - template, - get_global_varinfo(context), - ) - set_global_varinfo!(context, new_global_vi) + # + # TODO(penelopeysm): How is the RNG controlled here? + value = rand(right) + vnt = get_global_vnt(context) + vnt = DynamicPPL.templated_setindex!!(vnt, value, vn, template) + set_global_vnt!(context, vnt) + # Return the value (so that it can be used in the model), plus the unmodified local + # varinfo value, vi end end """ - make_conditional(model, target_variables, varinfo) + make_conditional(model, target_variables, global_vnt) Return a new, conditioned model for a component of a Gibbs sampler. # Arguments -# + - `model::DynamicPPL.Model`: The model to condition. - `target_variables::AbstractVector{<:VarName}`: The target variables of the component sampler. These will _not_ be conditioned. -- `varinfo::DynamicPPL.AbstractVarInfo`: Values for all variables in the model. All the - values in `varinfo` but not in `target_variables` will be conditioned to the values they - have in `varinfo`. +- `global_vnt::DynamicPPL.VarNamedTuple`: Raw values for all variables in the model, which + are used for all variables that are *not* in `target_variables`. # Returns + - A new model with the variables _not_ in `target_variables` conditioned. + - The `GibbsContext` object that will be used to condition the variables. This is necessary because evaluation can mutate its `global_varinfo` field, which we need to access later. """ function make_conditional( - model::DynamicPPL.Model, target_variables::AbstractVector{<:VarName}, varinfo + model::DynamicPPL.Model, + target_variables::AbstractVector{<:VarName}, + global_vnt::DynamicPPL.VarNamedTuple, ) # Insert the `GibbsContext` just before the leaf. # 1. Extract the `leafcontext` from `model` and wrap in `GibbsContext`. gibbs_context_inner = GibbsContext( - target_variables, Ref(varinfo), DynamicPPL.leafcontext(model.context) + target_variables, Ref(global_vnt), DynamicPPL.leafcontext(model.context) ) # 2. Set the leaf context to be the `GibbsContext` wrapping `leafcontext(model.context)`. gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) @@ -275,6 +344,15 @@ Gibbs(@varname(x) => NUTS(), @varname(y) => MH()) Gibbs((@varname(x), :y) => NUTS(), :z => MH()) ``` +Note that all variables in the model should be handled by one or more samplers. The +behaviour of Gibbs when there are unhandled variables is undefined: depending on the version +of Turing, it may either crash, or it may sample once from the prior and not update values +after that. See https://github.com/TuringLang/Turing.jl/issues/2810 for more information. + +There is currently no way to specify a different initialisation strategy for for each +component sampler individually. When sampling with Gibbs, `initial_params` applies to the +model as a whole. + # Fields $(TYPEDFIELDS) """ @@ -307,13 +385,11 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(last, algs)) end -struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} - vi::V +struct GibbsState{V<:DynamicPPL.VarNamedTuple,S} + vnt::V states::S end -get_varinfo(state::GibbsState) = state.vi - function check_all_variables_handled(vns, spl::Gibbs) handled_vars = Iterators.flatten(spl.varnames) missing_vars = [ @@ -344,20 +420,26 @@ function AbstractMCMC.step( ) varnames = spl.varnames samplers = spl.samplers - _, vi = DynamicPPL.init!!(rng, model, VarInfo(), initial_params, DynamicPPL.UnlinkAll()) + accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!(rng, model, accs, initial_params, DynamicPPL.UnlinkAll()) + vnt = DynamicPPL.get_raw_values(accs) - vi, states = gibbs_initialstep_recursive( + vnt, states = gibbs_initialstep_recursive( rng, model, AbstractMCMC.step, varnames, samplers, - vi; + vnt; initial_params=initial_params, kwargs..., ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, GibbsState(vi, states) + transition = if discard_sample + nothing + else + DynamicPPL.ParamsWithStats(DynamicPPL.InitFromParams(vnt), model) + end + return transition, GibbsState(vnt, states) end function AbstractMCMC.step_warmup( @@ -370,20 +452,27 @@ function AbstractMCMC.step_warmup( ) varnames = spl.varnames samplers = spl.samplers - _, vi = DynamicPPL.init!!(rng, model, VarInfo(), initial_params, DynamicPPL.UnlinkAll()) + # Sample a set of initial values + accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!(rng, model, accs, initial_params, DynamicPPL.UnlinkAll()) + vnt = DynamicPPL.get_raw_values(accs) - vi, states = gibbs_initialstep_recursive( + vnt, states = gibbs_initialstep_recursive( rng, model, AbstractMCMC.step_warmup, varnames, samplers, - vi; + vnt; initial_params=initial_params, kwargs..., ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, GibbsState(vi, states) + transition = if discard_sample + nothing + else + DynamicPPL.ParamsWithStats(DynamicPPL.InitFromParams(vnt), model) + end + return transition, GibbsState(vnt, states) end """ @@ -400,21 +489,21 @@ function gibbs_initialstep_recursive( step_function::Function, varname_vecs, samplers, - vi, + vnt, states=(); initial_params, kwargs..., ) # End recursion if isempty(varname_vecs) && isempty(samplers) - return vi, states + return vnt, states end varnames, varname_vecs_tail... = varname_vecs sampler, samplers_tail... = samplers # Construct the conditioned model. - conditioned_model, context = make_conditional(model, varnames, vi) + conditioned_model, context = make_conditional(model, varnames, vnt) # Take initial step with the current sampler. _, new_state = step_function( @@ -427,12 +516,12 @@ function gibbs_initialstep_recursive( kwargs..., discard_sample=true, ) - new_vi_local = get_varinfo(new_state) + new_vnt_local = gibbs_get_raw_values(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. - vi = merge(vi, get_global_varinfo(context)) + vnt = merge(vnt, get_global_vnt(context)) # Merge the new values for all the variables sampled by the current sampler. - vi = merge(vi, new_vi_local) + vnt = merge(vnt, new_vnt_local) states = (states..., new_state) return gibbs_initialstep_recursive( @@ -441,7 +530,7 @@ function gibbs_initialstep_recursive( step_function, varname_vecs_tail, samplers_tail, - vi, + vnt, states; initial_params=initial_params, kwargs..., @@ -456,17 +545,21 @@ function AbstractMCMC.step( discard_sample=false, kwargs..., ) - vi = get_varinfo(state) varnames = spl.varnames samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) - vi, states = gibbs_step_recursive( - rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... + vnt, states = gibbs_step_recursive( + rng, model, AbstractMCMC.step, varnames, samplers, states, state.vnt; kwargs... ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, GibbsState(vi, states) + + transition = if discard_sample + nothing + else + DynamicPPL.ParamsWithStats(DynamicPPL.InitFromParams(vnt), model) + end + return transition, GibbsState(vnt, states) end function AbstractMCMC.step_warmup( @@ -477,128 +570,27 @@ function AbstractMCMC.step_warmup( discard_sample=false, kwargs..., ) - vi = get_varinfo(state) varnames = spl.varnames samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) - vi, states = gibbs_step_recursive( - rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... - ) - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model) - return transition, GibbsState(vi, states) -end - -""" - setparams_varinfo!!(model::DynamicPPL.Model, sampler::AbstractSampler, state, params::AbstractVarInfo) - -A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an -`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to -`AbstractMCMC.setparams!!(model, state, params[:])`. -""" -function setparams_varinfo!!( - model::DynamicPPL.Model, ::AbstractSampler, state, params::AbstractVarInfo -) - return AbstractMCMC.setparams!!(model, state, params[:]) -end - -function setparams_varinfo!!( - model::DynamicPPL.Model, spl::MH, ::AbstractVarInfo, params::AbstractVarInfo -) - # Setting `params` into `state` really just means using `params` itself, but we - # need to update the logprob. We also need to be a bit more careful, because - # the `state` here carries a VAIMAcc, which is needed for the MH step() function - # but may not be present in `params`. So we need to make sure that the value - # we return from this function also has a VAIMAcc which corresponds to the - # values in `params`. Likewise with the other MH-specific accumulators. - params = DynamicPPL.setacc!!(params, DynamicPPL.RawValueAccumulator(false)) - params = DynamicPPL.setacc!!(params, MHLinkedValuesAccumulator()) - params = DynamicPPL.setacc!!( - params, MHUnspecifiedPriorsAccumulator(spl.vns_with_proposal) - ) - # TODO(penelopeysm): Remove need for evaluate_nowarn here, by allowing MH-in-Gibbs to - # use OAVI. - return last(DynamicPPL.evaluate_nowarn!!(model, params)) -end - -function setparams_varinfo!!( - model::DynamicPPL.Model, ::ESS, state::TuringESSState, params::AbstractVarInfo -) - # The state is basically a VarInfo (plus a constant `priors` field), so we can just - # return `params`, but first we need to update its logprob. - # TODO(penelopeysm): Remove need for evaluate_nowarn here, by allowing ESS-in-Gibbs to - # use OAVI. - new_vi = last(DynamicPPL.evaluate_nowarn!!(model, params)) - return TuringESSState(new_vi, state.priors) -end - -function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::ExternalSampler, - state::TuringState, - params::AbstractVarInfo, -) - new_ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, params; adtype=sampler.adtype - ) - new_inner_state = AbstractMCMC.setparams!!( - AbstractMCMC.LogDensityModel(new_ldf), state.state, params[:] + vnt, states = gibbs_step_recursive( + rng, + model, + AbstractMCMC.step_warmup, + varnames, + samplers, + states, + state.vnt; + kwargs..., ) - return TuringState(new_inner_state, params, params[:], new_ldf) -end - -function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Hamiltonian, state::HMCState, params::AbstractVarInfo -) - θ_new = params[:] - hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) - - # Update the parameter values in `state.z`. - # TODO: Avoid mutation - z = state.z - resize!(z.θ, length(θ_new)) - z.θ .= θ_new - return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.ldf) -end - -function setparams_varinfo!!( - ::DynamicPPL.Model, ::PG, state::PGState, params::AbstractVarInfo -) - return PGState(params, state.rng) -end - -""" - match_linking!!(varinfo_local, prev_state_local, model) - -Make sure the linked/invlinked status of varinfo_local matches that of the previous -state for this sampler. This is relevant when multiple samplers are sampling the same -variables, and one might need it to be linked while the other doesn't. -""" -function match_linking!!(varinfo_local, prev_state_local, model) - prev_varinfo_local = get_varinfo(prev_state_local) - # Get a set of all previously linked variables - linked_vns = Set{VarName}() - unlinked_vns = Set{VarName}() - for vn in keys(prev_varinfo_local) - if DynamicPPL.is_transformed(prev_varinfo_local, vn) - push!(linked_vns, vn) - else - push!(unlinked_vns, vn) - end - end - transform_strategy = if isempty(unlinked_vns) - # All variables were linked - DynamicPPL.LinkAll() - elseif isempty(linked_vns) - # No variables were linked - DynamicPPL.UnlinkAll() + transition = if discard_sample + nothing else - DynamicPPL.LinkSome( - linked_vns, DynamicPPL.UnlinkSome(unlinked_vns, DynamicPPL.LinkAll()) - ) + DynamicPPL.ParamsWithStats(DynamicPPL.InitFromParams(vnt), model) end - return DynamicPPL.update_link_status!!(varinfo_local, transform_strategy, model) + return transition, GibbsState(vnt, states) end """ @@ -615,34 +607,24 @@ function gibbs_step_recursive( varname_vecs, samplers, states, - global_vi, + global_vnt, new_states=(); kwargs..., ) # End recursion. if isempty(varname_vecs) && isempty(samplers) && isempty(states) - return global_vi, new_states + return global_vnt, new_states end varnames, varname_vecs_tail... = varname_vecs sampler, samplers_tail... = samplers state, states_tail... = states - # Construct the conditional model and the varinfo that this sampler should use. - conditioned_model, context = make_conditional(model, varnames, global_vi) - vi = DynamicPPL.subset(global_vi, varnames) - vi = match_linking!!(vi, state, conditioned_model) - - # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not - # sampled by other samplers, we don't need to `setparams`, but could rather simply - # recompute the log probability. More over, in some cases the recomputation could also - # be avoided, if e.g. the previous sampler has done all the necessary work already. - # However, we've judged that doing any caching or other tricks to avoid this now would - # be premature optimization. In most use cases of Gibbs a single model call here is not - # going to be a significant expense anyway. - # Set the state of the current sampler, accounting for any changes made by other + # Construct the conditional model that this sampler should use. + conditioned_model, context = make_conditional(model, varnames, global_vnt) + # Update the sampler's state based on global values that were provided by other # samplers. - state = setparams_varinfo!!(conditioned_model, sampler, state, vi) + state = gibbs_update_state!!(sampler, state, conditioned_model, global_vnt) # Take a step with the local sampler. We don't need the actual sample, only the state. # Note that we pass `discard_sample=true` after `kwargs...`, because AbstractMCMC will @@ -652,10 +634,9 @@ function gibbs_step_recursive( rng, conditioned_model, sampler, state; kwargs..., discard_sample=true ) - new_vi_local = get_varinfo(new_state) - # Merge the latest values for all the variables in the current sampler. - new_global_vi = merge(get_global_varinfo(context), new_vi_local) - new_global_vi = DynamicPPL.setlogp!!(new_global_vi, DynamicPPL.getlogp(new_vi_local)) + # The current sampler will return some raw values, which we update the global VNT with. + new_vnt_local = gibbs_get_raw_values(new_state) + new_global_vnt = merge(get_global_vnt(context), new_vnt_local) new_states = (new_states..., new_state) return gibbs_step_recursive( @@ -665,7 +646,7 @@ function gibbs_step_recursive( varname_vecs_tail, samplers_tail, states_tail, - new_global_vi, + new_global_vnt, new_states; kwargs..., ) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index ff65e3a800..c3a53350c7 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -109,9 +109,7 @@ function build_values_vnt(model::DynamicPPL.Model) # model.args is a NamedTuple arg_vals = DynamicPPL.VarNamedTuple(model.args) # Extract values from the GibbsContext itself, as a VNT. - init_strat = DynamicPPL.InitFromParams( - get_gibbs_global_varinfo(context).values, nothing - ) + init_strat = DynamicPPL.InitFromParams(get_gibbs_global_vnt(context), nothing) oavi = DynamicPPL.OnlyAccsVarInfo((DynamicPPL.RawValueAccumulator(false),)) # We need to remove the Gibbs conditioning so that we can get all variables in the # accumulator (otherwise those that are conditioned on in `model` will not be included). @@ -131,47 +129,36 @@ function replace_gibbs_context(m::DynamicPPL.Model) return DynamicPPL.contextualize(m, replace_gibbs_context(m.context)) end -function get_gibbs_global_varinfo(context::GibbsContext) - return get_global_varinfo(context) +function get_gibbs_global_vnt(context::GibbsContext) + return get_global_vnt(context) end -function get_gibbs_global_varinfo(context::DynamicPPL.AbstractParentContext) - return get_gibbs_global_varinfo(DynamicPPL.childcontext(context)) +function get_gibbs_global_vnt(context::DynamicPPL.AbstractParentContext) + return get_gibbs_global_vnt(DynamicPPL.childcontext(context)) end -function get_gibbs_global_varinfo(::DynamicPPL.AbstractContext) +function get_gibbs_global_vnt(::DynamicPPL.AbstractContext) msg = """No GibbsContext found in context stack. Are you trying to use \ GibbsConditional outside of Gibbs? """ throw(ArgumentError(msg)) end -function Turing.Inference.initialstep( - ::Random.AbstractRNG, - model::DynamicPPL.Model, - ::GibbsConditional, - vi::DynamicPPL.VarInfo; - kwargs..., -) - state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi - # Since GibbsConditional is only used within Gibbs, it does not need to return a - # transition. - return nothing, state -end - -@inline _to_varnamedtuple(dists::NamedTuple, ::DynamicPPL.VarInfo) = +@inline _to_varnamedtuple(dists::NamedTuple, ::DynamicPPL.VarNamedTuple) = DynamicPPL.VarNamedTuple(dists) -@inline _to_varnamedtuple(dists::DynamicPPL.VarNamedTuple, ::DynamicPPL.VarInfo) = dists -function _to_varnamedtuple(dists::AbstractDict{<:VarName}, state::DynamicPPL.VarInfo) - template_vnt = state.values +@inline _to_varnamedtuple(dists::DynamicPPL.VarNamedTuple, ::DynamicPPL.VarNamedTuple) = + dists +function _to_varnamedtuple( + dists::AbstractDict{<:VarName}, raw_values::DynamicPPL.VarNamedTuple +) vnt = DynamicPPL.VarNamedTuple() for (vn, dist) in dists top_sym = AbstractPPL.getsym(vn) - template = get(template_vnt.data, top_sym, DynamicPPL.NoTemplate()) + template = get(raw_values.data, top_sym, DynamicPPL.NoTemplate()) vnt = DynamicPPL.templated_setindex!!(vnt, dist, vn, template) end return vnt end -function _to_varnamedtuple(dist::Distribution, state::DynamicPPL.VarInfo) - vns = keys(state) +function _to_varnamedtuple(dist::Distribution, raw_values::DynamicPPL.VarNamedTuple) + vns = keys(raw_values) if length(vns) > 1 msg = ( "In GibbsConditional, `get_cond_dists` returned a single distribution," * @@ -182,7 +169,7 @@ function _to_varnamedtuple(dist::Distribution, state::DynamicPPL.VarInfo) end vn = only(vns) top_sym = AbstractPPL.getsym(vn) - template = get(state.values.data, top_sym, DynamicPPL.NoTemplate()) + template = get(raw_values.data, top_sym, DynamicPPL.NoTemplate()) return DynamicPPL.templated_setindex!!(DynamicPPL.VarNamedTuple(), dist, vn, template) end @@ -192,14 +179,30 @@ end function DynamicPPL.init( rng::Random.AbstractRNG, vn::VarName, ::Distribution, init_strat::InitFromCondDists ) - return DynamicPPL.UntransformedValue(rand(rng, init_strat.cond_dists[vn])) + return DynamicPPL.TransformedValue( + rand(rng, init_strat.cond_dists[vn]), DynamicPPL.NoTransform() + ) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + ::GibbsConditional; + initial_params, + kwargs..., +) + accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!(rng, model, accs, initial_params, DynamicPPL.UnlinkAll()) + # Since GibbsConditional is only used within Gibbs, it does not need to return a + # transition. + return nothing, accs end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::GibbsConditional, - state::DynamicPPL.VarInfo; + state::DynamicPPL.OnlyAccsVarInfo; kwargs..., ) # Get all the conditioned variable values from the model context. This is assumed to @@ -212,17 +215,25 @@ function AbstractMCMC.step( # - a VarNamedTuple of distributions # - a NamedTuple of distributions # - an AbstractDict mapping VarNames to distributions - conddists = _to_varnamedtuple(sampler.get_cond_dists(condvals), state) + raw_values = DynamicPPL.get_raw_values(state) + conddists = _to_varnamedtuple(sampler.get_cond_dists(condvals), raw_values) init_strategy = InitFromCondDists(conddists) - _, new_state = DynamicPPL.init!!(rng, model, state, init_strategy) + _, new_state = DynamicPPL.init!!( + rng, model, state, init_strategy, DynamicPPL.UnlinkAll() + ) # Since GibbsConditional is only used within Gibbs, it does not need to return a # transition. return nothing, new_state end -function setparams_varinfo!!( - ::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.VarInfo +function gibbs_update_state!!( + ::GibbsConditional, + state::DynamicPPL.OnlyAccsVarInfo, + ::DynamicPPL.Model, + ::DynamicPPL.VarNamedTuple, ) - return params + # Nothing in the state is used in the next iteration (we overwrite it immediately with + # init!! anyway), so we can just return the state as is. + return state end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index aec5397dcc..e086400f03 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -8,28 +8,33 @@ Turing.allow_discrete_variables(sampler::Hamiltonian) = false ### struct HMCState{ - TV<:AbstractVarInfo, TKernel<:AHMC.HMCKernel, THam<:AHMC.Hamiltonian, PhType<:AHMC.PhasePoint, TAdapt<:AHMC.Adaptation.AbstractAdaptor, L<:DynamicPPL.LogDensityFunction, + V<:DynamicPPL.VarNamedTuple, } - vi::TV i::Int kernel::TKernel hamiltonian::THam z::PhType adaptor::TAdapt ldf::L + # TODO(penelopeysm): This field is needed for Gibbs because each time we call + # gibbs_update_state!! on this, we need to reconstruct a LogDensityFunction. + # In general this would require reevaluating the model, unless we supply a + # VarNamedTuple which already contains vectorised parameters. This can probably + # be improved in DynamicPPL, but for now we will just store an extra VNT in + # the state. + # NOTE: The actual values of this field should never be used or relied on! + _vector_vnt::V end ### ### Hamiltonian Monte Carlo samplers. ### -get_varinfo(state::HMCState) = state.vi - """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -162,10 +167,16 @@ function AbstractMCMC.step( verbose::Bool=true, kwargs..., ) - # Create a Hamiltonian + # Create a LogDensityFunction + oavi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.VectorValueAccumulator()) + _, oavi = DynamicPPL.init!!( + model, oavi, DynamicPPL.InitFromPrior(), DynamicPPL.LinkAll() + ) + vecvals = DynamicPPL.get_vector_values(oavi) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, DynamicPPL.LinkAll(); adtype=spl.adtype + model, DynamicPPL.getlogjoint_internal, vecvals; adtype=spl.adtype ) + # And a Hamiltonian metricT = getmetricT(spl) metric = metricT(LogDensityProblems.dimension(ldf)) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) @@ -193,12 +204,7 @@ function AbstractMCMC.step( DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple()) end - # TODO(penelopeysm): this varinfo is only needed for Gibbs. HMC itself has no use for - # it. Get rid of this as soon as possible. - vi = DynamicPPL.link!!(VarInfo(model), model) - vi = DynamicPPL.unflatten!!(vi, theta) - - state = HMCState(vi, 0, kernel, hamiltonian, z, adaptor, ldf) + state = HMCState(0, kernel, hamiltonian, z, adaptor, ldf, vecvals) return transition, state end @@ -236,33 +242,19 @@ function AbstractMCMC.step( kernel = state.kernel end - # Update variables - vi = state.vi - if t.stat.is_accept - vi = DynamicPPL.unflatten!!(vi, t.z.θ) - end - # Compute next transition and state. transition = if discard_sample nothing else DynamicPPL.ParamsWithStats(t.z.θ, state.ldf, t.stat) end - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, state.ldf) + newstate = HMCState( + i, kernel, hamiltonian, t.z, state.adaptor, state.ldf, state._vector_vnt + ) return transition, newstate end -function get_hamiltonian(model, spl, vi, state, n) - metric = gen_metric(n, spl, state) - ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype - ) - lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) - lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) - return AHMC.Hamiltonian(metric, lp_func, lp_grad_func) -end - """ HMCDA( n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64 = 0.0; @@ -422,16 +414,15 @@ end ##### getstepsize(sampler::Hamiltonian, state) = sampler.ϵ -getstepsize(sampler::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor) +getstepsize(::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor) function getstepsize( - sampler::AdaptiveHamiltonian, - state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation}, -) where {TV,TKernel,THam,PhType} + ::AdaptiveHamiltonian, state::HMCState{TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation} +) where {TKernel,THam,PhType} return state.kernel.τ.integrator.ϵ end -gen_metric(dim::Int, spl::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim) -function gen_metric(dim::Int, spl::AdaptiveHamiltonian, state) +gen_metric(dim::Int, ::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim) +function gen_metric(::Int, ::AdaptiveHamiltonian, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) end @@ -480,3 +471,46 @@ end function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric, nadapts::Int; kwargs...) return AHMC.Adaptation.NoAdaptation() end + +#### +#### Gibbs interface +#### + +function gibbs_get_raw_values(state::HMCState) + # In general this needs reevaluation (unless the LDF has all fixed transforms -- + # DynamicPPL handles this.) + pws = DynamicPPL.ParamsWithStats( + state.z.θ, state.ldf; include_log_probs=false, include_colon_eq=false + ) + return pws.params +end + +function gibbs_update_state!!( + spl::Hamiltonian, + state::HMCState, + model::DynamicPPL.Model, + global_vals::DynamicPPL.VarNamedTuple, +) + # Construct a new LDF with the newly conditioned `model` (not `state.ldf.model`, which + # contains stale conditioned values) and recompute the vectorised params. + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params( + state.ldf, model, state._vector_vnt, global_vals + ) + # Update the Hamiltonian (because that depends on the LDF). + metric = gen_metric(LogDensityProblems.dimension(new_ldf), spl, state) + lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf) + lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, new_ldf) + new_hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) + # Apart from the Hamiltonian, we also need to update the position variables. It would be + # nice to do this without mutating, but it's probably fine for now. + state.z.θ .= new_params + return HMCState( + state.i, + state.kernel, + new_hamiltonian, + state.z, + state.adaptor, + new_ldf, + state._vector_vnt, + ) +end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index fb13804ebf..42f946962b 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -204,7 +204,7 @@ function DynamicPPL.init( else strategy.verbose && @info "varname $vn: no proposal specified, drawing from prior" # No proposal was specified for this variable, so we sample from the prior. - return DynamicPPL.UntransformedValue(rand(rng, prior)) + return DynamicPPL.TransformedValue(rand(rng, prior), DynamicPPL.NoTransform()) end end @@ -268,22 +268,8 @@ function AbstractMCMC.step( verbose=true, kwargs..., ) - # Generate and return initial parameters. We need to use VAIMAcc because that will - # generate the VNT for us that provides the values (as opposed to `vi.values` which - # stores `AbstractTransformedValues`). - # - # TODO(penelopeysm): This in fact could very well be OnlyAccsVarInfo. Indeed, if you - # only run MH, OnlyAccsVarInfo already works right now. The problem is that using MH - # inside Gibbs needs a full VarInfo. - # - # see e.g. - # @model f() = x ~ Beta(2, 2) - # sample(f(), MH(:x => LinkedRW(0.4)), 100_000; progress=false) - # with full VarInfo: - # 2.302728 seconds (18.81 M allocations: 782.125 MiB, 9.00% gc time) - # with OnlyAccsVarInfo: - # 1.196674 seconds (18.51 M allocations: 722.256 MiB, 5.11% gc time) - vi = DynamicPPL.VarInfo() + # Generate and return initial parameters. + vi = DynamicPPL.OnlyAccsVarInfo() vi = DynamicPPL.setacc!!(vi, DynamicPPL.RawValueAccumulator(false)) vi = DynamicPPL.setacc!!(vi, MHLinkedValuesAccumulator()) vi = DynamicPPL.setacc!!(vi, MHUnspecifiedPriorsAccumulator(spl.vns_with_proposal)) @@ -342,7 +328,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, spl::MH, - old_vi::DynamicPPL.AbstractVarInfo; + old_vi::DynamicPPL.OnlyAccsVarInfo; discard_sample=false, kwargs..., ) @@ -359,7 +345,7 @@ function AbstractMCMC.step( ) # Evaluate the model with a new proposal. - new_vi = DynamicPPL.VarInfo() + new_vi = DynamicPPL.OnlyAccsVarInfo() new_vi = DynamicPPL.setacc!!(new_vi, DynamicPPL.RawValueAccumulator(false)) new_vi = DynamicPPL.setacc!!(new_vi, MHLinkedValuesAccumulator()) new_vi = DynamicPPL.setacc!!( @@ -401,7 +387,7 @@ end """ log_proposal_density( - old_vi::DynamicPPL.AbstractVarInfo, + old_vi::DynamicPPL.OnlyAccsVarInfo, init_strategy_given_new::DynamicPPL.AbstractInitStrategy, old_unspecified_priors::DynamicPPL.VarNamedTuple ) @@ -421,14 +407,14 @@ from: distribution. """ function log_proposal_density( - vi::DynamicPPL.AbstractVarInfo, ::DynamicPPL.InitFromPrior, ::DynamicPPL.VarNamedTuple + vi::DynamicPPL.OnlyAccsVarInfo, ::DynamicPPL.InitFromPrior, ::DynamicPPL.VarNamedTuple ) # All samples were drawn from the prior -- in this case g(x|x') = g(x) = prior # probability of x. return DynamicPPL.getlogprior(vi) end function log_proposal_density( - vi::DynamicPPL.AbstractVarInfo, + vi::DynamicPPL.OnlyAccsVarInfo, strategy::InitFromProposals, unspecified_priors::DynamicPPL.VarNamedTuple, ) @@ -500,3 +486,28 @@ end function MH(cov_matrix::Any) return externalsampler(AdvancedMH.RWMH(MvNormal(cov_matrix)); unconstrained=true) end + +#### +#### Gibbs interface +#### + +function gibbs_update_state!!( + spl::MH, + state::AbstractVarInfo, + model::DynamicPPL.Model, + global_vals::DynamicPPL.VarNamedTuple, +) + # `state` here is a AbstractVarInfo; the MH sampler since Turing v0.40 only uses + # the accumulator part of the state. We do need to reevaluate the model though + # because it's necessary for the log-probability to be updated to reflect the new + # values in `global_vals`. If we don't do that, the MH acceptance step will return + # wrong results. + # + # In order to make sure that our logjacs are consistent with what the sampler expects, + # we can use `spl.transform_strategy` here. + # + # TODO(penelopeysm): Is the `nothing` fallback OK, or do we need InitFromPrior as + # a fallback? In the latter case, how do we control `rng`? + init_strat = DynamicPPL.InitFromParams(global_vals, nothing) + return last(DynamicPPL.init!!(model, state, init_strat, spl.transform_strategy)) +end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 190d1f0933..414f422296 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -20,6 +20,10 @@ end struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end +# Because pMCMC uses OnlyAccsVarInfo, we need to overload this. It's fine to use Any (see +# the docstring of get_param_eltype in DynamicPPL) because pMCMC doesn't involve AD or any +# other tracer types. +DynamicPPL.get_param_eltype(::DynamicPPL.AbstractVarInfo, ::ParticleMCMCContext) = Any mutable struct TracedModel{M<:Model,T<:Tuple,NT<:NamedTuple} <: AdvancedPS.AbstractGenericModel @@ -156,23 +160,23 @@ function AbstractMCMC.sample( ) end -function Turing.Inference.initialstep( +function AbstractMCMC.step( rng::AbstractRNG, model::DynamicPPL.Model, - spl::SMC, - vi::AbstractVarInfo; + spl::SMC; nparticles::Int, + initial_params, discard_sample=false, kwargs..., ) - # Reset the VarInfo. - vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - vi = DynamicPPL.empty!!(vi) + # Create an empty VarInfo + accs = DynamicPPL.OnlyAccsVarInfo() + accs = DynamicPPL.setacc!!(accs, ProduceLogLikelihoodAccumulator()) + accs = DynamicPPL.setacc!!(accs, DynamicPPL.RawValueAccumulator(true)) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( - # her - [AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for _ in 1:nparticles], + [AdvancedPS.Trace(model, accs, AdvancedPS.TracedRNG(), true) for _ in 1:nparticles], AdvancedPS.TracedRNG(), rng, ) @@ -189,7 +193,7 @@ function Turing.Inference.initialstep( transition = if discard_sample nothing else - DynamicPPL.ParamsWithStats(deepcopy(particle.model.f.varinfo), model, stats) + DynamicPPL.ParamsWithStats(particle.model.f.varinfo, stats) end state = SMCState(particles, 2, logevidence) @@ -217,7 +221,7 @@ function AbstractMCMC.step( transition = if discard_sample nothing else - DynamicPPL.ParamsWithStats(deepcopy(particle.model.f.varinfo), model, stats) + DynamicPPL.ParamsWithStats(deepcopy(particle.model.f.varinfo), stats) end nextstate = SMCState(state.particles, index + 1, state.average_logevidence) @@ -272,29 +276,24 @@ Equivalent to [`PG`](@ref). """ const CSMC = PG # type alias of PG as Conditional SMC -struct PGState - vi::AbstractVarInfo - rng::Random.AbstractRNG +struct PGState{V<:DynamicPPL.AbstractVarInfo,R<:Random.AbstractRNG} + vi::V + rng::R end -get_varinfo(state::PGState) = state.vi - -function Turing.Inference.initialstep( - rng::AbstractRNG, - model::DynamicPPL.Model, - spl::PG, - vi::AbstractVarInfo; - discard_sample=false, - kwargs..., +function AbstractMCMC.step( + rng::AbstractRNG, model::DynamicPPL.Model, spl::PG; discard_sample=false, kwargs... ) error_if_threadsafe_eval(model) - vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) + oavi = DynamicPPL.OnlyAccsVarInfo() + oavi = DynamicPPL.setacc!!(oavi, ProduceLogLikelihoodAccumulator()) + oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.RawValueAccumulator(true)) # Create a new set of particles num_particles = spl.nparticles particles = AdvancedPS.ParticleContainer( [ - AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for + AdvancedPS.Trace(model, oavi, AdvancedPS.TracedRNG(), true) for _ in 1:num_particles ], AdvancedPS.TracedRNG(), @@ -314,7 +313,7 @@ function Turing.Inference.initialstep( transition = if discard_sample nothing else - DynamicPPL.ParamsWithStats(deepcopy(_vi), model, (; logevidence=logevidence)) + DynamicPPL.ParamsWithStats(deepcopy(_vi), (; logevidence=logevidence)) end return transition, PGState(_vi, reference.rng) @@ -328,18 +327,26 @@ function AbstractMCMC.step( discard_sample=false, kwargs..., ) - # Reset the VarInfo before new sweep. - vi = state.vi - vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) + # Reset log-prob accs in reference particle, to avoid accumulating into the same accs + # across iterations. If the chosen particle for this iteration is the reference + # particle, this allows us to just read off the log-probs from the accumulators, + # without having to re-evaluate the model. + reference_vi = state.vi + reference_vi = DynamicPPL.setacc!!(reference_vi, ProduceLogLikelihoodAccumulator()) + reference_vi = DynamicPPL.setacc!!(reference_vi, DynamicPPL.LogPriorAccumulator()) + reference_vi = DynamicPPL.setacc!!(reference_vi, DynamicPPL.LogJacobianAccumulator()) # Create reference particle for which the samples will be retained. - reference = AdvancedPS.forkr(AdvancedPS.Trace(model, vi, state.rng, false)) + reference = AdvancedPS.forkr(AdvancedPS.Trace(model, reference_vi, state.rng, false)) - # Create a new set of particles. + # Create a new set of particles with newly emptied accs + empty_accs = DynamicPPL.OnlyAccsVarInfo() + empty_accs = DynamicPPL.setacc!!(empty_accs, ProduceLogLikelihoodAccumulator()) + empty_accs = DynamicPPL.setacc!!(empty_accs, DynamicPPL.RawValueAccumulator(true)) num_particles = spl.nparticles x = map(1:num_particles) do i if i != num_particles - return AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) + return AdvancedPS.Trace(model, empty_accs, AdvancedPS.TracedRNG(), true) else return reference end @@ -359,7 +366,7 @@ function AbstractMCMC.step( transition = if discard_sample nothing else - DynamicPPL.ParamsWithStats(deepcopy(_vi), model, (; logevidence=logevidence)) + DynamicPPL.ParamsWithStats(deepcopy(_vi), (; logevidence=logevidence)) end return transition, PGState(_vi, newreference.rng) @@ -428,12 +435,14 @@ function DynamicPPL.tilde_assume!!( trng = get_trace_local_rng() resample = get_trace_local_resampled() # Modify the varinfo as appropriate. - dispatch_ctx = if ~haskey(vi, vn) || resample - DynamicPPL.InitContext(trng, DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll()) + values = DynamicPPL.get_raw_values(vi) + init_strat = if ~haskey(values, vn) || resample + DynamicPPL.InitFromPrior() else - DynamicPPL.DefaultContext() + DynamicPPL.InitFromParams(values, nothing) end - x, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, template, vi) + ctx = DynamicPPL.InitContext(trng, init_strat, DynamicPPL.UnlinkAll()) + x, vi = DynamicPPL.tilde_assume!!(ctx, dist, vn, template, vi) # Set the varinfo back in the trace. set_trace_local_varinfo(vi) return x, vi @@ -527,3 +536,19 @@ Libtask.@might_produce(DynamicPPL.tilde_assume!!) # DynamicPPL.Model as an argument, so we can just check for that. See # https://github.com/TuringLang/Libtask.jl/issues/217. Libtask.might_produce_if_sig_contains(::Type{<:DynamicPPL.Model}) = true + +#### +#### Gibbs interface +#### + +function gibbs_get_raw_values(state::PGState) + return DynamicPPL.get_raw_values(state.vi) +end + +function gibbs_update_state!!( + ::PG, state::PGState, model::DynamicPPL.Model, global_vals::DynamicPPL.VarNamedTuple +) + init_strat = DynamicPPL.InitFromParams(global_vals, nothing) + new_vi = last(DynamicPPL.init!!(model, state.vi, init_strat, DynamicPPL.UnlinkAll())) + return PGState(new_vi, state.rng) +end diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index 494f455b0d..351014f24a 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -24,13 +24,13 @@ struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSa end end -function setparams_varinfo!!( - model::DynamicPPL.Model, +function gibbs_update_state!!( sampler::RepeatSampler, state, - params::DynamicPPL.AbstractVarInfo, + model::DynamicPPL.Model, + global_vnt::DynamicPPL.VarNamedTuple, ) - return setparams_varinfo!!(model, sampler.sampler, state, params) + return gibbs_update_state!!(sampler.sampler, state, model, global_vnt) end function AbstractMCMC.step( diff --git a/src/optimisation/init.jl b/src/optimisation/init.jl index eda0f41c56..e3cf11ce90 100644 --- a/src/optimisation/init.jl +++ b/src/optimisation/init.jl @@ -287,17 +287,6 @@ function DynamicPPL.combine(acc1::ConstraintAccumulator, acc2::ConstraintAccumul return combined end -function _get_ldf_range(ldf::LogDensityFunction, vn::VarName) - if haskey(ldf._varname_ranges, vn) - return ldf._varname_ranges[vn].range - elseif haskey(ldf._iden_varname_ranges, AbstractPPL.getsym(vn)) - return ldf._iden_varname_ranges[AbstractPPL.getsym(vn)].range - else - # Should not happen. - error("could not find range for variable name $(vn) in LogDensityFunction") - end -end - """ make_optim_bounds_and_init( rng::Random.AbstractRNG, @@ -340,7 +329,7 @@ function make_optim_bounds_and_init( lb = fill(et(-Inf), nelems) ub = fill(et(Inf), nelems) for (vn, init_val) in constraint_acc.init_vecs - range = _get_ldf_range(ldf, vn) + range = DynamicPPL.get_range_and_transform(ldf, vn).range inits[range] = init_val if haskey(constraint_acc.lb_vecs, vn) lb[range] = constraint_acc.lb_vecs[vn] diff --git a/src/optimisation/stats.jl b/src/optimisation/stats.jl index cacb511361..0cf1562d69 100644 --- a/src/optimisation/stats.jl +++ b/src/optimisation/stats.jl @@ -37,7 +37,7 @@ function vector_names_and_params(m::ModeResult) priors = DynamicPPL.get_priors(accs) vector_varnames = Vector{VarName}(undef, length(vector_params)) for (vn, dist) in pairs(priors) - range = ldf._varname_ranges[vn].range + range = DynamicPPL.get_range_and_transform(ldf, vn).range optics = optic_vec(dist) # Really shouldn't happen, but catch in case optic_vec isn't properly defined if any(isnothing, optics) diff --git a/src/variational/Variational.jl b/src/variational/Variational.jl index 94d4e9a904..ac00667e6b 100644 --- a/src/variational/Variational.jl +++ b/src/variational/Variational.jl @@ -110,7 +110,7 @@ end """ q_locationscale( - rng::Random.AbstractRNG, + [rng::Random.AbstractRNG,] ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing, @@ -184,10 +184,13 @@ function q_locationscale( end return AdvancedVI.MvLocationScale(μ, L, basedist) end +function q_locationscale(ldf::LogDensityFunction; kwargs...) + return q_locationscale(Random.default_rng(), ldf; kwargs...) +end """ q_meanfield_gaussian( - rng::Random.AbstractRNG, + [rng::Random.AbstractRNG,] ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:Diagonal} = nothing, @@ -214,7 +217,7 @@ The remaining keyword arguments are passed to `q_locationscale`. """ function q_meanfield_gaussian( rng::Random.AbstractRNG, - ldf::LogDensityFunction, + ldf::LogDensityFunction; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., @@ -223,10 +226,13 @@ function q_meanfield_gaussian( rng, ldf; location, scale, meanfield=true, basedist=Normal(), kwargs... ) end +function q_meanfield_gaussian(ldf::LogDensityFunction; kwargs...) + return q_meanfield_gaussian(Random.default_rng(), ldf; kwargs...) +end """ q_fullrank_gaussian( - rng::Random.AbstractRNG, + [rng::Random.AbstractRNG,] ldf::DynamicPPL.LogDensityFunction; location::Union{Nothing,<:AbstractVector} = nothing, scale::Union{Nothing,<:LowerTriangular} = nothing, @@ -253,7 +259,7 @@ The remaining keyword arguments are passed to `q_locationscale`. """ function q_fullrank_gaussian( rng::Random.AbstractRNG, - ldf::LogDensityFunction, + ldf::LogDensityFunction; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., @@ -262,6 +268,9 @@ function q_fullrank_gaussian( rng, ldf; location, scale, meanfield=false, basedist=Normal(), kwargs... ) end +function q_fullrank_gaussian(ldf::LogDensityFunction; kwargs...) + return q_fullrank_gaussian(Random.default_rng(), ldf; kwargs...) +end """ VIResult(ldf, q, info, state) diff --git a/test/ad.jl b/test/ad.jl index 8d87eea3d8..1f0f4648e3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -149,7 +149,6 @@ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInf param_eltype = DynamicPPL.get_param_eltype(vi, context) valids = valid_eltypes(context) if !(any(param_eltype .<: valids)) - @show context throw(IncompatibleADTypeError(param_eltype, adtype(context))) end end @@ -247,11 +246,11 @@ end # with a gradient-based sampler (say HMC(0.1, 10)). # This means we need to construct one with only `s`, and one model with # only `m`. - global_vi = DynamicPPL.VarInfo(model) + global_vnt = rand(model) @testset for varnames in ([@varname(s)], [@varname(m)]) @info "Testing Gibbs AD with model=$(model.f), varnames=$varnames" conditioned_model = Turing.Inference.make_conditional( - model, varnames, deepcopy(global_vi) + model, varnames, deepcopy(global_vnt) ) @test run_ad( model, adtype; rng=StableRNG(123), test=true, benchmark=false diff --git a/test/essential/container.jl b/test/essential/container.jl index 19609b6b51..3554526c97 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -18,11 +18,12 @@ using Turing end @testset "constructor" begin - vi = DynamicPPL.VarInfo() - vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) + accs = DynamicPPL.OnlyAccsVarInfo() + accs = DynamicPPL.setacc!!(accs, Turing.Inference.ProduceLogLikelihoodAccumulator()) + accs = DynamicPPL.setacc!!(accs, DynamicPPL.RawValueAccumulator(true)) sampler = PG(10) model = test() - trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) + trace = AdvancedPS.Trace(model, accs, AdvancedPS.TracedRNG(), false) # Make sure the backreference from taped_globals to the trace is in place. @test trace.model.ctask.taped_globals.other === trace @@ -43,12 +44,13 @@ using Turing 1.5 ~ Normal(b, 2) return a, b end - vi = DynamicPPL.VarInfo() - vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) + accs = DynamicPPL.OnlyAccsVarInfo() + accs = DynamicPPL.setacc!!(accs, Turing.Inference.ProduceLogLikelihoodAccumulator()) + accs = DynamicPPL.setacc!!(accs, DynamicPPL.RawValueAccumulator(true)) sampler = PG(10) model = normal() - trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) + trace = AdvancedPS.Trace(model, accs, AdvancedPS.TracedRNG(), false) newtrace = AdvancedPS.forkr(trace) # Catch broken replay mechanism diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index 93a6a4c16a..6d8ec66213 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -16,11 +16,11 @@ MODELS = DynamicPPL.TestUtils.DEMO_MODELS @testset verbose = true "AD / GibbsContext" begin @testset "adtype=$adtype_name" for (adtype_name, adtype) in ADTYPES @testset "model=$(model.f)" for model in MODELS - global_vi = DynamicPPL.VarInfo(model) + global_vnt = rand(model) @testset for varnames in ([@varname(s)], [@varname(m)]) @info "Testing Gibbs AD with adtype=$(adtype_name), model=$(model.f), varnames=$varnames" conditioned_model = Turing.Inference.make_conditional( - model, varnames, deepcopy(global_vi) + model, varnames, deepcopy(global_vnt) ) @test run_ad( model, adtype; rng=StableRNG(468), test=true, benchmark=false diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 73f83040fc..9d7cabd083 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -72,20 +72,27 @@ using Turing @testset "save/resume correctly reloads state" begin struct StaticSampler <: AbstractMCMC.AbstractSampler end - function Turing.Inference.initialstep(rng, model, ::StaticSampler, vi; kwargs...) - return DynamicPPL.ParamsWithStats(vi, model), vi + function AbstractMCMC.step( + rng::Random.AbstractRNG, model::DynamicPPL.Model, ::StaticSampler; kwargs... + ) + t = DynamicPPL.ParamsWithStats(rand(rng, model), (;)) + return t, t end function AbstractMCMC.step( - rng, model, ::StaticSampler, vi::DynamicPPL.AbstractVarInfo; kwargs... + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + ::StaticSampler, + t::DynamicPPL.ParamsWithStats; + kwargs..., ) - return DynamicPPL.ParamsWithStats(vi, model), vi + return t, t end @model demo() = x ~ Normal() @testset "single-chain" begin chn1 = sample(demo(), StaticSampler(), 10; save_state=true) - @test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo + @test chn1.info.samplerstate isa DynamicPPL.ParamsWithStats chn2 = sample(demo(), StaticSampler(), 10; initial_state=loadstate(chn1)) xval = chn1[:x][1] @test all(chn2[:x] .== xval) @@ -95,7 +102,7 @@ using Turing chn1 = sample( demo(), StaticSampler(), MCMCThreads(), 10, nchains; save_state=true ) - @test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo} + @test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.ParamsWithStats} @test length(chn1.info.samplerstate) == nchains chn2 = sample( demo(), diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 04478b58f2..68bd71fcf8 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -83,14 +83,19 @@ end struct OnlyInitDefault <: OnlyInit end struct OnlyInitUniform <: OnlyInit end Turing.Inference.init_strategy(::OnlyInitUniform) = InitFromUniform() - function Turing.Inference.initialstep( + function AbstractMCMC.step( rng::AbstractRNG, model::DynamicPPL.Model, - ::OnlyInit, - vi::DynamicPPL.VarInfo=DynamicPPL.VarInfo(rng, model); + ::OnlyInit; + initial_params::DynamicPPL.AbstractInitStrategy, kwargs..., ) - return vi, nothing + accs = DynamicPPL.OnlyAccsVarInfo() + accs = DynamicPPL.setacc!!(accs, DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!( + rng, model, accs, initial_params, DynamicPPL.UnlinkAll() + ) + return accs, nothing end @testset "init_strategy" begin @@ -108,10 +113,10 @@ end model = coinflip() lptrue = logpdf(Binomial(25, 0.2), 10) let inits = InitFromParams((; p=0.2)) - varinfos = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(varinfos) - @test varinfo[@varname(p)] == 0.2 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + oavis = sample(model, spl, 1; initial_params=inits, progress=false) + oavi = only(oavis) + @test DynamicPPL.get_raw_values(oavi)[@varname(p)] == 0.2 + @test DynamicPPL.getlogjoint(oavi) == lptrue # parallel sampling chains = sample( @@ -124,9 +129,9 @@ end progress=false, ) for c in chains - varinfo = only(c) - @test varinfo[@varname(p)] == 0.2 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + oavi = only(c) + @test DynamicPPL.get_raw_values(oavi)[@varname(p)] == 0.2 + @test DynamicPPL.getlogjoint(oavi) == lptrue end end @@ -152,10 +157,11 @@ end Dict(@varname(s) => 4, @varname(m) => -1), ) chain = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(chain) - @test varinfo[@varname(s)] == 4 - @test varinfo[@varname(m)] == -1 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + oavi = only(chain) + vnt = DynamicPPL.get_raw_values(oavi) + @test vnt[@varname(s)] == 4 + @test vnt[@varname(m)] == -1 + @test DynamicPPL.getlogjoint(oavi) == lptrue # parallel sampling chains = sample( @@ -168,10 +174,11 @@ end progress=false, ) for c in chains - varinfo = only(c) - @test varinfo[@varname(s)] == 4 - @test varinfo[@varname(m)] == -1 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + oavi = only(c) + vnt = DynamicPPL.get_raw_values(oavi) + @test vnt[@varname(s)] == 4 + @test vnt[@varname(m)] == -1 + @test DynamicPPL.getlogjoint(oavi) == lptrue end end @@ -187,9 +194,9 @@ end Dict(@varname(m) => -1), ) chain = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(chain) - @test !ismissing(varinfo[@varname(s)]) - @test varinfo[@varname(m)] == -1 + vnt = DynamicPPL.get_raw_values(only(chain)) + @test !ismissing(vnt[@varname(s)]) + @test vnt[@varname(m)] == -1 # parallel sampling chains = sample( @@ -202,9 +209,9 @@ end progress=false, ) for c in chains - varinfo = only(c) - @test !ismissing(varinfo[@varname(s)]) - @test varinfo[@varname(m)] == -1 + vnt = DynamicPPL.get_raw_values(only(c)) + @test !ismissing(vnt[@varname(s)]) + @test vnt[@varname(m)] == -1 end end end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c9e5ed1393..8fd3d34a0d 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -66,16 +66,16 @@ end ) @testset "$(target_vns)" for target_vns in target_vn_combinations - global_varinfo = DynamicPPL.VarInfo(model) + global_vnt = rand(model) target_vns = collect(target_vns) - local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + local_vnt = DynamicPPL.subset(global_vnt, target_vns) ctx = Turing.Inference.GibbsContext( - target_vns, Ref(global_varinfo), DynamicPPL.DefaultContext() + target_vns, Ref(global_vnt), DynamicPPL.DefaultContext() ) # Check that the correct varnames are conditioned, and that getting their # values is type stable when the varinfo is. - for k in keys(global_varinfo) + for k in keys(global_vnt) is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) @test Turing.Inference.is_target_varname(ctx, k) == is_target if !is_target @@ -87,7 +87,7 @@ end for k in all_varnames # The map(identity, ...) part is there to concretise the eltype. subkeys = map( - identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) + identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_vnt)) ) is_target = (k in target_vns) @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target @@ -96,14 +96,12 @@ end end end - # Check that evaluate_nowarn!! and the result it returns are type stable. + # Check that init!! is type stable. conditioned_model = DynamicPPL.contextualize(model, ctx) - _, post_eval_varinfo = @inferred DynamicPPL.evaluate_nowarn!!( - conditioned_model, local_varinfo + accs = DynamicPPL.OnlyAccsVarInfo() + _, accs = @inferred DynamicPPL.init!!( + conditioned_model, accs, DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll() ) - for k in keys(post_eval_varinfo) - @inferred post_eval_varinfo[k] - end end end end @@ -156,14 +154,17 @@ end # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. # They all just propagate the call to the inner algorithm. - Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) - function Inference.setparams_varinfo!!( - model::DynamicPPL.Model, + Turing.Inference.isgibbscomponent(wrap::AlgWrapper) = + Turing.Inference.isgibbscomponent(wrap.inner) + function Turing.Inference.gibbs_update_state!!( sampler::AlgWrapper, state, - params::DynamicPPL.AbstractVarInfo, + model::DynamicPPL.Model, + global_vnt::DynamicPPL.VarNamedTuple, ) - return Inference.setparams_varinfo!!(model, sampler.inner, state, params) + return Turing.Inference.gibbs_update_state!!( + sampler.inner, state, model, global_vnt + ) end # targets_and_algs will be a list of tuples, where the first element is the target_vns @@ -193,17 +194,6 @@ end return AbstractMCMC.step(rng, model, sampler.inner, args...; kwargs...) end - function Turing.Inference.initialstep( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler::AlgWrapper, - args...; - kwargs..., - ) - capture_targets_and_algs(sampler.inner, model.context) - return Turing.Inference.initialstep(rng, model, sampler.inner, args...; kwargs...) - end - struct Wrapper{T<:Real} a::T end @@ -296,57 +286,44 @@ end Turing.Inference.isgibbscomponent(::WarmupCounter) = true - # A trivial state that holds nothing but a VarInfo, to be used with WarmupCounter. - struct VarInfoState{T} - vi::T - end - - Turing.Inference.get_varinfo(state::VarInfoState) = state.vi - function Turing.Inference.setparams_varinfo!!( - ::DynamicPPL.Model, - ::WarmupCounter, - ::VarInfoState, - params::DynamicPPL.AbstractVarInfo, - ) - return VarInfoState(params) - end - - function AbstractMCMC.step( - ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... + # we need some state type to implement the Gibbs interface (we can't just use `nothing`) + struct TrivialState end + Turing.Inference.gibbs_get_raw_values(::TrivialState) = VarNamedTuple() + function Turing.Inference.gibbs_update_state!!( + ::WarmupCounter, s::TrivialState, ::DynamicPPL.Model, ::DynamicPPL.VarNamedTuple ) - spl.non_warmup_init_count += 1 - vi = DynamicPPL.VarInfo(model) - return (DynamicPPL.ParamsWithStats(vi, model), VarInfoState(vi)) - end - - function AbstractMCMC.step_warmup( - ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... - ) - spl.warmup_init_count += 1 - vi = DynamicPPL.VarInfo(model) - return (DynamicPPL.ParamsWithStats(vi, model), VarInfoState(vi)) + return s end function AbstractMCMC.step( - ::Random.AbstractRNG, + rng::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter, - s::VarInfoState; + state::Union{Nothing,TrivialState}=nothing; kwargs..., ) - spl.non_warmup_count += 1 - return DynamicPPL.ParamsWithStats(s.vi, model), s + if state === nothing + spl.non_warmup_init_count += 1 + else + spl.non_warmup_count += 1 + end + # no need a transition since we never check the actual outputs + return nothing, TrivialState() end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter, - s::VarInfoState; + state::Union{Nothing,TrivialState}=nothing; kwargs..., ) - spl.warmup_count += 1 - return DynamicPPL.ParamsWithStats(s.vi, model), s + if state === nothing + spl.warmup_init_count += 1 + else + spl.warmup_count += 1 + end + return nothing, TrivialState() end @model f() = x ~ Normal() @@ -839,6 +816,30 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end + + @testset "dynamic transformations with linked samplers" begin + # See https://github.com/TuringLang/Turing.jl/issues/2801. + # The issue there was that the linked value for `y` was never updated when `x` + # changed, even though it should (the transform for `y` depends on `x`), leading to + # incorrect results. + @model function dyn() + x ~ Uniform(-5, 5) + return y ~ truncated(Normal(); lower=x) + end + model = dyn() + + for spl in ( + Gibbs(:x => MH(), :y => HMC(0.1, 20)), + Gibbs(:x => MH(), :y => MH(:y => LinkedRW(1.0))), + ) + chn = sample( + StableRNG(468), model, spl, MCMCThreads(), 100000, 4; verbose=false + ) + # ground truth obtained from NUTS + @test mean(chn[:x]) ≈ 0.0 atol = 0.1 + @test mean(chn[:y]) ≈ 1.5 atol = 0.1 + end + end end end diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 8b3c68ff15..71e0b84f7d 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -24,7 +24,7 @@ using Turing @test s.resampler === resample_systematic end - @testset "models" begin + @testset "basic model" begin @model function normal() a ~ Normal(4, 5) 3 ~ Normal(a, 2) @@ -32,11 +32,10 @@ using Turing 1.5 ~ Normal(b, 2) return a, b end - tested = sample(normal(), SMC(), 100) + end - # TODO(mhauru) This needs an explanation for why it fails. - # failing test + @testset "errors when number of observations is not fixed" begin @model function fail_smc() a ~ Normal(4, 5) 3 ~ Normal(a, 2) @@ -46,8 +45,8 @@ using Turing end return a, b end - @test_throws ErrorException sample(fail_smc(), SMC(), 100) + @test_throws "number of observations" sample(fail_smc(), SMC(), 100) end @testset "chain log-density metadata" begin diff --git a/test/variational/vi.jl b/test/variational/vi.jl index fff1d0118b..f856a7e94d 100644 --- a/test/variational/vi.jl +++ b/test/variational/vi.jl @@ -4,10 +4,13 @@ module AdvancedVITests using ..Models: gdemo_default using ..NumericalTests: check_gdemo +using AbstractMCMC: AbstractMCMC using AdvancedVI using Bijectors: Bijectors using Distributions: Dirichlet, Normal +using DynamicPPL: DynamicPPL using LinearAlgebra +using LogDensityProblems: LogDensityProblems using MCMCChains: Chains using Random using ReverseDiff @@ -22,34 +25,37 @@ begin @testset "q initialization" begin m = gdemo_default - d = length(Turing.DynamicPPL.VarInfo(m)[:]) - for q in [q_meanfield_gaussian(m), q_fullrank_gaussian(m)] + l = LogDensityFunction(m, DynamicPPL.getlogjoint_internal, DynamicPPL.LinkAll()) + d = LogDensityProblems.dimension(l) + + for q in [q_meanfield_gaussian(l), q_fullrank_gaussian(l)] rand(q) end μ = ones(d) - q = q_meanfield_gaussian(m; location=μ) - @assert mean(q.dist) ≈ μ + q = q_meanfield_gaussian(l; location=μ) + @test mean(q) ≈ μ - q = q_fullrank_gaussian(m; location=μ) - @assert mean(q.dist) ≈ μ + q = q_fullrank_gaussian(l; location=μ) + @test mean(q) ≈ μ L = Diagonal(fill(0.1, d)) - q = q_meanfield_gaussian(m; scale=L) - @assert cov(q.dist) ≈ L * L + q = q_meanfield_gaussian(l; scale=L) + @test cov(q) ≈ L * L L = LowerTriangular(tril(0.01 * ones(d, d) + I)) - q = q_fullrank_gaussian(m; scale=L) - @assert cov(q.dist) ≈ L * L' + q = q_fullrank_gaussian(l; scale=L) + @test cov(q) ≈ L * L' end @testset "default interface" begin - for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)] - q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[], adtype) - c1 = rand(q, 10) - end - @test_throws "unconstrained" begin - q, _, _ = vi(gdemo_default, Normal(), 1; adtype) + for q0 in [q_meanfield_gaussian, q_fullrank_gaussian] + result = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[], adtype) + @test rand(result) isa DynamicPPL.VarNamedTuple + @test rand(result, 2) isa Vector{<:DynamicPPL.VarNamedTuple} + @test size(rand(result, 2)) == (2,) + @test rand(result, 5, 2) isa Matrix{<:DynamicPPL.VarNamedTuple} + @test size(rand(result, 5, 2)) == (5, 2) end end @@ -65,15 +71,15 @@ begin ("FisherMinBatchMatch", FisherMinBatchMatch()), ] T = 1000 - q, _, _ = vi( + result = vi( gdemo_default, - q_fullrank_gaussian(gdemo_default), + q_fullrank_gaussian, T; algorithm, show_progress=Turing.PROGRESS[], ) - N = 1000 - c2 = rand(q, N) + c2 = rand(result, 10) + @test c2 isa Vector{<:DynamicPPL.VarNamedTuple} end @testset "inference $name" for (name, algorithm) in [ @@ -90,27 +96,27 @@ begin ("KLMinWassFwdBwd", KLMinWassFwdBwd(; stepsize=1e-2, n_samples=10)), ("FisherMinBatchMatch", FisherMinBatchMatch()), ] - rng = StableRNG(0x517e1d9bf89bf94f) + rng = StableRNG(468) T = 1000 - q, _, _ = vi( + result = vi( rng, gdemo_default, - q_fullrank_gaussian(gdemo_default), + q_fullrank_gaussian, T; algorithm, show_progress=Turing.PROGRESS[], ) N = 1000 - samples = transpose(rand(rng, q, N)) - chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) + samples = rand(rng, result, N) + chn = AbstractMCMC.from_samples(MCMCChains.Chains, hcat(samples)) check_gdemo(chn; atol=0.5) end - # regression test for: - # https://github.com/TuringLang/Turing.jl/issues/2065 + # regression test for https://github.com/TuringLang/Turing.jl/issues/2065 + # and https://github.com/TuringLang/Turing.jl/issues/2160 @testset "simplex bijector" begin rng = StableRNG(0x517e1d9bf89bf94f) @@ -118,20 +124,10 @@ begin x ~ Dirichlet([1.0, 1.0]) return x end - m = dirichlet() - b = Bijectors.bijector(m) - x0 = m() - z0 = b(x0) - @test size(z0) == (1,) - x0_inv = Bijectors.inverse(b)(z0) - @test size(x0_inv) == size(x0) - @test all(x0 .≈ x0_inv) - - # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. - q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000) - x = rand(rng, q, 1000) - @test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1 + result = vi(rng, m, q_meanfield_gaussian, 1000) + samples = rand(rng, result, 1000) + @test mean(s -> s[@varname(x)], samples) ≈ [0.5, 0.5] atol = 0.1 end # Ref: https://github.com/TuringLang/Turing.jl/issues/2205 @@ -144,16 +140,15 @@ begin end model = demo_issue2205() | (y=1.0,) - q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000) + result = vi(rng, model, q_meanfield_gaussian, 1000) # True mean. mean_true = 1 / 2 var_true = 1 / 2 # Check the mean and variance of the posterior. - samples = rand(rng, q, 1000) - mean_est = mean(samples) - var_est = var(samples) - @test mean_est ≈ mean_true atol = 0.2 - @test var_est ≈ var_true atol = 0.2 + samples = rand(rng, result, 1000) + xs = [s[@varname(x)] for s in samples] + @test mean(xs) ≈ mean_true atol = 0.2 + @test var(xs) ≈ var_true atol = 0.2 end end From a095ccc0db1c5b0a410320d0802f090f630f2602 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 15:40:15 +0100 Subject: [PATCH 05/13] Add changelog and docstrings for Gibbs --- HISTORY.md | 38 ++++++++++++++++++++++++++++++++- src/mcmc/gibbs.jl | 54 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 7b25ca8662..42bb9bd993 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,42 @@ # 0.44.0 -[...] +## Breaking changes + +**Variational inference interface** + +**Gibbs sampler interface** + +This section is only relevant if you are writing a sampler that is intended to be *directly* used as a component sampler in Turing's Gibbs sampler. +(If Gibbs calls your sampler via Turing's `externalsampler` interface, this section does not apply toyou.) + +Turing's Gibbs sampler has been reworked in this release to fix a number of correctness and performance issues. +The main change is that the Gibbs state carries a `VarNamedTuple` of raw values, instead of a `VarInfo` of vectorised (transformed) parameters. +This fixes correctness issues that arise with value-dependent transformations, and also leads to much reduced overhead when sampling with Gibbs: see https://github.com/TuringLang/Turing.jl/pull/2803 for some representative benchmarks. + +In Turing v0.43, you would have to define two methods + + - `Turing.Inference.get_varinfo(::MyState)` -> returns a VarInfo of values from your sampler's state + - `Turing.Inference.setparams_varinfo!!(::DynamicPPL.Model, ::MySampler, ::MyState, params::AbstractVarInfo)` -> uses a VarInfo of values to update your sampler's state + +The corresponding methods in Turing v0.44 are + + - `Turing.Inference.gibbs_get_raw_values(::MyState)` -> returns a VarNamedTuple of *raw* values from your sampler's state. Note that these values should not be transformed or wrapped in any way: if `x` is `3` in the model, then we should have that `gibbs_get_raw_values(state)[@varname(x)] == 3`. + + - `Turing.Inference.gibbs_update_state!!(::MySampler, ::MyState, ::DynamicPPL.Model, global_vals::VarNamedTuple)` -> uses the values in `global_vals` to update your sampler's state and return a new state. + Note that the values in `global_vals` are raw values. + Also, note that the model argument passed in here will be 'conditioned' on the *new* values inside `global_vals`. + That means that if any part of your state relies on caching the model being evaluated, you have to update those parts of your state to reflect the new model as well. + +As can be seen, the interface is almost entirely the same except that we use `VarNamedTuple`s of raw values instead of `VarInfo`s of potentially transformed parameters. + +Please see the docstrings in the Turing.jl API page for more information. + +## Other changes + +Previously many of Turing's samplers needed to carry around a VarInfo so that they could communicate with Gibbs. +This release frees them up from having to do so, and in particular allows for usage of `DynamicPPL.OnlyAccsVarInfo`, which is much cheaper as it avoids unnecessary computations. + +As a result, samplers such as MH and ESS are faster in this release. # 0.43.7 diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 5f9992eeb2..19b5d6d0ad 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -22,6 +22,25 @@ isgibbscomponent(::SMC) = false Turing.Inference.gibbs_get_raw_values(state) Return a `VarNamedTuple` containing the raw values of all variables in the sampler state. + +Turing's Gibbs sampler maintains, at all points during the sampling process, a single global +`VarNamedTuple` that contains the **raw** values for all variables in the model. During the +sampling process, it calls each component sampler in turn and updates the global +`VarNamedTuple` with the new raw values returned by each sampler. + +This function is used to pass that information *from* a component sampler *to* the Gibbs +sampler. Note that this means that the `VarNamedTuple` returned by this function should +**only** contain raw values for the variables that the component sampler is responsible for +sampling, and should not contain any values for other variables. +""" +function gibbs_get_raw_values end + +""" + Turing.Inference.gibbs_get_raw_values(state::AbstractVarInfo) + +If your sampler state is an `AbstractVarInfo`, there is a default method available for this, +which reads the values stored in its `RawValueAccumulator`. (This means that the `VarInfo` +used for evaluation in the component sampler *must* contain a `RawValueAccumulator`.) """ function gibbs_get_raw_values(state::AbstractVarInfo) return DynamicPPL.get_raw_values(state) @@ -33,9 +52,25 @@ end ) Update the state of a Gibbs component sampler to be consistent with the new values in -`global_vals`. The exact meaning of this depends on what the sampler state contains. +`global_vals`. Each sampler should implement a method for its respective state type. + +Note that the `model` argument passed in here will be 'conditioned' on the *new* values +inside `global_vals`. Thus, evaluating it will reflect the log-probability associated with +the new values. + +Exactly what this function should do will depends on what the sampler state contains, but +for example, it will often mean: + +- Updating any raw or vectorised values stored in the sampler state to be consistent with + `global_vals`. +- Reevaluating the (new) model to update any cached log-probabilities. +- Updating any log-density callables (such as a `DynamicPPL.LogDensityFunction`) stored in + the sampler state, to be consistent with the new model. -Each sampler should implement a method for its respective state type. +For examples of this, please see the implementations of this function for the samplers in +Turing.jl. In particular, the `HMC` and `ExternalSampler` implementations work with +`LogDensityFunction` and demonstrate how information such as that can be updated based on +the new model. """ function gibbs_update_state!! end @@ -55,7 +90,7 @@ If extra information is needed (e.g. log-probabilities), `extra_accs` can be use other accumulators to be used in the same model evaluation, to avoid having to recompute them later. -Returns `(new_ldf, new_params, accs)` where `accs` is the accumulator VarInfo after +Returns `(new_ldf, new_params, accs)` where `accs` is the set of accumulators after evaluation, from which extra accumulators (e.g. `LogLikelihoodAccumulator`) can be read. """ function gibbs_recompute_ldf_and_params( @@ -159,8 +194,8 @@ end get_global_vnt(context::GibbsContext) = context.global_vnt[] -function set_global_vnt!(context::GibbsContext, new_global_varinfo) - context.global_vnt[] = new_global_varinfo +function set_global_vnt!(context::GibbsContext, new_global_vnt) + context.global_vnt[] = new_global_vnt return nothing end @@ -267,10 +302,11 @@ function DynamicPPL.tilde_assume!!( else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one + # this new variable to the global `vnt` of the context, but not to the local one # being used by the current sampler. # - # TODO(penelopeysm): How is the RNG controlled here? + # TODO(penelopeysm): This branch is very hard to hit, but it will crash if it is + # hit: see https://github.com/TuringLang/Turing.jl/issues/2810 value = rand(right) vnt = get_global_vnt(context) vnt = DynamicPPL.templated_setindex!!(vnt, value, vn, template) @@ -301,7 +337,7 @@ Return a new, conditioned model for a component of a Gibbs sampler. - A new model with the variables _not_ in `target_variables` conditioned. - The `GibbsContext` object that will be used to condition the variables. This is necessary -because evaluation can mutate its `global_varinfo` field, which we need to access later. +because evaluation can mutate its `global_vnt` field, which we need to access later. """ function make_conditional( model::DynamicPPL.Model, @@ -477,7 +513,7 @@ end """ Take the first step of MCMC for the first component sampler, and call the same function -recursively on the remaining samplers, until no samplers remain. Return the global VarInfo +recursively on the remaining samplers, until no samplers remain. Return the global VNT and a tuple of initial states for all component samplers. The `step_function` argument should always be either AbstractMCMC.step or From f52cbcfd70830debd0980a1cb14dc09475504de1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:00:05 +0100 Subject: [PATCH 06/13] Add docs and changelog stuff --- HISTORY.md | 40 +++++++++++++++++++++++++++++++++- docs/Project.toml | 2 -- src/mcmc/particle_mcmc.jl | 2 +- src/variational/Variational.jl | 6 ++--- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 42bb9bd993..518676389e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,28 @@ **Variational inference interface** +The VI interface in Turing has been modified to make it more interoperable with the rest of Turing. + + - The arguments to `vi(...)` are slightly different: instead of specifying a `q_init` argument (the initial variational approximation), you now directly pass a function that constructs this for you. For example, instead of + + ```julia + q_init = q_meanfield_gaussian(model) + vi(model, q_init, n_iters) + ``` + + you would now do + + ```julia + vi(model, q_meanfield_gaussian, n_iters) + ``` + + - The return value of `vi` is now a `VIResult` struct (please see the documentation for information), which bundles the previous return values together in a more cohesive way. + Most importantly, you can now call `rand([rng,] result::VIResult)` to obtain new samples from the variational approximation. + This returns a `VarNamedTuple` of raw values, which can be used directly in all other Turing interfaces without any further wrangling. + (In contrast, the previous return value of `rand(q)` would yield a vector of transformed parameters.) + +Internally, the VI interface has been reworked to directly use `DynamicPPL.LogDensityFunction` instead of relying on a transformed distribution from Bijectors.jl. + **Gibbs sampler interface** This section is only relevant if you are writing a sampler that is intended to be *directly* used as a component sampler in Turing's Gibbs sampler. @@ -16,6 +38,7 @@ This fixes correctness issues that arise with value-dependent transformations, a In Turing v0.43, you would have to define two methods - `Turing.Inference.get_varinfo(::MyState)` -> returns a VarInfo of values from your sampler's state + - `Turing.Inference.setparams_varinfo!!(::DynamicPPL.Model, ::MySampler, ::MyState, params::AbstractVarInfo)` -> uses a VarInfo of values to update your sampler's state The corresponding methods in Turing v0.44 are @@ -33,10 +56,25 @@ Please see the docstrings in the Turing.jl API page for more information. ## Other changes +**Performance** + Previously many of Turing's samplers needed to carry around a VarInfo so that they could communicate with Gibbs. This release frees them up from having to do so, and in particular allows for usage of `DynamicPPL.OnlyAccsVarInfo`, which is much cheaper as it avoids unnecessary computations. -As a result, samplers such as MH and ESS are faster in this release. +As a result, samplers such as MH and ESS are faster in this release, sometimes by up to 5x. + +**Fixed transforms** + +The various inference methods in Turing (MCMC sampling, optimisation, and VI) all accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference. +(In contrast, the default behaviour is to rederive transforms each time the model is run.) + +The reason why Turing rederives transforms is to ensure correctness in cases where the transform *depends on the value of another random variable*. +For example, if `a` is a parameter, then `b ~ Uniform(-a, a)` has a transform that depends on the value of `a`. +Since `a` is not static, this in turn means that the transform associated with `b` is not static. + +If you know that this consideration is not relevant for you, it can sometimes lead to performance benefits if the transforms are expensive to compute. +However, in many cases the benefits are negligible and you should always benchmark this on a case-by-case basis. +Please see https://turinglang.org/DynamicPPL.jl/stable/fixed_transforms/ for further details about this. # 0.43.7 diff --git a/docs/Project.toml b/docs/Project.toml index 8cf000eca9..14157ee724 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,8 +3,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 414f422296..01333ab9f8 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -270,7 +270,7 @@ function PG(nparticles::Int, threshold::Real) end """ -CSMC(...) + CSMC(...) Equivalent to [`PG`](@ref). """ diff --git a/src/variational/Variational.jl b/src/variational/Variational.jl index ac00667e6b..6cadd8a6f4 100644 --- a/src/variational/Variational.jl +++ b/src/variational/Variational.jl @@ -313,8 +313,8 @@ end Base.rand(rng::Random.AbstractRNG, res::VIResult, sz...) Draw a sample, or array of samples, from the variational distribution `q` in `res`. Each -sample is a [`DynamicPPL.VarNamedTuple`](@ref) containing parameter values (in original, -untransformed space). +sample is a [`DynamicPPL.VarNamedTuple`](@extref DynamicPPL.VarNamedTuples.VarNamedTuple) +containing raw parameter values. """ function Base.rand(rng::Random.AbstractRNG, res::VIResult, sz::Integer...) # TODO(penelopeysm): Should we expose a way to get colon_eq results as well -- maybe a @@ -368,7 +368,7 @@ For other variational families, refer to the documentation of `AdvancedVI` to de - `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiating the variational objective. - `algorithm`: Variational inference algorithm. The default is `KLMinRepGradProxDescent`, please refer to [AdvancedVI docs](https://turinglang.org/AdvancedVI.jl/stable/) for all the options. - `show_progress`: Whether to show the progress bar. -- `unconstrained`: Whether to transform the posterior to be unconstrained for running the variational inference algorithm. If `true`, then the output `q` will be wrapped into a `Bijectors.TransformedDistribution` with the transformation matching the support of the posterior. The default value depends on the chosen `algorithm`. +- `unconstrained`: Whether to transform the posterior to be unconstrained for running the variational inference algorithm. The default value depends on the chosen `algorithm` (most algorithms require unconstrained space). - `fix_transforms`: Whether to precompute the transforms needed to convert model parameters to (possibly unconstrained) vectors. This can lead to performance improvements, but if any transforms depend on model parameters, setting `fix_transforms=true` can silently yield incorrect results. - Any additional keyword arguments are passed on both to the function `initial_approx`, and also to `AdvancedVI.optimize`. From dca01573d62f387052aa8e5b942741ced106aa9a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:09:54 +0100 Subject: [PATCH 07/13] Pass fix_transform flag to LDF --- HISTORY.md | 5 ++++- ext/TuringDynamicHMCExt.jl | 7 ++++++- src/mcmc/ess.jl | 11 +++++++---- src/mcmc/external_sampler.jl | 7 ++++++- src/mcmc/hmc.jl | 7 ++++++- src/mcmc/sghmc.jl | 14 ++++++++++++-- src/optimisation/Optimisation.jl | 5 ++++- 7 files changed, 45 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 518676389e..da53bc6fc8 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -65,9 +65,12 @@ As a result, samplers such as MH and ESS are faster in this release, sometimes b **Fixed transforms** -The various inference methods in Turing (MCMC sampling, optimisation, and VI) all accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference. +The various inference methods in Turing (MCMC sampling, optimisation, and VI) now accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference. (In contrast, the default behaviour is to rederive transforms each time the model is run.) +Note that not all MCMC samplers currently support fixed transforms. +In particular, HMC, NUTS, ESS and external samplers currently do, but all other samplers do not (including MH and Gibbs). + The reason why Turing rederives transforms is to ensure correctness in cases where the transform *depends on the value of another random variable*. For example, if `a` is a parameter, then `b ~ Uniform(-a, a)` has a transform that depends on the value of `a`. Since `a` is not static, this in turn means that the transform associated with `b` is not static. diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index b4cfadaf61..7b69ba6ab5 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -48,12 +48,17 @@ function AbstractMCMC.step( model::DynamicPPL.Model, spl::DynamicNUTS; initial_params, + fix_transforms::Bool=false, kwargs..., ) # Construct LogDensityFunction tfm_strategy = DynamicPPL.LinkAll() ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype + model, + DynamicPPL.getlogjoint_internal, + tfm_strategy; + adtype=spl.adtype, + fix_transforms=fix_transforms, ) x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index eae57427d6..607b4f96b1 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -44,6 +44,7 @@ function AbstractMCMC.step( model::DynamicPPL.Model, ::ESS; discard_sample=false, + fix_transforms::Bool=false, initial_params, kwargs..., ) @@ -62,10 +63,12 @@ function AbstractMCMC.step( end # Set up a LogDensityFunction which evaluates the model's log-likelihood. - # TODO(penelopeysm): We could conceivably use fixed transforms here because every prior - # distribution is Gaussian, which by definition must have a fixed bijector. Need to - # benchmark to see if it's worth it. - loglike_ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, oavi) + # TODO(penelopeysm): We could conceivably always use fixed transforms here because every + # prior distribution is Gaussian, which by definition must have a fixed bijector. Need + # to benchmark to see if it's worth it. + loglike_ldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getloglikelihood, oavi; fix_transforms=fix_transforms + ) # And the corresponding vector of params, and the likelihood at this point vecvals = DynamicPPL.get_vector_values(oavi) vector_params = DynamicPPL.internal_values_as_vector(vecvals) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 7509e26ddd..e6ba6a2bca 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -139,6 +139,7 @@ function AbstractMCMC.step( initial_state=nothing, initial_params, # passed through from sample discard_sample=false, + fix_transforms::Bool=false, kwargs..., ) where {unconstrained} sampler = sampler_wrapper.sampler @@ -149,7 +150,11 @@ function AbstractMCMC.step( _, oavi = DynamicPPL.init!!(model, oavi, DynamicPPL.InitFromPrior(), tfm_strategy) vecvals = DynamicPPL.get_vector_values(oavi) f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vecvals; adtype=sampler_wrapper.adtype + model, + DynamicPPL.getlogjoint_internal, + vecvals; + adtype=sampler_wrapper.adtype, + fix_transforms=fix_transforms, ) x = find_initial_params_ldf(rng, f, initial_params) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index e086400f03..479a82ae9a 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -165,6 +165,7 @@ function AbstractMCMC.step( nadapts=0, discard_sample=false, verbose::Bool=true, + fix_transforms::Bool=false, kwargs..., ) # Create a LogDensityFunction @@ -174,7 +175,11 @@ function AbstractMCMC.step( ) vecvals = DynamicPPL.get_vector_values(oavi) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vecvals; adtype=spl.adtype + model, + DynamicPPL.getlogjoint_internal, + vecvals; + adtype=spl.adtype, + fix_transforms=fix_transforms, ) # And a Hamiltonian metricT = getmetricT(spl) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 35ed102daf..8882f9d9ef 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -57,11 +57,16 @@ function AbstractMCMC.step( spl::SGHMC; initial_params::DynamicPPL.AbstractInitStrategy, discard_sample=false, + fix_transforms::Bool=false, kwargs..., ) tfm_strategy = DynamicPPL.LinkAll() ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype + model, + DynamicPPL.getlogjoint_internal, + tfm_strategy; + adtype=spl.adtype, + fix_transforms=fix_transforms, ) x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params) @@ -191,11 +196,16 @@ function AbstractMCMC.step( spl::SGLD; initial_params::DynamicPPL.AbstractInitStrategy, discard_sample=false, + fix_transforms::Bool=false, kwargs..., ) tfm_strategy = DynamicPPL.LinkAll() ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype + model, + DynamicPPL.getlogjoint_internal, + tfm_strategy; + adtype=spl.adtype, + fix_transforms=fix_transforms, ) x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 8994888d41..86a1cae3d3 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -329,6 +329,7 @@ function estimate_mode( adtype=ADTypes.AutoForwardDiff(), check_model::Bool=true, check_constraints_at_runtime::Bool=true, + fix_transforms::Bool=false, solve_kwargs..., ) check_model && Turing._check_model(model) @@ -347,7 +348,9 @@ function estimate_mode( end # Note that we don't need adtype to construct the LDF, because it's specified inside the # OptimizationProblem. - ldf = LogDensityFunction(model, getlogdensity, tfm_strategy, accs) + ldf = LogDensityFunction( + model, getlogdensity, tfm_strategy, accs; fix_transforms=fix_transforms + ) # Generate bounds and initial parameters in the unlinked or linked space as requested. lb_vec, ub_vec, inits_vec = make_optim_bounds_and_init( From 021c9985f75aaf1b8912c3eb83b45c3be89a92ff Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:36:21 +0100 Subject: [PATCH 08/13] Test that fix_transforms is respected --- HISTORY.md | 6 +++-- src/mcmc/ess.jl | 8 +----- test/mcmc/Inference.jl | 45 ++++++++++++++++++++++++++++++- test/optimisation/Optimisation.jl | 39 +++++++++++++++++++++++++++ test/variational/vi.jl | 34 +++++++++++++++++++++++ 5 files changed, 122 insertions(+), 10 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index da53bc6fc8..0af535557c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -65,11 +65,13 @@ As a result, samplers such as MH and ESS are faster in this release, sometimes b **Fixed transforms** -The various inference methods in Turing (MCMC sampling, optimisation, and VI) now accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference. +The MCMC sampling (`sample`), optimisation (`mode_estimate` / `maximum_likelihood` / `maximum_a_posteriori`, and VI (`vi`) entry points now accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference. (In contrast, the default behaviour is to rederive transforms each time the model is run.) Note that not all MCMC samplers currently support fixed transforms. -In particular, HMC, NUTS, ESS and external samplers currently do, but all other samplers do not (including MH and Gibbs). +In particular, HMC, NUTS, and external samplers currently do, but all other samplers do not (including MH's `LinkedRW`, and Gibbs). +For some samplers such as ESS and particle MCMC, fixed transforms do not affect the sampling +process at all (in such cases the keyword argument is accepted but ignored). The reason why Turing rederives transforms is to ensure correctness in cases where the transform *depends on the value of another random variable*. For example, if `a` is a parameter, then `b ~ Uniform(-a, a)` has a transform that depends on the value of `a`. diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 607b4f96b1..665bee441f 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -44,7 +44,6 @@ function AbstractMCMC.step( model::DynamicPPL.Model, ::ESS; discard_sample=false, - fix_transforms::Bool=false, initial_params, kwargs..., ) @@ -63,12 +62,7 @@ function AbstractMCMC.step( end # Set up a LogDensityFunction which evaluates the model's log-likelihood. - # TODO(penelopeysm): We could conceivably always use fixed transforms here because every - # prior distribution is Gaussian, which by definition must have a fixed bijector. Need - # to benchmark to see if it's worth it. - loglike_ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getloglikelihood, oavi; fix_transforms=fix_transforms - ) + loglike_ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, oavi) # And the corresponding vector of params, and the likelihood at this point vecvals = DynamicPPL.get_vector_values(oavi) vector_params = DynamicPPL.internal_values_as_vector(vecvals) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9d7cabd083..464e055914 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -2,9 +2,11 @@ module InferenceTests using ..Models: gdemo_d, gdemo_default using ..NumericalTests: check_gdemo, check_numerical -using Distributions: Bernoulli, Beta, InverseGamma, Normal +using Bijectors: Bijectors +using Distributions: Bernoulli, Beta, InverseGamma, Normal, ContinuousUnivariateDistribution using Distributions: sample using AbstractMCMC: AbstractMCMC +import AdvancedMH import DynamicPPL using DynamicPPL: filldist import ForwardDiff @@ -13,6 +15,7 @@ import MCMCChains import Random using Random: Xoshiro import ReverseDiff +import Statistics using StableRNGs: StableRNG using StatsFuns: logsumexp using Test: @test, @test_throws, @testset @@ -631,6 +634,46 @@ using Turing # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802 @test sample(e(), Prior(), 100) isa MCMCChains.Chains end + + @testset "fix_transforms" begin + # Create a new distribution that increments a counter each time we derive its + # transformation. + struct MyNormal <: ContinuousUnivariateDistribution end + Distributions.logpdf(::MyNormal, x) = logpdf(Normal(), x) + Distributions.rand(rng::Random.AbstractRNG, ::MyNormal) = rand(rng, Normal()) + counter = Ref(0) + struct VectAndIncrement end + (::VectAndIncrement)(x) = [x] + Bijectors.with_logabsdet_jacobian(::VectAndIncrement, x) = [x], 0.0 + Bijectors.inverse(::VectAndIncrement) = OnlyAndIncrement() + struct OnlyAndIncrement end + (::OnlyAndIncrement)(x) = x[] + Bijectors.with_logabsdet_jacobian(::OnlyAndIncrement, x) = x[], 0.0 + Bijectors.inverse(::OnlyAndIncrement) = VectAndIncrement() + function Bijectors.VectorBijectors.to_linked_vec(::MyNormal) + counter[] += 1 + return VectAndIncrement() + end + function Bijectors.VectorBijectors.from_linked_vec(::MyNormal) + counter[] += 1 + return OnlyAndIncrement() + end + + @model f() = x ~ MyNormal() + model = f() + + @testset "$spl" for spl in ( + NUTS(), HMC(0.1, 5), externalsampler(AdvancedMH.RWMH(MvNormal([0.0], [1.0;;]))) + ) + counter[] = 0 + sample(model, spl, 100) + @test counter[] > 100 + + counter[] = 0 + sample(model, spl, 100; fix_transforms=true) + @test counter[] < 100 + end + end end end diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index f878fe39a7..15e7b748cb 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -667,6 +667,45 @@ end ) isa ModeResult end + @testset "fix_transforms" begin + struct MyNormal <: ContinuousUnivariateDistribution end + Distributions.logpdf(::MyNormal, x) = logpdf(Normal(), x) + Distributions.rand(rng::Random.AbstractRNG, ::MyNormal) = rand(rng, Normal()) + counter = Ref(0) + struct VectAndIncrement end + (::VectAndIncrement)(x) = [x] + Bijectors.with_logabsdet_jacobian(::VectAndIncrement, x) = [x], 0.0 + Bijectors.inverse(::VectAndIncrement) = OnlyAndIncrement() + struct OnlyAndIncrement end + (::OnlyAndIncrement)(x) = x[] + Bijectors.with_logabsdet_jacobian(::OnlyAndIncrement, x) = x[], 0.0 + Bijectors.inverse(::OnlyAndIncrement) = VectAndIncrement() + function Bijectors.VectorBijectors.to_linked_vec(::MyNormal) + counter[] += 1 + return VectAndIncrement() + end + function Bijectors.VectorBijectors.from_linked_vec(::MyNormal) + counter[] += 1 + return OnlyAndIncrement() + end + + @model f() = x ~ MyNormal() + model = f() + + # It's very hard to determine how many times the transforms will be called during + # optimisation, because it mostly depends on how many iterations the solver has to + # do. We make sure to start it at a terrible location so that we ensure that it + # _does_ get called at least a few times. + inits = InitFromParams(VarNamedTuple(; x=-100.0)) + counter[] = 0 + maximum_a_posteriori(model; initial_params=inits, fix_transforms=false) + counter_without_fixed_tfms = counter[] + + counter[] = 0 + maximum_a_posteriori(model; initial_params=inits, fix_transforms=true) + @test counter[] < counter_without_fixed_tfms + end + @testset "using ModeResult to initialise MCMC" begin @model function f(y) μ ~ Normal(0, 1) diff --git a/test/variational/vi.jl b/test/variational/vi.jl index f856a7e94d..029a1a1915 100644 --- a/test/variational/vi.jl +++ b/test/variational/vi.jl @@ -150,6 +150,40 @@ begin @test mean(xs) ≈ mean_true atol = 0.2 @test var(xs) ≈ var_true atol = 0.2 end + + @testset "fix_transforms" begin + struct MyNormal <: ContinuousUnivariateDistribution end + Distributions.logpdf(::MyNormal, x) = logpdf(Normal(), x) + Distributions.rand(rng::Random.AbstractRNG, ::MyNormal) = rand(rng, Normal()) + counter = Ref(0) + struct VectAndIncrement end + (::VectAndIncrement)(x) = [x] + Bijectors.with_logabsdet_jacobian(::VectAndIncrement, x) = [x], 0.0 + Bijectors.inverse(::VectAndIncrement) = OnlyAndIncrement() + struct OnlyAndIncrement end + (::OnlyAndIncrement)(x) = x[] + Bijectors.with_logabsdet_jacobian(::OnlyAndIncrement, x) = x[], 0.0 + Bijectors.inverse(::OnlyAndIncrement) = VectAndIncrement() + function Bijectors.VectorBijectors.to_linked_vec(::MyNormal) + counter[] += 1 + return VectAndIncrement() + end + function Bijectors.VectorBijectors.from_linked_vec(::MyNormal) + counter[] += 1 + return OnlyAndIncrement() + end + + @model f() = x ~ MyNormal() + model = f() + + counter[] = 0 + vi(model, q_meanfield_gaussian, 100) + @test counter[] > 100 + + counter[] = 0 + vi(model, q_meanfield_gaussian, 100; fix_transforms=true) + @test counter[] < 100 + end end end From f6d073dda2ec7cae74339dae364b78317feaeaef Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:38:58 +0100 Subject: [PATCH 09/13] Add one more VI test --- test/variational/vi.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/variational/vi.jl b/test/variational/vi.jl index 029a1a1915..d0407eb305 100644 --- a/test/variational/vi.jl +++ b/test/variational/vi.jl @@ -51,6 +51,7 @@ begin @testset "default interface" begin for q0 in [q_meanfield_gaussian, q_fullrank_gaussian] result = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[], adtype) + @test result isa Turing.Variational.VIResult @test rand(result) isa DynamicPPL.VarNamedTuple @test rand(result, 2) isa Vector{<:DynamicPPL.VarNamedTuple} @test size(rand(result, 2)) == (2,) From 4d650005ac79072adeac12b93964756f1c2bac0f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:44:46 +0100 Subject: [PATCH 10/13] Fix some grammar, copy to be safe --- HISTORY.md | 2 +- src/mcmc/gibbs.jl | 4 +++- src/mcmc/hmc.jl | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0af535557c..76242a6ee2 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -29,7 +29,7 @@ Internally, the VI interface has been reworked to directly use `DynamicPPL.LogDe **Gibbs sampler interface** This section is only relevant if you are writing a sampler that is intended to be *directly* used as a component sampler in Turing's Gibbs sampler. -(If Gibbs calls your sampler via Turing's `externalsampler` interface, this section does not apply toyou.) +(If Gibbs calls your sampler via Turing's `externalsampler` interface, this section does not apply to you.) Turing's Gibbs sampler has been reworked in this release to fix a number of correctness and performance issues. The main change is that the Gibbs state carries a `VarNamedTuple` of raw values, instead of a `VarInfo` of vectorised (transformed) parameters. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 19b5d6d0ad..73ac91c438 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -58,7 +58,7 @@ Note that the `model` argument passed in here will be 'conditioned' on the *new* inside `global_vals`. Thus, evaluating it will reflect the log-probability associated with the new values. -Exactly what this function should do will depends on what the sampler state contains, but +Exactly what this function should do will depend on what the sampler state contains, but for example, it will often mean: - Updating any raw or vectorised values stored in the sampler state to be consistent with @@ -100,6 +100,8 @@ function gibbs_recompute_ldf_and_params( global_vals::DynamicPPL.VarNamedTuple, extra_accs::NTuple{N,<:DynamicPPL.AbstractAccumulator}=(), ) where {N} + # TODO(penelopeysm): If old_ldf has fixed transforms, this will overwrite it. This + # probably needs to be fixed by improving the constructor in DynamicPPL. new_ldf = DynamicPPL.LogDensityFunction( model, DynamicPPL.get_logdensity_callable(old_ldf), diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 479a82ae9a..be8b5fe81c 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -506,14 +506,14 @@ function gibbs_update_state!!( lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, new_ldf) new_hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) - # Apart from the Hamiltonian, we also need to update the position variables. It would be - # nice to do this without mutating, but it's probably fine for now. - state.z.θ .= new_params + # We also need to update the position variables in the PhasePoint. + new_z = deepcopy(state.z) + new_z.θ .= new_params return HMCState( state.i, state.kernel, new_hamiltonian, - state.z, + new_z, state.adaptor, new_ldf, state._vector_vnt, From d5c64961ca9be22aa885fa64b4479353dfc1bdd1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:45:35 +0100 Subject: [PATCH 11/13] some typos --- src/mcmc/gibbs.jl | 6 +++--- src/variational/Variational.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 73ac91c438..eb2709b39c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -387,9 +387,9 @@ behaviour of Gibbs when there are unhandled variables is undefined: depending on of Turing, it may either crash, or it may sample once from the prior and not update values after that. See https://github.com/TuringLang/Turing.jl/issues/2810 for more information. -There is currently no way to specify a different initialisation strategy for for each -component sampler individually. When sampling with Gibbs, `initial_params` applies to the -model as a whole. +There is currently no way to specify a different initialisation strategy for each component +sampler individually. When sampling with Gibbs, `initial_params` applies to the model as a +whole. # Fields $(TYPEDFIELDS) diff --git a/src/variational/Variational.jl b/src/variational/Variational.jl index 6cadd8a6f4..8c2a42e8f9 100644 --- a/src/variational/Variational.jl +++ b/src/variational/Variational.jl @@ -303,7 +303,6 @@ function Base.show(io::IO, ::MIME"text/plain", r::VIResult) tree_char = i == length(last_info) ? "└" : "├" println(io, " │ $(tree_char) $k = $v") end - else end print(io, " └ (2 more fields: state, ldf)") return nothing From 47bf4eb122bb60362687a3a44e56774b71786b79 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Apr 2026 16:54:46 +0100 Subject: [PATCH 12/13] add statistics as a test dep --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 6562c81a2f..55af1f41bd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,6 +34,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -74,6 +75,7 @@ Random = "1" ReverseDiff = "1.4.2" SpecialFunctions = "0.10.3, 1, 2" StableRNGs = "1" +Statistics = "1" StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" From af73c7f4c56bf03c26b5167359df59cb063f73fc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 22 Apr 2026 12:19:49 +0100 Subject: [PATCH 13/13] Remove unnecessary hacks in Gibbs interface (#2811) --- HISTORY.md | 5 ++-- Project.toml | 2 +- src/mcmc/ess.jl | 57 ++++++++++++++++++------------------ src/mcmc/external_sampler.jl | 22 ++++---------- src/mcmc/gibbs.jl | 22 +++++++++----- src/mcmc/hmc.jl | 36 ++++------------------- test/mcmc/Inference.jl | 5 +++- 7 files changed, 60 insertions(+), 89 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 76242a6ee2..1311d8cdde 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -69,9 +69,8 @@ The MCMC sampling (`sample`), optimisation (`mode_estimate` / `maximum_likelihoo (In contrast, the default behaviour is to rederive transforms each time the model is run.) Note that not all MCMC samplers currently support fixed transforms. -In particular, HMC, NUTS, and external samplers currently do, but all other samplers do not (including MH's `LinkedRW`, and Gibbs). -For some samplers such as ESS and particle MCMC, fixed transforms do not affect the sampling -process at all (in such cases the keyword argument is accepted but ignored). +In particular, HMC, NUTS, and external samplers currently do support it, and Gibbs will pass the flag through to its component samplers, but MH's `LinkedRW` option will currently ignore the `fix_transforms` argument. +For some samplers such as ESS and particle MCMC, fixed transforms do not affect the sampling process at all (in such cases the keyword argument is accepted but ignored). The reason why Turing rederives transforms is to ensure correctness in cases where the transform *depends on the value of another random variable*. For example, if `a` is a parameter, then `b ~ Uniform(-a, a)` has a transform that depends on the value of `a`. diff --git a/Project.toml b/Project.toml index 74a3f43547..e89a3e5c8c 100644 --- a/Project.toml +++ b/Project.toml @@ -61,7 +61,7 @@ DifferentiationInterface = "0.7" Distributions = "0.25.77" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.41.2" +DynamicPPL = "0.41.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.14" diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 665bee441f..54313e3639 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -26,16 +26,12 @@ struct TuringESSState{ L<:DynamicPPL.LogDensityFunction, P<:AbstractVector{<:Real}, R<:Real, - Va<:DynamicPPL.VarNamedTuple, - Vb<:DynamicPPL.VarNamedTuple, + V<:DynamicPPL.VarNamedTuple, } ldf::L params::P loglikelihood::R - priors::Va - # Minor optimisation: we cache a VNT storing vectorised values here to avoid having to - # reconstruct it each time in `gibbs_update_state!!`. - _vector_vnt::Vb + priors::V end # always accept in the first step @@ -47,29 +43,36 @@ function AbstractMCMC.step( initial_params, kwargs..., ) - # Add some extra accumulators so that we can compute everything we need to in one pass. - oavi = DynamicPPL.OnlyAccsVarInfo() - oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.PriorDistributionAccumulator()) - oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.RawValueAccumulator(true)) - oavi = DynamicPPL.setacc!!(oavi, DynamicPPL.VectorValueAccumulator()) - _, oavi = DynamicPPL.init!!(rng, model, oavi, initial_params, DynamicPPL.UnlinkAll()) + # Set up a LogDensityFunction which evaluates the model's log-likelihood. + # Note that this costs one model evaluation (fine since it's only in the first step) + loglike_ldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getloglikelihood, DynamicPPL.UnlinkAll() + ) + + # Run the model using the specified initialisation strategy and extract all necessary + # information. + accs = DynamicPPL.OnlyAccsVarInfo( + # no transforms so no need for LogJacobianAccumulator + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.VectorParamAccumulator(loglike_ldf), + DynamicPPL.PriorDistributionAccumulator(), + DynamicPPL.RawValueAccumulator(true), # for ParamsWithStats later + ) + _, accs = DynamicPPL.init!!(rng, model, accs, initial_params, DynamicPPL.UnlinkAll()) + + priors = DynamicPPL.get_priors(accs) + vector_params = DynamicPPL.get_vector_params(accs) + loglike = DynamicPPL.getloglikelihood(accs) # Check that priors are all Gaussian - priors = DynamicPPL.get_priors(oavi) for dist in values(priors) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - # Set up a LogDensityFunction which evaluates the model's log-likelihood. - loglike_ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, oavi) - # And the corresponding vector of params, and the likelihood at this point - vecvals = DynamicPPL.get_vector_values(oavi) - vector_params = DynamicPPL.internal_values_as_vector(vecvals) - loglike = DynamicPPL.getloglikelihood(oavi) - - transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(oavi) - state = TuringESSState(loglike_ldf, vector_params, loglike, priors, vecvals) + transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(accs) + state = TuringESSState(loglike_ldf, vector_params, loglike, priors) return transition, state end @@ -99,7 +102,7 @@ function AbstractMCMC.step( transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(sample, state.ldf) new_state = TuringESSState( - state.ldf, sample, new_wrapped_state.loglikelihood, state.priors, state._vector_vnt + state.ldf, sample, new_wrapped_state.loglikelihood, state.priors ) return transition, new_state end @@ -199,12 +202,8 @@ function gibbs_update_state!!( # pass an extra LogLikelihoodAccumulator here so that we can calculate the new loglike in # one pass. new_ldf, new_params, accs = gibbs_recompute_ldf_and_params( - state.ldf, - model, - state._vector_vnt, - global_vals, - (DynamicPPL.LogLikelihoodAccumulator(),), + state.ldf, model, global_vals, (DynamicPPL.LogLikelihoodAccumulator(),) ) new_loglike = DynamicPPL.getloglikelihood(accs) - return TuringESSState(new_ldf, new_params, new_loglike, state.priors, state._vector_vnt) + return TuringESSState(new_ldf, new_params, new_loglike, state.priors) end diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index e6ba6a2bca..14be5afc7d 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -121,15 +121,10 @@ function externalsampler( return ExternalSampler(sampler, adtype, Val(unconstrained)) end -struct TuringState{ - S,P<:AbstractVector,L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.VarNamedTuple -} +struct TuringState{S,P<:AbstractVector,L<:DynamicPPL.LogDensityFunction} state::S params::P ldf::L - # Cached vector VNT, used to construct new LDFs in gibbs_update_state!! without - # reevaluating the model. Same role as HMCState._vector_vnt. - _vector_vnt::V end function AbstractMCMC.step( @@ -146,13 +141,10 @@ function AbstractMCMC.step( # Construct LogDensityFunction tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll() - oavi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.VectorValueAccumulator()) - _, oavi = DynamicPPL.init!!(model, oavi, DynamicPPL.InitFromPrior(), tfm_strategy) - vecvals = DynamicPPL.get_vector_values(oavi) f = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, - vecvals; + tfm_strategy; adtype=sampler_wrapper.adtype, fix_transforms=fix_transforms, ) @@ -183,7 +175,7 @@ function AbstractMCMC.step( DynamicPPL.ParamsWithStats(new_parameters, f, new_stats) end - return (new_transition, TuringState(state_inner, new_parameters, f, vecvals)) + return (new_transition, TuringState(state_inner, new_parameters, f)) end function AbstractMCMC.step( @@ -209,7 +201,7 @@ function AbstractMCMC.step( new_stats = AbstractMCMC.getstats(state_inner) DynamicPPL.ParamsWithStats(new_parameters, f, new_stats) end - return (new_transition, TuringState(state_inner, new_parameters, f, state._vector_vnt)) + return (new_transition, TuringState(state_inner, new_parameters, f)) end #### @@ -229,12 +221,10 @@ function gibbs_update_state!!( model::DynamicPPL.Model, global_vals::DynamicPPL.VarNamedTuple, ) - new_ldf, new_params, _ = gibbs_recompute_ldf_and_params( - state.ldf, model, state._vector_vnt, global_vals - ) + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, global_vals) # Update the inner sampler's state with the new parameters. new_inner_state = AbstractMCMC.setparams!!( AbstractMCMC.LogDensityModel(new_ldf), state.state, new_params ) - return TuringState(new_inner_state, new_params, new_ldf, state._vector_vnt) + return TuringState(new_inner_state, new_params, new_ldf) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index eb2709b39c..40aa935627 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -76,15 +76,18 @@ function gibbs_update_state!! end """ gibbs_recompute_ldf_and_params( - old_ldf, model, vector_vnt, global_vals, extra_accs + old_ldf::LogDensityFunction, + model::Model, + global_vals::VarNamedTuple, + extra_accs=() ) Shared helper that is used in `gibbs_update_state!!` for any sampler that uses a LogDensityFunction. -Creates a new `LogDensityFunction` from the newly conditioned `model` (using a cached -`vector_vnt` to avoid an extra model evaluation), then reevaluates the model to obtain the -correct vectorised parameters corresponding to the raw values in `global_vals`. +Creates a new `LogDensityFunction` from the newly conditioned `model`, then reevaluates the +model to obtain the correct vectorised parameters corresponding to the raw values in +`global_vals`. If extra information is needed (e.g. log-probabilities), `extra_accs` can be used to pass in other accumulators to be used in the same model evaluation, to avoid having to recompute @@ -92,20 +95,23 @@ them later. Returns `(new_ldf, new_params, accs)` where `accs` is the set of accumulators after evaluation, from which extra accumulators (e.g. `LogLikelihoodAccumulator`) can be read. + +!!! warning + This assumes that `old_ldf.model` (i.e., the model conditioned on the previous values) + and `model` (i.e., the model conditioned on the new values) have the same structure, i.e., + all other components of the LogDensityFunction can be reused. """ function gibbs_recompute_ldf_and_params( old_ldf::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model, - vector_vnt::DynamicPPL.VarNamedTuple, global_vals::DynamicPPL.VarNamedTuple, extra_accs::NTuple{N,<:DynamicPPL.AbstractAccumulator}=(), ) where {N} - # TODO(penelopeysm): If old_ldf has fixed transforms, this will overwrite it. This - # probably needs to be fixed by improving the constructor in DynamicPPL. new_ldf = DynamicPPL.LogDensityFunction( model, DynamicPPL.get_logdensity_callable(old_ldf), - vector_vnt; + DynamicPPL.get_all_ranges_and_transforms(old_ldf), + DynamicPPL.get_sample_input_vector(old_ldf); adtype=old_ldf.adtype, ) accs = DynamicPPL.OnlyAccsVarInfo( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index be8b5fe81c..ebdc7ed56b 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -13,7 +13,6 @@ struct HMCState{ PhType<:AHMC.PhasePoint, TAdapt<:AHMC.Adaptation.AbstractAdaptor, L<:DynamicPPL.LogDensityFunction, - V<:DynamicPPL.VarNamedTuple, } i::Int kernel::TKernel @@ -21,14 +20,6 @@ struct HMCState{ z::PhType adaptor::TAdapt ldf::L - # TODO(penelopeysm): This field is needed for Gibbs because each time we call - # gibbs_update_state!! on this, we need to reconstruct a LogDensityFunction. - # In general this would require reevaluating the model, unless we supply a - # VarNamedTuple which already contains vectorised parameters. This can probably - # be improved in DynamicPPL, but for now we will just store an extra VNT in - # the state. - # NOTE: The actual values of this field should never be used or relied on! - _vector_vnt::V end ### @@ -169,15 +160,10 @@ function AbstractMCMC.step( kwargs..., ) # Create a LogDensityFunction - oavi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.VectorValueAccumulator()) - _, oavi = DynamicPPL.init!!( - model, oavi, DynamicPPL.InitFromPrior(), DynamicPPL.LinkAll() - ) - vecvals = DynamicPPL.get_vector_values(oavi) ldf = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, - vecvals; + DynamicPPL.LinkAll(); adtype=spl.adtype, fix_transforms=fix_transforms, ) @@ -209,7 +195,7 @@ function AbstractMCMC.step( DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple()) end - state = HMCState(0, kernel, hamiltonian, z, adaptor, ldf, vecvals) + state = HMCState(0, kernel, hamiltonian, z, adaptor, ldf) return transition, state end @@ -253,9 +239,7 @@ function AbstractMCMC.step( else DynamicPPL.ParamsWithStats(t.z.θ, state.ldf, t.stat) end - newstate = HMCState( - i, kernel, hamiltonian, t.z, state.adaptor, state.ldf, state._vector_vnt - ) + newstate = HMCState(i, kernel, hamiltonian, t.z, state.adaptor, state.ldf) return transition, newstate end @@ -498,9 +482,7 @@ function gibbs_update_state!!( ) # Construct a new LDF with the newly conditioned `model` (not `state.ldf.model`, which # contains stale conditioned values) and recompute the vectorised params. - new_ldf, new_params, _ = gibbs_recompute_ldf_and_params( - state.ldf, model, state._vector_vnt, global_vals - ) + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, global_vals) # Update the Hamiltonian (because that depends on the LDF). metric = gen_metric(LogDensityProblems.dimension(new_ldf), spl, state) lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf) @@ -509,13 +491,5 @@ function gibbs_update_state!!( # We also need to update the position variables in the PhasePoint. new_z = deepcopy(state.z) new_z.θ .= new_params - return HMCState( - state.i, - state.kernel, - new_hamiltonian, - new_z, - state.adaptor, - new_ldf, - state._vector_vnt, - ) + return HMCState(state.i, state.kernel, new_hamiltonian, new_z, state.adaptor, new_ldf) end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 464e055914..c8d78425c9 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -663,7 +663,10 @@ using Turing model = f() @testset "$spl" for spl in ( - NUTS(), HMC(0.1, 5), externalsampler(AdvancedMH.RWMH(MvNormal([0.0], [1.0;;]))) + NUTS(), + HMC(0.1, 5), + externalsampler(AdvancedMH.RWMH(MvNormal([0.0], [1.0;;]))), + Gibbs(:x => HMC(0.1, 5)), ) counter[] = 0 sample(model, spl, 100)