diff --git a/confidence_interval_display.py b/confidence_interval_display.py index b69956d..5c41106 100644 --- a/confidence_interval_display.py +++ b/confidence_interval_display.py @@ -313,6 +313,8 @@ class MetricFormatter(object): if_flip_color: A boolean indicating if to flip green/red coloring scheme. hide_null_ctrl: If to hide control value or use '-' to represent it when it is null, + neutral_value: The value representing no change, used to determine coloring + thresholds. form_lookup: A dict to look up formatting str for the display. unit_lookup: A dict to look up the unit to append to numbers in display. @@ -321,16 +323,20 @@ class MetricFormatter(object): include html classes used to style with CSS. """ - def __init__(self, - metric_formats=None, - if_flip_color=None, - hide_null_ctrl=False): + def __init__( + self, + metric_formats=None, + if_flip_color=None, + hide_null_ctrl=False, + neutral_value=0.0, + ): metric_formats = metric_formats or {} metric_formats.setdefault('Value', 'absolute') metric_formats.setdefault('Ratio', 'absolute') self.if_flip_color = if_flip_color self.hide_null_ctrl = hide_null_ctrl self.metric_formats = metric_formats + self.neutral_value = neutral_value self.form_lookup = { 'percent': '{:.2f}', 'absolute': '{:.4f}', @@ -406,13 +412,15 @@ def __call__(self, x, div=_div, span=_span, line_break_join=LINE_BREAK.join): res = line_break_join([value_formatted, ratio_formatted, ci_formatted]) res = div(res) - ci_lower = ci_lower if ci_lower is not None else 0 - ci_upper = ci_upper if ci_upper is not None else 0 - if ((ci_lower > 0 and not self.if_flip_color) or - (ci_upper < 0 and self.if_flip_color)): + ci_lower = ci_lower if ci_lower is not None else self.neutral_value + ci_upper = ci_upper if ci_upper is not None else self.neutral_value + if (ci_lower > self.neutral_value and not self.if_flip_color) or ( + ci_upper < self.neutral_value and self.if_flip_color + ): return div(res, 'ci-display-good-change ci-display-cell') - if ((ci_upper < 0 and not self.if_flip_color) or - (ci_lower > 0 and self.if_flip_color)): + if (ci_upper < self.neutral_value and not self.if_flip_color) or ( + ci_lower > self.neutral_value and self.if_flip_color + ): return div(res, 'ci-display-bad-change ci-display-cell') return div(res, 'ci-display-cell') @@ -449,11 +457,14 @@ def dimension_formatter(x, return div(div(line_break_join(d))) -def _get_formatter(df, - dims, - if_flip_colors, - hide_null_ctrl=False, - metric_formats=None): +def _get_formatter( + df, + dims, + if_flip_colors, + hide_null_ctrl=False, + metric_formats=None, + neutral_values=None, +): """Returns a custom formatter for df. Args: @@ -468,6 +479,7 @@ def _get_formatter(df, 'Value' and 'Ratio'. Values can be 'absolute', 'percent', 'pp' or a formatting string. For example, '{:.2%}' would have the same effect as 'percent'. By default, Value is in absolute form and Ratio in percent. + neutral_values: A dict mapping metric names to their neutral values. Returns: A dict which can be used as a custom formatter for @@ -480,8 +492,19 @@ def _get_formatter(df, elif col == 'Dimensions': custom_formatter[i] = dimension_formatter else: - custom_formatter[i] = MetricFormatter(metric_formats, if_flip_colors[i], - hide_null_ctrl) + col_formats = {'Value': 'absolute', 'Ratio': 'absolute'} + if metric_formats: + col_formats['Value'] = metric_formats.get('Value', col_formats['Value']) + col_formats['Ratio'] = metric_formats.get('Ratio', col_formats['Ratio']) + if col in metric_formats and isinstance(metric_formats[col], dict): + col_formats.update(metric_formats[col]) + neutral_val = neutral_values.get(col, 0.0) if neutral_values else 0.0 + custom_formatter[i] = MetricFormatter( + col_formats, + if_flip_colors[i], + hide_null_ctrl, + neutral_value=neutral_val, + ) return custom_formatter @@ -583,6 +606,7 @@ def get_formatted_df( auto_add_description=True, show_metric_value_when_control_hidden=False, return_pre_agg_df=False, + neutral_values=None, ): """Gets the formatted df with raw HTML as values in every cell. @@ -591,7 +615,7 @@ def get_formatted_df( description (if provided) experiment id dim1 * dim2 *... - If not aggregate_dimensions, the dimension info will be spreaded into multple + If not aggregate_dimensions, the dimension info will be spreaded into multiple columns at left of the display. All the remaining columns are for metrics. Each cell, if in control rows, will display a single value, for the experiment rows, will show three rows, @@ -670,6 +694,9 @@ def get_formatted_df( False. If True, we also display the raw metric value, otherwise only the change and confidence interval are displayed. return_pre_agg_df: If to return the pre-aggregated df. + neutral_values: A dict mapping metric names to their neutral values. The + neutral value is used to determine if a change is good/bad for coloring. + Defaults to 0.0 for all metrics. Returns: A DataFrame with raw HTML in each cell ready to be rendered. If @@ -731,8 +758,14 @@ def get_formatted_df( flip_color = flip_color or [] if_flip_colors = [c in flip_color for c in formatted_df.columns] - custom_formatters = _get_formatter(formatted_df, dims, if_flip_colors, - hide_null_ctrl, metric_formats) + custom_formatters = _get_formatter( + formatted_df, + dims, + if_flip_colors, + hide_null_ctrl, + metric_formats, + neutral_values=neutral_values, + ) formatted_df = formatted_df.rename(columns={ 'Experiment_Id': expr_id, 'Description': description @@ -793,6 +826,7 @@ def render( show_metric_value_when_control_hidden=False, return_pre_agg_df=False, return_formatted_df=False, + neutral_values=None, ): """Gets the formatted df with raw HTML as values in every cell. @@ -801,7 +835,7 @@ def render( description (if provided) experiment id dim1 * dim2 *... - If not aggregate_dimensions, the dimension info will be spreaded into multple + If not aggregate_dimensions, the dimension info will be spreaded into multiple columns at left of the display. All the remaining columns are for metrics. Each cell, if in control rows, will display a single value, for the experiment rows, will show three rows, @@ -881,6 +915,9 @@ def render( change and confidence interval are displayed. return_pre_agg_df: If to return the pre-aggregated df. return_formatted_df: If to return raw HTML df to be rendered. + neutral_values: A dict mapping metric names to their neutral values. The + neutral value is used to determine if a change is good/bad for coloring. + Defaults to 0.0 for all metrics. Returns: Displays confidence interval nicely for df, or aggregated/formatted if @@ -911,6 +948,7 @@ def render( auto_add_description, show_metric_value_when_control_hidden, return_pre_agg_df, + neutral_values=neutral_values, ) if return_pre_agg_df or return_formatted_df: return formatted_df diff --git a/confidence_interval_display_test.py b/confidence_interval_display_test.py index d9e679c..8d7ca1c 100644 --- a/confidence_interval_display_test.py +++ b/confidence_interval_display_test.py @@ -101,6 +101,55 @@ def test_normal(self): auto_add_description=False) testing.assert_frame_equal(expected, actual, check_names=False) + def test_get_formatted_df_with_neutral_values(self): + expected = pd.DataFrame( + { + 'Country': [ + '