Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module DynamicPPLMCMCChainsExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using DynamicPPL: DynamicPPL, Random
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..DynamicPPL: DynamicPPL, Random
using ..MCMCChains: MCMCChains
end

Expand Down Expand Up @@ -190,4 +190,57 @@ function _varname_pairs_without_varname_indexing(
return varname_pairs
end

function DynamicPPL.predict(
model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false
)
return predict(Random.default_rng(), model, chain; include_all=include_all)
end
function DynamicPPL.predict(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
params_only_chain = MCMCChains.get_sections(chain, :parameters)

varname_to_symbol = if :varname_to_symbol in keys(params_only_chain.info)
# the mapping is introduced in Turing by
# https://github.com/TuringLang/Turing.jl/commit/8d8416ac6c7363c6003ee6ea1fbaac26b4fc8dc3
params_only_chain.info[:varname_to_symbol]
else
# if not using Turing, then we need to construct the mapping ourselves
Dict{DynamicPPL.VarName,Symbol}([
DynamicPPL.@varname($sym) => sym for
sym in params_only_chain.name_map.parameters
])
end

num_of_chains = size(params_only_chain, 3)
# num_of_params =
num_of_samples = size(params_only_chain, 1)

predictions = []
for chain_idx in 1:num_of_chains
predictions_single_chain = []
for sample_idx in 1:num_of_samples
d_to_fix = OrderedDict{DynamicPPL.VarName,Any}()

# construct the dictionary to fix the model
for (vn, sym) in varname_to_symbol
d_to_fix[vn] = params_only_chain[sample_idx, sym, chain_idx]
end

# fix the model and sample from it
fixed_model = DynamicPPL.fix(model, d_to_fix)
predictive_sample = rand(rng, fixed_model)

# TODO: Turing version uses `Transition` and `bundle_samples` to form new chains: is it worth it to move Transition to AbstractMCMC?
push!(predictions, predictive_sample)
end
push!(predictions, predictions_single_chain)
end

return predictions
end

end
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ end
# Used here and overloaded in Turing
function getspace end

function predict end

"""
AbstractVarInfo

Expand Down
57 changes: 57 additions & 0 deletions test/predict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
module TestPredict

using Test
using DynamicPPL
using AbstractMCMC
using MCMCChains
using Distributions
using Random
using LogDensityProblemsAD
using AdvancedHMC
using Tapir
using ForwardDiff

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

model = linear_reg(xs_train, ys_train)

m_lin_reg = linear_reg(xs_train, ys_train)
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model))
ad_ldf = LogDensityProblemsAD.ADgradient(Val(:Tapir), ldf; safety_on=false)
chain = AbstractMCMC.sample(
ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:β]
)

DynamicPPL.predict(test_model, chain)

# LKJ example
@model demo_lkj() = x ~ LKJCholesky(2, 1.0)

model = demo_lkj()

ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.SimpleVarInfo(model))
ad_ldf = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ldf)

chain = AbstractMCMC.sample(
ad_ldf, AdvancedHMC.NUTS(0.6), 1000; chain_type=MCMCChains.Chains, param_names=[:Σ, :x]
)

end # module