Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions ml_metrics/_src/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ml_metrics._src.aggregates import types
from ml_metrics._src.metrics import utils
from ml_metrics.google.tools.signal_registry import registry
from ml_metrics.google.tools.signal_registry import signal_types
from ml_metrics._src.tools.telemetry import telemetry
import numpy as np

Expand All @@ -37,7 +38,7 @@

# TODO: b/368067018 - Inherit from ml_metrics._src.aggregates.stats.Histogram.
@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
usage_category=telemetry.CATEGORY.METRIC,
)
@dataclasses.dataclass
Expand Down Expand Up @@ -152,7 +153,7 @@ def result(self) -> CalibrationHistogramResult:


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
usage_category=telemetry.CATEGORY.METRIC,
)
class ClassificationAggFn(chainable.AggregateFn):
Expand Down Expand Up @@ -220,7 +221,7 @@ def merge_states(self, states):


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def classification_metrics(
Expand Down Expand Up @@ -275,7 +276,7 @@ def classification_metrics(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def precision(
Expand Down Expand Up @@ -326,7 +327,7 @@ def precision(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def ppv(
Expand Down Expand Up @@ -377,7 +378,7 @@ def ppv(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def recall(
Expand Down Expand Up @@ -428,7 +429,7 @@ def recall(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def f1_score(
Expand Down Expand Up @@ -479,7 +480,7 @@ def f1_score(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def accuracy(
Expand Down Expand Up @@ -530,7 +531,7 @@ def accuracy(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def binary_accuracy(
Expand Down Expand Up @@ -581,7 +582,7 @@ def binary_accuracy(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def sensitivity(
Expand Down Expand Up @@ -632,7 +633,7 @@ def sensitivity(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def tpr(
Expand Down Expand Up @@ -683,7 +684,7 @@ def tpr(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def specificity(
Expand Down Expand Up @@ -734,7 +735,7 @@ def specificity(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def tnr(
Expand Down Expand Up @@ -785,7 +786,7 @@ def tnr(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def fall_out(
Expand Down Expand Up @@ -836,7 +837,7 @@ def fall_out(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def fpr(
Expand Down Expand Up @@ -887,7 +888,7 @@ def fpr(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def miss_rate(
Expand Down Expand Up @@ -938,7 +939,7 @@ def miss_rate(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def fnr(
Expand Down Expand Up @@ -989,7 +990,7 @@ def fnr(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def negative_predictive_value(
Expand Down Expand Up @@ -1040,7 +1041,7 @@ def negative_predictive_value(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def npv(
Expand Down Expand Up @@ -1091,7 +1092,7 @@ def npv(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def false_discovery_rate(
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def false_discovery_rate(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def false_omission_rate(
Expand Down Expand Up @@ -1193,7 +1194,7 @@ def false_omission_rate(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def threat_score(
Expand Down Expand Up @@ -1244,7 +1245,7 @@ def threat_score(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def positive_likelihood_ratio(
Expand Down Expand Up @@ -1295,7 +1296,7 @@ def positive_likelihood_ratio(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def negative_likelihood_ratio(
Expand Down Expand Up @@ -1346,7 +1347,7 @@ def negative_likelihood_ratio(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def diagnostic_odds_ratio(
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def diagnostic_odds_ratio(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def positive_predictive_value(
Expand Down Expand Up @@ -1448,7 +1449,7 @@ def positive_predictive_value(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def intersection_over_union(
Expand Down Expand Up @@ -1499,7 +1500,7 @@ def intersection_over_union(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def prevalence(
Expand Down Expand Up @@ -1550,7 +1551,7 @@ def prevalence(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def prevalence_threshold(
Expand Down Expand Up @@ -1601,7 +1602,7 @@ def prevalence_threshold(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def matthews_correlation_coefficient(
Expand Down Expand Up @@ -1652,7 +1653,7 @@ def matthews_correlation_coefficient(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def informedness(
Expand Down Expand Up @@ -1703,7 +1704,7 @@ def informedness(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def markedness(
Expand Down Expand Up @@ -1754,7 +1755,7 @@ def markedness(


@registry.register_signal(
signal_modality=registry.SignalModality.OTHER,
signal_modality=signal_types.SignalModality.OTHER,
enable_telemetry=False,
)
def balanced_accuracy(
Expand Down
Loading
Loading