From 81f7bba0f4abea92b4a13a7355d7019edc848238 Mon Sep 17 00:00:00 2001 From: Nikola Savic Date: Fri, 13 Feb 2026 01:00:54 +0100 Subject: [PATCH] Filter extra kwargs in NnxWrapper.update to support MultiMetric --- src/metrax/nnx/nnx_wrapper.py | 21 ++++++++++++++++++++- src/metrax/nnx/nnx_wrapper_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/metrax/nnx/nnx_wrapper.py b/src/metrax/nnx/nnx_wrapper.py index 4382211..78edf5c 100644 --- a/src/metrax/nnx/nnx_wrapper.py +++ b/src/metrax/nnx/nnx_wrapper.py @@ -15,6 +15,7 @@ """A wrapper for metrax metrics to be used with NNX.""" from flax import nnx +import inspect class NnxWrapper(nnx.metrics.Metric): @@ -27,7 +28,25 @@ def reset(self) -> None: self.clu_metric = self.clu_metric.empty() def update(self, **kwargs) -> None: - other_clu_metric = self.clu_metric.from_model_output(**kwargs) + # Filter kwargs to only those accepted by from_model_output + sig = inspect.signature(self.clu_metric.from_model_output) + params = sig.parameters + has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + if has_var_keyword: + # Method accepts **kwargs, pass everything + filtered_kwargs = kwargs + else: + accepted = { + name + for name, p in params.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted} + other_clu_metric = self.clu_metric.from_model_output(**filtered_kwargs) self.clu_metric = self.clu_metric.merge(other_clu_metric) def compute(self): diff --git a/src/metrax/nnx/nnx_wrapper_test.py b/src/metrax/nnx/nnx_wrapper_test.py index d5a48b5..a9d451a 100644 --- a/src/metrax/nnx/nnx_wrapper_test.py +++ b/src/metrax/nnx/nnx_wrapper_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +from flax import nnx import jax.numpy as jnp import metrax import metrax.nnx @@ -102,6 +103,30 @@ def test_metric_update_and_compute(self, y_true, y_pred, sample_weights): rtol=rtol, atol=atol, ) + def test_update_ignores_extra_kwargs(self): + """Tests that NnxWrapper ignores extra kwargs not accepted by the metric.""" + nnx_metric = metrax.nnx.MSE() + # Should not raise even though 'extra_arg' is not a valid parameter. + nnx_metric.update( + predictions=jnp.array([1.0, 2.0, 3.0]), + labels=jnp.array([1.5, 2.5, 3.5]), + extra_arg=jnp.array([0.0]), + ) + np.testing.assert_allclose(nnx_metric.compute(), 0.25, rtol=1e-5) + + def test_multi_metric_with_different_kwargs(self): + """Tests that NnxWrapper works with MultiMetric passing mixed kwargs.""" + metrics = nnx.MultiMetric( + loss=metrax.nnx.Average(), + accuracy=metrax.nnx.Accuracy(), + ) + metrics.update( + values=jnp.array([0.5, 0.2]), + predictions=jnp.array([0, 1]), + labels=jnp.array([0, 1]), + ) + np.testing.assert_allclose(metrics.compute()['loss'], 0.35, rtol=1e-5) + np.testing.assert_allclose(metrics.compute()['accuracy'], 1.0, rtol=1e-5) if __name__ == '__main__':