Skip to content

Override optimize_marginal! for LogDensityFunctionWrapper#1395

Merged
yebai merged 1 commit into
mainfrom
fix-marginal-optimize
May 18, 2026
Merged

Override optimize_marginal! for LogDensityFunctionWrapper#1395
yebai merged 1 commit into
mainfrom
fix-marginal-optimize

Conversation

@yebai
Copy link
Copy Markdown
Member

@yebai yebai commented May 18, 2026

Fix test errors on main.

Adds a MarginalLogDensities.optimize_marginal! method specialised on LogDensityFunctionWrapper so the OptimizationProblem is rebuilt with the current non-marginalised parameters p2 on each call, rather than reusing a stale problem.

Rebuild the OptimizationProblem on each call so that updated non-marginalized
parameters are passed through to Optimization.jl.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

DynamicPPL.jl documentation for PR #1395 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1395/

@github-actions
Copy link
Copy Markdown
Contributor

Benchmarks @ c5b64ae

==================================================================================================
                                              eval                       gradient                 
                                           ----------  -------------------------------------------
Model                       dim    linked      primal     FwdDiff    RvsDiff    Mooncake    Enzyme
--------------------------------------------------------------------------------------------------
Simple assume observe         1     false     5.87 ns       10.25    1141.41       28.98      6.24
Simple assume observe         1      true     23.7 ns        2.52     313.07        7.22      1.54
Smorgasbord                 201     false      6.4 μs       67.60     122.58        6.29      8.95
Smorgasbord                 201      true     8.92 μs       64.52     120.29        5.32      5.98
Loop univariate 1k         1000     false     20.0 μs      942.90     258.98        7.19      5.89
Loop univariate 1k         1000      true     21.3 μs     1332.65     248.10        6.71      5.49
Multivariate 1k            1000     false     23.4 μs      321.67      72.21        8.59      2.93
Multivariate 1k            1000      true     27.3 μs      290.87      60.22        8.40      2.93
Loop univariate 10k       10000     false    202.0 μs    10335.64     274.31        7.29      5.80
Loop univariate 10k       10000      true    216.0 μs    10332.40     254.20        7.13      5.48
Multivariate 10k          10000     false    217.0 μs     4954.69      80.04       10.36      2.13
Multivariate 10k          10000      true    217.0 μs     5558.05      79.87       10.32      2.13
Dynamic                      15     false     1.35 μs         err      41.66       13.47     10.90
Dynamic                      10      true     1.87 μs        1.95      56.08       11.97     18.96
Submodel                      1     false     5.88 ns       10.27    1233.57       29.11      6.28
Submodel                      1      true     5.87 ns       10.18    1355.72       29.05      6.38
LDA                          12      true     22.0 μs        0.48       1.95       33.71       err
==================================================================================================

Each row times one of DynamicPPL's reference models on this PR's head. Dim is the parameter count; Linked is true when parameters have been mapped to unconstrained space. t(logdensity) is the wall-clock time for one log-density evaluation. The AD (automatic differentiation) backend columns express gradient time as a multiple of t(logdensity) — a value of 10 means computing the gradient takes 10× as long as the log-density. Lower is better throughout; err means the backend errored on that model. Compare against main below to spot regressions.

Main @ e2d4b5d
==================================================================================================
                                              eval                       gradient                 
                                           ----------  -------------------------------------------
Model                       dim    linked      primal     FwdDiff    RvsDiff    Mooncake    Enzyme
--------------------------------------------------------------------------------------------------
Simple assume observe         1     false     6.07 ns       10.61     833.84       31.61      6.22
Simple assume observe         1      true     21.6 ns        1.28     326.44        8.87      1.84
Smorgasbord                 201     false     6.23 μs       72.72     124.05        6.32      6.91
Smorgasbord                 201      true     8.94 μs       73.58     116.11        5.18      4.68
Loop univariate 1k         1000     false     18.8 μs     1133.54     275.51        8.36      6.96
Loop univariate 1k         1000      true     20.3 μs     1529.78     263.63        7.93      6.66
Multivariate 1k            1000     false     27.2 μs      415.00      61.46        7.49      1.85
Multivariate 1k            1000      true     24.7 μs      277.82      64.71        9.07      2.11
Loop univariate 10k       10000     false    175.0 μs    15679.96     347.19        9.15      7.10
Loop univariate 10k       10000      true    195.0 μs    14500.01     306.69        8.34      6.49
Multivariate 10k          10000     false    220.0 μs     5259.10      81.68       10.44      1.89
Multivariate 10k          10000      true    218.0 μs     4445.85      81.47       10.14      1.87
Dynamic                      15     false     1.45 μs         err      40.53       13.36     10.65
Dynamic                      10      true     1.96 μs        2.03      58.52       11.38     18.14
Submodel                      1     false      6.1 ns       10.67     982.89       31.97      6.15
Submodel                      1      true     6.29 ns       10.31    1120.60       30.81      5.83
LDA                          12      true     23.4 μs        0.50       2.02       31.54       err
==================================================================================================
Environment
Julia Version 1.11.9
Commit 53a02c0720c (2026-02-06 00:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

@codecov
Copy link
Copy Markdown

codecov Bot commented May 18, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.30%. Comparing base (e2d4b5d) to head (c5b64ae).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1395      +/-   ##
==========================================
+ Coverage   78.56%   82.30%   +3.73%     
==========================================
  Files          50       50              
  Lines        3522     3543      +21     
==========================================
+ Hits         2767     2916     +149     
+ Misses        755      627     -128     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yebai yebai enabled auto-merge May 18, 2026 14:38
@yebai yebai added this pull request to the merge queue May 18, 2026
Merged via the queue into main with commit 90a74c3 May 18, 2026
24 checks passed
@yebai yebai deleted the fix-marginal-optimize branch May 18, 2026 15:03
sunxd3 added a commit that referenced this pull request May 18, 2026
Bumps the patch version following #1395
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant