diff --git a/AGENTS.md b/AGENTS.md index c051f5269..303401071 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -28,7 +28,7 @@ DynamicPPL builds on AbstractPPL.jl for shared PPL interfaces such as `VarName`, - `VarName` (AbstractPPL): address for model variables, including nested fields/indices. - `VarNamedTuple` (`src/varnamedtuple.jl`): named-tuple-like parameter storage keyed by `VarName`. - `LogDensityFunction` (`src/logdensityfunction.jl`): bridge from named parameters to flat `AbstractVector{<:Real}` for samplers, optimisers, and AD via LogDensityProblems.jl. - - `ext/`: `DynamicPPLForwardDiffExt`, `DynamicPPLMooncakeExt`, `DynamicPPLReverseDiffExt`, `DynamicPPLEnzymeCoreExt`, `DynamicPPLComponentArraysExt`, `DynamicPPLMCMCChainsExt`, and `DynamicPPLMarginalLogDensitiesExt`. + - `ext/`: `DynamicPPLForwardDiffExt`, `DynamicPPLMooncakeExt`, `DynamicPPLReverseDiffExt`, `DynamicPPLEnzymeCoreExt`, `DynamicPPLComponentArraysExt`, and `DynamicPPLMCMCChainsExt`. - `DynamicPPL.TestUtils`: analytical test models (`logprior_true`, `loglikelihood_true`, etc.), `run_ad`, `ADResult`. ## DynamicPPL Invariants diff --git a/Project.toml b/Project.toml index f1be81a23..8b8e6fe1b 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -44,7 +43,6 @@ DynamicPPLComponentArraysExt = ["ComponentArrays"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake", "DifferentiationInterface"] DynamicPPLReverseDiffExt = ["ReverseDiff"] @@ -71,7 +69,6 @@ LinearAlgebra = "1.6" LogDensityProblems = "2" MCMCChains = "6, 7" MacroTools = "0.5.6" -MarginalLogDensities = "0.4.3" Mooncake = "0.4.147, 0.5" OrderedCollections = "1" PartitionedDistributions = "0.0.1" diff --git a/docs/make.jl b/docs/make.jl index 83380bddc..8c2a2a76a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,6 @@ using AbstractPPL using Distributions using DocumenterMermaid using MCMCChains -using MarginalLogDensities using AbstractMCMC: AbstractMCMC using Random @@ -20,7 +19,6 @@ DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true ) # Need this to document a method which uses a type inside the extension -DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt) links = InterLinks("AbstractPPL" => "https://turinglang.org/AbstractPPL.jl/stable/") @@ -34,7 +32,6 @@ makedocs(; modules=[ DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt), - Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt), ], pages=[ "Home" => "index.md", diff --git a/docs/src/api.md b/docs/src/api.md index 33362634c..a4ceb8950 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -158,17 +158,6 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a - `include_all=false` (default): Include only newly predicted variables - `include_all=true`: Include both parameters from the original chain and predicted variables -## Marginalisation - -DynamicPPL provides the `marginalize` function to marginalise out variables from a model. -This requires `MarginalLogDensities.jl` to be loaded in your environment. - -```@docs -marginalize -``` - -A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability. -To retrieve a VarInfo object from it, you can use [`InitFromVector`](@ref). ## Models within models diff --git a/docs/src/init.md b/docs/src/init.md index 0093b6c0e..af959bf3c 100644 --- a/docs/src/init.md +++ b/docs/src/init.md @@ -24,7 +24,6 @@ InitFromParams(::ParamsWithStats, ::Union{Nothing,AbstractInitStrategy}) InitFromUniform InitFromVector InitFromVector(::AbstractVector{<:Real}, ::LogDensityFunction) -InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing}) ``` However, sometimes you will need to implement your own initialisation strategy. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl deleted file mode 100644 index 905ee168b..000000000 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ /dev/null @@ -1,215 +0,0 @@ -module DynamicPPLMarginalLogDensitiesExt - -using DynamicPPL: DynamicPPL, VarName, RangeAndTransform -using LogDensityProblems: LogDensityProblems -using MarginalLogDensities: MarginalLogDensities - -# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by -# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type -# below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} - ldf::L -end -function (lw::LogDensityFunctionWrapper)(x, _) - return LogDensityProblems.logdensity(lw.ldf, x) -end - -""" - DynamicPPL.marginalize( - model::DynamicPPL.Model, - marginalized_varnames::AbstractVector{<:VarName}; - transform_strategy::DynamicPPL.AbstractTransformStrategy=DynamicPPL.LinkAll(), - getlogprob=DynamicPPL.getlogjoint, - method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); - kwargs..., - ) - -Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal -log-density of the given `model`, after marginalizing out the variables specified in -`varnames`. - -The resulting object can be called with a vector of parameter values to compute the marginal -log-density. - -## Keyword arguments - -- `transform_strategy`: The transform strategy to use for the model, which determines - whether the marginalisation is performed in the original (possibly constrained) space - or in a transformed (unconstrained) space. By default, this is `DynamicPPL.LinkAll()`, - which transforms all variables to unconstrained space. - - To avoid this transformation and perform the marginalisation in the original space, use - `DynamicPPL.UnlinkAll()`. You can also use fixed transforms which can in specific - circumstances improve performance: see the DynamicPPL documentation for more details. - -- `getlogprob`: A function which specifies which kind of marginal log-density to compute. - Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint - probability in the original space (i.e., log-Jacobians from the transformation to - unconstrained space are ignored). - -- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the - MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) - for other options. - -- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` - constructor. - -## Example - -```jldoctest -julia> using DynamicPPL, Distributions, MarginalLogDensities - -julia> @model function demo() - x ~ Normal(1.0) - y ~ Normal(2.0) - end -demo (generic function with 2 methods) - -julia> marginalized = marginalize(demo(), [@varname(x)]); - -julia> # The resulting callable computes the marginal log-density of `y`. - marginalized([1.0]) --1.4189385332046727 - -julia> logpdf(Normal(2.0), 1.0) --1.4189385332046727 -``` - - -!!! warning - - The default usage of `DynamicPPL.LinkAll()` means that, for example, optimization of the - marginal log-density can be performed in unconstrained space. However, care must be - taken if the model contains variables where the link transformation depends on a - marginalized variable. For example: - - ```julia - @model function f() - x ~ Normal() - y ~ truncated(Normal(); lower=x) - end - ``` - - Here, the support of `y`, and hence the link transformation used, depends on the value - of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of - `y` to log-probabilities. However, it will not be possible to use DynamicPPL to - correctly retrieve _unlinked_ values of `y`. -""" -function DynamicPPL.marginalize( - model::DynamicPPL.Model, - marginalized_varnames::AbstractVector{<:VarName}; - transform_strategy::DynamicPPL.AbstractTransformStrategy=DynamicPPL.LinkAll(), - getlogprob::Function=DynamicPPL.getlogjoint, - method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), - kwargs..., -) - # Construct the marginal log-density model. - ldf = DynamicPPL.LogDensityFunction(model, getlogprob, transform_strategy) - initial_params = rand(ldf) - # Determine the indices for the variables to marginalise out. - varindices = mapreduce(vcat, marginalized_varnames) do vn - DynamicPPL.get_range_and_transform(ldf, vn).range - end - mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(ldf), initial_params, varindices, (), method; kwargs... - ) - return mld -end - -""" - DynamicPPL.InitFromVector( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, - unmarginalized_params::Union{AbstractVector,Nothing}=nothing, - ) - -Create a new initialisation strategy using the parameter values stored in `mld` (and -optionally `unmarginalized_params`). - -If a Laplace approximation was used for the marginalisation, the values of the marginalized -parameters are set to their mode (note that this only happens if the `mld` object has been -used to compute the marginal log-density at least once, so that the mode has been computed). - -If a vector of `unmarginalized_params` is specified, the values for the corresponding -parameters will also be available as part of the initialisation strategy. This vector may be -obtained e.g. by performing an optimization of the marginal log-density. - -To use this initialisation strategy to obtain e.g. updated log-probabilities, you should -re-evaluate the model with the values inside the returned VarInfo, for example using: - -```julia -init_strategy = DynamicPPL.InitFromVector(mld, unmarginalized_params) -accs = DynamicPPL.OnlyAccsVarInfo(( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.RawValueAccumulator(false), - # ... whatever else you need -)) -_, accs = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) -``` - -You can then extract all the updated data from `accs` using DynamicPPL's existing API (see -the DynamicPPL documentation for more details). - -## Example - -```jldoctest -julia> using DynamicPPL, Distributions, MarginalLogDensities - -julia> @model function demo() - x ~ Normal() - y ~ Beta(2, 2) - end -demo (generic function with 2 methods) - -julia> # Note that by default `marginalize` uses a linked VarInfo. - mld = marginalize(demo(), [@varname(x)]); - -julia> using MarginalLogDensities: Optimization, OptimizationOptimJL - -julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. - y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) -OptimizationProblem. In-place: true -u0: 1-element Vector{Float64}: - 2.0 - -julia> # This tells us the optimal (linked) value of `y` is around 0. - opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) -retcode: Success -u: 1-element Vector{Float64}: - 4.88281250001733e-5 - -julia> # Get the an initialisation strategy representing the mode of `y`. - init_strategy = InitFromVector(mld, opt_solution.u); - -julia> # Evaluate the model with this initialisation strategy. - accs = DynamicPPL.OnlyAccsVarInfo(( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.RawValueAccumulator(false), - )); - _, accs = DynamicPPL.init!!(demo(), accs, init_strategy, DynamicPPL.UnlinkAll()); - -julia> # Extract the raw (i.e. untransformed) values for all variables. - # `x` is set to its mode (which for `Normal()` is zero). - # Furthermore, `y` is set to the optimal value we found above. - vals = DynamicPPL.get_raw_values(accs) -VarNamedTuple -├─ x => 0.0 -└─ y => 0.5000122070312476 -``` -""" -function DynamicPPL.InitFromVector( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, - unmarginalized_params::Union{AbstractVector,Nothing}=nothing, -) - # Extract the stored parameters, which includes the modes for any marginalized - # parameters - full_params = MarginalLogDensities.cached_params(mld) - # We can then (if needed) set the values for any non-marginalized parameters - if unmarginalized_params !== nothing - full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params - end - return DynamicPPL.InitFromVector(full_params, mld.logdensity.ldf) -end - -end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 3001210c8..9a38694d1 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -211,7 +211,6 @@ export AbstractVarInfo, @addlogprob!, check_model, set_logprob_type!, - marginalize, # Deprecated. generated_quantities, typed_identity @@ -283,28 +282,6 @@ include("test_utils.jl") include("deprecated.jl") -# Extended in MarginalLogDensitiesExt -function marginalize end -if isdefined(Base.Experimental, :register_error_hint) - function __init__() - # Same for MarginalLogDensities.jl - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ - requires_mld = - exc.f === DynamicPPL.marginalize && - length(argtypes) == 2 && - argtypes[1] <: Model && - argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}} - if requires_mld - printstyled( - io, - "\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n"; - color=:cyan, - bold=true, - ) - end - end - end -end # Standard tag: Improves stacktraces # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0b1c250b9..0435d5f59 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -399,8 +399,11 @@ Return an iterator over all `vns` in `vi`. """ getindex(vi::AbstractVarInfo, ::Colon) -Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) -distribution(s) as a flattened `Vector`. +Return the internal value(s) stored in `vi` as a flattened `Vector`. Note that +these values may be in transformed (linked) space if `vi` has been linked. + +For untransformed values, use [`getindex_internal`](@ref) after calling +[`invlink`](@ref) on `vi`. The default implementation is to call [`internal_values_as_vector`](@ref). """ diff --git a/test/Project.toml b/test/Project.toml index 944c32102..ed172d638 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,6 @@ InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -50,7 +49,6 @@ ForwardDiff = "0.10.12, 1" InvertedIndices = "1" LogDensityProblems = "2" MCMCChains = "7.2.1" -MarginalLogDensities = "0.4" Mooncake = "0.4, 0.5" OffsetArrays = "1" OrderedCollections = "1" diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl deleted file mode 100644 index e16b21a3c..000000000 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ /dev/null @@ -1,112 +0,0 @@ -module MarginalLogDensitiesExtTests - -using Bijectors: Bijectors -using DynamicPPL, Distributions, Test -using MarginalLogDensities - -@testset "MarginalLogDensities" begin - @testset "Basic usage" begin - @model function demo() - x ~ MvNormal(zeros(2), [1, 1]) - return y ~ Normal(0, 1) - end - model = demo() - # Marginalize out `x`. - @testset for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] - marginalized = marginalize( - model, - [@varname(x)]; - transform_strategy=UnlinkAll(), - getlogprob=getlogprob, - hess_adtype=AutoForwardDiff(), - ) - for y in range(-5, 5; length=100) - @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 - end - end - end - - @testset "Respects linked status of VarInfo" begin - @model function f() - x ~ Normal() - return y ~ Beta(2, 2) - end - model = f() - - @testset "unlinked" begin - mx = marginalize(model, [@varname(x)]; transform_strategy=UnlinkAll()) - for x in range(0.01, 0.99; length=10) - @test mx([x]) ≈ logpdf(Beta(2, 2), x) - end - # generally when marginalising Beta it doesn't go to zero - # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 - my = marginalize(model, [@varname(y)]; transform_strategy=UnlinkAll()) - diff = my([0.0]) - logpdf(Normal(), 0.0) - for x in range(-5, 5; length=10) - @test my([x]) ≈ logpdf(Normal(), x) + diff - end - end - - @testset "linked VarInfo" begin - mx = marginalize(model, [@varname(x)]; transform_strategy=LinkAll()) - binv = Bijectors.VectorBijectors.from_linked_vec(Beta(2, 2)) - for y_linked in range(-5, 5; length=10) - y_unlinked = binv([y_linked]) - @test mx([y_linked]) ≈ logpdf(Beta(2, 2), y_unlinked) - end - # generally when marginalising Beta it doesn't go to zero - # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 - my = marginalize(model, [@varname(y)]; transform_strategy=LinkAll()) - diff = my([0.0]) - logpdf(Normal(), 0.0) - for x in range(-5, 5; length=10) - @test my([x]) ≈ logpdf(Normal(), x) + diff - end - end - end - - @testset "retrieving VarInfo from MLD" begin - @model function f() - x ~ Normal() - return y ~ Beta(2, 2) - end - model = f() - vi_unlinked = VarInfo(model) - vi_linked = DynamicPPL.link(vi_unlinked, model) - - function get_raw_values_from_init_strat(model, init_strat) - accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) - _, accs = DynamicPPL.init!!(model, accs, init_strat, DynamicPPL.UnlinkAll()) - return DynamicPPL.get_raw_values(accs) - end - - @testset "unlinked VarInfo" begin - mx = marginalize(model, [@varname(x)]; transform_strategy=UnlinkAll()) - mx([0.5]) # evaluate at some point to force calculation of Laplace approx - init_strat = InitFromVector(mx) - vnt = get_raw_values_from_init_strat(model, init_strat) - @test vnt[@varname(x)] ≈ mode(Normal()) - - init_strat = InitFromVector(mx, [0.5]) # this 0.5 is unlinked - vnt = get_raw_values_from_init_strat(model, init_strat) - @test vnt[@varname(x)] ≈ mode(Normal()) - @test vnt[@varname(y)] ≈ 0.5 - end - - @testset "linked VarInfo" begin - mx = marginalize(model, [@varname(x)]; transform_strategy=LinkAll()) - mx([0.5]) # evaluate at some point to force calculation of Laplace approx - init_strat = InitFromVector(mx) - vnt = get_raw_values_from_init_strat(model, init_strat) - @test vnt[@varname(x)] ≈ mode(Normal()) - - init_strat = InitFromVector(mx, [0.5]) # this 0.5 is linked - vnt = get_raw_values_from_init_strat(model, init_strat) - binv = Bijectors.VectorBijectors.from_linked_vec(Beta(2, 2)) - @test vnt[@varname(x)] ≈ mode(Normal()) - # when using getindex it always returns unlinked values - @test vnt[@varname(y)] ≈ binv([0.5]) - end - end -end - -end diff --git a/test/runtests.jl b/test/runtests.jl index 1ba744c3f..bdd14521c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,7 +49,6 @@ Random.seed!(100) include("transformed_values.jl") include("logdensityfunction.jl") @testset "extensions" begin - include("ext/DynamicPPLMarginalLogDensitiesExt.jl") include("ext/DynamicPPLMCMCChainsExt.jl") end @testset "ad" begin