diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 905ee168b..714a3be2d 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -14,6 +14,20 @@ function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.ldf, x) end +function MarginalLogDensities.optimize_marginal!( + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, p2 +) + w0 = mld.u[mld.iw] + # Rebuild the optimization problem so that the current non-marginalized + # parameters are used by Optimization.jl. + prob = MarginalLogDensities.Optimization.OptimizationProblem(mld.f_opt, w0, p2) + sol = MarginalLogDensities.Optimization.solve(prob, mld.method.solver) + wopt = sol.u::typeof(w0) + objective = sol.objective::eltype(w0) + mld.u[mld.iw] .= wopt + return wopt, objective +end + """ DynamicPPL.marginalize( model::DynamicPPL.Model,