diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 6adb27efa..9a2b2afcc 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -18,7 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [sources] -DynamicPPL = {path = "../"} +DynamicPPL = {path = ".."} [compat] ADTypes = "1.14.0" diff --git a/docs/Project.toml b/docs/Project.toml index 288ae162a..38261c0a4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -22,6 +22,9 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} + [compat] ADTypes = "1" AbstractMCMC = "5" diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 905ee168b..ae5f28444 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -144,7 +144,7 @@ accs = DynamicPPL.OnlyAccsVarInfo(( DynamicPPL.RawValueAccumulator(false), # ... whatever else you need )) -_, accs = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) +_, accs = DynamicPPL.init!!(rng, model, accs, init_strategy, DynamicPPL.UnlinkAll()) ``` You can then extract all the updated data from `accs` using DynamicPPL's existing API (see @@ -178,7 +178,7 @@ retcode: Success u: 1-element Vector{Float64}: 4.88281250001733e-5 -julia> # Get the an initialisation strategy representing the mode of `y`. +julia> # Get an initialisation strategy representing the mode of `y`. init_strategy = InitFromVector(mld, opt_solution.u); julia> # Evaluate the model with this initialisation strategy. diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 135fcab73..524f1e676 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -32,7 +32,7 @@ using Random: Random vi_vnt_or_tfm_strategy=_default_vnt(model, UnlinkAll()), accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - fix_transform::Bool=false, + fix_transforms::Bool=false, ) A struct which contains a model, along with all the information necessary to: @@ -260,17 +260,12 @@ struct LogDensityFunction{ else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, x) - lda = LogDensityAt( + problem = LogDensityAt( model, getlogdensity, ranges_and_transforms, transform_strategy, accs ) - problem = if _use_closure(adtype) - lda - else - let lda = lda - params -> lda(params) - end - end - AbstractPPL.prepare(adtype, problem, x) + # `x` was just constructed from the same range metadata stored in `problem`, + # so the AD wrapper can skip its hot-path dimension validation. + AbstractPPL.prepare(adtype, problem, x; check_dims=false) end return new{ typeof(model), @@ -473,7 +468,6 @@ function LogDensityProblems.logdensity_and_gradient( # `params` has to be converted to the same vector type that was used for AD preparation, # otherwise the preparation will not be valid. params = convert(get_input_vector_type(ldf), params) - # Choice between LogDensityAt and closure was fixed at prepare time. return AbstractPPL.value_and_gradient(ldf._adprep, params) end @@ -505,43 +499,6 @@ By default, this just returns the input unchanged. """ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype -""" - _use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with -respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is -a constant. However, DifferentiationInterface generally expects a single-argument function -g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we - also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD backend used. -Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172 - -This function is used to determine whether a given AD backend should use a closure or a -constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used. -By default, this function returns `false`, i.e. the constant approach will be used. -""" -# For these AD backends both closure and no closure work, but it is just faster to not use a -# closure (see link in the docstring). -_use_closure(::ADTypes.AutoForwardDiff) = false -_use_closure(::ADTypes.AutoMooncake) = false -_use_closure(::ADTypes.AutoMooncakeForward) = false -# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with -# DI.Constant arguments the tape will always be recompiled upon each call to -# value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure. -_use_closure(::ADTypes.AutoReverseDiff{compile}) where {compile} = compile -# For AutoEnzyme it allows us to avoid setting function_annotation -_use_closure(::ADTypes.AutoEnzyme) = false -# Since for most backends it's faster to not use a closure, we set that as the default -# for unknown AD backends -_use_closure(::ADTypes.AbstractADType) = false - ###################################################### # Helper functions to extract ranges and link status # ###################################################### diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 4c145943d..42ac9203d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -339,8 +339,6 @@ function run_ad( # Calculate log-density and gradient with the backend of interest value, grad = logdensity_and_gradient(ldf, params) - # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 - grad = collect(grad) verbose && println(" actual : $((value, grad))") # Test correctness @@ -357,8 +355,6 @@ function run_ad( model, getlogdensity, transform_strategy; adtype=test.adtype ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) - # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 - grad_true = collect(grad_true) end # Perform testing verbose && println(" expected : $((value_true, grad_true))") diff --git a/test/Project.toml b/test/Project.toml index 73cff23ed..2c635ed19 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -30,6 +31,10 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"} +DynamicPPL = {path = ".."} + [compat] ADTypes = "1" AbstractMCMC = "5.10" diff --git a/test/floattypes/Project.toml b/test/floattypes/Project.toml index 02a770fe7..e47e1ebf4 100644 --- a/test/floattypes/Project.toml +++ b/test/floattypes/Project.toml @@ -7,4 +7,4 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] -DynamicPPL = {path = "../../"} +DynamicPPL = {path = "../.."} diff --git a/test/integration/enzyme/Project.toml b/test/integration/enzyme/Project.toml index c26655fae..c673319b1 100644 --- a/test/integration/enzyme/Project.toml +++ b/test/integration/enzyme/Project.toml @@ -1,9 +1,11 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] -DynamicPPL = {path = "../../../"} +AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"} +DynamicPPL = {path = "../../.."} diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 0b956dfc1..5ccdb48a1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -485,6 +485,36 @@ end end end + # Compiled ReverseDiff prep should be observable as lower repeated-call allocations. + @testset "ReverseDiff compiled prep reduces repeated-call allocations" begin + @model f() = x ~ Normal() + ldf_compiled = LogDensityFunction( + f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=true) + ) + ldf_uncompiled = LogDensityFunction( + f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=false) + ) + params = rand(ldf_compiled) + + LogDensityProblems.logdensity_and_gradient(ldf_compiled, params) + LogDensityProblems.logdensity_and_gradient(ldf_uncompiled, params) + + function repeated_call_allocs(ldf, params) + GC.gc() + before = Base.gc_num() + for _ in 1:100 + LogDensityProblems.logdensity_and_gradient(ldf, params) + end + after = Base.gc_num() + return Base.GC_Diff(after, before).allocd + end + + allocs_compiled = repeated_call_allocs(ldf_compiled, params) + allocs_uncompiled = repeated_call_allocs(ldf_uncompiled, params) + + @test allocs_compiled < allocs_uncompiled + end + # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/runtests.jl b/test/runtests.jl index 1ba744c3f..edde98144 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using Documenter: Documenter +using DifferentiationInterface using DynamicPPL: DynamicPPL using Random: Random using Test: @testset, @test_throws