Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def render_html(self, obj: ColumnDriftMetric) -> List[BaseWidgetInfo]:
y_name=result.column_name,
x_name=result.scatter.x_name,
color_options=self.color_options,
current_title=obj.get_options().render_options.current_title,
)
else:
scatter_fig = plot_agg_line_data(
Expand Down
28 changes: 22 additions & 6 deletions src/evidently/legacy/metrics/data_drift/column_value_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,15 @@ def _make_df_for_plot(self, df, column_name: str, datetime_column_name: Optional

@default_renderer(wrap_type=ColumnValuePlot)
class ColumnValuePlotRenderer(MetricRenderer):
def render_raw(self, current_scatter, reference_scatter, column_name, datetime_column_name):
def render_raw(
self,
current_scatter,
reference_scatter,
column_name,
datetime_column_name,
current_title: str = "Current",
reference_title: str = "Reference",
):
# todo: better typing
column = reference_scatter[column_name]
if not isinstance(column, pd.Series):
Expand Down Expand Up @@ -156,7 +164,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c
x=curr_x,
y=current_scatter[column_name],
mode="markers",
name="Current",
name=current_title,
marker=dict(size=6, color=color_options.get_current_data_color()),
)
)
Expand All @@ -165,7 +173,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c
x=ref_x,
y=column,
mode="markers",
name="Reference",
name=reference_title,
marker=dict(size=6, color=color_options.get_reference_data_color()),
)
)
Expand All @@ -177,7 +185,7 @@ def render_raw(self, current_scatter, reference_scatter, column_name, datetime_c
x=[x0, x0],
y=[y0, y1],
mode="markers",
name="Current",
name=current_title,
marker=dict(size=0.01, color=color_options.non_visible_color, opacity=0.005),
showlegend=False,
)
Expand Down Expand Up @@ -247,8 +255,16 @@ def render_agg(self, current, reference, column_name, datetime_column_name, pref

def render_html(self, obj: ColumnValuePlot) -> List[BaseWidgetInfo]:
result = obj.get_result()
if obj.get_options().render_options.raw_data:
return self.render_raw(result.current, result.reference, result.column_name, result.datetime_column_name)
render_options = obj.get_options().render_options
if render_options.raw_data:
return self.render_raw(
result.current,
result.reference,
result.column_name,
result.datetime_column_name,
current_title=render_options.current_title,
reference_title=render_options.reference_title,
)
return self.render_agg(
result.current, result.reference, result.column_name, result.datetime_column_name, result.prefix
)
4 changes: 4 additions & 0 deletions src/evidently/legacy/metrics/data_drift/data_drift_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _generate_column_params(
agg_data: bool,
current_fi: Optional[Dict[str, float]] = None,
reference_fi: Optional[Dict[str, float]] = None,
current_title: str = "Current",
) -> Optional[RichTableDataRow]:
details = RowDetails()
if data.column_type == "text":
Expand Down Expand Up @@ -232,6 +233,7 @@ def _generate_column_params(
y_name=data.column_name,
x_name=data.scatter.x_name,
color_options=self.color_options,
current_title=current_title,
)
else:
scatter_fig = plot_agg_line_data(
Expand Down Expand Up @@ -309,13 +311,15 @@ def render_html(self, obj: DataDriftTable) -> List[BaseWidgetInfo]:

columns = columns + all_columns

current_title = obj.get_options().render_options.current_title
for column_name in columns:
column_params = self._generate_column_params(
column_name,
results.drift_by_columns[column_name],
agg_data,
results.current_fi,
results.reference_fi,
current_title=current_title,
)

if column_params is not None:
Expand Down
48 changes: 37 additions & 11 deletions src/evidently/legacy/metrics/data_drift/target_by_features_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def calculate(self, data: InputData) -> TargetByFeaturesTableResults:
@default_renderer(wrap_type=TargetByFeaturesTable)
class TargetByFeaturesTableRenderer(MetricRenderer):
def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]:
if not obj.get_options().render_options.raw_data:
render_options = obj.get_options().render_options
if not render_options.raw_data:
return []
result = obj.get_result()
current_data = result.current.plot_data
Expand All @@ -229,6 +230,8 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]:
ref_predictions = result.reference.predictions
columns = result.columns
task = result.task
current_title = render_options.current_title
reference_title = render_options.reference_title
if curr_predictions is not None and ref_predictions is not None:
current_data["prediction_labels"] = curr_predictions.predictions.values
reference_data["prediction_labels"] = ref_predictions.predictions.values
Expand All @@ -243,9 +246,13 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]:
if target_name is not None:
parts.append({"title": "Target", "id": feature_name + "_target_values"})
if task == "regression":
target_fig = self._get_regression_fig(feature_name, target_name, current_data, reference_data)
target_fig = self._get_regression_fig(
feature_name, target_name, current_data, reference_data, current_title, reference_title
)
else:
target_fig = self._get_classification_fig(feature_name, target_name, current_data, reference_data)
target_fig = self._get_classification_fig(
feature_name, target_name, current_data, reference_data, current_title, reference_title
)

target_fig_json = json.loads(target_fig.to_json())

Expand All @@ -263,11 +270,16 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]:
parts.append({"title": "Prediction", "id": feature_name + "_prediction_values"})
if task == "regression":
preds_fig = self._get_regression_fig(
feature_name, "prediction_labels", current_data, reference_data
feature_name, "prediction_labels", current_data, reference_data, current_title, reference_title
)
else:
preds_fig = self._get_classification_fig(
feature_name, "prediction_labels", current_data, reference_data
feature_name,
"prediction_labels",
current_data,
reference_data,
current_title,
reference_title,
)
preds_fig_json = json.loads(preds_fig.to_json())

Expand Down Expand Up @@ -304,8 +316,16 @@ def render_html(self, obj: TargetByFeaturesTable) -> List[BaseWidgetInfo]:
)
]

def _get_regression_fig(self, feature_name: str, main_column: str, curr_data: pd.DataFrame, ref_data: pd.DataFrame):
fig = make_subplots(rows=1, cols=2, subplot_titles=("Current", "Reference"), shared_yaxes=True)
def _get_regression_fig(
self,
feature_name: str,
main_column: str,
curr_data: pd.DataFrame,
ref_data: pd.DataFrame,
current_title: str = "Current",
reference_title: str = "Reference",
):
fig = make_subplots(rows=1, cols=2, subplot_titles=(current_title, reference_title), shared_yaxes=True)
fig.add_trace(
go.Scattergl(
x=curr_data[feature_name],
Expand Down Expand Up @@ -336,20 +356,26 @@ def _get_regression_fig(self, feature_name: str, main_column: str, curr_data: pd
return fig

def _get_classification_fig(
self, feature_name: str, main_column: str, curr_data: pd.DataFrame, ref_data: pd.DataFrame
self,
feature_name: str,
main_column: str,
curr_data: pd.DataFrame,
ref_data: pd.DataFrame,
current_title: str = "Current",
reference_title: str = "Reference",
):
curr = curr_data.copy()
ref = ref_data.copy()
ref["dataset"] = "Reference"
curr["dataset"] = "Current"
ref["dataset"] = reference_title
curr["dataset"] = current_title
merged_data = pd.concat([ref, curr])
fig = px.histogram(
merged_data,
x=feature_name,
color=main_column,
facet_col="dataset",
barmode="overlay",
category_orders={"dataset": ["Current", "Reference"]},
category_orders={"dataset": [current_title, reference_title]},
)

return fig
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def render_pandas(self, obj: TextDescriptorsDriftMetric) -> pd.DataFrame:
return pd.concat([v.get_pandas() for v in result.drift_by_columns.values()])

def _generate_column_params(
self, column_name: str, data: ColumnDataDriftMetrics, agg_data: bool
self, column_name: str, data: ColumnDataDriftMetrics, agg_data: bool, current_title: str = "Current"
) -> Optional[RichTableDataRow]:
details = RowDetails()
if (
Expand All @@ -187,6 +187,7 @@ def _generate_column_params(
y_name=data.column_name,
x_name=data.scatter.x_name,
color_options=self.color_options,
current_title=current_title,
)
else:
scatter_fig = plot_agg_line_data(
Expand Down Expand Up @@ -253,8 +254,11 @@ def render_html(self, obj: TextDescriptorsDriftMetric) -> List[BaseWidgetInfo]:
reverse=True,
)

current_title = obj.get_options().render_options.current_title
for column_name in columns:
column_params = self._generate_column_params(column_name, results.drift_by_columns[column_name], agg_data)
column_params = self._generate_column_params(
column_name, results.drift_by_columns[column_name], agg_data, current_title
)

if column_params is not None:
params_data.append(column_params)
Expand Down
2 changes: 2 additions & 0 deletions src/evidently/legacy/options/agg_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

class RenderOptions(Option):
raw_data: bool = False
current_title: str = "Current"
reference_title: str = "Reference"


class DataDefinitionOptions(Option):
Expand Down
11 changes: 9 additions & 2 deletions src/evidently/legacy/utils/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,14 @@ def plot_line_in_time(


def plot_scatter_for_data_drift(
curr_y: list, curr_x: list, y0: float, y1: float, y_name: str, x_name: str, color_options: ColorOptions
curr_y: list,
curr_x: list,
y0: float,
y1: float,
y_name: str,
x_name: str,
color_options: ColorOptions,
current_title: str = "Current",
):
fig = go.Figure()

Expand All @@ -1073,7 +1080,7 @@ def plot_scatter_for_data_drift(
x=curr_x,
y=curr_y,
mode="markers",
name="Current",
name=current_title,
marker=dict(size=6, color=color_options.get_current_data_color()),
)
)
Expand Down
15 changes: 15 additions & 0 deletions tests/metrics/data_drift/test_column_drift_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from evidently.legacy.calculations.stattests.registry import add_stattest_impl
from evidently.legacy.core import ColumnType
from evidently.legacy.metrics import ColumnDriftMetric
from evidently.legacy.options.agg_data import RenderOptions
from evidently.legacy.options.base import Options
from evidently.legacy.pipeline.column_mapping import ColumnMapping
from evidently.legacy.report import Report

Expand Down Expand Up @@ -81,3 +83,16 @@ def test_column_drift_metric_errors(
with pytest.raises(ValueError, match=expected_error):
report.run(current_data=current_data, reference_data=reference_data, column_mapping=data_mapping)
report.json()


def test_column_drift_metric_custom_titles() -> None:
current = pd.DataFrame({"col": range(20)})
reference = pd.DataFrame({"col": range(10, 30)})
metric = ColumnDriftMetric(column_name="col")
report = Report(
metrics=[metric],
options=Options(render=RenderOptions(raw_data=True, current_title="Test", reference_title="Baseline")),
)
report.run(current_data=current, reference_data=reference)
html = report._build_dashboard_info()
assert html is not None