Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,85 @@
# 0.44.0

## Breaking changes

**Variational inference interface**

The VI interface in Turing has been modified to make it more interoperable with the rest of Turing.

- The arguments to `vi(...)` are slightly different: instead of specifying a `q_init` argument (the initial variational approximation), you now directly pass a function that constructs this for you. For example, instead of

```julia
q_init = q_meanfield_gaussian(model)
vi(model, q_init, n_iters)
```

you would now do

```julia
vi(model, q_meanfield_gaussian, n_iters)
```

- The return value of `vi` is now a `VIResult` struct (please see the documentation for information), which bundles the previous return values together in a more cohesive way.
Most importantly, you can now call `rand([rng,] result::VIResult)` to obtain new samples from the variational approximation.
This returns a `VarNamedTuple` of raw values, which can be used directly in all other Turing interfaces without any further wrangling.
(In contrast, the previous return value of `rand(q)` would yield a vector of transformed parameters.)

Internally, the VI interface has been reworked to directly use `DynamicPPL.LogDensityFunction` instead of relying on a transformed distribution from Bijectors.jl.

**Gibbs sampler interface**

This section is only relevant if you are writing a sampler that is intended to be *directly* used as a component sampler in Turing's Gibbs sampler.
(If Gibbs calls your sampler via Turing's `externalsampler` interface, this section does not apply to you.)

Turing's Gibbs sampler has been reworked in this release to fix a number of correctness and performance issues.
The main change is that the Gibbs state carries a `VarNamedTuple` of raw values, instead of a `VarInfo` of vectorised (transformed) parameters.
This fixes correctness issues that arise with value-dependent transformations, and also leads to much reduced overhead when sampling with Gibbs: see https://github.com/TuringLang/Turing.jl/pull/2803 for some representative benchmarks.

In Turing v0.43, you would have to define two methods

- `Turing.Inference.get_varinfo(::MyState)` -> returns a VarInfo of values from your sampler's state

- `Turing.Inference.setparams_varinfo!!(::DynamicPPL.Model, ::MySampler, ::MyState, params::AbstractVarInfo)` -> uses a VarInfo of values to update your sampler's state

The corresponding methods in Turing v0.44 are

- `Turing.Inference.gibbs_get_raw_values(::MyState)` -> returns a VarNamedTuple of *raw* values from your sampler's state. Note that these values should not be transformed or wrapped in any way: if `x` is `3` in the model, then we should have that `gibbs_get_raw_values(state)[@varname(x)] == 3`.

- `Turing.Inference.gibbs_update_state!!(::MySampler, ::MyState, ::DynamicPPL.Model, global_vals::VarNamedTuple)` -> uses the values in `global_vals` to update your sampler's state and return a new state.
Note that the values in `global_vals` are raw values.
Also, note that the model argument passed in here will be 'conditioned' on the *new* values inside `global_vals`.
That means that if any part of your state relies on caching the model being evaluated, you have to update those parts of your state to reflect the new model as well.

As can be seen, the interface is almost entirely the same except that we use `VarNamedTuple`s of raw values instead of `VarInfo`s of potentially transformed parameters.

Please see the docstrings in the Turing.jl API page for more information.

## Other changes

**Performance**

Previously many of Turing's samplers needed to carry around a VarInfo so that they could communicate with Gibbs.
This release frees them up from having to do so, and in particular allows for usage of `DynamicPPL.OnlyAccsVarInfo`, which is much cheaper as it avoids unnecessary computations.

As a result, samplers such as MH and ESS are faster in this release, sometimes by up to 5x.

**Fixed transforms**

The MCMC sampling (`sample`), optimisation (`mode_estimate` / `maximum_likelihood` / `maximum_a_posteriori`, and VI (`vi`) entry points now accept an extra `fix_transforms` keyword argument, which specifies that all transforms in the model should be determined once at the start of inference and then fixed to those values for the rest of inference.
(In contrast, the default behaviour is to rederive transforms each time the model is run.)

Note that not all MCMC samplers currently support fixed transforms.
In particular, HMC, NUTS, and external samplers currently do support it, and Gibbs will pass the flag through to its component samplers, but MH's `LinkedRW` option will currently ignore the `fix_transforms` argument.
For some samplers such as ESS and particle MCMC, fixed transforms do not affect the sampling process at all (in such cases the keyword argument is accepted but ignored).

The reason why Turing rederives transforms is to ensure correctness in cases where the transform *depends on the value of another random variable*.
For example, if `a` is a parameter, then `b ~ Uniform(-a, a)` has a transform that depends on the value of `a`.
Since `a` is not static, this in turn means that the transform associated with `b` is not static.

If you know that this consideration is not relevant for you, it can sometimes lead to performance benefits if the transforms are expensive to compute.
However, in many cases the benefits are negligible and you should always benchmark this on a case-by-case basis.
Please see https://turinglang.org/DynamicPPL.jl/stable/fixed_transforms/ for further details about this.

# 0.43.7

Fixes an issue where sampling with `MH()` in v0.43 would not include `x := expr` results in the chain.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.43.7"
version = "0.44"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -61,7 +61,7 @@ DifferentiationInterface = "0.7"
Distributions = "0.25.77"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.40.20"
DynamicPPL = "0.41.3"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.14"
Expand Down
2 changes: 0 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
7 changes: 6 additions & 1 deletion ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@ function AbstractMCMC.step(
model::DynamicPPL.Model,
spl::DynamicNUTS;
initial_params,
fix_transforms::Bool=false,
kwargs...,
)
# Construct LogDensityFunction
tfm_strategy = DynamicPPL.LinkAll()
ldf = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype
model,
DynamicPPL.getlogjoint_internal,
tfm_strategy;
adtype=spl.adtype,
fix_transforms=fix_transforms,
)
x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params)

Expand Down
6 changes: 5 additions & 1 deletion src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ function AbstractMCMC.step(
transition = if discard_sample
nothing
else
[DynamicPPL.ParamsWithStats(vi, model) for vi in vis]
[
DynamicPPL.ParamsWithStats(
DynamicPPL.InitFromParams(DynamicPPL.get_values(vi)), model
) for vi in vis
]
end

linked_vi = DynamicPPL.link!!(vis[1], model)
Expand Down
144 changes: 98 additions & 46 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ Mean
"""
struct ESS <: AbstractSampler end

struct TuringESSState{V<:DynamicPPL.AbstractVarInfo,VNT<:DynamicPPL.VarNamedTuple}
vi::V
priors::VNT
struct TuringESSState{
L<:DynamicPPL.LogDensityFunction,
P<:AbstractVector{<:Real},
R<:Real,
V<:DynamicPPL.VarNamedTuple,
}
ldf::L
params::P
loglikelihood::R
priors::V
end
get_varinfo(state::TuringESSState) = state.vi

# always accept in the first step
function AbstractMCMC.step(
Expand All @@ -37,20 +43,37 @@ function AbstractMCMC.step(
initial_params,
kwargs...,
)
vi = DynamicPPL.VarInfo()
vi = DynamicPPL.setacc!!(vi, DynamicPPL.RawValueAccumulator(true))
prior_acc = DynamicPPL.PriorDistributionAccumulator()
prior_accname = DynamicPPL.accumulator_name(prior_acc)
vi = DynamicPPL.setacc!!(vi, prior_acc)
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params, DynamicPPL.UnlinkAll())
priors = DynamicPPL.get_priors(vi)
# Set up a LogDensityFunction which evaluates the model's log-likelihood.
# Note that this costs one model evaluation (fine since it's only in the first step)
loglike_ldf = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getloglikelihood, DynamicPPL.UnlinkAll()
)

# Run the model using the specified initialisation strategy and extract all necessary
# information.
accs = DynamicPPL.OnlyAccsVarInfo(
# no transforms so no need for LogJacobianAccumulator
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.VectorParamAccumulator(loglike_ldf),
DynamicPPL.PriorDistributionAccumulator(),
DynamicPPL.RawValueAccumulator(true), # for ParamsWithStats later
)
_, accs = DynamicPPL.init!!(rng, model, accs, initial_params, DynamicPPL.UnlinkAll())

priors = DynamicPPL.get_priors(accs)
vector_params = DynamicPPL.get_vector_params(accs)
loglike = DynamicPPL.getloglikelihood(accs)

# Check that priors are all Gaussian
for dist in values(priors)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("ESS only supports Gaussian prior distributions")
end
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
return transition, TuringESSState(vi, priors)

transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(accs)
state = TuringESSState(loglike_ldf, vector_params, loglike, priors)
return transition, state
end

function AbstractMCMC.step(
Expand All @@ -61,54 +84,54 @@ function AbstractMCMC.step(
discard_sample=false,
kwargs...,
)
# obtain previous sample
vi = state.vi
f = vi[:]

# define previous sampler state
# (do not use cache to avoid in-place sampling from prior)
wrapped_state = EllipticalSliceSampling.ESSState(
f, DynamicPPL.getloglikelihood(vi), nothing
state.params, state.loglikelihood, nothing
)

# compute next state
sample, new_wrapped_state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, vi, state.priors), ESSLikelihood(model, vi)
ESSPrior(state.ldf, state.priors), ESSLikelihood(state.ldf)
),
EllipticalSliceSampling.ESS(),
wrapped_state,
)

# update sample and log-likelihood
vi = DynamicPPL.unflatten!!(vi, sample)
vi = DynamicPPL.setloglikelihood!!(vi, new_wrapped_state.loglikelihood)

transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(vi, model)
return transition, TuringESSState(vi, state.priors)
transition = discard_sample ? nothing : DynamicPPL.ParamsWithStats(sample, state.ldf)
new_state = TuringESSState(
state.ldf, sample, new_wrapped_state.loglikelihood, state.priors
)
return transition, new_state
end

# NOTE: This is a quick and easy definition but it assumes that _vec(x) is the same as
# Bijectors.VectorBijectors.from_vec(dist) for all distributions we care about in the
# priors. If that ever becomes untrue, then this could silently cause bugs.
_vec(x::Real) = [x]
_vec(x::AbstractArray) = vec(x)

# Prior distribution of considered random variable
struct ESSPrior{M<:Model,V<:AbstractVarInfo,T}
model::M
varinfo::V
μ::T

function ESSPrior(
model::Model, varinfo::AbstractVarInfo, priors::DynamicPPL.VarNamedTuple
)
μ = mapreduce(vcat, priors; init=Float64[]) do pair
prior_dist = pair.second
EllipticalSliceSampling.isgaussian(typeof(prior_dist)) || error(
"[ESS] only supports Gaussian prior distributions, but found $(typeof(prior_dist))",
struct ESSPrior{L<:DynamicPPL.LogDensityFunction,T<:AbstractVector{<:Real}}
ldf::L
means::T

function ESSPrior(ldf::DynamicPPL.LogDensityFunction, priors::DynamicPPL.VarNamedTuple)
# Calculate means from priors.
means = fill(NaN, LogDensityProblems.dimension(ldf))
for (vn, dist) in pairs(priors)
range = DynamicPPL.get_range_and_transform(ldf, vn).range
this_mean = _vec(mean(dist))
means[range] .= this_mean
end
if any(isnan, means)
error(
"Some means were not filled in when constructing ESSPrior. This is likely a bug in Turing.jl, please report it.",
)
_vec(mean(prior_dist))
end
return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ)
return new{typeof(ldf),typeof(means)}(ldf, means)
end
end

Expand All @@ -117,14 +140,11 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true

# Only define out-of-place sampling
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
_, vi = DynamicPPL.init!!(
rng, p.model, p.varinfo, DynamicPPL.InitFromPrior(), DynamicPPL.UnlinkAll()
)
return vi[:]
return Base.rand(rng, p.ldf)
end

# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ
Distributions.mean(p::ESSPrior) = p.means

# Evaluate log-likelihood of proposals. We need this struct because
# EllipticalSliceSampling.jl expects a callable struct / a function as its
Expand All @@ -133,8 +153,13 @@ struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction}
ldf::L

# Force usage of `getloglikelihood` in inner constructor
function ESSLikelihood(model::Model, varinfo::AbstractVarInfo)
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo)
function ESSLikelihood(ldf::DynamicPPL.LogDensityFunction)
logp_callable = DynamicPPL.get_logdensity_callable(ldf)
if logp_callable !== DynamicPPL.getloglikelihood
error(
"The log-density function passed to ESSLikelihood must use `getloglikelihood` as its log-density function, but found $(logp_callable). This is likely a bug in Turing.jl, please report it!",
)
end
return new{typeof(ldf)}(ldf)
end
end
Expand All @@ -155,3 +180,30 @@ function AbstractMCMC.step(
"This method is not implemented! If you want to use the ESS sampler in Turing.jl, please use `Turing.ESS()` instead. If you want the default behaviour in EllipticalSliceSampling.jl, wrap your model in a different subtype of `AbstractMCMC.AbstractModel`, and then implement the necessary EllipticalSliceSampling.jl methods on it.",
)
end

####
#### Gibbs interface
####

function gibbs_get_raw_values(state::TuringESSState)
pws = DynamicPPL.ParamsWithStats(
state.params, state.ldf; include_log_probs=false, include_colon_eq=false
)
return pws.params
end

function gibbs_update_state!!(
::ESS,
state::TuringESSState,
model::DynamicPPL.Model,
global_vals::DynamicPPL.VarNamedTuple,
)
# We need to update everything in `state` except for the priors (which are constant). We
# pass an extra LogLikelihoodAccumulator here so that we can calculate the new loglike in
# one pass.
new_ldf, new_params, accs = gibbs_recompute_ldf_and_params(
state.ldf, model, global_vals, (DynamicPPL.LogLikelihoodAccumulator(),)
)
new_loglike = DynamicPPL.getloglikelihood(accs)
return TuringESSState(new_ldf, new_params, new_loglike, state.priors)
end
Loading
Loading