Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# 0.41.6

Add a `factorize::Bool` keyword argument for `pointwise_logdensities(model, values)`, which controls whether pointwise logdensities for factorisable distributions (e.g. `MvNormal`, `product_distribution`, etc.) are returned as a single log-density for the whole distribution, or as an array of log-densities for each factor.
The same argument is also added to `pointwise_loglikelihoods` and `pointwise_prior_logdensities`.

# 0.41.5

Make sure that `DynamicPPL.TestUtils.AD.run_ad(...; verbose=false)` _truly_ silences all messages.
Expand Down
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.41.5"
version = "0.41.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PartitionedDistributions = "569bd051-8d7b-4221-bcb8-d78512b5866a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -40,8 +41,8 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
[extensions]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMooncakeExt = ["Mooncake", "DifferentiationInterface"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]

Expand All @@ -65,11 +66,12 @@ InteractiveUtils = "1"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
MarginalLogDensities = "0.4.3"
MCMCChains = "6, 7"
MacroTools = "0.5.6"
MarginalLogDensities = "0.4.3"
Mooncake = "0.4.147, 0.5"
OrderedCollections = "1"
PartitionedDistributions = "0.0.1"
PrecompileTools = "1.2.1"
Preferences = "1.5.2"
Printf = "1.10"
Expand Down
38 changes: 27 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,11 @@ function _pointwise_logdensities_chain(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
::Val{Likelihood}=Val(true);
factorize=false,
) where {Prior,Likelihood}
acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}(
DynamicPPL.PointwiseLogProb{Prior,Likelihood}()
DynamicPPL.PointwiseLogProb{Prior,Likelihood,factorize}()
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
# Reevaluating this gives us a VNT of log probs. We can densify and then wrap in
Expand All @@ -443,12 +444,15 @@ end
"""
DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
chain::MCMCChains.Chains;
factorize=false,
)

Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
the log-density of each variable at each sample is stored (rather than its value).

$(DynamicPPL._FACTORIZE_KWARG_DOC)

See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
[`DynamicPPL.pointwise_prior_logdensities`](@ref).

Expand Down Expand Up @@ -507,43 +511,55 @@ julia> # The above is the same as:
```
"""
function DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains
model::DynamicPPL.Model, chain::MCMCChains.Chains; factorize=false
)
return _pointwise_logdensities_chain(model, chain, Val(true), Val(true))
return _pointwise_logdensities_chain(
model, chain, Val(true), Val(true); factorize=factorize
)
end

"""
DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
chain::MCMCChains.Chains;
factorize=false,
)

Compute the pointwise log-likelihoods of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the likelihood terms.

$(DynamicPPL._FACTORIZE_KWARG_DOC)

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
"""
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, chain::MCMCChains.Chains
model::DynamicPPL.Model, chain::MCMCChains.Chains; factorize=false
)
return _pointwise_logdensities_chain(model, chain, Val(false), Val(true))
return _pointwise_logdensities_chain(
model, chain, Val(false), Val(true); factorize=factorize
)
end

"""
DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains
chain::MCMCChains.Chains;
factorize=false,
)

Compute the pointwise log-prior-densities of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the prior terms.

$(DynamicPPL._FACTORIZE_KWARG_DOC)

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
"""
function DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains
model::DynamicPPL.Model, chain::MCMCChains.Chains; factorize=false
)
return _pointwise_logdensities_chain(model, chain, Val(true), Val(false))
return _pointwise_logdensities_chain(
model, chain, Val(true), Val(false); factorize=factorize
)
end

"""
Expand Down
129 changes: 96 additions & 33 deletions src/accumulators/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,65 @@
import PartitionedDistributions

# Force specialisation on D and V.
function _maybe_pointwise_logpdf(dist::D, value::V, ::Val{true}) where {D<:Distribution,V}
return if hasmethod(
PartitionedDistributions.pointwise_conditional_logpdfs,
Tuple{typeof(dist),typeof(value)},
)
PartitionedDistributions.pointwise_conditional_logpdfs(dist, value)
else
logpdf(dist, value)
end
end
function _maybe_pointwise_logpdf(dist::Distribution, value, ::Val{false})
return logpdf(dist, value)
end

const _FACTORIZE_KWARG_DOC = """
If `factorize=true`, additionally attempt to provide factorised log-densities for
distributions that can be partitioned into blocks, using PartitionedDistributions.jl.

For example, if `factorize=true`, then `y ~ MvNormal(...)` will return a vector of
log-densities, one for each element of `y`. The `i`-th element of this vector will be the
conditional log-probability of `y[i]` given all the other elements of `y` (often denoted
`log p(y_{i} | y_{-i})`): in particular this is exactly the log-density required for
leave-one-out cross-validation.

In contrast, if `factorize=false`, then the log-density for `y ~ MvNormal(...)` will be a
single scalar corresponding to `logpdf(MvNormal(...), y)`.

Note that the sum of the factorised log-densities may not, in general, be equal to the
log-density of the full distribution: they will only be equal if the original distribution
can be completely factorised into independent components. For example, if `y ~ MvNormal(μ,
Σ)` where `Σ` is diagonal, then each element of `y` is independent and the sum of the
factorised log-densities will be equal to the log-density of the full distribution. In
contrast, if `Σ` has off-diagonal entries, then the elements of `y` are not independent.
"""
Comment on lines +19 to +37
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethaxen Would you be willing to sense-check my docstring?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you want to review generally please feel free to as well, but it's also not your job so please don't feel obliged 🙂)


"""
PointwiseLogProb{Prior,Likelihood}
PointwiseLogProb{Prior,Likelihood,Factorised}

A callable struct that computes the log probability of a given value under a distribution.
The `Prior` and `Likelihood` type parameters are used to control whether the log probability
is computed for prior or likelihood terms, respectively.
is computed for prior or likelihood terms, respectively. The `Factorised` type parameter
controls whether to attempt to factorise the log-densities.

This struct is used in conjunction with `VNTAccumulator`, via

acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}())
acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood,Factorised}())

where `Prior` and `Likelihood` are the boolean type parameters. This accumulator will then
store the log-probabilities for all tilde-statements in the model.
where `Prior`, `Likelihood`, and `Factorised` are the boolean type parameters. This
accumulator will then store the log-probabilities for all tilde-statements in the model.
"""
struct PointwiseLogProb{Prior,Likelihood} end
struct PointwiseLogProb{Prior,Likelihood,Factorised} end
Base.copy(plp::PointwiseLogProb) = plp
function (plp::PointwiseLogProb{Prior,Likelihood})(
function (plp::PointwiseLogProb{Prior,Likelihood,Factorised})(
val, tval, logjac, vn, dist
) where {Prior,Likelihood}
if Prior
return logpdf(dist, val)
) where {Prior,Likelihood,Factorised}
return if Prior
_maybe_pointwise_logpdf(dist, val, Val{Factorised}())
else
return DoNotAccumulate()
DoNotAccumulate()
end
end
const POINTWISE_ACCNAME = :PointwiseLogProb
Expand All @@ -33,17 +72,17 @@ end
# Have to overload accumulate_assume!! since VNTAccumulator by default does not track
# observe statements.
function accumulate_observe!!(
acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood}},
acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood,Factorised}},
right,
left,
vn,
template,
) where {Prior,Likelihood}
) where {Prior,Likelihood,Factorised}
# vn could be `nothing`, in which case we can't store it in a VNT.
return if Likelihood && vn isa VarName
logp = logpdf(right, left)
logp = _maybe_pointwise_logpdf(right, left, Val{Factorised}())
new_values = DynamicPPL.templated_setindex!!(acc.values, logp, vn, template)
return VNTAccumulator{POINTWISE_ACCNAME}(acc.f, new_values)
VNTAccumulator{POINTWISE_ACCNAME}(acc.f, new_values)
else
# No need to accumulate likelihoods.
acc
Expand All @@ -55,7 +94,8 @@ end
model::Model,
init_strat::AbstractInitStrategy,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
::Val{Likelihood}=Val(true);
factorize=false
) where {Prior,Likelihood}

Shared internal function that computes pointwise log-densities (either priors, likelihoods,
Expand All @@ -65,9 +105,10 @@ function _pointwise_logdensities(
model::Model,
init_strat::AbstractInitStrategy,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
::Val{Likelihood}=Val(true);
factorize=false,
) where {Prior,Likelihood}
acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}())
acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood,factorize}())
oavi = OnlyAccsVarInfo(acc)
oavi = last(init!!(model, oavi, init_strat, UnlinkAll()))
return get_pointwise_logprobs(oavi)
Expand All @@ -76,38 +117,60 @@ end
"""
DynamicPPL.pointwise_logdensities(
model::Model,
init_strat::AbstractInitStrategy
)
init_strat::AbstractInitStrategy;
factorize=false
)::VarNamedTuple

Calculate the pointwise log-densities for the parameters obtained by evaluating the model
with the given initialisation strategy. The resulting VarNamedTuple will contain both
log-prior probabilities (for random variables) and log-likelihoods (for observed variables).

$(_FACTORIZE_KWARG_DOC)
"""
function pointwise_logdensities(model::Model, init_strat::AbstractInitStrategy)
return _pointwise_logdensities(model, init_strat, Val(true), Val(true))
function pointwise_logdensities(
model::Model, init_strat::AbstractInitStrategy; factorize=false
)
return _pointwise_logdensities(
model, init_strat, Val(true), Val(true); factorize=factorize
)
end

"""
DynamicPPL.pointwise_loglikelihoods(
model::Model,
init_strat::AbstractInitStrategy
)
init_strat::AbstractInitStrategy;
factorize=false
)::VarNamedTuple

Same as `pointwise_logdensities`, but only returns the log-likelihoods for observed variables.
Calculate the pointwise log-likelihoods for observed variables, using parameters obtained
from the given initialisation strategy.

$(_FACTORIZE_KWARG_DOC)
"""
function pointwise_loglikelihoods(model::Model, init_strat::AbstractInitStrategy)
return _pointwise_logdensities(model, init_strat, Val(false), Val(true))
function pointwise_loglikelihoods(
model::Model, init_strat::AbstractInitStrategy; factorize=false
)
return _pointwise_logdensities(
model, init_strat, Val(false), Val(true); factorize=factorize
)
end

"""
DynamicPPL.pointwise_prior_logdensities(
model::Model,
init_strat::AbstractInitStrategy
)
init_strat::AbstractInitStrategy;
factorize=false
)::VarNamedTuple

Same as `pointwise_logdensities`, but only returns the log-densities for random variables
(i.e. the priors).
Calculate the pointwise log-prior probabilities for random variables, using parameters
obtained from the given initialisation strategy.

$(_FACTORIZE_KWARG_DOC)
"""
function pointwise_prior_logdensities(model::Model, init_strat::AbstractInitStrategy)
return _pointwise_logdensities(model, init_strat, Val(true), Val(false))
function pointwise_prior_logdensities(
model::Model, init_strat::AbstractInitStrategy; factorize=false
)
return _pointwise_logdensities(
model, init_strat, Val(true), Val(false); factorize=factorize
)
end
2 changes: 1 addition & 1 deletion test/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST_ACCUMULATORS = (
DynamicPPL.DebugRawValueAccumulator(),
DynamicPPL.FixedTransformAccumulator(),
DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}(
DynamicPPL.PointwiseLogProb{true,true}()
DynamicPPL.PointwiseLogProb{true,true,false}()
),
PriorDistributionAccumulator(),
DynamicPPL.VectorValueAccumulator(),
Expand Down
Loading
Loading