From bad12527baa2a0bae55531d0f75eabe1b0d84662 Mon Sep 17 00:00:00 2001 From: Pavan Tummala Date: Wed, 20 Aug 2025 03:54:41 +0530 Subject: [PATCH 1/2] Added from_logits flag --- src/metrax/classification_metrics.py | 47 +++++++++++- src/metrax/classification_metrics_test.py | 87 +++++++++++++++++------ 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/src/metrax/classification_metrics.py b/src/metrax/classification_metrics.py index 57f054d..5025e5c 100644 --- a/src/metrax/classification_metrics.py +++ b/src/metrax/classification_metrics.py @@ -43,6 +43,23 @@ def _default_threshold(num_thresholds: int) -> jax.Array: return thresholds +def _convert_logits_to_probabilities( + predictions: jax.Array, from_logits: bool )-> jax.Array: + """Converts logits to probabilities if `from_logits` is True. + Args: + predictions: JAX array of predicted values, expected to be logits if `from_logits` is True. + from_logits: Boolean indicating whether `predictions` are logits. + Returns: + JAX array of probabilities if `from_logits` is True, otherwise returns `predictions` unchanged. + """ + print(f"predictions_before.shape: {predictions.shape}, from_logits: {from_logits}") + if from_logits: + predictions = jax.nn.softmax(predictions, axis=-1) + # Assuming binary classification, take the positive class probability. + + print(f"predictions.shape: {predictions.shape}, from_logits: {from_logits}") + return predictions + @flax.struct.dataclass class Accuracy(base.Average): r"""Computes accuracy, which is the frequency with which `predictions` match `labels`. @@ -69,6 +86,7 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, sample_weights: jax.Array | None = None, + from_logits: bool = False, ) -> 'Accuracy': """Updates the metric state with new `predictions` and `labels`. @@ -99,6 +117,12 @@ def from_model_output( comparison, or if `sample_weights` cannot be broadcast to `labels`' shape. """ + + if from_logits: + predictions = jax.nn.softmax(predictions, axis=-1) + + + correct = predictions == labels count = jnp.ones_like(labels, dtype=jnp.int32) if sample_weights is not None: @@ -149,6 +173,7 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, + from_logits: bool = False, ) -> 'Precision': """Updates the metric. @@ -166,7 +191,10 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + predictions = jnp.where(predictions >= threshold, 1, 0) + true_positives = jnp.sum((predictions == 1) & (labels == 1)) false_positives = jnp.sum((predictions == 1) & (labels == 0)) @@ -219,7 +247,7 @@ def empty(cls) -> 'Recall': @classmethod def from_model_output( - cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5 + cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False ) -> 'Recall': """Updates the metric. @@ -237,6 +265,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + predictions = jnp.where(predictions >= threshold, 1, 0) true_positives = jnp.sum((predictions == 1) & (labels == 1)) false_negatives = jnp.sum((predictions == 0) & (labels == 1)) @@ -325,6 +355,7 @@ def from_model_output( labels: jax.Array, sample_weights: jax.Array | None = None, num_thresholds: int = 200, + from_logits: bool = False, ) -> 'AUCPR': """Updates the metric. @@ -345,6 +376,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + pred_is_pos = jnp.greater( predictions, _default_threshold(num_thresholds=num_thresholds)[..., None], @@ -513,6 +546,7 @@ def from_model_output( labels: jax.Array, sample_weights: jax.Array | None = None, num_thresholds: int = 200, + from_logits: bool = False, ) -> 'AUCROC': """Updates the metric. @@ -533,6 +567,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + pred_is_pos = jnp.greater( predictions, _default_threshold(num_thresholds=num_thresholds)[..., None], @@ -622,7 +658,8 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, beta = beta, - threshold = 0.5,) -> 'FBetaScore': + threshold = 0.5, + from_logits : bool = False) -> 'FBetaScore': """Updates the metric. Note: When only predictions and labels are given, the score calculated is the F1 score if the FBetaScore beta value has not been previously modified. @@ -656,9 +693,13 @@ def from_model_output( if threshold < 0.0 or threshold > 1.0: raise ValueError('The "Threshold" value must be between 0 and 1.') + # If the predictions are logits, convert them to probabilities + print(f"labels.shape: {labels.shape}, predictions.shape: {predictions.shape}, from_logits: {from_logits}") + predictions = _convert_logits_to_probabilities(predictions, from_logits) + # Modify predictions with the given threshold value predictions = jnp.where(predictions >= threshold, 1, 0) - + # Assign the true_positive, false_positive, and false_negative their values """ We are calculating these values manually instead of using Metrax's diff --git a/src/metrax/classification_metrics_test.py b/src/metrax/classification_metrics_test.py index 2e54451..58bddef 100644 --- a/src/metrax/classification_metrics_test.py +++ b/src/metrax/classification_metrics_test.py @@ -15,6 +15,8 @@ """Tests for metrax classification metrics.""" import os + +import jax os.environ['KERAS_BACKEND'] = 'jax' from absl.testing import absltest @@ -34,6 +36,7 @@ ).astype(np.float32) OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)) OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16) +OUTPUT_LOGITS_F16 = np.random.randn(BATCHES, BATCH_SIZE).astype(jnp.float16) OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32) OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16) OUTPUT_LABELS_BS1 = np.random.randint( @@ -92,9 +95,13 @@ def test_fbeta_empty(self): ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS), ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), + ('batch_size_logits_f16', OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS,True), ) - def test_accuracy(self, y_true, y_pred, sample_weights): + def test_accuracy(self, y_true, y_pred, sample_weights, from_logits=False): """Test that `Accuracy` metric computes correct values.""" + + print(f"y_true.shape: {y_true}, y_pred.shape: {y_pred}, sample_weights: {sample_weights}, from_logits: {from_logits}") + if sample_weights is None: sample_weights = np.ones_like(y_true) metrax_accuracy = metrax.Accuracy.empty() @@ -104,6 +111,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, ) metrax_accuracy = metrax_accuracy.merge(update) keras_accuracy.update_state(labels, logits, weights) @@ -120,6 +128,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): @parameterized.named_parameters( ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5), + ('basic_f16_logits', OUTPUT_LABELS, OUTPUT_LOGITS_F16, 0.5, True), ('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7), ('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1), ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), @@ -130,12 +139,17 @@ def test_accuracy(self, y_true, y_pred, sample_weights): ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), ) - def test_precision(self, y_true, y_pred, threshold): + def test_precision(self, y_true, y_pred, threshold,from_logits=False): """Test that `Precision` metric computes correct values.""" - y_true = y_true.reshape((-1,)) - y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + y_true_keras = y_true.reshape((-1,)) + if from_logits: + probs = jax.nn.softmax(y_pred, axis=-1) + y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) + else: + y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + keras_precision = keras.metrics.Precision(thresholds=threshold) - keras_precision.update_state(y_true, y_pred) + keras_precision.update_state(y_true_keras, y_pred_keras) expected = keras_precision.result() metric = None @@ -144,6 +158,7 @@ def test_precision(self, y_true, y_pred, threshold): predictions=logits, labels=labels, threshold=threshold, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) @@ -161,6 +176,7 @@ def test_precision(self, y_true, y_pred, threshold): ('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5), ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7), ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.1), + ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5,True), ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), ('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7), ('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1), @@ -169,12 +185,18 @@ def test_precision(self, y_true, y_pred, threshold): ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), ) - def test_recall(self, y_true, y_pred, threshold): + def test_recall(self, y_true, y_pred, threshold, from_logits=False): """Test that `Recall` metric computes correct values.""" - y_true = y_true.reshape((-1,)) - y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + + y_true_keras = y_true.reshape((-1,)) + if from_logits: + probs = jax.nn.softmax(y_pred, axis=-1) + y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) + else: + y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + keras_recall = keras.metrics.Recall(thresholds=threshold) - keras_recall.update_state(y_true, y_pred) + keras_recall.update_state(y_true_keras, y_pred_keras) expected = keras_recall.result() metric = None @@ -183,6 +205,7 @@ def test_recall(self, y_true, y_pred, threshold): predictions=logits, labels=labels, threshold=threshold, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) @@ -193,19 +216,21 @@ def test_recall(self, y_true, y_pred, threshold): @parameterized.product( inputs=( - (OUTPUT_LABELS, OUTPUT_PREDS, None), - (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), - (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), + (OUTPUT_LABELS, OUTPUT_PREDS, None, False), + (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), + (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), + (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True), ), dtype=( jnp.float16, jnp.float32, jnp.bfloat16, + jnp.bfloat16 ), ) - def test_aucpr(self, inputs, dtype): + def test_aucpr(self, inputs, dtype, from_logits=False): """Test that `AUC-PR` Metric computes correct values.""" - y_true, y_pred, sample_weights = inputs + y_true, y_pred, sample_weights, from_logits = inputs y_true = y_true.astype(dtype) y_pred = y_pred.astype(dtype) if sample_weights is None: @@ -217,10 +242,13 @@ def test_aucpr(self, inputs, dtype): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) keras_aucpr = keras.metrics.AUC(curve='PR') + if from_logits: + y_pred = jax.nn.softmax(y_pred, axis=-1) for labels, logits, weights in zip(y_true, y_pred, sample_weights): keras_aucpr.update_state(labels, logits, sample_weight=weights) expected = keras_aucpr.result() @@ -234,19 +262,21 @@ def test_aucpr(self, inputs, dtype): @parameterized.product( inputs=( - (OUTPUT_LABELS, OUTPUT_PREDS, None), - (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), - (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), + (OUTPUT_LABELS, OUTPUT_PREDS, None, False), + (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), + (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), + (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True) ), dtype=( jnp.float16, jnp.float32, jnp.bfloat16, + jnp.bfloat16 ), ) - def test_aucroc(self, inputs, dtype): + def test_aucroc(self, inputs, dtype, from_logits=False): """Test that `AUC-ROC` Metric computes correct values.""" - y_true, y_pred, sample_weights = inputs + y_true, y_pred, sample_weights,from_logits = inputs y_true = y_true.astype(dtype) y_pred = y_pred.astype(dtype) if sample_weights is None: @@ -258,10 +288,13 @@ def test_aucroc(self, inputs, dtype): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, # AUCROC typically expects probabilities. ) metric = update if metric is None else metric.merge(update) keras_aucroc = keras.metrics.AUC(curve='ROC') + if from_logits: + y_pred = jax.nn.softmax(y_pred, axis=-1) for labels, logits, weights in zip(y_true, y_pred, sample_weights): keras_aucroc.update_state(labels, logits, sample_weight=weights) expected = keras_aucroc.result() @@ -286,16 +319,26 @@ def test_aucroc(self, inputs, dtype): ('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), ('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), ('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), + ('batch_size_one_logits_beta_3.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 3.0, True), ) - def test_fbetascore(self, y_true, y_pred, threshold, beta): + def test_fbetascore(self, y_true, y_pred, threshold, beta, from_logits=False): + + print(f"y_true.shape: {y_true.shape}, y_pred.shape: {y_pred.shape}, threshold: {threshold}, beta: {beta}") # Define the Keras FBeta class to be tested against + keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold) - keras_fbeta.update_state(y_true, y_pred) + if from_logits: + y_pred_keras = jax.nn.softmax(y_pred, axis=-1) + keras_fbeta.update_state(y_true, y_pred_keras) + + else: + keras_fbeta.update_state(y_true, y_pred) + expected = keras_fbeta.result() # Calculate the F-beta score using the metrax variant metric = metrax.FBetaScore - metric = metric.from_model_output(y_pred, y_true, beta, threshold) + metric = metric.from_model_output(y_pred, y_true, beta, threshold, from_logits=from_logits) # Use lower tolerance for lower precision dtypes. rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5 From b543d72560b2acd196b8367e3d5ed225ee498529 Mon Sep 17 00:00:00 2001 From: Pavan Tummala Date: Wed, 20 Aug 2025 04:23:51 +0530 Subject: [PATCH 2/2] ruff fixed --- src/metrax/classification_metrics.py | 20 ++++++++++---------- src/metrax/classification_metrics_test.py | 12 +++--------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/metrax/classification_metrics.py b/src/metrax/classification_metrics.py index 5025e5c..973072f 100644 --- a/src/metrax/classification_metrics.py +++ b/src/metrax/classification_metrics.py @@ -45,19 +45,19 @@ def _default_threshold(num_thresholds: int) -> jax.Array: def _convert_logits_to_probabilities( predictions: jax.Array, from_logits: bool )-> jax.Array: - """Converts logits to probabilities if `from_logits` is True. + """Converts logits to probabilities if `from_logits` is True Args: predictions: JAX array of predicted values, expected to be logits if `from_logits` is True. from_logits: Boolean indicating whether `predictions` are logits. Returns: JAX array of probabilities if `from_logits` is True, otherwise returns `predictions` unchanged. - """ - print(f"predictions_before.shape: {predictions.shape}, from_logits: {from_logits}") + """ + if from_logits: predictions = jax.nn.softmax(predictions, axis=-1) # Assuming binary classification, take the positive class probability. - - print(f"predictions.shape: {predictions.shape}, from_logits: {from_logits}") + + return predictions @flax.struct.dataclass @@ -120,9 +120,9 @@ def from_model_output( if from_logits: predictions = jax.nn.softmax(predictions, axis=-1) - - + + correct = predictions == labels count = jnp.ones_like(labels, dtype=jnp.int32) if sample_weights is not None: @@ -694,12 +694,12 @@ def from_model_output( raise ValueError('The "Threshold" value must be between 0 and 1.') # If the predictions are logits, convert them to probabilities - print(f"labels.shape: {labels.shape}, predictions.shape: {predictions.shape}, from_logits: {from_logits}") + predictions = _convert_logits_to_probabilities(predictions, from_logits) - + # Modify predictions with the given threshold value predictions = jnp.where(predictions >= threshold, 1, 0) - + # Assign the true_positive, false_positive, and false_negative their values """ We are calculating these values manually instead of using Metrax's diff --git a/src/metrax/classification_metrics_test.py b/src/metrax/classification_metrics_test.py index 58bddef..8bd3f4d 100644 --- a/src/metrax/classification_metrics_test.py +++ b/src/metrax/classification_metrics_test.py @@ -99,9 +99,6 @@ def test_fbeta_empty(self): ) def test_accuracy(self, y_true, y_pred, sample_weights, from_logits=False): """Test that `Accuracy` metric computes correct values.""" - - print(f"y_true.shape: {y_true}, y_pred.shape: {y_pred}, sample_weights: {sample_weights}, from_logits: {from_logits}") - if sample_weights is None: sample_weights = np.ones_like(y_true) metrax_accuracy = metrax.Accuracy.empty() @@ -147,7 +144,7 @@ def test_precision(self, y_true, y_pred, threshold,from_logits=False): y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) else: y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) - + keras_precision = keras.metrics.Precision(thresholds=threshold) keras_precision.update_state(y_true_keras, y_pred_keras) expected = keras_precision.result() @@ -187,14 +184,14 @@ def test_precision(self, y_true, y_pred, threshold,from_logits=False): ) def test_recall(self, y_true, y_pred, threshold, from_logits=False): """Test that `Recall` metric computes correct values.""" - + y_true_keras = y_true.reshape((-1,)) if from_logits: probs = jax.nn.softmax(y_pred, axis=-1) y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) else: y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) - + keras_recall = keras.metrics.Recall(thresholds=threshold) keras_recall.update_state(y_true_keras, y_pred_keras) expected = keras_recall.result() @@ -322,10 +319,7 @@ def test_aucroc(self, inputs, dtype, from_logits=False): ('batch_size_one_logits_beta_3.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 3.0, True), ) def test_fbetascore(self, y_true, y_pred, threshold, beta, from_logits=False): - - print(f"y_true.shape: {y_true.shape}, y_pred.shape: {y_pred.shape}, threshold: {threshold}, beta: {beta}") # Define the Keras FBeta class to be tested against - keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold) if from_logits: y_pred_keras = jax.nn.softmax(y_pred, axis=-1)