diff --git a/usage/threadsafe-evaluation/index.qmd b/usage/threadsafe-evaluation/index.qmd index d02a14aff..edec292c1 100755 --- a/usage/threadsafe-evaluation/index.qmd +++ b/usage/threadsafe-evaluation/index.qmd @@ -19,24 +19,20 @@ The Julia manual [has a section on multithreading](https://docs.julialang.org/en We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races. This page specificaly discusses Turing's support for threadsafe model evaluation. -:::{.callout-note} -Please note that this is a rapidly-moving topic, and things may change in future releases of Turing. -If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48) -::: - ```{julia} println("This notebook is being run with $(Threads.nthreads()) threads.") ``` ## Threading in Turing models -Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models. +To a first approximation, Turing completely supports multithreaded code inside models. -This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations. -For example, here we use parallelism to speed up a transformation of `x`: +For example, you can use `Threads.@threads` to parallelise 'ordinary' Julia code inside a model. +Here is an example of parallelising some expensive computation inside a model: ```{julia} using Turing +Turing.setprogress!(false) @model function parallel(y) x ~ dist @@ -48,73 +44,71 @@ using Turing end ``` -In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code. +An example like the above, where the parallelisation is separate from the modelling syntax (i.e., tilde-statements), will work without any special considerations. **However, extra care must be taken when using tilde-statements (`x ~ dist`), or `@addlogprob!`, inside threaded blocks.** - -::: {.callout-note} -## Why are tilde-statements special? -Tilde-statements are expanded by the `@model` macro into something that modifies the internal VarInfo object used for model evaluation. -Essentially, `x ~ dist` expands to something like - -```julia -x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__) -``` - -and writing into `__varinfo__` is, _in general_, not threadsafe. -Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races). -::: - -## Threaded observations - -**As of version 0.42, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).** - -However, such models **must** be marked by the user as requiring threadsafe evaluation, using `setthreadsafe`. - -This means that the following code is safe to use: +Specifically, if you do this, you *must* mark the model as requiring threadsafe evaluation, using `setthreadsafe`. +For example: ```{julia} -@model function threaded_obs(N) - x ~ Normal() +@model function threaded(N) + x = Vector{Float64}(undef, N) y = Vector{Float64}(undef, N) Threads.@threads for i in 1:N - y[i] ~ Normal(x) + x[i] ~ Normal() + y[i] ~ Normal(x[i]) end end -N = 100 +N = 20 y = randn(N) -threadunsafe_model = threaded_obs(N) | (; y = y) +threadunsafe_model = threaded(N) | (; y = y) threadsafe_model = setthreadsafe(threadunsafe_model, true) ``` -Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as: +::: {.callout-note} +## Why are tilde-statements special? +Tilde-statements are expanded by the `@model` macro into something that modifies the internal `AbstractVarInfo` object used during model evaluation. +Essentially, `x ~ dist` expands to something like + +```julia +x, __abstractvarinfo__ = DynamicPPL.tilde_assume!!(..., __abstractvarinfo__) +``` + +and writing into `__abstractvarinfo__` is, _in general_, not threadsafe. +Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races). + +Turing's threadsafe flag works by creating one `AbstractVarInfo` per thread, and then combining the results at the end of model evaluation. +::: + +Once the model has been marked as threadsafe, Turing guarantees to provide the correct result in functions such as: ```{julia} -logjoint(threadsafe_model, (; x = 0.0)) +x = zeros(N) +logjoint(threadsafe_model, (; x = x)) ``` (we can compare with the true value) ```{julia} -logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y)) +sum(logpdf.(Normal(), x)) + sum(logpdf.(Normal.(x), y)) ``` Note that if you do not use `setthreadsafe`, the above code may give wrong results, or even error: ```{julia} -logjoint(threadunsafe_model, (; x = 0.0)) +logjoint(threadunsafe_model, (; x = x)) ``` You can sample from this model and safely use functions such as `predict` or `returned`, as long as the model is always marked as threadsafe: ```{julia} -model = setthreadsafe(threaded_obs(N) | (; y = y), true) -chn = sample(model, NUTS(), 100; check_model=false, progress=false) +model = setthreadsafe(threaded(N) | (; y = y), true) +chn = sample(model, NUTS(), 100) ``` ```{julia} -pmodel = setthreadsafe(threaded_obs(N), true) # don't condition on data +pmodel = setthreadsafe(threaded(N), true) # don't condition on data predict(pmodel, chn) ``` @@ -128,26 +122,53 @@ There were several reasons for changing this: one major one is because threadsaf Furthermore, the number of threads is not an appropriate way to determine whether threadsafe evaluation is needed! ::: -## Threaded assumptions / sampling latent values +## A note on reproducibility + +Note that, due to reasons which we do not yet fully understand (but likely relate to race conditions in the mutation of the random number generator), the use of threadsafe evaluation is not always fully deterministic when assume-statements, i.e. random variables, are parallelised. + +In the model above, the `x[i]`'s are random variables since they are on the left-hand side of a tilde-statement but are not conditioned on. +In contrast, the `y[i]`'s are data, not a random variable. -**On the other hand, parallelising the sampling of latent values is not supported.** -Attempting to do this will either error or give wrong results. +This means if your model contains parallelised random variables, you are not guaranteed to get the same results every time, even if you set the random seed: ```{julia} -#| error: true -@model function threaded_assume_bad(N) - x = Vector{Float64}(undef, N) +using Random: Xoshiro + +chn = sample(Xoshiro(468), threadsafe_model, NUTS(), 100; verbose=false) +@show mean(chn[Symbol("x[1]")]) + +chn = sample(Xoshiro(468), threadsafe_model, NUTS(), 100; verbose=false) +@show mean(chn[Symbol("x[1]")]); +``` + +Some samplers do indeed yield the same results (but NUTS is not one of them, and we cannot make any concrete guarantees at this point in time): + +```{julia} +chn = sample(Xoshiro(468), threadsafe_model, MH(), 100) +@show mean(chn[Symbol("x[1]")]) + +chn = sample(Xoshiro(468), threadsafe_model, MH(), 100) +@show mean(chn[Symbol("x[1]")]); +``` + +Now consider a different situation where you only have parallelised data, and not random variables. +In this case we _do_ guarantee that sampling is fully deterministic: + +```{julia} +@model function threaded_data(N) + x ~ Normal() + y = Vector{Float64}(undef, N) Threads.@threads for i in 1:N - x[i] ~ Normal() + y[i] ~ Normal(x) end - return x end +threadsafe_model_data_only = setthreadsafe(threaded_data(N) | (; y = y), true) -model = threaded_assume_bad(100) +chn = sample(Xoshiro(468), threadsafe_model_data_only, NUTS(), 100; verbose=false) +@show mean(chn[:x]) -# This will throw an error (and probably a different error -# each time it's run...) -model() +chn = sample(Xoshiro(468), threadsafe_model_data_only, NUTS(), 100; verbose=false) +@show mean(chn[:x]); ``` ## When is threadsafe evaluation really needed? @@ -205,92 +226,23 @@ One can see that evaluation of the threadsafe model is substantially slower: ```{julia} using Chairmarks, DynamicPPL -function benchmark_eval(m) - vi = VarInfo(m) - display(median(@be DynamicPPL.evaluate!!($m, $vi))) -end - -benchmark_eval(model_no_threadsafe) -benchmark_eval(model_threadsafe) +display(median(@be rand($model_no_threadsafe))) +display(median(@be rand($model_threadsafe))) ``` In previous versions of Turing, this cost would **always** be incurred whenever Julia was launched with multiple threads, even if the model did not use any threading at all! -## Alternatives to threaded observation - -An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}). - -For example: - -```{julia} -# Note that `y` has to be passed as an argument; you can't -# condition on it because otherwise `y[i]` won't be defined. -@model function threaded_obs_addlogprob(N, y) - x ~ Normal() - - # Instead of this: - # Threads.@threads for i in 1:N - # y[i] ~ Normal(x) - # end - - # Do this instead: - lls = map(1:N) do i - Threads.@spawn begin - logpdf(Normal(x), y[i]) - end - end - @addlogprob! sum(fetch.(lls)) -end -``` - -In a similar way, you can also use your favourite parallelism package, such as `FLoops.jl` or `OhMyThreads.jl`. -See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-turing-jl-model/54064/9) for some examples. - -We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended. - -The main downside of this approach is: - -1. You can't use conditioning syntax to provide data; it has to be passed as an argument or otherwise included inside the model. -2. You can't use `predict` to sample new data. - -On the other hand, one benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible. - -```{julia} -using Random -N = 100 -y = randn(N) -# Note that since `@addlogprob!` is outside of the threaded block, we don't -# need to use `setthreadsafe`. -model = threaded_obs_addlogprob(N, y) -nuts_kwargs = (progress=false, verbose=false) - -chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...) -chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...) -mean(chain1[:x]), mean(chain2[:x]) # should be identical -``` - -In contrast, the original `threaded_obs` (which used tilde inside `Threads.@threads`) is not reproducible when using `MCMCThreads()`. -(In principle, we would like to fix this bug, but we haven't yet investigated where it stems from.) - -```{julia} -model = setthreadsafe(threaded_obs(N) | (; y = y), true) -nuts_kwargs = (progress=false, verbose=false) -chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...) -chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...) -mean(chain1[:x]), mean(chain2[:x]) # oops! -``` - ## AD support Finally, if you are [using Turing with automatic differentiation]({{< meta usage-automatic-differentiation >}}), you also need to keep track of which AD backends support threadsafe evaluation. -ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation. +ForwardDiff and Enzyme are the only AD backends that we find to work reliably with threaded model evaluation. +Note that for Enzyme, you should use a relatively recent version (at least v0.13.140) as prior to that reverse-mode could yield incorrect results. -In particular: +In contrast, ReverseDiff sometimes gives right results, but quite often gives incorrect gradients. +Mooncake [currently does not support multithreading at all](https://github.com/chalk-lab/Mooncake.jl/issues/570). - - ReverseDiff sometimes gives right results, but quite often gives incorrect gradients. - - Mooncake [currently does not support multithreading at all](https://github.com/chalk-lab/Mooncake.jl/issues/570). - - Enzyme [mostly gives the right result, but sometimes gives incorrect gradients](https://github.com/TuringLang/DynamicPPL.jl/issues/1131). +For more details you can take a look at [the `threaded_...` models on ADTests](https://turinglang.org/ADTests/). ## Under the hood @@ -298,6 +250,14 @@ In particular: This part will likely only be of interest to DynamicPPL developers and the very curious user. ::: +Code in DynamicPPL that uses `VarInfo` is _not_ threadsafe in general. +For any code that uses `VarInfo`, observe statements are threadsafe, but assume statements are not. + +In contrast, code that uses `OnlyAccsVarInfo` is completely threadsafe. + +Now, virtually all of DynamicPPL and Turing use `OnlyAccsVarInfo`, this means that most of DynamicPPL and Turing is threadsafe. +You only need to worry about edge cases if you are still using `VarInfo` directly. + ### Why is VarInfo not threadsafe? As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races. @@ -325,9 +285,3 @@ Thus, since DynamicPPL v0.39, `LogDensityFunction` itself is completely threadsa Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all. It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata. - -There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can. -For example, this is why `predict` is threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation. - -However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata). -See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.