diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 90ad875b7..cfa0f9f68 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -145,6 +145,128 @@ def _reshape_time_for_heatmap( return result.transpose('timestep', 'timeframe', *other_dims) +def _iter_all_traces(fig: go.Figure): + """Iterate over all traces in a figure, including animation frames. + + Yields traces from fig.data first, then from each frame in fig.frames. + Useful for applying styling to all traces including those in animations. + + Args: + fig: Plotly Figure. + + Yields: + Each trace object from the figure. + """ + yield from fig.data + for frame in getattr(fig, 'frames', []) or []: + yield from frame.data + + +def _style_area_as_bar(fig: go.Figure) -> None: + """Style area chart traces to look like bar charts with proper pos/neg stacking. + + Iterates over all traces in fig.data and fig.frames (for animations), + setting stepped line shape, removing line borders, making fills opaque, + and assigning stackgroups based on whether values are positive or negative. + + Handles faceting + animation combinations by building color and classification + maps from trace names in the base figure. + + Args: + fig: Plotly Figure with area chart traces. + """ + import plotly.express as px + + default_colors = px.colors.qualitative.Plotly + + # Build color map and classify traces from base figure + # trace.name -> color, trace.name -> 'positive'|'negative'|'mixed'|'zero' + color_map: dict[str, str] = {} + class_map: dict[str, str] = {} + + for i, trace in enumerate(fig.data): + # Get color + if hasattr(trace, 'line') and trace.line and trace.line.color: + color_map[trace.name] = trace.line.color + else: + color_map[trace.name] = default_colors[i % len(default_colors)] + + # Classify based on y values + y_vals = trace.y + if y_vals is None or len(y_vals) == 0: + class_map[trace.name] = 'zero' + else: + y_arr = np.asarray(y_vals) + y_clean = y_arr[np.abs(y_arr) > 1e-9] + if len(y_clean) == 0: + class_map[trace.name] = 'zero' + else: + has_pos = np.any(y_clean > 0) + has_neg = np.any(y_clean < 0) + if has_pos and has_neg: + class_map[trace.name] = 'mixed' + elif has_neg: + class_map[trace.name] = 'negative' + else: + class_map[trace.name] = 'positive' + + def style_trace(trace: go.Scatter) -> None: + """Apply bar-like styling to a single trace.""" + # Look up color by trace name + color = color_map.get(trace.name, default_colors[0]) + + # Look up classification + cls = class_map.get(trace.name, 'positive') + + # Set stackgroup based on classification (positive and negative stack separately) + if cls in ('positive', 'negative'): + trace.stackgroup = cls + trace.fillcolor = color + trace.line = dict(width=0, color=color, shape='hv') + elif cls == 'mixed': + # Mixed: show as dashed line, no stacking + trace.stackgroup = None + trace.fill = None + trace.line = dict(width=2, color=color, shape='hv', dash='dash') + else: # zero + trace.stackgroup = None + trace.fill = None + trace.line = dict(width=0, color=color, shape='hv') + + # Style all traces (main + animation frames) + for trace in _iter_all_traces(fig): + style_trace(trace) + + +def _apply_unified_hover(fig: go.Figure, unit: str = '', decimals: int = 1) -> None: + """Apply unified hover mode with clean formatting to any Plotly figure. + + Sets up 'x unified' hovermode with spike lines and formats hover labels + as 'name: value unit'. + + Works with any plot type (area, bar, line, scatter). + + Args: + fig: Plotly Figure to style. + unit: Unit string to append (e.g., 'kW', 'MWh'). Empty for no unit. + decimals: Number of decimal places for values. + """ + unit_suffix = f' {unit}' if unit else '' + hover_template = f'%{{fullData.name}}: %{{y:.{decimals}f}}{unit_suffix}' + + # Apply to all traces (main + animation frames) + for trace in _iter_all_traces(fig): + trace.hovertemplate = hover_template + + # Layout settings for unified hover + fig.update_layout( + hovermode='x unified', + xaxis_showspikes=True, + xaxis_spikecolor='gray', + xaxis_spikethickness=1, + ) + + # --- Helper functions --- @@ -1529,13 +1651,14 @@ def balance( unit_label = ds[first_var].attrs.get('unit', '') _apply_slot_defaults(plotly_kwargs, 'balance') - fig = ds.plotly.bar( + fig = ds.plotly.area( title=f'{node} [{unit_label}]' if unit_label else node, + line_shape='hv', **color_kwargs, **plotly_kwargs, ) - fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) + _style_area_as_bar(fig) + _apply_unified_hover(fig, unit=unit_label) if show is None: show = CONFIG.Plotting.default_show @@ -1653,13 +1776,14 @@ def carrier_balance( unit_label = ds[first_var].attrs.get('unit', '') _apply_slot_defaults(plotly_kwargs, 'carrier_balance') - fig = ds.plotly.bar( + fig = ds.plotly.area( title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', + line_shape='hv', **color_kwargs, **plotly_kwargs, ) - fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) + _style_area_as_bar(fig) + _apply_unified_hover(fig, unit=unit_label) if show is None: show = CONFIG.Plotting.default_show @@ -2249,15 +2373,22 @@ def storage( else: color_kwargs = _build_color_kwargs(colors, flow_labels) - # Create stacked bar chart for flows + # Get unit label from flow data + unit_label = '' + if flow_ds.data_vars: + first_var = next(iter(flow_ds.data_vars)) + unit_label = flow_ds[first_var].attrs.get('unit', '') + + # Create stacked area chart for flows (styled as bar) _apply_slot_defaults(plotly_kwargs, 'storage') - fig = flow_ds.plotly.bar( - title=f'{storage} Operation ({unit})', + fig = flow_ds.plotly.area( + title=f'{storage} Operation [{unit_label}]' if unit_label else f'{storage} Operation', + line_shape='hv', **color_kwargs, **plotly_kwargs, ) - fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) + _style_area_as_bar(fig) + _apply_unified_hover(fig, unit=unit_label) # Add charge state as line on secondary y-axis # Only pass faceting kwargs that add_line_overlay accepts