Move predict from Turing#716
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
The reason is some tests implicitly rely on the variance of the posterior samples. Discarding some initial samples fixes this. |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Pull Request Test Coverage Report for Build 12435577043Details
💛 - Coveralls |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master TuringLang/Turing.jl#716 +/- ##
==========================================
+ Coverage 85.93% 86.05% +0.12%
==========================================
Files 36 36
Lines 4280 4325 +45
==========================================
+ Hits 3678 3722 +44
- Misses 602 603 +1 ☔ View full report in Codecov by Sentry. |
|
We had a fast discussion on this today at the meeting. Tor raised that we should probably implement Also although we don't use |
Specifically, I was thinking |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
| varinfos::AbstractArray{<:AbstractVarInfo}; | ||
| include_all=false, | ||
| ) | ||
| predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) |
There was a problem hiding this comment.
Do we really need the PredictiveSample here?
My original suggestion was just to use Vector{<:OrderedDict} for the return-value (an abstractly typed PredictiveSample doesn't really offer anything beyond this, does it?)
There was a problem hiding this comment.
I haven't think too deep about this. A new type certainly is easier to dispatch on, but may not be necessary. Let me look into it
There was a problem hiding this comment.
But we don't need to dispatch on this, do we?
Also, maybe it makes more sense to follow the convetion of return the same type as the input type, i.e. in this case we should return a AbstractArray{<:AbstractVarInfo} and in the Chains case we return Chains
|
Otherwise stuff is starting to look nice though:) |
| varinfos::AbstractArray{<:AbstractVarInfo}; | ||
| include_all=false, | ||
| ) | ||
| predictive_samples = similar(varinfos, OrderedDict{Symbol,Any}) |
There was a problem hiding this comment.
Is there a resaon why you're using Symbol instead of VarName here? Seems better to use VarName, no?
There was a problem hiding this comment.
Yeah, this is confusing. The OrderedDict here is actually
OrderedDict{Symbol, Any}(
values => ..., # a vector of Tuples (varname, value)
logp =>
)using NamedTuple now, and use better field names
There was a problem hiding this comment.
But if your keeping other information than just the realizations (which, tbh, is IMO all we need here), why aren't we just returning the varinfos themselves (I suggested this is in the other comment here)?
There was a problem hiding this comment.
Ah gotcha 👍 , I totally misunderstood: I was reading the AbstractVector part but somehow ignored {<:AbstractVarInfo} part.
I can get behind the idea of using a vector of VarInfo for predict and return a vector of VarInfos. But I think the interface need to be spec-ed more. For instance, ideally we want be more clear on questions like: in the returned VarInfos, should the VarName be varname leaves or as appeared in the model; should the values in the returned VarInfo in transformed or constrained space; how exactly should model and input VarInfos conform to each other.
I am a bit short for time now, so after some thoughts, I think it's probably a good idea now to just keep all the logic in MCMCChainsExt and maintain exactly the same interface Turing.jl has now. Then in the future, we can work on to improve predict interface.
There was a problem hiding this comment.
hould the VarName be varname leaves or as appeared in the model
As appeared in the model:)
should the values in the returned VarInfo in transformed or constrained spac
Constrained space.
how exactly should model and input VarInfos conform to each other.
Confused; what do you mean?
Overall, I'm still a bit confused by this discussion: Turing.jl's predict literally does: iterate over chain, create varinfo, evaluate model on varinfo, and extract variables from varinfo.
So, why do we not just do
# In DynamicPPL.jl proper:
function predict(rng::Random.AbstractRNG, model::Model, chain::AbstractVector{<:AbstractVarInfo})
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do varinfo_params
DynamicPPL.setval_and_resample!(varinfo, varinfo_params)
model(rng, varinfo)
return deepcopy(varinfo)
end
endwhich is effectively what Turing.jl's predict does before converting into a Chains?
EDIT: This is ignoring the values_as_in_model which apparently is used in Turing.jl's predict, though, as mentioned in the other comment, it's very unclear if that's what we want here.
There was a problem hiding this comment.
Yeah, sorry it was a bit confusing.
I am thinking that it'll be more intuitive for predict to hold that
predicted_vis = predict(rng, model, varinfos)then
_varinfos = predict(rng, model, predicted_vis)returns varinfos that looks like varinfos.
But if the values are in constrained space, can this break?
There was a problem hiding this comment.
also does the above code return varinfos with values in constainted space?
There was a problem hiding this comment.
There was a problem hiding this comment.
julia> @model f() = x ~ Beta(2, 2)
f (generic function with 2 methods)
julia> m = f(); v = link!!(VarInfo(m), m)
VarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, typeof(identity)}, Int64}, Vector{Beta{Float64}}, Vector{VarName{:x, typeof(identity)}}, Vector{Float64}}}, Float64}((x = DynamicPPL.Metadata{Dict{VarName{:x, typeof(identity)}, Int64}, Vector{Beta{Float64}}, Vector{VarName{:x, typeof(identity)}}, Vector{Float64}}(Dict(x => 1), [x], UnitRange{Int64}[1:1], [1.3950616230805561], Beta{Float64}[Beta{Float64}(α=2.0, β=2.0)], [0], Dict{String, BitVector}("del" => [0], "trans" => [1])),), Base.RefValue{Float64}(-1.8839487262608983), Base.RefValue{Int64}(0))
julia> v[@varname(x)]
0.8013990731398695
julia> v[@varname(x)] # Should always be in (0, 1)
0.8013990731398695
julia> predict(m, [v])[1][@varname(x)]
1.3950616230805561| ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
| AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
| AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
| AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" |
There was a problem hiding this comment.
Hmm, this doesn't quite seem worth it to test predict, no? What's the reasoning here?
There was a problem hiding this comment.
I didn't add anything or change the implementation in this PR.
Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45
rely on quality of samples
There was a problem hiding this comment.
Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷
There was a problem hiding this comment.
Actually, would it be really bad to make AdvancedHMC be a test dependency of DynamicPPL? (again, I don't like this either, but it's not too bad, I would be for adding an issue for removing this dependency later than tempering more with this PR anymore)
There was a problem hiding this comment.
I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called make_chain_from_prior if the link doesn't bring you to the right place)
Feel free to take it if you think it's useful :)
There was a problem hiding this comment.
@sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that
- add this BLR model to DynamicPPL test models
- implement its analytical posterior
- sample from the analytical posterior directly and drop the AHMC deps.
There was a problem hiding this comment.
Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍
There was a problem hiding this comment.
prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it?
| function predict(model::Model, chain; include_all=false) | ||
| # this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` | ||
| # TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` | ||
| return predict(Random.default_rng(), model, chain; include_all) | ||
| end |
There was a problem hiding this comment.
If so, we should definitively inform the user of this, no? Otherwise they'll just be like "oh why is this not defined?"
There was a problem hiding this comment.
I don't think we want to export predict right now, so predict is only available through Turing.jl, give or take.
would function not defined be meaningful enough if user give other types of input?
There was a problem hiding this comment.
If Turing exports it, it's better for DynamicPPL to export it, too.
There was a problem hiding this comment.
I agree, I was proposing delaying this until a good predict spec is reached
| m_lin_reg = linear_reg(xs_train, ys_train) | ||
| chain_lin_reg = sample( | ||
| DynamicPPL.LogDensityFunction(m_lin_reg), | ||
| AdvancedHMC.NUTS(0.65), |
There was a problem hiding this comment.
Really doesn't seem necessary to use NUTS here. Just construct a Chains by hand or something, no?
There was a problem hiding this comment.
same reason as above: some tests relies on the quality of the samples
|
|
||
| # Examples | ||
| ```jldoctest | ||
| julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; |
There was a problem hiding this comment.
Same here: no need to use AdvancedHMC (or any of the other packages), just construct the Chains by hand.
This also doesn't actually show that you need to import MCMCChains for this to work, which might be a good idea
| ) | ||
| model(rng, varinfo, DynamicPPL.SampleFromPrior()) | ||
|
|
||
| vals = DynamicPPL.values_as_in_model(model, varinfo) |
There was a problem hiding this comment.
This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in := statements, which is not currently done.
There was a problem hiding this comment.
There was a problem hiding this comment.
Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕
There was a problem hiding this comment.
I saw your issue on :=, totally understand the concern here. But if we are not exporting predict, we can change this in near future, also we might want to use fix in the future, so the behavior will be right then.
We would need to make a minor release of Turing if we change this now.
There was a problem hiding this comment.
But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?
also we might want to use fix in the future
Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?
There was a problem hiding this comment.
But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?
Ideally, I would want this PR to do a proper implementation of predict in DynamicPPL. But now, I am okay with the PR being only a first step towards that.
Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?
what I was trying to say is that, with fix it should have the right behavior (with regards to :=). Of course not the only way to reach the desired behavior.
There was a problem hiding this comment.
Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment.
| the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
| predictive distribution. | ||
| """ | ||
| function predict(model::Model, chain; include_all=false) |
There was a problem hiding this comment.
In Turing.jl we're currently overloading StatsBase.predict, so we should probably do the same here, no?
There was a problem hiding this comment.
agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍
There was a problem hiding this comment.
But is this PR then held up until that PR is merged then?
There was a problem hiding this comment.
Also, that PR doesn't really matter; overloading StatsBase.predict here and now just means that we'll immediately be compliant with the AbstractPPL.jl interface when that PR merges?
There was a problem hiding this comment.
Grey area: for me it is okay, because this PR is just about introduce a Turing-faced predict, not a user faced one yet. At the moment predict is not a public API yet
There was a problem hiding this comment.
If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload AbstractPPL.predict here.
|
@sunxd3, let's get this merged in the next few days. |
|
will do, on top of my priority list |
|
Some regression test for TuringLang/Turing.jl#1352 are removed, as far as I can tell, it should be covered by tests of |
|
edit: this is inaccurate, |
|
@yebai @torfjelde @penelopeysm I think this should be ready for another look |
yebai
left a comment
There was a problem hiding this comment.
The tests are failing because fix, condition are exported by AbstractPPL, while DynamicPPL currently doesn't actually import these from AbstractPPL.
Let's fix these in this PR if possible.
|
I think the tests are run, but the codecov thinks the code in |
This PR migrates the
predictfunction from Turing.jl to DynamicPPL while maintaining its existing interface and core implementation. Sincepredictreturns aMCMCChains.Chain, the implementation is placed inMCMCChainsExt, similar togenerated_quantities.The purpose of the PR is not to add a "proper"
predictimplementation for DynamicPPL just yet, but as a first step towards that. Some improvements we should make in the future:StatsBase.predictNamedTuple,OrderedDict,VarInfo, etc.values_as_in_modelis probably wrong (ref Better support for:=#913)