Skip to content

perf: avoid JIT retrace in fit() for repeated calls#59

Open
MoAly98 wants to merge 2 commits into
mainfrom
maly/profile
Open

perf: avoid JIT retrace in fit() for repeated calls#59
MoAly98 wants to merge 2 commits into
mainfrom
maly/profile

Conversation

@MoAly98
Copy link
Copy Markdown
Owner

@MoAly98 MoAly98 commented Apr 3, 2026

  • Replace per-call NLL closure with reusable _WrappedNLL eqx.Module to avoid JIT retracing
  • Convert fixed_state leaves to JAX arrays so POI test values are treated as dynamic inputs
  • Fix _WrappedNLL docstring to accurately reflect that observation is stored as-is (not coerced to JAX arrays)
  • Update paramore source to feat/bernstein-analytical-integrate

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

Updated a dependency source branch reference in project configuration and refactored NLL evaluation in the fitting module by extracting a closure into a dedicated eqx.Module class while ensuring proper JAX array conversion for dynamic input handling.

Changes

Cohort / File(s) Summary
Dependency Source Update
pyproject.toml
Updated paramore git branch from "master" to "feat/bernstein-analytical-integrate"
NLL Evaluation Refactoring
src/everwillow/_src/inference/fitting.py
Introduced _WrappedNLL eqx.Module class to encapsulate NLL evaluation logic previously inlined as a closure; added jax.tree.map(jnp.asarray, fixed_state) conversion to ensure proper JAX array handling for dynamic inputs in eqx.filter_jit

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Poem

🐰 The closure found a cozy home,
In a module of its own,
JAX arrays now dance in place,
A refactored, cleaner space,
With branches new, the code takes flight! 🌿✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main performance improvement: avoiding JIT retrace in the fit() function for repeated calls, which aligns with the primary change of replacing a closure with a reusable _WrappedNLL module.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch maly/profile

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
pyproject.toml (1)

184-184: Consider pinning paramore to an explicit commit SHA in pyproject.toml for clarity.

At Line 184, the dependency uses a branch reference (feat/bernstein-analytical-integrate). While the uv.lock file already captures a pinned commit (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) for reproducibility, adding an explicit rev parameter in pyproject.toml makes the immutable reference clear without requiring the lockfile as the source of truth.

Example:

paramore = { git = "https://github.com/maxgalli/paramore", rev = "88dfeec0ad7427cc1e62f32ec10f40ed4fdee160" }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` at line 184, The pyproject.toml currently references the
paramore dependency by branch name (paramore = { git =
"https://github.com/maxgalli/paramore", branch =
"feat/bernstein-analytical-integrate" }); update that entry to pin to the exact
commit referenced in the lockfile by replacing the branch spec with a rev field
using the commit SHA (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) so the
dependency is explicitly immutable and clear without relying solely on the
lockfile.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/everwillow/_src/inference/fitting.py`:
- Around line 348-350: The current cast fixed_state = jax.tree.map(jnp.asarray,
fixed_state) blindly converts every leaf and can replace partition sentinels
(e.g., None) expected by sl.partition/sl.combine_partitions; change the mapping
to only convert actual array-like leaves and leave sentinel values untouched
(e.g., test for the partition sentinel such as None or the library's
missing-leaf marker and return it unchanged), so replace the jnp.asarray mapping
in the fixed_state construction with a conditional mapper that preserves sl
partition sentinels before returning arrays.

---

Nitpick comments:
In `@pyproject.toml`:
- Line 184: The pyproject.toml currently references the paramore dependency by
branch name (paramore = { git = "https://github.com/maxgalli/paramore", branch =
"feat/bernstein-analytical-integrate" }); update that entry to pin to the exact
commit referenced in the lockfile by replacing the branch spec with a rev field
using the commit SHA (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) so the
dependency is explicitly immutable and clear without relying solely on the
lockfile.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 209ae86b-43cc-4d38-b352-43996113dd9e

📥 Commits

Reviewing files that changed from the base of the PR and between ef7615e and c927a78.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (2)
  • pyproject.toml
  • src/everwillow/_src/inference/fitting.py

Comment on lines +348 to +350
# Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs.
fixed_state = jax.tree.map(jnp.asarray, fixed_state)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve partition sentinels when casting fixed_state leaves.

On Line 349, jax.tree.map(jnp.asarray, fixed_state) applies to every leaf. If sl.partition uses None as a missing-leaf sentinel, this can break sl.combine_partitions expectations (or error) during reconstruction.

Proposed fix
-    fixed_state = jax.tree.map(jnp.asarray, fixed_state)
+    fixed_state = jax.tree.map(
+        lambda x: x if x is None else jnp.asarray(x),
+        fixed_state,
+        is_leaf=lambda x: x is None,
+    )
#!/bin/bash
set -euo pipefail

# Locate statelib implementation files
fd -i statelib src

# Inspect partition/combine conventions and sentinel handling
rg -n -C4 'def partition|def combine_partitions|notnone|is None|None' src/everwillow/_src

# Inspect fixed_state casting in fitting flow
rg -n -C3 'fixed_state\s*=|jnp\.asarray|tree\.map|tree_map' src/everwillow/_src/inference/fitting.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/everwillow/_src/inference/fitting.py` around lines 348 - 350, The current
cast fixed_state = jax.tree.map(jnp.asarray, fixed_state) blindly converts every
leaf and can replace partition sentinels (e.g., None) expected by
sl.partition/sl.combine_partitions; change the mapping to only convert actual
array-like leaves and leave sentinel values untouched (e.g., test for the
partition sentinel such as None or the library's missing-leaf marker and return
it unchanged), so replace the jnp.asarray mapping in the fixed_state
construction with a conditional mapper that preserves sl partition sentinels
before returning arrays.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves JAX/Equinox compilation reuse for repeated inference fits by changing how the NLL wrapper is constructed and by adjusting how fixed parameter values enter the jitted optimisation path. Also updates dependency sources to align with the new workflow.

Changes:

  • Replace per-call wrapped NLL closure with a reusable eqx.Module (_WrappedNLL) to reduce retracing in repeated fit() calls.
  • Convert fixed_state leaves to JAX arrays so fixed POI test values are treated as dynamic by eqx.filter_jit.
  • Update paramore source branch and refresh uv.lock (including switching evermore to PyPI source).

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.

File Description
src/everwillow/_src/inference/fitting.py Introduces _WrappedNLL and changes fixed-state handling to reduce JIT retracing during repeated fits.
pyproject.toml Points paramore to feat/bernstein-analytical-integrate.
uv.lock Refreshes lock entries for evermore, paramore, and platform markers.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +348 to +349
# Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs.
fixed_state = jax.tree.map(jnp.asarray, fixed_state)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed_state produced by sl.partition(...) contains None placeholders for non-fixed params. Mapping jnp.asarray over the whole tree will call jnp.asarray(None) and raise a TypeError. Preserve None leaves (e.g., map only over non-None leaves or use an is_leaf/conditional so placeholders remain None).

Suggested change
# Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs.
fixed_state = jax.tree.map(jnp.asarray, fixed_state)
# Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs,
# while preserving `None` placeholders produced by `sl.partition`.
fixed_state = jax.tree.map(
lambda x: None if x is None else jnp.asarray(x),
fixed_state,
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot None values are not pytree leaves in JAX — they're ignored by tree.map and tree.leaves.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — in JAX, None is treated as an empty pytree (not a leaf), so jax.tree.map already skips None placeholders automatically. The existing jax.tree.map(jnp.asarray, fixed_state) call is correct as-is and no guard is needed.


nll_fn: tp.Callable = eqx.field(static=True)
observation: PyTree

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _WrappedNLL docstring says the observation is treated as a dynamic input, but observation is stored as-is. If callers pass observations containing Python scalars (common in tests), eqx.filter_jit will treat those leaves as static and still retrace when observation values change. Consider converting numeric scalar leaves in observation to JAX arrays (while leaving non-array-like leaves alone) or adjust the docstring/expectations accordingly.

Suggested change
@staticmethod
def _coerce_dynamic_observation_leaf(leaf):
if isinstance(leaf, (bool, int, float, complex)):
return jnp.asarray(leaf)
return leaf
def __init__(self, nll_fn: tp.Callable, observation: PyTree):
object.__setattr__(self, "nll_fn", nll_fn)
object.__setattr__(
self,
"observation",
jax.tree_util.tree_map(
self._coerce_dynamic_observation_leaf,
observation,
),
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Owner Author

@MoAly98 MoAly98 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot It would only matter if someone called fit() in a loop with different observation values as Python scalars. In our setup composed calls to fit function will always pass same observation, so never triggeringgg a retrace

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — since observations don't change across composed fit calls in your setup, no coercion is needed. Updated the docstring in 06febc7 to say "holding a fixed observation" instead of "with observation as dynamic input" to accurately reflect the actual behaviour.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object.__setattr__ if kind of hacky, for converting a class attribute on constructing the equinox dataclass, you can use converters: https://docs.kidger.site/equinox/api/module/advanced_fields/#converters-and-static-fields

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, what about other data types that asarray can convert (lists)? Maybe instead of doing this we should just require that all inputs (e.g. observation or param values) are always provided as jax.Arrays ?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pfackeldey i did not follow the copilot suggestion, so currently do not do a converstion on construction of anything. I only convert the fixed_state to an array to mitigate what the hypothesis test workflow does (fix poi value to compute the LR). I think that the suggestion to ask for jax.Arrays might be sensible, but not sure how strict we should be here? raise? coerce ourselves and raise if that fails?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is usually to restrict to minimal surface, i.e., only allow arrays as input (we don't do any coercing). The reason is that it's always easy to 'open up' in the future, but much harder to restrict the api surface again.

@MoAly98
Copy link
Copy Markdown
Owner Author

MoAly98 commented Apr 3, 2026

Using @maxgalli setup:

  • paramore fix takes down upper_limit wall time from 200s to 50s or so
  • everwillow fix takes down wall time to 37s on a cold run, 0.9s in hot run if jax.jit is being used.
  • In combine, upper limit computation takes: (5.58 real 2.64 user 0.59 sys)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants