perf: avoid JIT retrace in fit() for repeated calls#59
Conversation
📝 WalkthroughWalkthroughUpdated a dependency source branch reference in project configuration and refactored NLL evaluation in the fitting module by extracting a closure into a dedicated Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
pyproject.toml (1)
184-184: Consider pinningparamoreto an explicit commit SHA inpyproject.tomlfor clarity.At Line 184, the dependency uses a branch reference (
feat/bernstein-analytical-integrate). While theuv.lockfile already captures a pinned commit (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) for reproducibility, adding an explicitrevparameter inpyproject.tomlmakes the immutable reference clear without requiring the lockfile as the source of truth.Example:
paramore = { git = "https://github.com/maxgalli/paramore", rev = "88dfeec0ad7427cc1e62f32ec10f40ed4fdee160" }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` at line 184, The pyproject.toml currently references the paramore dependency by branch name (paramore = { git = "https://github.com/maxgalli/paramore", branch = "feat/bernstein-analytical-integrate" }); update that entry to pin to the exact commit referenced in the lockfile by replacing the branch spec with a rev field using the commit SHA (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) so the dependency is explicitly immutable and clear without relying solely on the lockfile.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/everwillow/_src/inference/fitting.py`:
- Around line 348-350: The current cast fixed_state = jax.tree.map(jnp.asarray,
fixed_state) blindly converts every leaf and can replace partition sentinels
(e.g., None) expected by sl.partition/sl.combine_partitions; change the mapping
to only convert actual array-like leaves and leave sentinel values untouched
(e.g., test for the partition sentinel such as None or the library's
missing-leaf marker and return it unchanged), so replace the jnp.asarray mapping
in the fixed_state construction with a conditional mapper that preserves sl
partition sentinels before returning arrays.
---
Nitpick comments:
In `@pyproject.toml`:
- Line 184: The pyproject.toml currently references the paramore dependency by
branch name (paramore = { git = "https://github.com/maxgalli/paramore", branch =
"feat/bernstein-analytical-integrate" }); update that entry to pin to the exact
commit referenced in the lockfile by replacing the branch spec with a rev field
using the commit SHA (88dfeec0ad7427cc1e62f32ec10f40ed4fdee160) so the
dependency is explicitly immutable and clear without relying solely on the
lockfile.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 209ae86b-43cc-4d38-b352-43996113dd9e
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (2)
pyproject.tomlsrc/everwillow/_src/inference/fitting.py
| # Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs. | ||
| fixed_state = jax.tree.map(jnp.asarray, fixed_state) | ||
|
|
There was a problem hiding this comment.
Preserve partition sentinels when casting fixed_state leaves.
On Line 349, jax.tree.map(jnp.asarray, fixed_state) applies to every leaf. If sl.partition uses None as a missing-leaf sentinel, this can break sl.combine_partitions expectations (or error) during reconstruction.
Proposed fix
- fixed_state = jax.tree.map(jnp.asarray, fixed_state)
+ fixed_state = jax.tree.map(
+ lambda x: x if x is None else jnp.asarray(x),
+ fixed_state,
+ is_leaf=lambda x: x is None,
+ )#!/bin/bash
set -euo pipefail
# Locate statelib implementation files
fd -i statelib src
# Inspect partition/combine conventions and sentinel handling
rg -n -C4 'def partition|def combine_partitions|notnone|is None|None' src/everwillow/_src
# Inspect fixed_state casting in fitting flow
rg -n -C3 'fixed_state\s*=|jnp\.asarray|tree\.map|tree_map' src/everwillow/_src/inference/fitting.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/everwillow/_src/inference/fitting.py` around lines 348 - 350, The current
cast fixed_state = jax.tree.map(jnp.asarray, fixed_state) blindly converts every
leaf and can replace partition sentinels (e.g., None) expected by
sl.partition/sl.combine_partitions; change the mapping to only convert actual
array-like leaves and leave sentinel values untouched (e.g., test for the
partition sentinel such as None or the library's missing-leaf marker and return
it unchanged), so replace the jnp.asarray mapping in the fixed_state
construction with a conditional mapper that preserves sl partition sentinels
before returning arrays.
There was a problem hiding this comment.
Pull request overview
Improves JAX/Equinox compilation reuse for repeated inference fits by changing how the NLL wrapper is constructed and by adjusting how fixed parameter values enter the jitted optimisation path. Also updates dependency sources to align with the new workflow.
Changes:
- Replace per-call wrapped NLL closure with a reusable
eqx.Module(_WrappedNLL) to reduce retracing in repeatedfit()calls. - Convert
fixed_stateleaves to JAX arrays so fixed POI test values are treated as dynamic byeqx.filter_jit. - Update
paramoresource branch and refreshuv.lock(including switchingevermoreto PyPI source).
Reviewed changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/everwillow/_src/inference/fitting.py |
Introduces _WrappedNLL and changes fixed-state handling to reduce JIT retracing during repeated fits. |
pyproject.toml |
Points paramore to feat/bernstein-analytical-integrate. |
uv.lock |
Refreshes lock entries for evermore, paramore, and platform markers. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs. | ||
| fixed_state = jax.tree.map(jnp.asarray, fixed_state) |
There was a problem hiding this comment.
fixed_state produced by sl.partition(...) contains None placeholders for non-fixed params. Mapping jnp.asarray over the whole tree will call jnp.asarray(None) and raise a TypeError. Preserve None leaves (e.g., map only over non-None leaves or use an is_leaf/conditional so placeholders remain None).
| # Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs. | |
| fixed_state = jax.tree.map(jnp.asarray, fixed_state) | |
| # Convert to JAX arrays so eqx.filter_jit treats them as dynamic inputs, | |
| # while preserving `None` placeholders produced by `sl.partition`. | |
| fixed_state = jax.tree.map( | |
| lambda x: None if x is None else jnp.asarray(x), | |
| fixed_state, | |
| ) |
There was a problem hiding this comment.
@copilot None values are not pytree leaves in JAX — they're ignored by tree.map and tree.leaves.
There was a problem hiding this comment.
You're right — in JAX, None is treated as an empty pytree (not a leaf), so jax.tree.map already skips None placeholders automatically. The existing jax.tree.map(jnp.asarray, fixed_state) call is correct as-is and no guard is needed.
|
|
||
| nll_fn: tp.Callable = eqx.field(static=True) | ||
| observation: PyTree | ||
|
|
There was a problem hiding this comment.
The _WrappedNLL docstring says the observation is treated as a dynamic input, but observation is stored as-is. If callers pass observations containing Python scalars (common in tests), eqx.filter_jit will treat those leaves as static and still retrace when observation values change. Consider converting numeric scalar leaves in observation to JAX arrays (while leaving non-array-like leaves alone) or adjust the docstring/expectations accordingly.
| @staticmethod | |
| def _coerce_dynamic_observation_leaf(leaf): | |
| if isinstance(leaf, (bool, int, float, complex)): | |
| return jnp.asarray(leaf) | |
| return leaf | |
| def __init__(self, nll_fn: tp.Callable, observation: PyTree): | |
| object.__setattr__(self, "nll_fn", nll_fn) | |
| object.__setattr__( | |
| self, | |
| "observation", | |
| jax.tree_util.tree_map( | |
| self._coerce_dynamic_observation_leaf, | |
| observation, | |
| ), | |
| ) |
There was a problem hiding this comment.
@copilot It would only matter if someone called fit() in a loop with different observation values as Python scalars. In our setup composed calls to fit function will always pass same observation, so never triggeringgg a retrace
There was a problem hiding this comment.
Agreed — since observations don't change across composed fit calls in your setup, no coercion is needed. Updated the docstring in 06febc7 to say "holding a fixed observation" instead of "with observation as dynamic input" to accurately reflect the actual behaviour.
There was a problem hiding this comment.
object.__setattr__ if kind of hacky, for converting a class attribute on constructing the equinox dataclass, you can use converters: https://docs.kidger.site/equinox/api/module/advanced_fields/#converters-and-static-fields
There was a problem hiding this comment.
also, what about other data types that asarray can convert (lists)? Maybe instead of doing this we should just require that all inputs (e.g. observation or param values) are always provided as jax.Arrays ?
There was a problem hiding this comment.
@pfackeldey i did not follow the copilot suggestion, so currently do not do a converstion on construction of anything. I only convert the fixed_state to an array to mitigate what the hypothesis test workflow does (fix poi value to compute the LR). I think that the suggestion to ask for jax.Arrays might be sensible, but not sure how strict we should be here? raise? coerce ourselves and raise if that fails?
There was a problem hiding this comment.
My take is usually to restrict to minimal surface, i.e., only allow arrays as input (we don't do any coercing). The reason is that it's always easy to 'open up' in the future, but much harder to restrict the api surface again.
|
Using @maxgalli setup:
|
Agent-Logs-Url: https://github.com/MoAly98/everwillow/sessions/dd1e1607-a312-4031-afbc-8f27f6fd132f Co-authored-by: MoAly98 <89147478+MoAly98@users.noreply.github.com>
_WrappedNLLeqx.Module to avoid JIT retracingfixed_stateleaves to JAX arrays so POI test values are treated as dynamic inputs_WrappedNLLdocstring to accurately reflect that observation is stored as-is (not coerced to JAX arrays)paramoresource tofeat/bernstein-analytical-integrate