In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).
It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand for a conditioned model:
StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)
In Turing,
StatsBase.predictis overloaded to dispatch onDynamicPPL.ModelandMCMCChains.Chains(https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and callsrandon the model. We also want to do the same thing forInferenceData(see #465).It would be convenient if
StatsBase.predictwas added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just callrandfor a conditioned model: