Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from evidently.legacy.metric_results import raw_agg_properties
from evidently.legacy.model.widget import BaseWidgetInfo
from evidently.legacy.options.base import AnyOptions
from evidently.legacy.pipeline.column_mapping import TargetNames
from evidently.legacy.renderers.base_renderer import MetricRenderer
from evidently.legacy.renderers.base_renderer import default_renderer
from evidently.legacy.renderers.html_widgets import TabData
Expand All @@ -40,6 +41,7 @@ class Config:
}

target_name: str
target_names: Optional[TargetNames] = None

current: Optional[ColumnScatterOrAgg] = None
current_raw, current_agg = raw_agg_properties("current", ColumnScatter, ColumnAggScatter, True)
Expand Down Expand Up @@ -107,6 +109,7 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults
current=column_scatter_from_df(current_plot, True),
reference=column_scatter_from_df(reference_plot, True),
target_name=target_name,
target_names=dataset_columns.target_names,
)
current_plot = prepare_box_data(current_plot, target_name, prediction_names.tolist())
if reference_plot is not None:
Expand All @@ -115,9 +118,18 @@ def calculate(self, data: InputData) -> ClassificationClassSeparationPlotResults
current=current_plot,
reference=reference_plot,
target_name=target_name,
target_names=dataset_columns.target_names,
)


def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str:
if target_names is not None and isinstance(target_names, dict):
resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things here. First, this only triggers when target_names is a dict, but TargetNames is typed as Union[List, Dict]. If someone passes a list, the resolver does nothing and falls through to str(label). Not sure if list-based target_names is actually used in practice, but worth checking.

Second, the or on this line can swallow falsy mapped values. If target_names maps a label to 0 or an empty string, the or skips it and tries int(label) instead. And if label is a non-numeric string like "cat", int(label) raises ValueError. Safer to check is None explicitly rather than relying on truthiness.

if resolved is not None:
return str(resolved)
return str(label)


@default_renderer(wrap_type=ClassificationClassSeparationPlot)
class ClassificationClassSeparationPlotRenderer(MetricRenderer):
def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidgetInfo]:
Expand Down Expand Up @@ -150,7 +162,7 @@ def render_html(self, obj: ClassificationClassSeparationPlot) -> List[BaseWidget
target_name,
color_options=self.color_options,
)
tabs = [TabData(name, widget) for name, widget in tab_data]
tabs = [TabData(_resolve_target_name(name, metric_result.target_names), widget) for name, widget in tab_data]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tab names are resolved here which is good, but the inner traces built by class_separation_traces_raw and class_separation_traces_agg still use str(label) for name and legendgroup. So the tab heading will show the display name but the plot legend inside the tab still shows the raw class value. Might want to pass target_names into those trace-building helpers too.

return [
header_text(label="Class Separation Quality"),
widget_tabs(title="", tabs=tabs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from evidently.legacy.calculations.classification_performance import get_prediction_data
from evidently.legacy.core import IncludeTags
from evidently.legacy.model.widget import BaseWidgetInfo
from evidently.legacy.pipeline.column_mapping import TargetNames
from evidently.legacy.renderers.base_renderer import MetricRenderer
from evidently.legacy.renderers.base_renderer import default_renderer
from evidently.legacy.renderers.html_widgets import GraphData
Expand All @@ -32,6 +33,7 @@ class Config:

current_distribution: Optional[Dict[str, list]] # todo use DistributionField?
reference_distribution: Optional[Dict[str, list]]
target_names: Optional[TargetNames] = None


class ClassificationProbDistribution(Metric[ClassificationProbDistributionResults]):
Expand Down Expand Up @@ -93,19 +95,29 @@ def calculate(self, data: InputData) -> ClassificationProbDistributionResults:
return ClassificationProbDistributionResults(
current_distribution=current_distribution,
reference_distribution=reference_distribution,
target_names=columns.target_names,
)


@default_renderer(wrap_type=ClassificationProbDistribution)
class ClassificationProbDistributionRenderer(MetricRenderer):
def _plot(self, distribution: Dict[str, list]):
@staticmethod
def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str:
if target_names is not None and isinstance(target_names, dict):
resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type]
if resolved is not None:
return str(resolved)
return str(label)

def _plot(self, distribution: Dict[str, list], target_names: Optional[TargetNames] = None):
# plot distributions
graphs = []

for label in distribution:
display_name = self._resolve_target_name(label, target_names)
pred_distr = ff.create_distplot(
distribution[label],
[str(label), "other"],
[display_name, "other"],
colors=[
self.color_options.primary_color,
self.color_options.secondary_color,
Expand All @@ -123,7 +135,7 @@ def _plot(self, distribution: Dict[str, list]):
pred_distr_json = pred_distr.to_plotly_json()
graphs.append(
{
"title": str(label),
"title": display_name,
"data": pred_distr_json["data"],
"layout": pred_distr_json["layout"],
}
Expand All @@ -134,6 +146,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf
metric_result = obj.get_result()
reference_distribution = metric_result.reference_distribution
current_distribution = metric_result.current_distribution
target_names = metric_result.target_names
result = []
size = WidgetSize.FULL

Expand All @@ -147,7 +160,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf
size=size,
figures=[
GraphData(graph["title"], graph["data"], graph["layout"])
for graph in self._plot(current_distribution)
for graph in self._plot(current_distribution, target_names)
],
)
)
Expand All @@ -159,7 +172,7 @@ def render_html(self, obj: ClassificationProbDistribution) -> List[BaseWidgetInf
size=size,
figures=[
GraphData(graph["title"], graph["data"], graph["layout"])
for graph in self._plot(reference_distribution)
for graph in self._plot(reference_distribution, target_names)
],
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from evidently.legacy.model.widget import AdditionalGraphInfo
from evidently.legacy.model.widget import BaseWidgetInfo
from evidently.legacy.options.base import AnyOptions
from evidently.legacy.pipeline.column_mapping import TargetNames
from evidently.legacy.renderers.base_renderer import MetricRenderer
from evidently.legacy.renderers.base_renderer import default_renderer
from evidently.legacy.renderers.html_widgets import header_text
Expand All @@ -47,6 +48,7 @@ class Config:

target_name: str
columns: List[str]
target_names: Optional[TargetNames] = None


class ClassificationQualityByFeatureTable(UsesRawDataMixin, Metric[ClassificationQualityByFeatureTableResults]):
Expand Down Expand Up @@ -173,18 +175,28 @@ def calculate(self, data: InputData) -> ClassificationQualityByFeatureTableResul
reference=reference,
columns=columns,
target_name=target_name,
target_names=dataset_columns.target_names,
)


@default_renderer(wrap_type=ClassificationQualityByFeatureTable)
class ClassificationQualityByFeatureTableRenderer(MetricRenderer):
@staticmethod
def _resolve_target_name(label, target_names: Optional[TargetNames]) -> str:
if target_names is not None and isinstance(target_names, dict):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same resolver function as the other two files. If the operator precedence or dict-only issue needs fixing, it has to be done in three places. Pulling this into something like evidently.legacy.utils or a shared base would keep it maintainable.

resolved = target_names.get(label) or target_names.get(int(label)) if isinstance(label, str) else target_names.get(label) # type: ignore[arg-type]
if resolved is not None:
return str(resolved)
return str(label)

def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidgetInfo]:
if not obj.get_options().render_options.raw_data:
return []
result = obj.get_result()
current_data = result.current.plot_data
reference_data = result.reference.plot_data if result.reference is not None else None
target_name = result.target_name
target_names = result.target_names
curr_predictions = result.current.predictions
# todo: better typing?
assert curr_predictions is not None
Expand All @@ -211,7 +223,13 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg
{
"details": {
"parts": [{"title": "All", "id": "All" + "_" + str(feature_name)}]
+ [{"title": str(label), "id": feature_name + "_" + str(label)} for label in labels],
+ [
{
"title": self._resolve_target_name(label, target_names),
"id": feature_name + "_" + str(label),
}
for label in labels
],
"insights": [],
},
"f1": feature_name,
Expand Down Expand Up @@ -276,6 +294,7 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg
cols = 1
subplot_titles = [""]
for label in labels:
display_label = self._resolve_target_name(label, target_names)
fig = make_subplots(
rows=1,
cols=cols,
Expand All @@ -289,8 +308,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg
x=current_data[current_data[target_name] == label][feature_name],
y=current_data[current_data[target_name] == label][label],
mode="markers",
name=str(label),
legendgroup=str(label),
name=display_label,
legendgroup=display_label,
marker=dict(
size=6,
# set color equal to a variable
Expand Down Expand Up @@ -327,8 +346,8 @@ def render_html(self, obj: ClassificationQualityByFeatureTable) -> List[BaseWidg
x=reference_data[reference_data[target_name] == label][feature_name],
y=reference_data[reference_data[target_name] == label][label],
mode="markers",
name=str(label),
legendgroup=str(label),
name=display_label,
legendgroup=display_label,
showlegend=False,
marker=dict(size=6, color=color_options.get_current_data_color()),
),
Expand Down