diff --git a/HISTORY.md b/HISTORY.md index e07fc0e627..b61b8eed2e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,11 @@ +# 0.42.9 + +Improve handling of model evaluator functions with Libtask. + +This means that when running SMC or PG on a model with keyword arguments, you no longer need to use `@might_produce` (see patch notes of v0.42.5 for more details on this). + +It also means that submodels with observations inside will now be reliably handled by the SMC/PG samplers, which was not the case before (the observations were only picked up if the submodel was inlined by the Julia compiler, which could lead to correctness issues). + # 0.42.8 Add support for `TensorBoardLogger.jl` via `AbstractMCMC.mcmc_callback`. diff --git a/Project.toml b/Project.toml index 12776485d7..6abb454f1f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.42.8" +version = "0.42.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.39.1" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.5" +Libtask = "0.9.14" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 09251ad36f..050ad04cc9 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -123,7 +123,6 @@ function AbstractMCMC.sample( ) check_model && _check_model(model, sampler) error_if_threadsafe_eval(model) - check_model_kwargs(model) # need to add on the `nparticles` keyword argument for `initialstep` to make use of return AbstractMCMC.mcmcsample( rng, @@ -138,28 +137,6 @@ function AbstractMCMC.sample( ) end -function check_model_kwargs(model::DynamicPPL.Model) - if !isempty(model.defaults) - # If there are keyword arguments, we need to check that the user has - # accounted for this by overloading `might_produce`. - might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f))) - if !might_produce - io = IOBuffer() - ctx = IOContext(io, :color => true) - print( - ctx, - "Models with keyword arguments need special treatment to be used" * - " with particle methods. Please run:\n\n", - ) - printstyled( - ctx, " Turing.@might_produce($(model.f))"; bold=true, color=:blue - ) - print(ctx, "\n\nbefore sampling from this model with particle methods.\n") - error(String(take!(io))) - end - end -end - function Turing.Inference.initialstep( rng::AbstractRNG, model::DynamicPPL.Model, @@ -169,7 +146,6 @@ function Turing.Inference.initialstep( discard_sample=false, kwargs..., ) - check_model_kwargs(model) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) vi = DynamicPPL.empty!!(vi) @@ -292,7 +268,6 @@ function Turing.Inference.initialstep( kwargs..., ) error_if_threadsafe_eval(model) - check_model_kwargs(model) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create a new set of particles @@ -534,6 +509,9 @@ Libtask.@might_produce(DynamicPPL.tilde_observe!!) # Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. Libtask.@might_produce(DynamicPPL.tilde_assume!!) -Libtask.@might_produce(DynamicPPL.evaluate!!) -Libtask.@might_produce(DynamicPPL.init!!) -Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true + +# This handles all models and submodel evaluator functions (including those with keyword +# arguments). The key to this is realising that all model evaluator functions take +# DynamicPPL.Model as an argument, so we can just check for that. See +# https://github.com/TuringLang/Libtask.jl/issues/217. +Libtask.might_produce_if_sig_contains(::Type{<:DynamicPPL.Model}) = true diff --git a/test/Aqua.jl b/test/Aqua.jl index e159cae9ca..e5b655c6e0 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -1,11 +1,29 @@ module AquaTests using Aqua: Aqua +using Libtask: Libtask using Turing -# We test ambiguities separately because it catches a lot of problems -# in dependencies but we test it for Turing. -Aqua.test_ambiguities([Turing]) +# We test ambiguities specifically only for Turing, because testing ambiguities for all +# packages in the environment leads to a lot of ambiguities from dependencies that we cannot +# control. +# +# `Libtask.might_produce` is excluded because the `@might_produce` macro generates a lot of +# ambiguities that will never happen in practice. +# +# Specifically, when you write `@might_produce f` for a function `f` that has methods that +# take keyword arguments, we have to generate a `might_produce` method for +# `Type{<:Tuple{<:Function,...,typeof(f)}}`. There is no way to circumvent this: see +# https://github.com/TuringLang/Libtask.jl/issues/197. This in turn will cause method +# ambiguities with any other function, say `g`, for which +# `::Type{<:Tuple{typeof(g),Vararg}}` is marked as produceable. +# +# To avoid the method ambiguities, we *could* manually spell out `might_produce` methods for +# each method of `g` manually instead of using Vararg, but that would be both very verbose +# and fragile. It would also not provide any real benefit since those ambiguities are not +# meaningful in practice (in particular, to trigger this we would need to call `g(..., f)`, +# which is incredibly unlikely). +Aqua.test_ambiguities([Turing]; exclude=[Libtask.might_produce]) Aqua.test_all(Turing; ambiguities=false) end diff --git a/test/Project.toml b/test/Project.toml index 62de5efd16..b6947b3c91 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -57,6 +58,7 @@ DynamicPPL = "0.39.6" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" +Libtask = "0.9.14" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 99c38616de..584c921efb 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -314,17 +314,12 @@ using Turing return priors end - @test_throws ErrorException chain = sample( - StableRNG(seed), gauss2(; x=x), PG(10), 10 - ) - @test_throws ErrorException chain = sample( - StableRNG(seed), gauss2(; x=x), SMC(), 10 - ) - - @test_throws ErrorException chain = sample( + chain = sample(StableRNG(seed), gauss2(; x=x), PG(10), 10) + chain = sample(StableRNG(seed), gauss2(; x=x), SMC(), 10) + chain = sample( StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10 ) - @test_throws ErrorException chain = sample( + chain = sample( StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10 ) diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 6a410869a0..4c13d59bff 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -161,26 +161,59 @@ end @test mean(c[:x]) > 0.7 end - # https://github.com/TuringLang/Turing.jl/issues/2007 @testset "keyword argument handling" begin @model function kwarg_demo(y; n=0.0) x ~ Normal(n) return y ~ Normal(x) end - @test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10) - # Check that enabling `might_produce` does allow sampling - @might_produce kwarg_demo chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000) @test chain isa MCMCChains.Chains @test mean(chain[:x]) ≈ 2.5 atol = 0.2 - # Check that the keyword argument's value is respected chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000) @test chain2 isa MCMCChains.Chains @test mean(chain2[:x]) ≈ 7.5 atol = 0.2 end + @testset "submodels without kwargs" begin + @model function inner(y, x) + # Mark as noinline explicitly to make sure that behaviour is not reliant on the + # Julia compiler inlining it. + # See https://github.com/TuringLang/Turing.jl/issues/2772 + @noinline + return y ~ Normal(x) + end + @model function nested(y) + x ~ Normal() + return a ~ to_submodel(inner(y, x)) + end + m1 = nested(1.0) + chn = sample(StableRNG(468), m1, PG(10), 1000) + @test mean(chn[:x]) ≈ 0.5 atol = 0.1 + end + + @testset "submodels with kwargs" begin + @model function inner_kwarg(y; n=0.0) + @noinline # See above + x ~ Normal(n) + return y ~ Normal(x) + end + @model function outer_kwarg1() + return a ~ to_submodel(inner_kwarg(5.0)) + end + m1 = outer_kwarg1() + chn1 = sample(StableRNG(468), m1, PG(10), 1000) + @test mean(chn1[Symbol("a.x")]) ≈ 2.5 atol = 0.2 + + @model function outer_kwarg2(n) + return a ~ to_submodel(inner_kwarg(5.0; n=n)) + end + m2 = outer_kwarg2(10.0) + chn2 = sample(StableRNG(468), m2, PG(10), 1000) + @test mean(chn2[Symbol("a.x")]) ≈ 7.5 atol = 0.2 + end + @testset "refuses to run threadsafe eval" begin # PG can't run models that have nondeterministic evaluation order, # so it should refuse to run models marked as threadsafe.