(breaking) v0.44#2790
Conversation
|
Turing.jl documentation for PR #2790 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2790 +/- ##
==========================================
- Coverage 86.15% 85.23% -0.93%
==========================================
Files 22 22
Lines 1466 1483 +17
==========================================
+ Hits 1263 1264 +1
- Misses 203 219 +16 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Changes from new DPPL version + VI interface
This PR does a minor overhaul of the Gibbs interface, in line with various plans that have been made over the months. Specifically, this PR makes the following changes: ## VNT over VarInfo Instead of `GibbsContext` holding a global **`VarInfo`**, which may contain linked or unlinked values, it now holds a global **`VarNamedTuple`** which always holds **raw** values. The motivation for this is twofold: 1. `VarInfo` holding linked values and cached transforms leads to correctness issues, such as #2801, which I discovered while trying to write this refactor. I can confirm that this PR fixes that issue. 2. `VarInfo` is slow because it does a lot of extra bookkeeping. The performance issues are in fact not limited to Gibbs: it leaks through to every sampler which Gibbs uses, because it forces every component sampler to carry around a VarInfo so that it can talk to Gibbs. This refactor therefore also means that some component samplers will be faster (see below). ## Sampler <=> Gibbs interface The old interface that Gibbs expected for samplers was as follows: - `get_varinfo(state)` Return a VarInfo that the Gibbs sampler will then merge into the global VarInfo. - `setparams_varinfo!!(state, varinfo, ...)` Update the sampler's state with the new global VarInfo. In addition to this interface Gibbs also had a function `match_linking!!`, which 'lined up' the transform status of the global VarInfo with each state's individual VarInfo. This PR directly replaces them with VNT equivalents - `gibbs_get_raw_values(state)` Return a VNT that the Gibbs sampler will then merge into the global VNT. - `gibbs_update_state!!(state, vnt, ...)` Update the sampler's state with the new global VNT. `match_linking!!` is no longer required as a separate function now: inside `gibbs_update_state!!` each sampler can just use its own transform strategy to reconstruct the correct state from the raw values. ## Component samplers Component samplers have been updated to stop using VarInfo. ## Performance Gibbs itself is slightly faster: ```julia using Turing @model function g() x ~ Normal() y ~ Normal() end @time sample(g(), Gibbs(:x => MH(), :y => MH()), 100_000; chain_type=Any, progress=false, verbose=false); ``` - main: 1.106086 seconds (23.50 M allocations: 1.398 GiB, 7.60% gc time) - This PR: 0.807256 seconds (17.60 M allocations: 896.448 MiB, 7.04% gc time) This speedup probably comes from just Gibbs, because performance of MH on these models is very similar (see below). On top of that, as promised above, this gives some nice speedups for other samplers, and specifically for ESS. These are trivial models. Naturally with larger models the overhead from samplers will be smaller. I also specify `chain_type=Any` to cut out the constant overhead from MCMCChains, which is actually quite significant. ```julia using Turing @model function f() x ~ Normal() 1.0 ~ Normal(x) end @time sample(f(), spl, 100_000; chain_type=Any, progress=false, verbose=false); ``` `spl = ESS()`: - main: 0.711739 seconds (10.20 M allocations: 514.196 MiB, 5.33% gc time) - this PR: 0.166084 seconds (6.00 M allocations: 259.379 MiB, 9.10% gc time) I thought `spl = MH()` would also show some speedup, but it's way less drastic. - main: 0.226019 seconds (5.53 M allocations: 271.970 MiB, 9.79% gc time) - this PR: 0.218487 seconds (5.33 M allocations: 253.618 MiB, 8.29% gc time) SMC and PG have also been switched over to using `OnlyAccsVarInfo`, but like MH, there's no significant difference in timings either for those (presumably Libtask overhead dominates everything else). `spl = HMC(0.1, 20)` or `spl = NUTS()` are basically the same, they were already optimised before this. ## Closes Closes #2801 Closes #2764 Closes #2762 Closes #2642
| # TODO(penelopeysm): This field is needed for Gibbs because each time we call | ||
| # gibbs_update_state!! on this, we need to reconstruct a LogDensityFunction. | ||
| # In general this would require reevaluating the model, unless we supply a | ||
| # VarNamedTuple which already contains vectorised parameters. This can probably | ||
| # be improved in DynamicPPL, but for now we will just store an extra VNT in | ||
| # the state. | ||
| # NOTE: The actual values of this field should never be used or relied on! | ||
| _vector_vnt::V |
There was a problem hiding this comment.
TuringLang/DynamicPPL.jl#1366 should fix this
There was a problem hiding this comment.
In particular, right now this PR stores the _vector_vnt inside the state separately and then reconstructs the LDF from that. That is rather silly, because the state already contains an LDF with all the necessary info. However, DPPL currently doesn't allow us to actually use that info. The above DPPL PR will allow us to extract the fields of the LDF and use it to make a new LDF, meaning that this VNT doesn't have to be cached.
Working on TuringLang/Turing.jl#2790 made me realise that LDF simply doesn't make enough of its internals explicit. Turing thus faces a choice of either depending on DPPL internals or using hacky workarounds. This PR therefore exposes a lower-level constructor for LogDensityFunction that allows Turing to essentially update the model in an LDF without having to rely on private fields (bad) or recreate any of the other arguments (hacky). See TuringLang/Turing.jl#2790 (comment).
No description provided.