diff --git a/src/metrax/classification_metrics.py b/src/metrax/classification_metrics.py index 57f054d..973072f 100644 --- a/src/metrax/classification_metrics.py +++ b/src/metrax/classification_metrics.py @@ -43,6 +43,23 @@ def _default_threshold(num_thresholds: int) -> jax.Array: return thresholds +def _convert_logits_to_probabilities( + predictions: jax.Array, from_logits: bool )-> jax.Array: + """Converts logits to probabilities if `from_logits` is True + Args: + predictions: JAX array of predicted values, expected to be logits if `from_logits` is True. + from_logits: Boolean indicating whether `predictions` are logits. + Returns: + JAX array of probabilities if `from_logits` is True, otherwise returns `predictions` unchanged. + """ + + if from_logits: + predictions = jax.nn.softmax(predictions, axis=-1) + # Assuming binary classification, take the positive class probability. + + + return predictions + @flax.struct.dataclass class Accuracy(base.Average): r"""Computes accuracy, which is the frequency with which `predictions` match `labels`. @@ -69,6 +86,7 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, sample_weights: jax.Array | None = None, + from_logits: bool = False, ) -> 'Accuracy': """Updates the metric state with new `predictions` and `labels`. @@ -99,6 +117,12 @@ def from_model_output( comparison, or if `sample_weights` cannot be broadcast to `labels`' shape. """ + + if from_logits: + predictions = jax.nn.softmax(predictions, axis=-1) + + + correct = predictions == labels count = jnp.ones_like(labels, dtype=jnp.int32) if sample_weights is not None: @@ -149,6 +173,7 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, + from_logits: bool = False, ) -> 'Precision': """Updates the metric. @@ -166,7 +191,10 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + predictions = jnp.where(predictions >= threshold, 1, 0) + true_positives = jnp.sum((predictions == 1) & (labels == 1)) false_positives = jnp.sum((predictions == 1) & (labels == 0)) @@ -219,7 +247,7 @@ def empty(cls) -> 'Recall': @classmethod def from_model_output( - cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5 + cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False ) -> 'Recall': """Updates the metric. @@ -237,6 +265,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + predictions = jnp.where(predictions >= threshold, 1, 0) true_positives = jnp.sum((predictions == 1) & (labels == 1)) false_negatives = jnp.sum((predictions == 0) & (labels == 1)) @@ -325,6 +355,7 @@ def from_model_output( labels: jax.Array, sample_weights: jax.Array | None = None, num_thresholds: int = 200, + from_logits: bool = False, ) -> 'AUCPR': """Updates the metric. @@ -345,6 +376,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + pred_is_pos = jnp.greater( predictions, _default_threshold(num_thresholds=num_thresholds)[..., None], @@ -513,6 +546,7 @@ def from_model_output( labels: jax.Array, sample_weights: jax.Array | None = None, num_thresholds: int = 200, + from_logits: bool = False, ) -> 'AUCROC': """Updates the metric. @@ -533,6 +567,8 @@ def from_model_output( ValueError: If type of `labels` is wrong or the shapes of `predictions` and `labels` are incompatible. """ + predictions = _convert_logits_to_probabilities(predictions, from_logits) + pred_is_pos = jnp.greater( predictions, _default_threshold(num_thresholds=num_thresholds)[..., None], @@ -622,7 +658,8 @@ def from_model_output( predictions: jax.Array, labels: jax.Array, beta = beta, - threshold = 0.5,) -> 'FBetaScore': + threshold = 0.5, + from_logits : bool = False) -> 'FBetaScore': """Updates the metric. Note: When only predictions and labels are given, the score calculated is the F1 score if the FBetaScore beta value has not been previously modified. @@ -656,6 +693,10 @@ def from_model_output( if threshold < 0.0 or threshold > 1.0: raise ValueError('The "Threshold" value must be between 0 and 1.') + # If the predictions are logits, convert them to probabilities + + predictions = _convert_logits_to_probabilities(predictions, from_logits) + # Modify predictions with the given threshold value predictions = jnp.where(predictions >= threshold, 1, 0) diff --git a/src/metrax/classification_metrics_test.py b/src/metrax/classification_metrics_test.py index 2e54451..8bd3f4d 100644 --- a/src/metrax/classification_metrics_test.py +++ b/src/metrax/classification_metrics_test.py @@ -15,6 +15,8 @@ """Tests for metrax classification metrics.""" import os + +import jax os.environ['KERAS_BACKEND'] = 'jax' from absl.testing import absltest @@ -34,6 +36,7 @@ ).astype(np.float32) OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)) OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16) +OUTPUT_LOGITS_F16 = np.random.randn(BATCHES, BATCH_SIZE).astype(jnp.float16) OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32) OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16) OUTPUT_LABELS_BS1 = np.random.randint( @@ -92,8 +95,9 @@ def test_fbeta_empty(self): ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS), ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), + ('batch_size_logits_f16', OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS,True), ) - def test_accuracy(self, y_true, y_pred, sample_weights): + def test_accuracy(self, y_true, y_pred, sample_weights, from_logits=False): """Test that `Accuracy` metric computes correct values.""" if sample_weights is None: sample_weights = np.ones_like(y_true) @@ -104,6 +108,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, ) metrax_accuracy = metrax_accuracy.merge(update) keras_accuracy.update_state(labels, logits, weights) @@ -120,6 +125,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): @parameterized.named_parameters( ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5), + ('basic_f16_logits', OUTPUT_LABELS, OUTPUT_LOGITS_F16, 0.5, True), ('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7), ('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1), ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), @@ -130,12 +136,17 @@ def test_accuracy(self, y_true, y_pred, sample_weights): ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), ) - def test_precision(self, y_true, y_pred, threshold): + def test_precision(self, y_true, y_pred, threshold,from_logits=False): """Test that `Precision` metric computes correct values.""" - y_true = y_true.reshape((-1,)) - y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + y_true_keras = y_true.reshape((-1,)) + if from_logits: + probs = jax.nn.softmax(y_pred, axis=-1) + y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) + else: + y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + keras_precision = keras.metrics.Precision(thresholds=threshold) - keras_precision.update_state(y_true, y_pred) + keras_precision.update_state(y_true_keras, y_pred_keras) expected = keras_precision.result() metric = None @@ -144,6 +155,7 @@ def test_precision(self, y_true, y_pred, threshold): predictions=logits, labels=labels, threshold=threshold, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) @@ -161,6 +173,7 @@ def test_precision(self, y_true, y_pred, threshold): ('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5), ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7), ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.1), + ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5,True), ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), ('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7), ('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1), @@ -169,12 +182,18 @@ def test_precision(self, y_true, y_pred, threshold): ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), ) - def test_recall(self, y_true, y_pred, threshold): + def test_recall(self, y_true, y_pred, threshold, from_logits=False): """Test that `Recall` metric computes correct values.""" - y_true = y_true.reshape((-1,)) - y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + + y_true_keras = y_true.reshape((-1,)) + if from_logits: + probs = jax.nn.softmax(y_pred, axis=-1) + y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) + else: + y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) + keras_recall = keras.metrics.Recall(thresholds=threshold) - keras_recall.update_state(y_true, y_pred) + keras_recall.update_state(y_true_keras, y_pred_keras) expected = keras_recall.result() metric = None @@ -183,6 +202,7 @@ def test_recall(self, y_true, y_pred, threshold): predictions=logits, labels=labels, threshold=threshold, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) @@ -193,19 +213,21 @@ def test_recall(self, y_true, y_pred, threshold): @parameterized.product( inputs=( - (OUTPUT_LABELS, OUTPUT_PREDS, None), - (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), - (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), + (OUTPUT_LABELS, OUTPUT_PREDS, None, False), + (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), + (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), + (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True), ), dtype=( jnp.float16, jnp.float32, jnp.bfloat16, + jnp.bfloat16 ), ) - def test_aucpr(self, inputs, dtype): + def test_aucpr(self, inputs, dtype, from_logits=False): """Test that `AUC-PR` Metric computes correct values.""" - y_true, y_pred, sample_weights = inputs + y_true, y_pred, sample_weights, from_logits = inputs y_true = y_true.astype(dtype) y_pred = y_pred.astype(dtype) if sample_weights is None: @@ -217,10 +239,13 @@ def test_aucpr(self, inputs, dtype): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, ) metric = update if metric is None else metric.merge(update) keras_aucpr = keras.metrics.AUC(curve='PR') + if from_logits: + y_pred = jax.nn.softmax(y_pred, axis=-1) for labels, logits, weights in zip(y_true, y_pred, sample_weights): keras_aucpr.update_state(labels, logits, sample_weight=weights) expected = keras_aucpr.result() @@ -234,19 +259,21 @@ def test_aucpr(self, inputs, dtype): @parameterized.product( inputs=( - (OUTPUT_LABELS, OUTPUT_PREDS, None), - (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), - (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), + (OUTPUT_LABELS, OUTPUT_PREDS, None, False), + (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), + (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), + (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True) ), dtype=( jnp.float16, jnp.float32, jnp.bfloat16, + jnp.bfloat16 ), ) - def test_aucroc(self, inputs, dtype): + def test_aucroc(self, inputs, dtype, from_logits=False): """Test that `AUC-ROC` Metric computes correct values.""" - y_true, y_pred, sample_weights = inputs + y_true, y_pred, sample_weights,from_logits = inputs y_true = y_true.astype(dtype) y_pred = y_pred.astype(dtype) if sample_weights is None: @@ -258,10 +285,13 @@ def test_aucroc(self, inputs, dtype): predictions=logits, labels=labels, sample_weights=weights, + from_logits=from_logits, # AUCROC typically expects probabilities. ) metric = update if metric is None else metric.merge(update) keras_aucroc = keras.metrics.AUC(curve='ROC') + if from_logits: + y_pred = jax.nn.softmax(y_pred, axis=-1) for labels, logits, weights in zip(y_true, y_pred, sample_weights): keras_aucroc.update_state(labels, logits, sample_weight=weights) expected = keras_aucroc.result() @@ -286,16 +316,23 @@ def test_aucroc(self, inputs, dtype): ('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), ('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), ('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), + ('batch_size_one_logits_beta_3.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 3.0, True), ) - def test_fbetascore(self, y_true, y_pred, threshold, beta): + def test_fbetascore(self, y_true, y_pred, threshold, beta, from_logits=False): # Define the Keras FBeta class to be tested against keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold) - keras_fbeta.update_state(y_true, y_pred) + if from_logits: + y_pred_keras = jax.nn.softmax(y_pred, axis=-1) + keras_fbeta.update_state(y_true, y_pred_keras) + + else: + keras_fbeta.update_state(y_true, y_pred) + expected = keras_fbeta.result() # Calculate the F-beta score using the metrax variant metric = metrax.FBetaScore - metric = metric.from_model_output(y_pred, y_true, beta, threshold) + metric = metric.from_model_output(y_pred, y_true, beta, threshold, from_logits=from_logits) # Use lower tolerance for lower precision dtypes. rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5