Avoid reevaluating model where unnecessary#2759
Conversation
|
Turing.jl documentation for PR #2759 is available at: |
|
Assuming this passes CI, @joelkandiah, you might find this a little bit useful. It doesn't help with compilation, but here's a demonstration: using Turing
modelevals = Ref(0)
@model function f()
modelevals[] += 1
a1 ~ Normal()
a2 ~ Normal()
a3 ~ Normal()
a4 ~ Normal()
a5 ~ Normal()
a6 ~ Normal()
a7 ~ Normal()
a8 ~ Normal()
end
modelevals[] = 0
@time sample(f(), Gibbs(
@varname(a1) => MH(),
@varname(a2) => MH(),
@varname(a3) => MH(),
@varname(a4) => MH(),
@varname(a5) => MH(),
@varname(a6) => MH(),
@varname(a7) => MH(),
@varname(a8) => MH()
),
1000;
chain_type=Any, progress=false
);
@show modelevals[]
# this PR: 0.274875 seconds (2.26 M allocations: 169.478 MiB, 4.95% gc time)
# modelevals[] = 17003
#
# main: 0.354257 seconds (2.88 M allocations: 208.250 MiB, 4.81% gc time)
# modelevals[] = 25003If you're using discard_initial and/or thinning, the benefits should scale accordingly. For example, adding # With thinning=10
#
# this PR: 2.419929 seconds (21.36 M allocations: 1.582 GiB, 4.45% gc time)
# modelevals[] = 160859
#
# main: 3.395153 seconds (28.70 M allocations: 2.029 GiB, 4.05% gc time)
# modelevals[] = 249778 |
This comment was marked as resolved.
This comment was marked as resolved.
|
(I definitely think it should in theory be possible to cut the model evaluations down to 8000, or maybe 9000 (each component sampler goes once and Gibbs itself has to do it once), but that really requires deeper refactoring that would follow on from #2756. The extra 8000 evaluations are coming from this |
|
Turing v0.43 bumped the model evaluations back up to 25000 (it's because Timings (1000 iterations, no thinning): Main (v0.43.5) 0.062175 seconds (922.58 k allocations: 67.218 MiB, 13.04% gc time) |
Closes #2639
Quoting from the changelog:
This makes use of the
discard_samplekeyword argument tostep/step_warmupTuringLang/AbstractMCMC.jl#194Note that this is still type stable as the value of the
discard_samplekwarg can be constant propagated (I've checked this with code_warntype etc).