From e8985c1245b3019004a430ed1a73a7433bd08d66 Mon Sep 17 00:00:00 2001 From: Br1an67 <932039080@qq.com> Date: Sun, 1 Mar 2026 15:10:16 +0800 Subject: [PATCH] Support target_names in classification metrics renderers Add target_names support to ClassificationClassSeparationPlot, ClassificationProbDistribution, and ClassificationQualityByFeatureTable renderers. When column_mapping.target_names is set as a dict, the human-readable names are now used in plot labels, tab titles, and legend entries instead of raw class values. Fixes #578 --- .../class_separation_metric.py | 14 ++++++++- .../probability_distribution_metric.py | 23 +++++++++++---- .../quality_by_feature_table.py | 29 +++++++++++++++---- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/evidently/legacy/metrics/classification_performance/class_separation_metric.py b/src/evidently/legacy/metrics/classification_performance/class_separation_metric.py index ae2969ae11..860b5cc1f8 100644 --- a/src/evidently/legacy/metrics/classification_performance/class_separation_metric.py +++ b/src/evidently/legacy/metrics/classification_performance/class_separation_metric.py @@ -18,6 +18,7 @@ from evidently.legacy.metric_results import raw_agg_properties from evidently.legacy.model.widget import BaseWidgetInfo from evidently.legacy.options.base import AnyOptions +from evidently.legacy.pipeline.column_mapping import TargetNames from evidently.legacy.renderers.base_renderer import MetricRenderer from evidently.legacy.renderers.base_renderer import default_renderer from evidently.legacy.renderers.html_widgets import TabData @@ -40,6 +41,7 @@ class Config: } target_name: str + target_names: Optional[TargetNames] = None current: Optional[ColumnScatterOrAgg] = None current_raw, current_agg = raw_agg_properties("current", ColumnScatter, ColumnAggScatter, True) @@ -107,6 +109,7 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults current=column_scatter_from_df(current_plot, True), reference=column_scatter_from_df(reference_plot, True), target_name=target_name, + target_names=dataset_columns.target_names, ) current_plot = prepare_box_data(current_plot, target_name, prediction_names.tolist()) if reference_plot is not None: @@ -115,9 +118,18 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults current=current_plot, reference=reference_plot, target_name=target_name, + target_names=dataset_columns.target_names, ) +def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str: + if target_names is not None and isinstance(target_names, dict): + resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type] + if resolved is not None: + return str(resolved) + return str(label) + + @default_renderer(wrap_type=ClassificationClassSeparationPlot) class ClassificationClassSeparationPlotRenderer(MetricRenderer): def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidgetInfo]: @@ -150,7 +162,7 @@ def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidget target_name, color_options=self.color_options, ) - tabs = [TabData(name, widget) for name, widget in tab_data] + tabs = [TabData(_resolve_target_name(name, metric_result.target_names), widget) for name, widget in tab_data] return [ header_text(label="Class Separation Quality"), widget_tabs(title="", tabs=tabs), diff --git a/src/evidently/legacy/metrics/classification_performance/probability_distribution_metric.py b/src/evidently/legacy/metrics/classification_performance/probability_distribution_metric.py index e60c17ec92..95a6191927 100644 --- a/src/evidently/legacy/metrics/classification_performance/probability_distribution_metric.py +++ b/src/evidently/legacy/metrics/classification_performance/probability_distribution_metric.py @@ -13,6 +13,7 @@ from evidently.legacy.calculations.classification_performance import get_prediction_data from evidently.legacy.core import IncludeTags from evidently.legacy.model.widget import BaseWidgetInfo +from evidently.legacy.pipeline.column_mapping import TargetNames from evidently.legacy.renderers.base_renderer import MetricRenderer from evidently.legacy.renderers.base_renderer import default_renderer from evidently.legacy.renderers.html_widgets import GraphData @@ -32,6 +33,7 @@ class Config: current_distribution: Optional[Dict[str, list]] # todo use DistributionField? reference_distribution: Optional[Dict[str, list]] + target_names: Optional[TargetNames] = None class ClassificationProbDistribution(Metric[ClassificationProbDistributionResults]): @@ -93,19 +95,29 @@ def calculate(self, data: InputData) -> ClassificationProbDistributionResults: return ClassificationProbDistributionResults( current_distribution=current_distribution, reference_distribution=reference_distribution, + target_names=columns.target_names, ) @default_renderer(wrap_type=ClassificationProbDistribution) class ClassificationProbDistributionRenderer(MetricRenderer): - def _plot(self, distribution: Dict[str, list]): + @staticmethod + def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str: + if target_names is not None and isinstance(target_names, dict): + resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type] + if resolved is not None: + return str(resolved) + return str(label) + + def _plot(self, distribution: Dict[str, list], target_names: Optional[TargetNames] = None): # plot distributions graphs = [] for label in distribution: + display_name = self._resolve_target_name(label, target_names) pred_distr = ff.create_distplot( distribution[label], - [str(label), "other"], + [display_name, "other"], colors=[ self.color_options.primary_color, self.color_options.secondary_color, @@ -123,7 +135,7 @@ def _plot(self, distribution: Dict[str, list]): pred_distr_json = pred_distr.to_plotly_json() graphs.append( { - "title": str(label), + "title": display_name, "data": pred_distr_json["data"], "layout": pred_distr_json["layout"], } @@ -134,6 +146,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf metric_result = obj.get_result() reference_distribution = metric_result.reference_distribution current_distribution = metric_result.current_distribution + target_names = metric_result.target_names result = [] size = WidgetSize.FULL @@ -147,7 +160,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf size=size, figures=[ GraphData(graph["title"], graph["data"], graph["layout"]) - for graph in self._plot(current_distribution) + for graph in self._plot(current_distribution, target_names) ], ) ) @@ -159,7 +172,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf size=size, figures=[ GraphData(graph["title"], graph["data"], graph["layout"]) - for graph in self._plot(reference_distribution) + for graph in self._plot(reference_distribution, target_names) ], ) ) diff --git a/src/evidently/legacy/metrics/classification_performance/quality_by_feature_table.py b/src/evidently/legacy/metrics/classification_performance/quality_by_feature_table.py index a57a63ac67..50dedb3f8e 100644 --- a/src/evidently/legacy/metrics/classification_performance/quality_by_feature_table.py +++ b/src/evidently/legacy/metrics/classification_performance/quality_by_feature_table.py @@ -25,6 +25,7 @@ from evidently.legacy.model.widget import AdditionalGraphInfo from evidently.legacy.model.widget import BaseWidgetInfo from evidently.legacy.options.base import AnyOptions +from evidently.legacy.pipeline.column_mapping import TargetNames from evidently.legacy.renderers.base_renderer import MetricRenderer from evidently.legacy.renderers.base_renderer import default_renderer from evidently.legacy.renderers.html_widgets import header_text @@ -47,6 +48,7 @@ class Config: target_name: str columns: List[str] + target_names: Optional[TargetNames] = None class ClassificationQualityByFeatureTable(UsesRawDataMixin, Metric[ClassificationQualityByFeatureTableResults]): @@ -173,11 +175,20 @@ def calculate(self, data: InputData) -> ClassificationQualityByFeatureTableResul reference=reference, columns=columns, target_name=target_name, + target_names=dataset_columns.target_names, ) @default_renderer(wrap_type=ClassificationQualityByFeatureTable) class ClassificationQualityByFeatureTableRenderer(MetricRenderer): + @staticmethod + def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str: + if target_names is not None and isinstance(target_names, dict): + resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type] + if resolved is not None: + return str(resolved) + return str(label) + def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidgetInfo]: if not obj.get_options().render_options.raw_data: return [] @@ -185,6 +196,7 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg current_data = result.current.plot_data reference_data = result.reference.plot_data if result.reference is not None else None target_name = result.target_name + target_names = result.target_names curr_predictions = result.current.predictions # todo: better typing? assert curr_predictions is not None @@ -211,7 +223,13 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg { "details": { "parts": [{"title": "All", "id": "All" + "_" + str(feature_name)}] - + [{"title": str(label), "id": feature_name + "_" + str(label)} for label in labels], + + [ + { + "title": self._resolve_target_name(label, target_names), + "id": feature_name + "_" + str(label), + } + for label in labels + ], "insights": [], }, "f1": feature_name, @@ -276,6 +294,7 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg cols = 1 subplot_titles = [""] for label in labels: + display_label = self._resolve_target_name(label, target_names) fig = make_subplots( rows=1, cols=cols, @@ -289,8 +308,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg x=current_data[current_data[target_name] == label][feature_name], y=current_data[current_data[target_name] == label][label], mode="markers", - name=str(label), - legendgroup=str(label), + name=display_label, + legendgroup=display_label, marker=dict( size=6, # set color equal to a variable @@ -327,8 +346,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg x=reference_data[reference_data[target_name] == label][feature_name], y=reference_data[reference_data[target_name] == label][label], mode="markers", - name=str(label), - legendgroup=str(label), + name=display_label, + legendgroup=display_label, showlegend=False, marker=dict(size=6, color=color_options.get_current_data_color()), ),