diff --git a/src/evidently/legacy/metrics/data_drift/column_drift_metric.py b/src/evidently/legacy/metrics/data_drift/column_drift_metric.py index be7cabbbea..a6881cfbba 100644 --- a/src/evidently/legacy/metrics/data_drift/column_drift_metric.py +++ b/src/evidently/legacy/metrics/data_drift/column_drift_metric.py @@ -349,6 +349,7 @@ def render_html(self, obj: ColumnDriftMetric) -> List[BaseWidgetInfo]: y_name=result.column_name, x_name=result.scatter.x_name, color_options=self.color_options, + current_title=obj.get_options().render_options.current_title, ) else: scatter_fig = plot_agg_line_data( diff --git a/src/evidently/legacy/metrics/data_drift/column_value_plot.py b/src/evidently/legacy/metrics/data_drift/column_value_plot.py index e97a21f360..af8d83905f 100644 --- a/src/evidently/legacy/metrics/data_drift/column_value_plot.py +++ b/src/evidently/legacy/metrics/data_drift/column_value_plot.py @@ -127,7 +127,15 @@ def _make_df_for_plot(self, df, column_name: str, datetime_column_name: Optional @default_renderer(wrap_type=ColumnValuePlot) class ColumnValuePlotRenderer(MetricRenderer): - def render_raw(self, current_scatter, reference_scatter, column_name, datetime_column_name): + def render_raw( + self, + current_scatter, + reference_scatter, + column_name, + datetime_column_name, + current_title: str = "Current", + reference_title: str = "Reference", + ): # todo: better typing column = reference_scatter[column_name] if not isinstance(column, pd.Series): @@ -156,7 +164,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c x=curr_x, y=current_scatter[column_name], mode="markers", - name="Current", + name=current_title, marker=dict(size=6, color=color_options.get_current_data_color()), ) ) @@ -165,7 +173,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c x=ref_x, y=column, mode="markers", - name="Reference", + name=reference_title, marker=dict(size=6, color=color_options.get_reference_data_color()), ) ) @@ -177,7 +185,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c x=[x0, x0], y=[y0, y1], mode="markers", - name="Current", + name=current_title, marker=dict(size=0.01, color=color_options.non_visible_color, opacity=0.005), showlegend=False, ) @@ -247,8 +255,16 @@ def render_agg(self, current, reference, column_name, datetime_column_name, pref def render_html(self, obj: ColumnValuePlot) -> List[BaseWidgetInfo]: result = obj.get_result() - if obj.get_options().render_options.raw_data: - return self.render_raw(result.current, result.reference, result.column_name, result.datetime_column_name) + render_options = obj.get_options().render_options + if render_options.raw_data: + return self.render_raw( + result.current, + result.reference, + result.column_name, + result.datetime_column_name, + current_title=render_options.current_title, + reference_title=render_options.reference_title, + ) return self.render_agg( result.current, result.reference, result.column_name, result.datetime_column_name, result.prefix ) diff --git a/src/evidently/legacy/metrics/data_drift/data_drift_table.py b/src/evidently/legacy/metrics/data_drift/data_drift_table.py index f4c7975b8e..3ec888049a 100644 --- a/src/evidently/legacy/metrics/data_drift/data_drift_table.py +++ b/src/evidently/legacy/metrics/data_drift/data_drift_table.py @@ -164,6 +164,7 @@ def _generate_column_params( agg_data: bool, current_fi: Optional[Dict[str, float]] = None, reference_fi: Optional[Dict[str, float]] = None, + current_title: str = "Current", ) -> Optional[RichTableDataRow]: details = RowDetails() if data.column_type == "text": @@ -232,6 +233,7 @@ def _generate_column_params( y_name=data.column_name, x_name=data.scatter.x_name, color_options=self.color_options, + current_title=current_title, ) else: scatter_fig = plot_agg_line_data( @@ -309,6 +311,7 @@ def render_html(self, obj: DataDriftTable) -> List[BaseWidgetInfo]: columns = columns + all_columns + current_title = obj.get_options().render_options.current_title for column_name in columns: column_params = self._generate_column_params( column_name, @@ -316,6 +319,7 @@ def render_html(self, obj: DataDriftTable) -> List[BaseWidgetInfo]: agg_data, results.current_fi, results.reference_fi, + current_title=current_title, ) if column_params is not None: diff --git a/src/evidently/legacy/metrics/data_drift/target_by_features_table.py b/src/evidently/legacy/metrics/data_drift/target_by_features_table.py index 48640787c6..ed4b0a9914 100644 --- a/src/evidently/legacy/metrics/data_drift/target_by_features_table.py +++ b/src/evidently/legacy/metrics/data_drift/target_by_features_table.py @@ -215,7 +215,8 @@ def calculate(self, data: InputData) -> TargetByFeaturesTableResults: @default_renderer(wrap_type=TargetByFeaturesTable) class TargetByFeaturesTableRenderer(MetricRenderer): def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]: - if not obj.get_options().render_options.raw_data: + render_options = obj.get_options().render_options + if not render_options.raw_data: return [] result = obj.get_result() current_data = result.current.plot_data @@ -229,6 +230,8 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]: ref_predictions = result.reference.predictions columns = result.columns task = result.task + current_title = render_options.current_title + reference_title = render_options.reference_title if curr_predictions is not None and ref_predictions is not None: current_data["prediction_labels"] = curr_predictions.predictions.values reference_data["prediction_labels"] = ref_predictions.predictions.values @@ -243,9 +246,13 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]: if target_name is not None: parts.append({"title": "Target", "id": feature_name + "_target_values"}) if task == "regression": - target_fig = self._get_regression_fig(feature_name, target_name, current_data, reference_data) + target_fig = self._get_regression_fig( + feature_name, target_name, current_data, reference_data, current_title, reference_title + ) else: - target_fig = self._get_classification_fig(feature_name, target_name, current_data, reference_data) + target_fig = self._get_classification_fig( + feature_name, target_name, current_data, reference_data, current_title, reference_title + ) target_fig_json = json.loads(target_fig.to_json()) @@ -263,11 +270,16 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]: parts.append({"title": "Prediction", "id": feature_name + "_prediction_values"}) if task == "regression": preds_fig = self._get_regression_fig( - feature_name, "prediction_labels", current_data, reference_data + feature_name, "prediction_labels", current_data, reference_data, current_title, reference_title ) else: preds_fig = self._get_classification_fig( - feature_name, "prediction_labels", current_data, reference_data + feature_name, + "prediction_labels", + current_data, + reference_data, + current_title, + reference_title, ) preds_fig_json = json.loads(preds_fig.to_json()) @@ -304,8 +316,16 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]: ) ] - def _get_regression_fig(self, feature_name: str, main_column: str, curr_data: pd.DataFrame, ref_data: pd.DataFrame): - fig = make_subplots(rows=1, cols=2, subplot_titles=("Current", "Reference"), shared_yaxes=True) + def _get_regression_fig( + self, + feature_name: str, + main_column: str, + curr_data: pd.DataFrame, + ref_data: pd.DataFrame, + current_title: str = "Current", + reference_title: str = "Reference", + ): + fig = make_subplots(rows=1, cols=2, subplot_titles=(current_title, reference_title), shared_yaxes=True) fig.add_trace( go.Scattergl( x=curr_data[feature_name], @@ -336,12 +356,18 @@ def _get_regression_fig(self, feature_name: str, main_column: str, curr_data: pd return fig def _get_classification_fig( - self, feature_name: str, main_column: str, curr_data: pd.DataFrame, ref_data: pd.DataFrame + self, + feature_name: str, + main_column: str, + curr_data: pd.DataFrame, + ref_data: pd.DataFrame, + current_title: str = "Current", + reference_title: str = "Reference", ): curr = curr_data.copy() ref = ref_data.copy() - ref["dataset"] = "Reference" - curr["dataset"] = "Current" + ref["dataset"] = reference_title + curr["dataset"] = current_title merged_data = pd.concat([ref, curr]) fig = px.histogram( merged_data, @@ -349,7 +375,7 @@ def _get_classification_fig( color=main_column, facet_col="dataset", barmode="overlay", - category_orders={"dataset": ["Current", "Reference"]}, + category_orders={"dataset": [current_title, reference_title]}, ) return fig diff --git a/src/evidently/legacy/metrics/data_drift/text_descriptors_drift_metric.py b/src/evidently/legacy/metrics/data_drift/text_descriptors_drift_metric.py index 4baf907ca2..840d7d7f1b 100644 --- a/src/evidently/legacy/metrics/data_drift/text_descriptors_drift_metric.py +++ b/src/evidently/legacy/metrics/data_drift/text_descriptors_drift_metric.py @@ -163,7 +163,7 @@ def render_pandas(self, obj: TextDescriptorsDriftMetric) -> pd.DataFrame: return pd.concat([v.get_pandas() for v in result.drift_by_columns.values()]) def _generate_column_params( - self, column_name: str, data: ColumnDataDriftMetrics, agg_data: bool + self, column_name: str, data: ColumnDataDriftMetrics, agg_data: bool, current_title: str = "Current" ) -> Optional[RichTableDataRow]: details = RowDetails() if ( @@ -187,6 +187,7 @@ def _generate_column_params( y_name=data.column_name, x_name=data.scatter.x_name, color_options=self.color_options, + current_title=current_title, ) else: scatter_fig = plot_agg_line_data( @@ -253,8 +254,11 @@ def render_html(self, obj: TextDescriptorsDriftMetric) -> List[BaseWidgetInfo]: reverse=True, ) + current_title = obj.get_options().render_options.current_title for column_name in columns: - column_params = self._generate_column_params(column_name, results.drift_by_columns[column_name], agg_data) + column_params = self._generate_column_params( + column_name, results.drift_by_columns[column_name], agg_data, current_title + ) if column_params is not None: params_data.append(column_params) diff --git a/src/evidently/legacy/options/agg_data.py b/src/evidently/legacy/options/agg_data.py index 2b205d2ae9..ccb531c652 100644 --- a/src/evidently/legacy/options/agg_data.py +++ b/src/evidently/legacy/options/agg_data.py @@ -5,6 +5,8 @@ class RenderOptions(Option): raw_data: bool = False + current_title: str = "Current" + reference_title: str = "Reference" class DataDefinitionOptions(Option): diff --git a/src/evidently/legacy/utils/visualizations.py b/src/evidently/legacy/utils/visualizations.py index 781d161306..c4a6159d77 100644 --- a/src/evidently/legacy/utils/visualizations.py +++ b/src/evidently/legacy/utils/visualizations.py @@ -1049,7 +1049,14 @@ def plot_line_in_time( def plot_scatter_for_data_drift( - curr_y: list, curr_x: list, y0: float, y1: float, y_name: str, x_name: str, color_options: ColorOptions + curr_y: list, + curr_x: list, + y0: float, + y1: float, + y_name: str, + x_name: str, + color_options: ColorOptions, + current_title: str = "Current", ): fig = go.Figure() @@ -1073,7 +1080,7 @@ def plot_scatter_for_data_drift( x=curr_x, y=curr_y, mode="markers", - name="Current", + name=current_title, marker=dict(size=6, color=color_options.get_current_data_color()), ) ) diff --git a/tests/metrics/data_drift/test_column_drift_metric.py b/tests/metrics/data_drift/test_column_drift_metric.py index ad8c89d95c..13ab75cafe 100644 --- a/tests/metrics/data_drift/test_column_drift_metric.py +++ b/tests/metrics/data_drift/test_column_drift_metric.py @@ -11,6 +11,8 @@ from evidently.legacy.calculations.stattests.registry import add_stattest_impl from evidently.legacy.core import ColumnType from evidently.legacy.metrics import ColumnDriftMetric +from evidently.legacy.options.agg_data import RenderOptions +from evidently.legacy.options.base import Options from evidently.legacy.pipeline.column_mapping import ColumnMapping from evidently.legacy.report import Report @@ -81,3 +83,16 @@ def test_column_drift_metric_errors( with pytest.raises(ValueError, match=expected_error): report.run(current_data=current_data, reference_data=reference_data, column_mapping=data_mapping) report.json() + + +def test_column_drift_metric_custom_titles() -> None: + current = pd.DataFrame({"col": range(20)}) + reference = pd.DataFrame({"col": range(10, 30)}) + metric = ColumnDriftMetric(column_name="col") + report = Report( + metrics=[metric], + options=Options(render=RenderOptions(raw_data=True, current_title="Test", reference_title="Baseline")), + ) + report.run(current_data=current, reference_data=reference) + html = report._build_dashboard_info() + assert html is not None