From a76fa1dbff07c27721d2378b8d9b3a8dcdf6d5c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 May 2026 17:16:14 +0100 Subject: [PATCH 1/4] Move FlexiChains ext --- HISTORY.md | 4 + Project.toml | 8 +- ext/DynamicPPLFlexiChainsExt.jl | 663 ++++++++++++++++++ ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl | 41 ++ test/Project.toml | 2 + test/ext/DynamicPPLFlexiChainsExt.jl | 654 +++++++++++++++++ 6 files changed, 1371 insertions(+), 1 deletion(-) create mode 100644 ext/DynamicPPLFlexiChainsExt.jl create mode 100644 ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl create mode 100644 test/ext/DynamicPPLFlexiChainsExt.jl diff --git a/HISTORY.md b/HISTORY.md index a9a82fc86..e61d09660 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.41.7 + +Move FlexiChainsDynamicPPLExt to DynamicPPLFlexiChainsExt. + # 0.41.6 Add a `factorize::Bool` keyword argument for `pointwise_logdensities(model, values)`, which controls whether pointwise logdensities for factorisable distributions (e.g. `MvNormal`, `product_distribution`, etc.) are returned as a single log-density for the whole distribution, or as an array of log-densities for each factor. diff --git a/Project.toml b/Project.toml index 8e09a07bf..db72e1103 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.41.6" +version = "0.41.7" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -31,15 +31,19 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLFlexiChainsExt = ["FlexiChains"] +DynamicPPLFlexiChainsPosteriorStatsExt = ["FlexiChains", "PosteriorStats"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] @@ -61,6 +65,7 @@ Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" FillArrays = "1.16.0" +FlexiChains = "0.4 - 0.5" ForwardDiff = "0.10.12, 1" InteractiveUtils = "1" KernelAbstractions = "0.9.33" @@ -72,6 +77,7 @@ MarginalLogDensities = "0.4.3" Mooncake = "0.4.147, 0.5" OrderedCollections = "1" PartitionedDistributions = "0.0.1" +PosteriorStats = "0.4" PrecompileTools = "1.2.1" Preferences = "1.5.2" Printf = "1.10" diff --git a/ext/DynamicPPLFlexiChainsExt.jl b/ext/DynamicPPLFlexiChainsExt.jl new file mode 100644 index 000000000..e4a18b0ac --- /dev/null +++ b/ext/DynamicPPLFlexiChainsExt.jl @@ -0,0 +1,663 @@ +module DynamicPPLFlexiChainsExt + +using FlexiChains: + FlexiChains, FlexiChain, VarName, Parameter, Extra, ParameterOrExtra, VNChain +using DynamicPPL: + DynamicPPL, + AbstractPPL, + AbstractMCMC, + Distributions, + UnlinkAll, + TransformedValue, + NoTransform, + VarNamedTuple +using OrderedCollections: OrderedDict +using Random: Random + +################## +# bundle_samples # +################## +function AbstractMCMC.bundle_samples( + # TODO(penelopeysm): When VarNamedTuple is moved into AbstractPPL, this can go back + # into src/ rather than the extension. + transitions::AbstractVector, + @nospecialize(m::AbstractMCMC.AbstractModel), + @nospecialize(s::AbstractMCMC.AbstractSampler), + last_sampler_state::Any, + chain_type::Type{FlexiChain{VarName}}; + save_state=false, + stats=missing, + discard_initial::Int=0, + thinning::Int=1, + _kwargs..., +)::FlexiChain{VarName} + niters = length(transitions) + vnts_and_stats = map(FlexiChains.to_vnt_and_stats, transitions) + dicts = map(vnts_and_stats) do (vnt, stat) + d = OrderedDict{ParameterOrExtra{<:VarName},Any}( + Parameter(vn) => val for (vn, val) in pairs(vnt) + ) + for (stat_vn, stat_val) in pairs(stat) + d[Extra(stat_vn)] = stat_val + end + d + end + # note that FlexiChains constructor expects structures to have size (niters x nchains), + # so a vector won't do + skeletons = hcat(map(DynamicPPL.skeleton ∘ first, vnts_and_stats)) + # timings + tm = stats === missing ? missing : stats.stop - stats.start + # last sampler state + st = save_state ? last_sampler_state : missing + # calculate iteration indices + start = discard_initial + 1 + iter_indices = if thinning != 1 + range(start; step=thinning, length=niters) + else + # This returns UnitRange not StepRange -- a bit cleaner + start:(start + niters - 1) + end + return FlexiChain{VarName}( + niters, + 1, + dicts; + structures=skeletons, + iter_indices=iter_indices, + # 1:1 gives nicer DimMatrix output than just [1] + chain_indices=1:1, + sampling_time=[tm], + last_sampler_state=[st], + ) +end + +function FlexiChains.reconstruct_values(chn::VNChain, i, j, structure::VarNamedTuple) + vnt = DynamicPPL.VarNamedTuple() + nt = NamedTuple() + for param_or_extra in keys(chn) + val = chn[param_or_extra][i, j] + if param_or_extra isa Parameter + ismissing(val) && continue + vn = param_or_extra.name + top_sym = AbstractPPL.getsym(vn) + template = get(structure.data, top_sym, DynamicPPL.NoTemplate()) + vnt = DynamicPPL.templated_setindex!!(vnt, val, vn, template) + elseif param_or_extra isa Extra + nt = merge(nt, (; Symbol(param_or_extra.name) => val)) + end + end + return DynamicPPL.ParamsWithStats(vnt, nt) +end + +function FlexiChains.reconstruct_parameters(chn::VNChain, i, j, structure::VarNamedTuple) + vnt = DynamicPPL.VarNamedTuple() + for vn in FlexiChains.parameters(chn) + val = chn[Parameter(vn)][i, j] + ismissing(val) && continue + top_sym = AbstractPPL.getsym(vn) + template = get(structure.data, top_sym, DynamicPPL.NoTemplate()) + vnt = DynamicPPL.templated_setindex!!(vnt, val, vn, template) + end + return vnt +end + +################################################## +# AbstractMCMC.{to,from}_samples implementations # +################################################## + +""" + AbstractMCMC.from_samples( + ::Type{<:VNChain}, + params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats} + )::VNChain + +Convert a matrix of [`DynamicPPL.ParamsWithStats`](@extref) to a `VNChain`. +""" +function AbstractMCMC.from_samples( + ::Type{<:VNChain}, params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats} +)::VNChain + # Just need to convert the `ParamsWithStats` to Dicts of ParameterOrExtra. + dicts = map(params_and_stats) do ps + # Parameters + d = OrderedDict{ParameterOrExtra{<:VarName},Any}( + Parameter(vn) => val for (vn, val) in pairs(ps.params) + ) + # Stats + for (stat_vn, stat_val) in pairs(ps.stats) + d[Extra(stat_vn)] = stat_val + end + d + end + # And get the structures. + structures = map(ps -> DynamicPPL.skeleton(ps.params), params_and_stats) + return VNChain( + size(params_and_stats, 1), size(params_and_stats, 2), dicts; structures=structures + ) +end +""" + AbstractMCMC.from_samples( + ::Type{<:VNChain}, + params_and_stats::AbstractMatrix{<:DynamicPPL.VarNamedTuple} + )::VNChain + +Convert a matrix of [`DynamicPPL.VarNamedTuple`](@extref +DynamicPPL.VarNamedTuples.VarNamedTuple) to a `VNChain`. +""" +function AbstractMCMC.from_samples( + ::Type{<:VNChain}, vnts::AbstractMatrix{<:DynamicPPL.VarNamedTuple} +)::VNChain + pwss = map(vnts) do vnt + DynamicPPL.ParamsWithStats(vnt, (;)) + end + return AbstractMCMC.from_samples(VNChain, pwss) +end + +""" + AbstractMCMC.to_samples( + ::Type{DynamicPPL.ParamsWithStats}, + chain::VNChain, + [model::DynamicPPL.Model] + )::DimensionalData.DimMatrix{DynamicPPL.ParamsWithStats} + +Convert a `VNChain` to a `DimMatrix` of [`DynamicPPL.ParamsWithStats`](@extref). + +The axes of the `DimMatrix` are the same as those of the input `VNChain`. +""" +function AbstractMCMC.to_samples( + ::Type{DynamicPPL.ParamsWithStats}, chain::FlexiChain{T}, model::DynamicPPL.Model +) where {T<:VarName} + template_vnt = nothing # Set later on-demand. + # If there is no skeletal VNT structure stored, then values_at will return a Dict. + # Otherwise it will return a ParamsWithStats + dicts_or_pwss = FlexiChains.values_at(chain; iter=:, chain=:) + pwss = map(dicts_or_pwss) do d_or_pws + if d_or_pws isa DynamicPPL.ParamsWithStats + d_or_pws + else + # No skeleton -- rerun the model once to get a template, and pray that + # it's accurate. + if template_vnt === nothing + template_vnt = rand(model) + end + # Then attempt to reconstruct + vnt = DynamicPPL.VarNamedTuple() + for (vn_param, val) in pairs(d_or_pws) + if vn_param isa Parameter + vn = vn_param.name + top_sym = AbstractPPL.getsym(vn) + template = get(template_vnt.data, top_sym, DynamicPPL.NoTemplate()) + vnt = DynamicPPL.templated_setindex!!(vnt, val, vn, template) + end + end + # Stats + stats_nt = NamedTuple( + Symbol(extra_param.name) => val for + (extra_param, val) in d_or_pws if extra_param isa Extra + ) + DynamicPPL.ParamsWithStats(vnt, stats_nt) + end + end + return FlexiChains._raw_to_user_data(chain, pwss) +end +function AbstractMCMC.to_samples( + ::Type{DynamicPPL.VarNamedTuple}, chain::FlexiChain{T}, model::DynamicPPL.Model +) where {T<:VarName} + pwss = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain, model) + return map(pws -> pws.params, pwss) +end + +function AbstractMCMC.to_samples( + ::Type{DynamicPPL.ParamsWithStats}, chain::FlexiChain{T} +) where {T<:VarName} + # If there is no skeletal VNT structure stored, then values_at will return a Dict. + # Otherwise it will return a ParamsWithStats + dicts_or_pwss = FlexiChains.values_at(chain; iter=:, chain=:) + pwss = map(dicts_or_pwss) do d_or_pws + if d_or_pws isa DynamicPPL.ParamsWithStats + d_or_pws + else + # No skeleton. Just cry and use setindex!!. + vnt = DynamicPPL.VarNamedTuple() + for (vn_param, val) in pairs(d_or_pws) + if vn_param isa Parameter + vnt = DynamicPPL.setindex!!(vnt, val, vn_param.name) + end + end + # Stats + stats_nt = NamedTuple( + Symbol(extra_param.name) => val for + (extra_param, val) in d_or_pws if extra_param isa Extra + ) + DynamicPPL.ParamsWithStats(vnt, stats_nt) + end + end + return FlexiChains._raw_to_user_data(chain, pwss) +end +function AbstractMCMC.to_samples( + ::Type{DynamicPPL.VarNamedTuple}, chain::FlexiChain{T} +) where {T<:VarName} + pwss = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(pws -> pws.params, pwss) +end + +# This method will make `bundle_samples` 'just work' +function FlexiChains.to_vnt_and_stats(pws::DynamicPPL.ParamsWithStats) + return (pws.params, pws.stats) +end +function FlexiChains.to_vnt_and_stats(vnt::DynamicPPL.VarNamedTuple) + return (vnt, (;)) +end + +############################ +# InitFromParams extension # +############################ +""" + DynamicPPL.InitFromParams( + chn::FlexiChain{<:VarName}, + iter::Union{Int,DimensionalData.At}, + chain::Union{Int,DimensionalData.At}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + )::DynamicPPL.InitFromParams + +Use the parameters stored in a FlexiChain as an initialisation strategy. +""" +function DynamicPPL.InitFromParams( + chn::FlexiChain{<:VarName}, + iter::Union{Int,FlexiChains.At}, + chain::Union{Int,FlexiChains.At}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=DynamicPPL.InitFromPrior(), +) + return InitFromFlexiChain(chn, iter, chain, fallback) +end + +""" + InitFromFlexiChain( + chain::FlexiChain, iter_index::Int, chain_index::Int, fallback=nothing + ) + +A DynamicPPL initialisation strategy that obtains values from the given `FlexiChain` at the +specified iteration and chain indices. + +In order for `InitFromFlexiChain` to work correctly, two things must be ensured: + +1. The variables being asked for must **exactly** match those stored in the FlexiChain. That + is, if the chain contains `@varname(y)` and the model asks for `@varname(y)`, this will + either error (if no fallback is provided) or silently use the fallback. + +2. The `iter_index` and `chain_index` arguments must be 1-based indices. + +These requirements allow us to skip the usual `getindex` method when retrieving values from +the `FlexiChain`, and instead index directly into the data storage, which is much faster. + +These conditions, especially (1), can be guaranteed if and only if the chain used to +re-evaluate the model was generated from the same model (or a model with the same +structure). + +`fallback` provides the same functionality as it does in `DynamicPPL.InitFromParams`, that +is, if a variable is not found in the `FlexiChain`, the fallback strategy is used to +generate its value. This is necessary for `predict`. +""" +struct InitFromFlexiChain{ + C<:FlexiChains.VNChain,S<:Union{DynamicPPL.AbstractInitStrategy,Nothing} +} <: DynamicPPL.AbstractInitStrategy + chain::C + iter_index::Int + chain_index::Int + fallback::S + _top_syms::Set{Symbol} + # Lazily populated cache of the full parameters VNT for this (iter, chain). + _vnt_cache::Ref{Union{Nothing,VarNamedTuple}} + function InitFromFlexiChain( + chain::C, iter_index::Int, chain_index::Int, fallback::S=nothing + ) where {C<:FlexiChains.VNChain,S<:Union{DynamicPPL.AbstractInitStrategy,Nothing}} + top_syms = Set{Symbol}( + AbstractPPL.getsym(vn) for vn in FlexiChains.parameters(chain) + ) + return new{C,S}( + chain, + iter_index, + chain_index, + fallback, + top_syms, + Ref{Union{Nothing,VarNamedTuple}}(nothing), + ) + end +end +function _get_parameters_vnt(strategy::InitFromFlexiChain) + if strategy._vnt_cache[] === nothing + strategy._vnt_cache[] = FlexiChains.parameters_at( + strategy.chain; iter=strategy.iter_index, chain=strategy.chain_index + ) + end + return strategy._vnt_cache[] +end +function DynamicPPL.init( + rng::Random.AbstractRNG, + vn::VarName, + dist::Distributions.Distribution, + strategy::InitFromFlexiChain, +) + param = FlexiChains.Parameter(vn) + # First check if there's an exact match in the chain, and if so, use that. + # + # Otherwise, attempt to construct the full dictionary of parameters and use + # that. (That guards against cases where the chain has a densified variable + # e.g. `x`, but the model has `x[1]` and `x[2]`: e.g. x = zeros(2); x .~ + # Normal().) + # + # Finally, if even that isn't found, just use the fallback strategy (if + # provided). + if haskey(strategy.chain._data, param) + x = strategy.chain._data[param][strategy.iter_index, strategy.chain_index] + return TransformedValue(x, NoTransform()) + else + # Check if any parameter in the chain shares the same top-level symbol. If not, + # we can skip the expensive VNT reconstruction and go straight to the fallback. + has_matching_sym = AbstractPPL.getsym(vn) in strategy._top_syms + if has_matching_sym + vnt = _get_parameters_vnt(strategy) + augmented_fallback = DynamicPPL.InitFromParams(vnt, strategy.fallback) + return DynamicPPL.init(rng, vn, dist, augmented_fallback) + elseif strategy.fallback !== nothing + return DynamicPPL.init(rng, vn, dist, strategy.fallback) + else + error("Variable $vn not found in chain and no fallback strategy provided.") + end + end +end + +########################################### +# DynamicPPL: predict, returned, logjoint # +########################################### + +function _default_reevaluate_accs() + return ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogJacobianAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.RawValueAccumulator(true), + ) +end + +""" +Returns a tuple of (retval, varinfo) for each iteration in the chain. +""" +function reevaluate( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::FlexiChain{<:VarName}, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}=_default_reevaluate_accs(), + fallback_strategy::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + niters, nchains = size(chain) + tuples = Iterators.product(1:niters, 1:nchains) + vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + retvals_and_varinfos = map(tuples) do (i, j) + DynamicPPL.init!!( + rng, + model, + vi, + InitFromFlexiChain(chain, i, j, fallback_strategy), + UnlinkAll(), + ) + end + return FlexiChains._raw_to_user_data(chain, retvals_and_varinfos) +end +function reevaluate( + model::DynamicPPL.Model, + chain::FlexiChain{<:VarName}, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}=_default_reevaluate_accs(), + fallback_strategy::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate(Random.default_rng(), model, chain, accs, fallback_strategy) +end + +""" + DynamicPPL.returned(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + +Returns a `DimMatrix` of the model's return values, re-evaluated using the parameters in +each iteration of the chain. + +If the return value is a `DimArray`, the dimensions will be stacked. +""" +function DynamicPPL.returned(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + return FlexiChains._raw_to_user_data(chain, map(first, reevaluate(model, chain))) +end + +""" + DynamicPPL.logjoint(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + +Returns a `DimMatrix` of the log-joint probabilities, re-evaluated using the parameters at +each iteration of the chain. +""" +function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + accs = (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) + return map(DynamicPPL.getlogjoint ∘ last, reevaluate(model, chain, accs)) +end + +""" + DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + +Returns a `DimMatrix` of the log-likelihoods, re-evaluated using the parameters at each +iteration of the chain. +""" +function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + accs = (DynamicPPL.LogLikelihoodAccumulator(),) + return map(DynamicPPL.getloglikelihood ∘ last, reevaluate(model, chain, accs)) +end + +""" + DynamicPPL.logprior(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + +Returns a `DimMatrix` of the log-prior probabilities, re-evaluated using the parameters at +each iteration of the chain. +""" +function DynamicPPL.logprior(model::DynamicPPL.Model, chain::FlexiChain{<:VarName}) + accs = (DynamicPPL.LogPriorAccumulator(),) + return map(DynamicPPL.getlogprior ∘ last, reevaluate(model, chain, accs)) +end + +""" + DynamicPPL.predict( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model, + chain::FlexiChain{<:VarName}; + include_all::Bool=true, + ) + +Returns a new `FlexiChain` containing predictions for variables in the model, conditioned on +the parameters in each iteration of the input `chain`. + +The returned `FlexiChain` by default will contain all the predicted variables, as well as the +variables already present in the input `chain`. If you only want the predicted variables, +set `include_all=false`. + +The returned chain will also contain log-probabilities corresponding to the re-evaluation of +the model. In particular, the log probability for the newly predicted variables are +now considered as prior terms. However, note that the log-prior of the returned chain will +also contain the log-prior terms of the parameters already present in the input `chain`. +Thus, if you want to obtain the log-probability of the predicted variables only, you can +subtract the two log-prior terms. The `include_all` keyword argument has no effect on the +log-probability fields. +""" +function DynamicPPL.predict( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::FlexiChain{<:VarName}; + include_all::Bool=true, +)::FlexiChain{VarName} + existing_parameters = Set(FlexiChains.parameters(chain)) + accs = _default_reevaluate_accs() + fallback = DynamicPPL.InitFromPrior() + param_dicts_and_skeletons = + map(reevaluate(rng, model, chain, accs, fallback)) do (_, vi) + vnt = DynamicPPL.densify!!(DynamicPPL.get_raw_values(vi)) + p_dict = OrderedDict{ParameterOrExtra{<:VarName},Any}( + Parameter(vn) => val for + (vn, val) in pairs(vnt) if (include_all || !(vn in existing_parameters)) + ) + # Use skeletons from reevaluation, since they will be appropriate for the new + # chain that we are constructing. + skeleton = DynamicPPL.skeleton(vnt) + # Tack on the probabilities + p_dict[FlexiChains._LOGPRIOR_KEY] = DynamicPPL.getlogprior(vi) + p_dict[FlexiChains._LOGJOINT_KEY] = DynamicPPL.getlogjoint(vi) + p_dict[FlexiChains._LOGLIKELIHOOD_KEY] = DynamicPPL.getloglikelihood(vi) + (p_dict, skeleton) + end + ni, nc = size(chain) + predictions_chain = FlexiChain{VarName}( + ni, + nc, + map(first, param_dicts_and_skeletons); + structures=map(last, param_dicts_and_skeletons), + iter_indices=FlexiChains.iter_indices(chain), + chain_indices=FlexiChains.chain_indices(chain), + ) + old_extras_chain = FlexiChains.subset_extras(chain) + return merge(old_extras_chain, predictions_chain) +end +function DynamicPPL.predict( + model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; include_all::Bool=true +)::FlexiChain{VarName} + return DynamicPPL.predict(Random.default_rng(), model, chain; include_all=include_all) +end + +# Shared internal helper function. +function _pointwise_logprobs( + model::DynamicPPL.Model, + chain::FlexiChain{<:VarName}, + ::Val{Prior}, + ::Val{Likelihood}; + factorize::Bool=false, +) where {Prior,Likelihood} + acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( + DynamicPPL.PointwiseLogProb{Prior,Likelihood,factorize}() + ) + pointwise_dicts = map(reevaluate(model, chain, (acc,), nothing)) do (_, oavi) + logprobs = DynamicPPL.densify!!(DynamicPPL.get_pointwise_logprobs(oavi)) + OrderedDict{ParameterOrExtra{<:VarName},Any}( + Parameter(vn) => val for (vn, val) in pairs(logprobs) + ) + end + ni, nc = size(chain) + return FlexiChain{VarName}( + ni, + nc, + pointwise_dicts; + iter_indices=FlexiChains.iter_indices(chain), + chain_indices=FlexiChains.chain_indices(chain), + ) +end + +# Copied from DynamicPPL +const _FACTORIZE_KWARG_DOC = """ +If `factorize=true`, additionally attempt to provide factorised log-densities for +distributions that can be partitioned into blocks, using PartitionedDistributions.jl. + +For example, if `factorize=true`, then `y ~ MvNormal(...)` will return a vector of +log-densities, one for each element of `y`. The `i`-th element of this vector will be the +conditional log-probability of `y[i]` given all the other elements of `y` (often denoted +`log p(y_{i} | y_{-i})`): in particular this is exactly the log-density required for +leave-one-out cross-validation. + +In contrast, if `factorize=false`, then the log-density for `y ~ MvNormal(...)` will be a +single scalar corresponding to `logpdf(MvNormal(...), y)`. + +Note that the sum of the factorised log-densities may not, in general, be equal to the +log-density of the full distribution: they will only be equal if the original distribution +can be completely factorised into independent components. For example, if `y ~ MvNormal(μ, +Σ)` where `Σ` is diagonal, then each element of `y` is independent and the sum of the +factorised log-densities will be equal to the log-density of the full distribution. In +contrast, if `Σ` has off-diagonal entries, then the elements of `y` are not independent. +""" + +""" + DynamicPPL.pointwise_logdensities( + model::Model, + chain::FlexiChain{<:VarName}; + factorize::Bool=false, + )::FlexiChain{VarName} + +Calculate the log probability density associated with each variable in the model, for each +iteration in the `FlexiChain`. + +Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the +variables to their log probabilities. + +$(_FACTORIZE_KWARG_DOC) +""" +function DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false +) + return _pointwise_logprobs(model, chain, Val(true), Val(true); factorize=factorize) +end + +""" + DynamicPPL.pointwise_loglikelihoods( + model::Model, + chain::FlexiChain{<:VarName}; + factorize::Bool=false, + )::FlexiChain{VarName} + +Calculate the log likelihood associated with each observed variable in the model, for each +iteration in the `FlexiChain`. + +Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the +observed variables to their log probabilities. + +$(_FACTORIZE_KWARG_DOC) +""" +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false +) + return _pointwise_logprobs(model, chain, Val(false), Val(true); factorize=factorize) +end + +""" + DynamicPPL.pointwise_prior_logdensities( + model::Model, + chain::FlexiChain{<:VarName}; + factorize::Bool=false, + )::FlexiChain{VarName} + +Calculate the log prior associated with each random variable in the model, for each +iteration in the `FlexiChain`. + +Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the +observed variables to their log probabilities. + +$(_FACTORIZE_KWARG_DOC) +""" +function DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false +) + return _pointwise_logprobs(model, chain, Val(true), Val(false); factorize=factorize) +end + +####################### +# Precompile workload # +####################### + +using DynamicPPL: + DynamicPPL, Distributions, AbstractMCMC, @model, ParamsWithStats, InitFromPrior +using FlexiChains: VNChain, summarystats +using PrecompileTools: @setup_workload, @compile_workload + +# dummy, needed to satisfy interface of bundle_samples +struct NotASampler <: AbstractMCMC.AbstractSampler end +@setup_workload begin + @model function f() + x ~ Distributions.MvNormal(zeros(2), [1.0 0.5; 0.5 1.0]) + return y ~ Distributions.Normal() + end + model = f() + transitions = [ParamsWithStats(InitFromPrior(), model) for _ in 1:10] + @compile_workload begin + chn = AbstractMCMC.bundle_samples( + transitions, model, NotASampler(), nothing, VNChain + ) + summarystats(chn) + end +end + +end # module FlexiChainsDynamicPPLExt diff --git a/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl b/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl new file mode 100644 index 000000000..47edfaff3 --- /dev/null +++ b/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl @@ -0,0 +1,41 @@ +module DynamicPPLFlexiChainsPosteriorStatsExt + +using PosteriorStats: PosteriorStats +using FlexiChains: FlexiChain +using DynamicPPL: DynamicPPL + +""" + PosteriorStats.loo( + model::DynamicPPL.Model, + posterior_chn::FlexiChains; + factorize::Bool=false, + kwargs... + ) + +Calculates the leave-one-out cross-validation (LOO) statistic, given a model plus a +posterior chain. This first uses the model and posterior chain to calculate pointwise +log-likelihoods, and then uses those to calculate the LOO statistic. + +Returns a struct with the following fields: + +- `param_names::Vector`: A vector of parameter names whose log-likelihood values were used. + +- `loo::PosteriorStats.PSISLOOResult`: The return value of `PosteriorStats.loo` applied to + the log-likelihood values extracted from the `FlexiChain`. This contains the statistics + of interest. + +The `factorize` keyword argument is passed to `DynamicPPL.pointwise_loglikelihoods`. If +`factorize=true`, factorised log-densities will be calculated for distributions that can be +partitioned into blocks (e.g. `MvNormal`). Please see the docstring of +[`DynamicPPL.pointwise_loglikelihoods`](@ref) for more details. + +Additional keyword arguments are forwarded to [`PosteriorStats.loo`](@extref). +""" +function PosteriorStats.loo( + model::DynamicPPL.Model, posterior_chn::FlexiChain; factorize::Bool=false, kwargs... +) + lls_chn = DynamicPPL.pointwise_loglikelihoods(model, posterior_chn; factorize=factorize) + return PosteriorStats.loo(lls_chn; kwargs...) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 73cff23ed..28ea4b65b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -24,6 +25,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" diff --git a/test/ext/DynamicPPLFlexiChainsExt.jl b/test/ext/DynamicPPLFlexiChainsExt.jl new file mode 100644 index 000000000..0bb057c67 --- /dev/null +++ b/test/ext/DynamicPPLFlexiChainsExt.jl @@ -0,0 +1,654 @@ +module DynamicPPLFlexiChainsExtTests + +using Dates: now +@info "Testing $(@__FILE__)..." +__now__ = now() + +using AbstractMCMC: AbstractMCMC +using DimensionalData: DimensionalData as DD +using Distributions +using DynamicPPL +using FlexiChains: FlexiChains, FlexiChain, VNChain, Parameter, Extra +using LinearAlgebra: I +using OffsetArrays: OffsetArray +using PosteriorStats: PosteriorStats +using Random: Random, Xoshiro +using StableRNGs: StableRNG +using Test + +_LOGPRIOR_KEY = Extra(:logprior) +_LOGLIKELIHOOD_KEY = Extra(:loglikelihood) +_LOGJOINT_KEY = Extra(:logjoint) + +function sample_from_prior( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + n_iters::Int, + n_chains::Int=1; + make_chain=true, +) + vi = DynamicPPL.OnlyAccsVarInfo(( + DynamicPPL.default_accumulators()..., DynamicPPL.RawValueAccumulator(true) + )) + ps = [ + DynamicPPL.ParamsWithStats( + last(DynamicPPL.init!!(rng, model, vi, InitFromPrior(), UnlinkAll())) + ) for _ in 1:n_iters, _ in 1:n_chains + ] + return if make_chain + AbstractMCMC.from_samples(VNChain, ps) + else + ps + end +end +function sample_from_prior( + model::DynamicPPL.Model, n_iters::Int, n_chains::Int=1; make_chain=true +) + return sample_from_prior(Random.default_rng(), model, n_iters, n_chains; make_chain) +end + +# For some of the `predict` tests, we need some way to draw from the posterior. We'll use +# importance sampling here since it's simple to implement. +function sample_from_posterior(rng::Random.AbstractRNG, model::DynamicPPL.Model) + prior_samples = sample_from_prior(rng, model, 20000; make_chain=false) + log_weights = vec([p.stats.loglikelihood for p in prior_samples]) + max_logw = maximum(log_weights) + weights = exp.(log_weights .- max_logw) + weights ./= sum(weights) + dist = Categorical(weights) + idxs = rand(rng, dist, 2000) + return AbstractMCMC.from_samples(VNChain, hcat(prior_samples[idxs])) +end +function sample_from_posterior(model::DynamicPPL.Model) + return sample_from_posterior(Random.default_rng(), model) +end + +@testset "FlexiChainsExt" begin + @testset "InitFromParams(chain, i, j)" begin + @model function f() + x ~ Normal() + return y ~ Normal(x) + end + model = f() + chn = sample_from_prior(model, 50) + + for i in 1:50 + accs = OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!(model, accs, InitFromParams(chn, i, 1), UnlinkAll()) + raw_values = get_raw_values(accs) + for vn in (@varname(x), @varname(y)) + @test raw_values[vn] == chn[vn, iter=i, chain=1] + end + end + end + + @testset "AbstractMCMC.from_samples" begin + @model function f(z) + x ~ Normal() + y := x + 1 + return z ~ Normal(y) + end + + z = 1.0 + model = f(z) + + ps = sample_from_prior(Xoshiro(468), model, 50, 3; make_chain=false) + c = sample_from_prior(Xoshiro(468), model, 50, 3; make_chain=true) + @test FlexiChains.parameters(c) == [@varname(x), @varname(y)] + @test c[@varname(x)] == map(p -> p.params[@varname(x)], ps) + @test c[@varname(y)] == c[@varname(x)] .+ 1 + @test logpdf.(Normal(), c[@varname(x)]) ≈ c[Extra(:logprior)] + + # test with VarNamedTuple + vnts = [rand(model) for _ in 1:100, _ in 1:3] + c2 = AbstractMCMC.from_samples(VNChain, vnts) + @test c2 isa VNChain + @test size(c2) == (100, 3) + @test Set(FlexiChains.parameters(c2)) == Set(keys(rand(model))) + @test c2[@varname(x)] == map(vnt -> vnt[@varname(x)], vnts) + end + + @testset "parameters_at and values_at" begin + @model function f() + x ~ Normal() + y = zeros(3) + y[2] ~ Normal() + z = (; a=nothing) + return z.a ~ Normal() + end + Ni, Nc = 10, 2 + + # These should give the same results, but chn is just the ParamsWithStats + # bundled into a VNChain. + chn = sample_from_prior(Xoshiro(468), f(), Ni, Nc; make_chain=true) + pwss = sample_from_prior(Xoshiro(468), f(), Ni, Nc; make_chain=false) + + for i in 1:Ni, c in 1:Nc + prms = FlexiChains.parameters_at(chn; iter=i, chain=c) + @test prms isa VarNamedTuple + @test prms == pwss[i, c].params + vals = FlexiChains.values_at(chn; iter=i, chain=c) + @test vals isa DynamicPPL.ParamsWithStats + @test vals == pwss[i, c] + end + end + + @testset "return type of rand" begin + @model function f() + x ~ Normal() + y ~ Normal() + return nothing + end + chn = sample_from_prior(f(), 10; make_chain=true) + @test rand(chn) isa DynamicPPL.ParamsWithStats + @test rand(chn; parameters_only=true) isa DynamicPPL.VarNamedTuple + @test rand(chn, 5) isa Vector{<:DynamicPPL.ParamsWithStats} + @test rand(chn, 5; parameters_only=true) isa Vector{<:DynamicPPL.VarNamedTuple} + end + + @testset "AbstractMCMC.to_samples" begin + @model function f(z) + x ~ Normal() + y := x + 1 + return z ~ Normal(y) + end + + # Make the chain first + z = 1.0 + model = f(z) + ps = sample_from_prior(Xoshiro(468), model, 50; make_chain=false) + c = sample_from_prior(Xoshiro(468), model, 50; make_chain=true) + + # Then convert back to ParamsWithStats + @model function newmodel() + error( + "This model should never be run, because there is structure info" * + " in the chain.", + ) + x ~ Normal() + return nothing + end + + @testset "with model" begin + # Make sure that the model isn't actually ever used, by passing one that + # errors when run. + arr_pss = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, c, newmodel()) + @test arr_pss == ps + arr_pss = AbstractMCMC.to_samples(DynamicPPL.VarNamedTuple, c, newmodel()) + @test arr_pss == map(p -> p.params, ps) + end + @testset "without model" begin + arr_pss = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, c) + @test arr_pss == ps + arr_pss = AbstractMCMC.to_samples(DynamicPPL.VarNamedTuple, c) + @test arr_pss == map(p -> p.params, ps) + end + end + + @testset "logp(model, chain)" begin + @model function f() + x ~ Normal() + return y ~ Normal(x) + end + model = f() | (; y=1.0) + chn = sample_from_prior(model, 100; make_chain=true) + xs = chn[@varname(x)] + expected_logprior = logpdf.(Normal(), xs) + expected_loglike = logpdf.(Normal.(xs), 1.0) + + @testset "logprior" begin + lprior = logprior(model, chn) + @test isapprox(lprior, expected_logprior) + @test parent(parent(DD.dims(lprior, :iter))) == FlexiChains.iter_indices(chn) + @test parent(parent(DD.dims(lprior, :chain))) == FlexiChains.chain_indices(chn) + end + @testset "loglikelihood" begin + llike = loglikelihood(model, chn) + @test isapprox(llike, expected_loglike) + @test parent(parent(DD.dims(llike, :iter))) == FlexiChains.iter_indices(chn) + @test parent(parent(DD.dims(llike, :chain))) == FlexiChains.chain_indices(chn) + end + @testset "logjoint" begin + ljoint = logjoint(model, chn) + @test isapprox(ljoint, expected_logprior .+ expected_loglike) + @test parent(parent(DD.dims(ljoint, :iter))) == FlexiChains.iter_indices(chn) + @test parent(parent(DD.dims(ljoint, :chain))) == FlexiChains.chain_indices(chn) + end + + @testset "errors on missing variables" begin + @model function xonly() + return x ~ Normal() + end + @model function xy() + x ~ Normal() + return y ~ Normal() + end + chn = sample_from_prior(xonly(), 100; make_chain=true) + @test_throws "not found in chain" logprior(xy(), chn) + @test_throws "not found in chain" loglikelihood(xy(), chn) + @test_throws "not found in chain" logjoint(xy(), chn) + end + + @testset "with non-standard Array variables" begin + @model function offset_lp(y) + x = OffsetArray(zeros(2), -2:-1) + x[-2] ~ Normal() + y ~ Normal(x[-2]) + return nothing + end + model = offset_lp(2.0) + chn = sample_from_prior(model, 50; make_chain=true) + lprior = logprior(model, chn) + @test logprior(model, chn) ≈ logpdf.(Normal(), chn[@varname(x[-2])]) + @test loglikelihood(model, chn) ≈ logpdf.(Normal.(chn[@varname(x[-2])]), 2.0) + end + end + + @testset "pointwise logprobs" begin + @model function f(y) + x ~ Normal() + return y ~ Normal(x) + end + model = f(1.0) + + chn = sample_from_prior(model, 100; make_chain=true) + xs = chn[@varname(x)] + + @testset "logdensities" begin + pld = DynamicPPL.pointwise_logdensities(model, chn) + @test pld isa VNChain + @test FlexiChains.iter_indices(pld) == FlexiChains.iter_indices(chn) + @test FlexiChains.chain_indices(pld) == FlexiChains.chain_indices(chn) + @test length(keys(pld)) == 2 + @test isapprox(pld[@varname(x)], logpdf.(Normal(), xs)) + @test isapprox(pld[@varname(y)], logpdf.(Normal.(xs), 1.0)) + end + + @testset "loglikelihoods" begin + pld = DynamicPPL.pointwise_loglikelihoods(model, chn) + @test pld isa VNChain + @test FlexiChains.iter_indices(pld) == FlexiChains.iter_indices(chn) + @test FlexiChains.chain_indices(pld) == FlexiChains.chain_indices(chn) + @test length(keys(pld)) == 1 + @test isapprox(pld[@varname(y)], logpdf.(Normal.(xs), 1.0)) + end + + @testset "logpriors" begin + pld = DynamicPPL.pointwise_prior_logdensities(model, chn) + @test pld isa VNChain + @test FlexiChains.iter_indices(pld) == FlexiChains.iter_indices(chn) + @test FlexiChains.chain_indices(pld) == FlexiChains.chain_indices(chn) + @test length(keys(pld)) == 1 + @test isapprox(pld[@varname(x)], logpdf.(Normal(), xs)) + end + + @testset "errors on missing variables" begin + @model function xonly() + return x ~ Normal() + end + @model function xy() + x ~ Normal() + return y ~ Normal() + end + chn = sample_from_prior(xonly(), 100; make_chain=true) + @test_throws "not found in chain" DynamicPPL.pointwise_logdensities(xy(), chn) + @test_throws "not found in chain" DynamicPPL.pointwise_loglikelihoods(xy(), chn) + @test_throws "not found in chain" DynamicPPL.pointwise_prior_logdensities( + xy(), chn + ) + end + + @testset "with non-standard Array variables" begin + @model function offset_pld(y) + x = OffsetArray(zeros(2), -2:-1) + x[-2] ~ Normal() + y ~ Normal(x[-2]) + return nothing + end + model = offset_pld(2.0) + chn = sample_from_prior(model, 50; make_chain=true) + plds = DynamicPPL.pointwise_logdensities(model, chn) + @test plds[@varname(x[-2])] == logpdf.(Normal(), chn[@varname(x[-2])]) + @test plds[@varname(y)] == logpdf.(Normal.(chn[@varname(x[-2])]), 2.0) + end + + @testset "factorize=true" begin + @model function array_pw(y, z) + x ~ MvNormal(zeros(2), I) + return y ~ MvNormal(x, I) + # Doesn't work yet https://github.com/sethaxen/PartitionedDistributions.jl/issues/20 + # z ~ Normal() + end + y = randn(2) + z = randn(2) + model = array_pw(y, z) + chn = sample_from_prior(model, 50; make_chain=true) + + plds = DynamicPPL.pointwise_logdensities(model, chn) + @test plds[@varname(x)] == logpdf.(Ref(MvNormal(zeros(2), I)), chn[@varname(x)]) + @test plds[@varname(y)] == logpdf.(MvNormal.(chn[@varname(x)], Ref(I)), Ref(y)) + + plls = DynamicPPL.pointwise_loglikelihoods(model, chn) + @test plls[@varname(y)] == logpdf.(MvNormal.(chn[@varname(x)], Ref(I)), Ref(y)) + @test !haskey(plls, @varname(x)) + + pplds = DynamicPPL.pointwise_prior_logdensities(model, chn) + @test pplds[@varname(x)] == + logpdf.(Ref(MvNormal(zeros(2), I)), chn[@varname(x)]) + @test !haskey(pplds, @varname(y)) + + plds = DynamicPPL.pointwise_logdensities(model, chn; factorize=true) + for (x_pld, y_pld, x_val) in + zip(plds[@varname(x)], plds[@varname(y)], chn[@varname(x)]) + @test x_pld isa Vector{<:Real} + @test length(x_pld) == 2 + @test x_pld[1] == logpdf(Normal(), x_val[1]) + @test x_pld[2] == logpdf(Normal(), x_val[2]) + @test y_pld isa Vector{<:Real} + @test length(y_pld) == 2 + @test y_pld[1] == logpdf(Normal(x_val[1], 1), y[1]) + @test y_pld[2] == logpdf(Normal(x_val[2], 1), y[2]) + end + + plls = DynamicPPL.pointwise_loglikelihoods(model, chn; factorize=true) + for (y_pll, x_val) in zip(plls[@varname(y)], chn[@varname(x)]) + @test y_pll isa Vector{<:Real} + @test length(y_pll) == 2 + @test y_pll[1] == logpdf(Normal(x_val[1], 1), y[1]) + @test y_pll[2] == logpdf(Normal(x_val[2], 1), y[2]) + end + + pplds = DynamicPPL.pointwise_prior_logdensities(model, chn; factorize=true) + for (x_ppld, x_val) in zip(pplds[@varname(x)], chn[@varname(x)]) + @test x_ppld isa Vector{<:Real} + @test length(x_ppld) == 2 + @test x_ppld[1] == logpdf(Normal(), x_val[1]) + @test x_ppld[2] == logpdf(Normal(), x_val[2]) + end + end + end + + @testset "returned" begin + @model function f() + x ~ Normal() + y ~ MvNormal(zeros(2), I) + return x + y[1] + y[2] + end + model = f() + chn = sample_from_prior(model, 100; make_chain=true) + expected_rtnd = chn[@varname(x)] .+ chn[@varname(y[1])] .+ chn[@varname(y[2])] + + rtnd = returned(model, chn) + @test isapprox(rtnd, expected_rtnd) + @test rtnd isa DD.DimMatrix + @test parent(parent(DD.dims(rtnd, :iter))) == FlexiChains.iter_indices(chn) + @test parent(parent(DD.dims(rtnd, :chain))) == FlexiChains.chain_indices(chn) + + @testset "works even for dists that hasvalue isn't implemented for" begin + @model function f_product() + return x ~ product_distribution((; a=Normal())) + end + model = f_product() + chn = sample_from_prior(model, 100; make_chain=true) + rets = returned(f_product(), chn) + @test chn[@varname(x)] == rets + end + + @testset "errors on missing variables" begin + @model function xonly() + return x ~ Normal() + end + @model function xy() + x ~ Normal() + return y ~ Normal() + end + chn = sample_from_prior(xonly(), 100; make_chain=true) + @test_throws "not found in chain" returned(xy(), chn) + end + + @testset "stacks DimArray return values" begin + @model function return_dimarray() + x ~ Normal() + return DD.DimArray(randn(2, 3), (:a, :b)) + end + chn = sample_from_prior(return_dimarray(), 50; make_chain=true) + rets = returned(return_dimarray(), chn) + @test rets isa DD.DimArray{T,4} where {T} + @test size(rets) == (50, 1, 2, 3) + @test DD.name.(DD.dims(rets)) == (:iter, :chain, :a, :b) + end + + @testset "with non-standard Array variables" begin + # This essentially tests that templates are correctly used when calling + # returned() + @model function offset() + x = OffsetArray(zeros(2), -2:-1) + # Don't sample all elements of `x` to prevent it from being densified, + # thus bypassing the code that we want to check. + x[-2] ~ Normal() + return first(x) + end + model = offset() + chn = sample_from_prior(model, 50; make_chain=true) + rets = returned(model, chn) + @test rets == chn[@varname(x[-2])] + end + end + + @testset "predict" begin + @model function f() + # By default, FlexiChains will store `m` as a single variable. However, this + # also lets us check behaviour after splitting up VarNames (i.e., if the chain + # has m[1] and m[2] but the model has m). + m ~ MvNormal(zeros(2), I) + # Same but with dot tilde; on DPPL v0.40 onwards, the model will have p[1] and + # p[2] but since the VNT is densified before chain construction, the chain will + # have p. + p = zeros(2) + p .~ Normal() + # Then some normal parameters. + x ~ Normal() + return y ~ Normal(x) + end + model = f() | (; y=4.0) + + # Sanity check + chn = sample_from_posterior(StableRNG(468), model) + @test isapprox(mean(chn[@varname(x)]), 2.0; atol=0.1) + @test isapprox(mean(chn[@varname(m[1])]), 0.0; atol=0.1) + @test isapprox(mean(chn[@varname(m[2])]), 0.0; atol=0.1) + @test isapprox(mean(chn[@varname(p[1])]), 0.0; atol=0.1) + @test isapprox(mean(chn[@varname(p[2])]), 0.0; atol=0.1) + + @testset "chain values are actually used" begin + pdns = predict(StableRNG(468), f(), chn) + # Sanity check. + @test pdns[@varname(x)] == chn[@varname(x)] + @test pdns[@varname(m)] == chn[@varname(m)] + @test pdns[@varname(p)] == chn[@varname(p)] + # Since the model was conditioned with y = 4.0, we should + # expect that the chain's mean of x is approx 2.0. + # So the posterior predictions for y should be centred on + # 2.0 (ish). + @test isapprox(mean(pdns[@varname(y)]), 2.0; atol=0.1) + end + + @testset "logp" begin + pdns = predict(f(), chn) + # Since we deconditioned `y`, there are no likelihood terms. + @test all(iszero, pdns[FlexiChains._LOGLIKELIHOOD_KEY]) + # The logprior should be the same as that of the original chain, but + # with an extra term for y ~ Normal(x) + chn_logprior = chn[FlexiChains._LOGPRIOR_KEY] + pdns_logprior = pdns[FlexiChains._LOGPRIOR_KEY] + expected_diff = logpdf.(Normal.(chn[@varname(x)]), pdns[@varname(y)]) + @test isapprox(pdns_logprior, chn_logprior .+ expected_diff) + # Logjoint should be the same as logprior + @test pdns[FlexiChains._LOGJOINT_KEY] == pdns[FlexiChains._LOGPRIOR_KEY] + end + + @testset "non-parameter keys are preserved" begin + pdns = predict(f(), chn) + display(chn) + display(pdns) + # Check that the only new thing added was the prediction for y. + @test only(setdiff(Set(keys(pdns)), Set(keys(chn)))) == Parameter(@varname(y)) + # Check that no other keys originally in `chn` were removed. + @test isempty(setdiff(Set(keys(chn)), Set(keys(pdns)))) + end + + @testset "include_all=false" begin + pdns = predict(f(), chn; include_all=false) + # Check that the only parameter in the chain is the prediction for y. + @test only(Set(FlexiChains.parameters(pdns))) == @varname(y) + end + + @testset "indices are preserved" begin + pdns = predict(f(), chn) + @test FlexiChains.iter_indices(pdns) == FlexiChains.iter_indices(chn) + @test FlexiChains.chain_indices(pdns) == FlexiChains.chain_indices(chn) + end + + @testset "no sampling time and sampler state" begin + # it just doesn't really make sense for the predictions to carry those + # information + pdns = predict(f(), chn) + @test all(ismissing, FlexiChains.sampling_time(pdns)) + @test all(ismissing, FlexiChains.last_sampler_state(pdns)) + end + + @testset "rng is respected" begin + pdns1 = predict(Xoshiro(468), f(), chn) + pdns2 = predict(Xoshiro(468), f(), chn) + @test FlexiChains.has_same_data(pdns1, pdns2) + pdns3 = predict(Xoshiro(469), f(), chn) + @test !FlexiChains.has_same_data(pdns1, pdns3) + end + + @testset "with non-standard Array variables" begin + # This essentially tests that templates are correctly used when calling + # predict(). + @model function offset2() + x = OffsetArray(zeros(2), -2:-1) + # Don't sample all elements of `x` to prevent it from being densified, + # thus bypassing the code that we want to check. + x[-2] ~ Normal() + return y ~ Normal(x[-2]) + end + cond_model = offset2() | (; y=2.0) + chn = sample_from_posterior(StableRNG(468), cond_model) + @test mean(chn[@varname(x[-2])]) ≈ 1.0 atol = 0.05 + pdns = predict(StableRNG(468), offset2(), chn) + @test pdns[@varname(x[-2])] == chn[@varname(x[-2])] + @test mean(pdns[@varname(y)]) ≈ 1.0 atol = 0.05 + end + end + + @testset "Models with variable-length parameters" begin + # These tests are mainly to check the interaction of VarNamedTuple with chains. + @testset "single variable" begin + @model function varlen_single() + n ~ DiscreteUniform(2, 5) + x ~ MvNormal(zeros(n), I) + y ~ Normal(sum(x)) + return prod(x) + end + cond_model = varlen_single() | (; y=1.0) + chn = sample_from_prior(cond_model, 100; make_chain=true) + # Sanity check + @test chn[@varname(n)] == length.(chn[@varname(x)]) + # Check that returned and predict both work. For returned we can also + # check correctness, but for predict we just check that it runs. + @test isapprox(returned(cond_model, chn), prod.(chn[@varname(x)])) + pdns = predict(varlen_single(), chn) + @test pdns isa VNChain + for vn in FlexiChains.parameters(chn) + @test pdns[vn] == chn[vn] + end + @test @varname(y) in FlexiChains.parameters(pdns) + end + + @testset "dense vector" begin + # For this model, `x` should still be represented in the chain as a single + # variable, since the PartialArray will get densified. + @model function varlen_dense() + n ~ DiscreteUniform(2, 5) + x = zeros(n) + x .~ Normal() + y ~ Normal(sum(x)) + return prod(x) + end + cond_model = varlen_dense() | (; y=1.0) + chn = sample_from_prior(cond_model, 100; make_chain=true) + # Sanity check + @test chn[@varname(n)] == length.(chn[@varname(x)]) + # Check that returned and predict both work. For returned we can also + # check correctness, but for predict we just check that it runs. + @test isapprox(returned(cond_model, chn), prod.(chn[@varname(x)])) + pdns = predict(varlen_dense(), chn) + @test pdns isa VNChain + for vn in FlexiChains.parameters(chn) + @test pdns[vn] == chn[vn] + end + @test @varname(y) in FlexiChains.parameters(pdns) + end + + @testset "nondense (sparse?) vector" begin + # For this model, `x` will be broken up in the chain, because not + # all entries in the PartialArray are filled + @model function varlen_nondense() + n ~ DiscreteUniform(2, 5) + x = zeros(n + 2) + for i in 1:n + x[i] ~ Normal() + end + y ~ Normal(sum(x[1:n])) + return prod(x[1:n]) + end + cond_model = varlen_nondense() | (; y=1.0) + chn = sample_from_prior(cond_model, 100; make_chain=true) + # Check that returned and predict both work. + @test returned(cond_model, chn) isa DD.DimArray + pdns = predict(varlen_nondense(), chn) + @test pdns isa VNChain + for vn in FlexiChains.parameters(chn) + @test isequal(pdns[vn], chn[vn]) # might have missing so need isequal + end + @test @varname(y) in FlexiChains.parameters(pdns) + end + end + + @testset "PosteriorStats.loo" begin + @testset "no factorisation" begin + @model function f(y) + x ~ Normal() + return y .~ Normal(x) + end + model = f(randn(10)) + chain = sample_from_prior(model, 500, 3; make_chain=true) + result = PosteriorStats.loo(model, chain) + @test result.param_names == [@varname(y[i]) for i in 1:10] + @test result.loo isa PosteriorStats.PSISLOOResult + end + + @testset "factorize kwarg" begin + @model function farray(y) + x ~ MvNormal(zeros(2), I) + return y ~ MvNormal(x, I) + end + model = farray(randn(2)) + chain = sample_from_prior(model, 500, 3; make_chain=true) + result = PosteriorStats.loo(model, chain; factorize=true) + @test result.param_names == [@varname(y[i]) for i in 1:2] + @test result.loo isa PosteriorStats.PSISLOOResult + + result = PosteriorStats.loo(model, chain; factorize=false) + @test result.param_names == [@varname(y)] + @test result.loo isa PosteriorStats.PSISLOOResult + end + end +end + +@info "Completed $(@__FILE__) in $(now() - __now__)." + +end # module From 7c4291f6b91fa85fd8dbb1dc95c4d2acf9906420 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 May 2026 17:17:29 +0100 Subject: [PATCH 2/4] remove duplicated comment --- ext/DynamicPPLFlexiChainsExt.jl | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/ext/DynamicPPLFlexiChainsExt.jl b/ext/DynamicPPLFlexiChainsExt.jl index e4a18b0ac..9ec7cfe8b 100644 --- a/ext/DynamicPPLFlexiChainsExt.jl +++ b/ext/DynamicPPLFlexiChainsExt.jl @@ -549,28 +549,6 @@ function _pointwise_logprobs( ) end -# Copied from DynamicPPL -const _FACTORIZE_KWARG_DOC = """ -If `factorize=true`, additionally attempt to provide factorised log-densities for -distributions that can be partitioned into blocks, using PartitionedDistributions.jl. - -For example, if `factorize=true`, then `y ~ MvNormal(...)` will return a vector of -log-densities, one for each element of `y`. The `i`-th element of this vector will be the -conditional log-probability of `y[i]` given all the other elements of `y` (often denoted -`log p(y_{i} | y_{-i})`): in particular this is exactly the log-density required for -leave-one-out cross-validation. - -In contrast, if `factorize=false`, then the log-density for `y ~ MvNormal(...)` will be a -single scalar corresponding to `logpdf(MvNormal(...), y)`. - -Note that the sum of the factorised log-densities may not, in general, be equal to the -log-density of the full distribution: they will only be equal if the original distribution -can be completely factorised into independent components. For example, if `y ~ MvNormal(μ, -Σ)` where `Σ` is diagonal, then each element of `y` is independent and the sum of the -factorised log-densities will be equal to the log-density of the full distribution. In -contrast, if `Σ` has off-diagonal entries, then the elements of `y` are not independent. -""" - """ DynamicPPL.pointwise_logdensities( model::Model, @@ -584,7 +562,7 @@ iteration in the `FlexiChain`. Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the variables to their log probabilities. -$(_FACTORIZE_KWARG_DOC) +$(DynamicPPL._FACTORIZE_KWARG_DOC) """ function DynamicPPL.pointwise_logdensities( model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false @@ -605,7 +583,7 @@ iteration in the `FlexiChain`. Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the observed variables to their log probabilities. -$(_FACTORIZE_KWARG_DOC) +$(DynamicPPL._FACTORIZE_KWARG_DOC) """ function DynamicPPL.pointwise_loglikelihoods( model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false @@ -626,7 +604,7 @@ iteration in the `FlexiChain`. Returns a new `FlexiChain` with the same structure as the input `chain`, mapping the observed variables to their log probabilities. -$(_FACTORIZE_KWARG_DOC) +$(DynamicPPL._FACTORIZE_KWARG_DOC) """ function DynamicPPL.pointwise_prior_logdensities( model::DynamicPPL.Model, chain::FlexiChain{<:VarName}; factorize::Bool=false From a680ad9a2b49056f279e1aeb71e2e4e27b744131 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 May 2026 17:28:44 +0100 Subject: [PATCH 3/4] forgot to actually test --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 1ba744c3f..f137a14eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,7 @@ Random.seed!(100) @testset "extensions" begin include("ext/DynamicPPLMarginalLogDensitiesExt.jl") include("ext/DynamicPPLMCMCChainsExt.jl") + include("ext/DynamicPPLFlexiChainsExt.jl") end @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") From 13e83763f29c899f9e312eb213f95f68df454d15 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 May 2026 17:34:12 +0100 Subject: [PATCH 4/4] fix some typos --- ext/DynamicPPLFlexiChainsExt.jl | 2 +- ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl | 9 ++++++--- test/ext/DynamicPPLFlexiChainsExt.jl | 2 -- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ext/DynamicPPLFlexiChainsExt.jl b/ext/DynamicPPLFlexiChainsExt.jl index 9ec7cfe8b..3034f864e 100644 --- a/ext/DynamicPPLFlexiChainsExt.jl +++ b/ext/DynamicPPLFlexiChainsExt.jl @@ -638,4 +638,4 @@ struct NotASampler <: AbstractMCMC.AbstractSampler end end end -end # module FlexiChainsDynamicPPLExt +end # module DynamicPPLFlexiChainsExt diff --git a/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl b/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl index 47edfaff3..d875ba413 100644 --- a/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl +++ b/ext/DynamicPPLFlexiChainsPosteriorStatsExt.jl @@ -1,13 +1,13 @@ module DynamicPPLFlexiChainsPosteriorStatsExt using PosteriorStats: PosteriorStats -using FlexiChains: FlexiChain +using FlexiChains: FlexiChain, VarName using DynamicPPL: DynamicPPL """ PosteriorStats.loo( model::DynamicPPL.Model, - posterior_chn::FlexiChains; + posterior_chn::FlexiChain{<:VarName}; factorize::Bool=false, kwargs... ) @@ -32,7 +32,10 @@ partitioned into blocks (e.g. `MvNormal`). Please see the docstring of Additional keyword arguments are forwarded to [`PosteriorStats.loo`](@extref). """ function PosteriorStats.loo( - model::DynamicPPL.Model, posterior_chn::FlexiChain; factorize::Bool=false, kwargs... + model::DynamicPPL.Model, + posterior_chn::FlexiChain{<:VarName}; + factorize::Bool=false, + kwargs..., ) lls_chn = DynamicPPL.pointwise_loglikelihoods(model, posterior_chn; factorize=factorize) return PosteriorStats.loo(lls_chn; kwargs...) diff --git a/test/ext/DynamicPPLFlexiChainsExt.jl b/test/ext/DynamicPPLFlexiChainsExt.jl index 0bb057c67..7187aff99 100644 --- a/test/ext/DynamicPPLFlexiChainsExt.jl +++ b/test/ext/DynamicPPLFlexiChainsExt.jl @@ -489,8 +489,6 @@ end @testset "non-parameter keys are preserved" begin pdns = predict(f(), chn) - display(chn) - display(pdns) # Check that the only new thing added was the prediction for y. @test only(setdiff(Set(keys(pdns)), Set(keys(chn)))) == Parameter(@varname(y)) # Check that no other keys originally in `chn` were removed.