Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/metrax/nnx/nnx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""A wrapper for metrax metrics to be used with NNX."""

from flax import nnx
import inspect


class NnxWrapper(nnx.metrics.Metric):
Expand All @@ -27,7 +28,25 @@ def reset(self) -> None:
self.clu_metric = self.clu_metric.empty()

def update(self, **kwargs) -> None:
other_clu_metric = self.clu_metric.from_model_output(**kwargs)
# Filter kwargs to only those accepted by from_model_output
sig = inspect.signature(self.clu_metric.from_model_output)
params = sig.parameters
has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
if has_var_keyword:
# Method accepts **kwargs, pass everything
filtered_kwargs = kwargs
else:
accepted = {
name
for name, p in params.items()
if p.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted}
other_clu_metric = self.clu_metric.from_model_output(**filtered_kwargs)
self.clu_metric = self.clu_metric.merge(other_clu_metric)

def compute(self):
Expand Down
25 changes: 25 additions & 0 deletions src/metrax/nnx/nnx_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
import jax.numpy as jnp
import metrax
import metrax.nnx
Expand Down Expand Up @@ -102,6 +103,30 @@ def test_metric_update_and_compute(self, y_true, y_pred, sample_weights):
rtol=rtol,
atol=atol,
)
def test_update_ignores_extra_kwargs(self):
"""Tests that NnxWrapper ignores extra kwargs not accepted by the metric."""
nnx_metric = metrax.nnx.MSE()
# Should not raise even though 'extra_arg' is not a valid parameter.
nnx_metric.update(
predictions=jnp.array([1.0, 2.0, 3.0]),
labels=jnp.array([1.5, 2.5, 3.5]),
extra_arg=jnp.array([0.0]),
)
np.testing.assert_allclose(nnx_metric.compute(), 0.25, rtol=1e-5)

def test_multi_metric_with_different_kwargs(self):
"""Tests that NnxWrapper works with MultiMetric passing mixed kwargs."""
metrics = nnx.MultiMetric(
loss=metrax.nnx.Average(),
accuracy=metrax.nnx.Accuracy(),
)
metrics.update(
values=jnp.array([0.5, 0.2]),
predictions=jnp.array([0, 1]),
labels=jnp.array([0, 1]),
)
np.testing.assert_allclose(metrics.compute()['loss'], 0.35, rtol=1e-5)
np.testing.assert_allclose(metrics.compute()['accuracy'], 1.0, rtol=1e-5)


if __name__ == '__main__':
Expand Down