Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[sources]
DynamicPPL = {path = "../"}
DynamicPPL = {path = ".."}

[compat]
ADTypes = "1.14.0"
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[sources]
AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"}

[compat]
ADTypes = "1"
AbstractMCMC = "5"
Expand Down
4 changes: 2 additions & 2 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ accs = DynamicPPL.OnlyAccsVarInfo((
DynamicPPL.RawValueAccumulator(false),
# ... whatever else you need
))
_, accs = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
_, accs = DynamicPPL.init!!(rng, model, accs, init_strategy, DynamicPPL.UnlinkAll())
```

You can then extract all the updated data from `accs` using DynamicPPL's existing API (see
Expand Down Expand Up @@ -178,7 +178,7 @@ retcode: Success
u: 1-element Vector{Float64}:
4.88281250001733e-5

julia> # Get the an initialisation strategy representing the mode of `y`.
julia> # Get an initialisation strategy representing the mode of `y`.
init_strategy = InitFromVector(mld, opt_solution.u);

julia> # Evaluate the model with this initialisation strategy.
Expand Down
53 changes: 5 additions & 48 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using Random: Random
vi_vnt_or_tfm_strategy=_default_vnt(model, UnlinkAll()),
accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
fix_transform::Bool=false,
fix_transforms::Bool=false,
)

A struct which contains a model, along with all the information necessary to:
Expand Down Expand Up @@ -260,17 +260,12 @@ struct LogDensityFunction{
else
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, x)
lda = LogDensityAt(
problem = LogDensityAt(
model, getlogdensity, ranges_and_transforms, transform_strategy, accs
)
problem = if _use_closure(adtype)
lda
else
let lda = lda
params -> lda(params)
end
end
AbstractPPL.prepare(adtype, problem, x)
# `x` was just constructed from the same range metadata stored in `problem`,
# so the AD wrapper can skip its hot-path dimension validation.
AbstractPPL.prepare(adtype, problem, x; check_dims=false)
end
return new{
typeof(model),
Expand Down Expand Up @@ -473,7 +468,6 @@ function LogDensityProblems.logdensity_and_gradient(
# `params` has to be converted to the same vector type that was used for AD preparation,
# otherwise the preparation will not be valid.
params = convert(get_input_vector_type(ldf), params)
# Choice between LogDensityAt and closure was fixed at prepare time.
return AbstractPPL.value_and_gradient(ldf._adprep, params)
end

Expand Down Expand Up @@ -505,43 +499,6 @@ By default, this just returns the input unchanged.
"""
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype

"""
_use_closure(adtype::ADTypes.AbstractADType)

In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with
respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is
a constant. However, DifferentiationInterface generally expects a single-argument function
g(x) to differentiate.

There are two ways of dealing with this:

1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)

2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we
also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`.

The relative performance of the two approaches, however, depends on the AD backend used.
Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172

This function is used to determine whether a given AD backend should use a closure or a
constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used.
By default, this function returns `false`, i.e. the constant approach will be used.
"""
# For these AD backends both closure and no closure work, but it is just faster to not use a
# closure (see link in the docstring).
_use_closure(::ADTypes.AutoForwardDiff) = false
_use_closure(::ADTypes.AutoMooncake) = false
_use_closure(::ADTypes.AutoMooncakeForward) = false
# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with
# DI.Constant arguments the tape will always be recompiled upon each call to
# value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure.
_use_closure(::ADTypes.AutoReverseDiff{compile}) where {compile} = compile
# For AutoEnzyme it allows us to avoid setting function_annotation
_use_closure(::ADTypes.AutoEnzyme) = false
# Since for most backends it's faster to not use a closure, we set that as the default
# for unknown AD backends
_use_closure(::ADTypes.AbstractADType) = false

######################################################
# Helper functions to extract ranges and link status #
######################################################
Expand Down
4 changes: 0 additions & 4 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ function run_ad(

# Calculate log-density and gradient with the backend of interest
value, grad = logdensity_and_gradient(ldf, params)
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad = collect(grad)
verbose && println(" actual : $((value, grad))")

# Test correctness
Expand All @@ -357,8 +355,6 @@ function run_ad(
model, getlogdensity, transform_strategy; adtype=test.adtype
)
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad_true = collect(grad_true)
end
# Perform testing
verbose && println(" expected : $((value_true, grad_true))")
Expand Down
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -30,6 +31,10 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"}
DynamicPPL = {path = ".."}

[compat]
ADTypes = "1"
AbstractMCMC = "5.10"
Expand Down
2 changes: 1 addition & 1 deletion test/floattypes/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DynamicPPL = {path = "../../"}
DynamicPPL = {path = "../.."}
4 changes: 3 additions & 1 deletion test/integration/enzyme/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DynamicPPL = {path = "../../../"}
AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"}
DynamicPPL = {path = "../../.."}
30 changes: 30 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,36 @@ end
end
end

# Compiled ReverseDiff prep should be observable as lower repeated-call allocations.
@testset "ReverseDiff compiled prep reduces repeated-call allocations" begin
@model f() = x ~ Normal()
ldf_compiled = LogDensityFunction(
f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=true)
)
ldf_uncompiled = LogDensityFunction(
f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=false)
)
params = rand(ldf_compiled)

LogDensityProblems.logdensity_and_gradient(ldf_compiled, params)
LogDensityProblems.logdensity_and_gradient(ldf_uncompiled, params)

function repeated_call_allocs(ldf, params)
GC.gc()
before = Base.gc_num()
for _ in 1:100
LogDensityProblems.logdensity_and_gradient(ldf, params)
end
after = Base.gc_num()
return Base.GC_Diff(after, before).allocd
end

allocs_compiled = repeated_call_allocs(ldf_compiled, params)
allocs_uncompiled = repeated_call_allocs(ldf_uncompiled, params)

@test allocs_compiled < allocs_uncompiled
end

# Test that various different ways of specifying array types as arguments work with all
# ADTypes.
@testset "Array argument types" begin
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Documenter: Documenter
using DifferentiationInterface
using DynamicPPL: DynamicPPL
using Random: Random
using Test: @testset, @test_throws
Expand Down
Loading