-
Notifications
You must be signed in to change notification settings - Fork 872
Support target_names in classification metrics renderers #1841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from evidently.legacy.metric_results import raw_agg_properties | ||
| from evidently.legacy.model.widget import BaseWidgetInfo | ||
| from evidently.legacy.options.base import AnyOptions | ||
| from evidently.legacy.pipeline.column_mapping import TargetNames | ||
| from evidently.legacy.renderers.base_renderer import MetricRenderer | ||
| from evidently.legacy.renderers.base_renderer import default_renderer | ||
| from evidently.legacy.renderers.html_widgets import TabData | ||
|
|
@@ -40,6 +41,7 @@ class Config: | |
| } | ||
|
|
||
| target_name: str | ||
| target_names: Optional[TargetNames] = None | ||
|
|
||
| current: Optional[ColumnScatterOrAgg] = None | ||
| current_raw, current_agg = raw_agg_properties("current", ColumnScatter, ColumnAggScatter, True) | ||
|
|
@@ -107,6 +109,7 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults | |
| current=column_scatter_from_df(current_plot, True), | ||
| reference=column_scatter_from_df(reference_plot, True), | ||
| target_name=target_name, | ||
| target_names=dataset_columns.target_names, | ||
| ) | ||
| current_plot = prepare_box_data(current_plot, target_name, prediction_names.tolist()) | ||
| if reference_plot is not None: | ||
|
|
@@ -115,9 +118,18 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults | |
| current=current_plot, | ||
| reference=reference_plot, | ||
| target_name=target_name, | ||
| target_names=dataset_columns.target_names, | ||
| ) | ||
|
|
||
|
|
||
| def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str: | ||
| if target_names is not None and isinstance(target_names, dict): | ||
| resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type] | ||
| if resolved is not None: | ||
| return str(resolved) | ||
| return str(label) | ||
|
|
||
|
|
||
| @default_renderer(wrap_type=ClassificationClassSeparationPlot) | ||
| class ClassificationClassSeparationPlotRenderer(MetricRenderer): | ||
| def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidgetInfo]: | ||
|
|
@@ -150,7 +162,7 @@ def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidget | |
| target_name, | ||
| color_options=self.color_options, | ||
| ) | ||
| tabs = [TabData(name, widget) for name, widget in tab_data] | ||
| tabs = [TabData(_resolve_target_name(name, metric_result.target_names), widget) for name, widget in tab_data] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tab names are resolved here which is good, but the inner traces built by class_separation_traces_raw and class_separation_traces_agg still use str(label) for name and legendgroup. So the tab heading will show the display name but the plot legend inside the tab still shows the raw class value. Might want to pass target_names into those trace-building helpers too. |
||
| return [ | ||
| header_text(label="Class Separation Quality"), | ||
| widget_tabs(title="", tabs=tabs), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| from evidently.legacy.model.widget import AdditionalGraphInfo | ||
| from evidently.legacy.model.widget import BaseWidgetInfo | ||
| from evidently.legacy.options.base import AnyOptions | ||
| from evidently.legacy.pipeline.column_mapping import TargetNames | ||
| from evidently.legacy.renderers.base_renderer import MetricRenderer | ||
| from evidently.legacy.renderers.base_renderer import default_renderer | ||
| from evidently.legacy.renderers.html_widgets import header_text | ||
|
|
@@ -47,6 +48,7 @@ class Config: | |
|
|
||
| target_name: str | ||
| columns: List[str] | ||
| target_names: Optional[TargetNames] = None | ||
|
|
||
|
|
||
| class ClassificationQualityByFeatureTable(UsesRawDataMixin, Metric[ClassificationQualityByFeatureTableResults]): | ||
|
|
@@ -173,18 +175,28 @@ def calculate(self, data: InputData) -> ClassificationQualityByFeatureTableResul | |
| reference=reference, | ||
| columns=columns, | ||
| target_name=target_name, | ||
| target_names=dataset_columns.target_names, | ||
| ) | ||
|
|
||
|
|
||
| @default_renderer(wrap_type=ClassificationQualityByFeatureTable) | ||
| class ClassificationQualityByFeatureTableRenderer(MetricRenderer): | ||
| @staticmethod | ||
| def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str: | ||
| if target_names is not None and isinstance(target_names, dict): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same resolver function as the other two files. If the operator precedence or dict-only issue needs fixing, it has to be done in three places. Pulling this into something like evidently.legacy.utils or a shared base would keep it maintainable. |
||
| resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type] | ||
| if resolved is not None: | ||
| return str(resolved) | ||
| return str(label) | ||
|
|
||
| def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidgetInfo]: | ||
| if not obj.get_options().render_options.raw_data: | ||
| return [] | ||
| result = obj.get_result() | ||
| current_data = result.current.plot_data | ||
| reference_data = result.reference.plot_data if result.reference is not None else None | ||
| target_name = result.target_name | ||
| target_names = result.target_names | ||
| curr_predictions = result.current.predictions | ||
| # todo: better typing? | ||
| assert curr_predictions is not None | ||
|
|
@@ -211,7 +223,13 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg | |
| { | ||
| "details": { | ||
| "parts": [{"title": "All", "id": "All" + "_" + str(feature_name)}] | ||
| + [{"title": str(label), "id": feature_name + "_" + str(label)} for label in labels], | ||
| + [ | ||
| { | ||
| "title": self._resolve_target_name(label, target_names), | ||
| "id": feature_name + "_" + str(label), | ||
| } | ||
| for label in labels | ||
| ], | ||
| "insights": [], | ||
| }, | ||
| "f1": feature_name, | ||
|
|
@@ -276,6 +294,7 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg | |
| cols = 1 | ||
| subplot_titles = [""] | ||
| for label in labels: | ||
| display_label = self._resolve_target_name(label, target_names) | ||
| fig = make_subplots( | ||
| rows=1, | ||
| cols=cols, | ||
|
|
@@ -289,8 +308,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg | |
| x=current_data[current_data[target_name] == label][feature_name], | ||
| y=current_data[current_data[target_name] == label][label], | ||
| mode="markers", | ||
| name=str(label), | ||
| legendgroup=str(label), | ||
| name=display_label, | ||
| legendgroup=display_label, | ||
| marker=dict( | ||
| size=6, | ||
| # set color equal to a variable | ||
|
|
@@ -327,8 +346,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg | |
| x=reference_data[reference_data[target_name] == label][feature_name], | ||
| y=reference_data[reference_data[target_name] == label][label], | ||
| mode="markers", | ||
| name=str(label), | ||
| legendgroup=str(label), | ||
| name=display_label, | ||
| legendgroup=display_label, | ||
| showlegend=False, | ||
| marker=dict(size=6, color=color_options.get_current_data_color()), | ||
| ), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things here. First, this only triggers when target_names is a dict, but TargetNames is typed as Union[List, Dict]. If someone passes a list, the resolver does nothing and falls through to str(label). Not sure if list-based target_names is actually used in practice, but worth checking.
Second, the or on this line can swallow falsy mapped values. If target_names maps a label to 0 or an empty string, the or skips it and tries int(label) instead. And if label is a non-numeric string like "cat", int(label) raises ValueError. Safer to check
is Noneexplicitly rather than relying on truthiness.