From 47b600a68ddb352e28cb1238022c80836e104ff4 Mon Sep 17 00:00:00 2001 From: ML Metrics Team Date: Thu, 19 Mar 2026 11:13:06 -0700 Subject: [PATCH] internal change PiperOrigin-RevId: 886274768 --- ml_metrics/_src/metrics/classification.py | 67 ++++++++++++----------- ml_metrics/_src/metrics/retrieval.py | 37 +++++++------ ml_metrics/_src/metrics/text.py | 7 ++- ml_metrics/_src/metrics/utils.py | 3 +- 4 files changed, 59 insertions(+), 55 deletions(-) diff --git a/ml_metrics/_src/metrics/classification.py b/ml_metrics/_src/metrics/classification.py index 1825ed81..fe164721 100644 --- a/ml_metrics/_src/metrics/classification.py +++ b/ml_metrics/_src/metrics/classification.py @@ -23,6 +23,7 @@ from ml_metrics._src.aggregates import types from ml_metrics._src.metrics import utils from ml_metrics.google.tools.signal_registry import registry +from ml_metrics.google.tools.signal_registry import signal_types from ml_metrics._src.tools.telemetry import telemetry import numpy as np @@ -37,7 +38,7 @@ # TODO: b/368067018 - Inherit from ml_metrics._src.aggregates.stats.Histogram. @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, usage_category=telemetry.CATEGORY.METRIC, ) @dataclasses.dataclass @@ -152,7 +153,7 @@ def result(self) -> CalibrationHistogramResult: @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, usage_category=telemetry.CATEGORY.METRIC, ) class ClassificationAggFn(chainable.AggregateFn): @@ -220,7 +221,7 @@ def merge_states(self, states): @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def classification_metrics( @@ -275,7 +276,7 @@ def classification_metrics( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def precision( @@ -326,7 +327,7 @@ def precision( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def ppv( @@ -377,7 +378,7 @@ def ppv( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def recall( @@ -428,7 +429,7 @@ def recall( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def f1_score( @@ -479,7 +480,7 @@ def f1_score( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def accuracy( @@ -530,7 +531,7 @@ def accuracy( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def binary_accuracy( @@ -581,7 +582,7 @@ def binary_accuracy( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def sensitivity( @@ -632,7 +633,7 @@ def sensitivity( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def tpr( @@ -683,7 +684,7 @@ def tpr( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def specificity( @@ -734,7 +735,7 @@ def specificity( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def tnr( @@ -785,7 +786,7 @@ def tnr( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def fall_out( @@ -836,7 +837,7 @@ def fall_out( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def fpr( @@ -887,7 +888,7 @@ def fpr( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def miss_rate( @@ -938,7 +939,7 @@ def miss_rate( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def fnr( @@ -989,7 +990,7 @@ def fnr( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def negative_predictive_value( @@ -1040,7 +1041,7 @@ def negative_predictive_value( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def npv( @@ -1091,7 +1092,7 @@ def npv( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def false_discovery_rate( @@ -1142,7 +1143,7 @@ def false_discovery_rate( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def false_omission_rate( @@ -1193,7 +1194,7 @@ def false_omission_rate( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def threat_score( @@ -1244,7 +1245,7 @@ def threat_score( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def positive_likelihood_ratio( @@ -1295,7 +1296,7 @@ def positive_likelihood_ratio( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def negative_likelihood_ratio( @@ -1346,7 +1347,7 @@ def negative_likelihood_ratio( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def diagnostic_odds_ratio( @@ -1397,7 +1398,7 @@ def diagnostic_odds_ratio( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def positive_predictive_value( @@ -1448,7 +1449,7 @@ def positive_predictive_value( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def intersection_over_union( @@ -1499,7 +1500,7 @@ def intersection_over_union( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def prevalence( @@ -1550,7 +1551,7 @@ def prevalence( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def prevalence_threshold( @@ -1601,7 +1602,7 @@ def prevalence_threshold( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def matthews_correlation_coefficient( @@ -1652,7 +1653,7 @@ def matthews_correlation_coefficient( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def informedness( @@ -1703,7 +1704,7 @@ def informedness( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def markedness( @@ -1754,7 +1755,7 @@ def markedness( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def balanced_accuracy( diff --git a/ml_metrics/_src/metrics/retrieval.py b/ml_metrics/_src/metrics/retrieval.py index a65706e9..e1272849 100644 --- a/ml_metrics/_src/metrics/retrieval.py +++ b/ml_metrics/_src/metrics/retrieval.py @@ -17,6 +17,7 @@ from ml_metrics._src.aggregates import retrieval from ml_metrics._src.aggregates import types from ml_metrics.google.tools.signal_registry import registry +from ml_metrics.google.tools.signal_registry import signal_types # TODO: b/368688941 - Remove this alias once all users are migrated to the new @@ -25,7 +26,7 @@ @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def topk_retrieval_metrics( @@ -59,7 +60,7 @@ def topk_retrieval_metrics( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def precision( @@ -91,7 +92,7 @@ def precision( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def ppv( @@ -123,7 +124,7 @@ def ppv( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def recall( @@ -155,7 +156,7 @@ def recall( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def sensitivity( @@ -187,7 +188,7 @@ def sensitivity( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def tpr( @@ -219,7 +220,7 @@ def tpr( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def intersection_over_union( @@ -251,7 +252,7 @@ def intersection_over_union( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def positive_predictive_value( @@ -283,7 +284,7 @@ def positive_predictive_value( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def f1_score( @@ -315,7 +316,7 @@ def f1_score( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def miss_rate( @@ -347,7 +348,7 @@ def miss_rate( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def mean_average_precision( @@ -379,7 +380,7 @@ def mean_average_precision( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def mean_reciprocal_rank( @@ -411,7 +412,7 @@ def mean_reciprocal_rank( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def accuracy( @@ -443,7 +444,7 @@ def accuracy( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def dcg_score( @@ -475,7 +476,7 @@ def dcg_score( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def ndcg_score( @@ -507,7 +508,7 @@ def ndcg_score( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def fowlkes_mallows_index( @@ -539,7 +540,7 @@ def fowlkes_mallows_index( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def false_discovery_rate( @@ -571,7 +572,7 @@ def false_discovery_rate( @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, enable_telemetry=False, ) def threat_score( diff --git a/ml_metrics/_src/metrics/text.py b/ml_metrics/_src/metrics/text.py index 5a376709..f2aeb18d 100644 --- a/ml_metrics/_src/metrics/text.py +++ b/ml_metrics/_src/metrics/text.py @@ -20,11 +20,12 @@ from ml_metrics._src.aggregates import text from ml_metrics._src.signals import text as text_scores from ml_metrics.google.tools.signal_registry import registry +from ml_metrics.google.tools.signal_registry import signal_types from ml_metrics._src.tools.telemetry import telemetry @registry.register_signal( - signal_modality=registry.SignalModality.TEXT, + signal_modality=signal_types.SignalModality.TEXT, usage_category=telemetry.CATEGORY.METRIC, ) def topk_word_ngrams( @@ -79,7 +80,7 @@ def topk_word_ngrams( @registry.register_signal( - signal_modality=registry.SignalModality.TEXT, + signal_modality=signal_types.SignalModality.TEXT, usage_category=telemetry.CATEGORY.METRIC, ) def pattern_frequency( @@ -117,7 +118,7 @@ def pattern_frequency( @registry.register_signal( - signal_modality=registry.SignalModality.TEXT, + signal_modality=signal_types.SignalModality.TEXT, usage_category=telemetry.CATEGORY.METRIC, ) def avg_alphabetical_char_count( diff --git a/ml_metrics/_src/metrics/utils.py b/ml_metrics/_src/metrics/utils.py index a3486e5b..8084480c 100644 --- a/ml_metrics/_src/metrics/utils.py +++ b/ml_metrics/_src/metrics/utils.py @@ -6,11 +6,12 @@ from ml_metrics._src.aggregates import classification from ml_metrics._src.aggregates import types from ml_metrics.google.tools.signal_registry import registry +from ml_metrics.google.tools.signal_registry import signal_types from ml_metrics._src.tools.telemetry import telemetry @registry.register_signal( - signal_modality=registry.SignalModality.OTHER, + signal_modality=signal_types.SignalModality.OTHER, usage_category=telemetry.CATEGORY.METRIC, ) def verify_input(y_true, y_pred, average, input_type, vocab, pos_label):