Skip to content

improve implementation of predict(...; include_all) #1042

@penelopeysm

Description

@penelopeysm

#984 uses init!! to implement predict. However, the implementation of include_all=false seems a bit wasteful because it first constructs a chain using all parameters (including the ones we don't want) before then subsetting the chain. It seems more sensible to, inside the loop, filter the dictionary of varname => value pairs in each iteration so that those variables don't end up in the chain to begin with.

predictive_samples = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
end
chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

Not making this change in #984 to avoid complicating matters.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions