diff --git a/superset/utils/pandas_postprocessing/prophet.py b/superset/utils/pandas_postprocessing/prophet.py index 85d5530937d8..0c71807da109 100644 --- a/superset/utils/pandas_postprocessing/prophet.py +++ b/superset/utils/pandas_postprocessing/prophet.py @@ -71,13 +71,21 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments ) if df["ds"].dt.tz: df["ds"] = df["ds"].dt.tz_convert(None) - model.fit(df) - future = model.make_future_dataframe(periods=periods, freq=freq) - forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + try: + model.fit(df) + future = model.make_future_dataframe(periods=periods, freq=freq) + forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + except Exception as ex: # noqa: BLE001 + raise InvalidPostProcessingError( + _( + "Unable to generate forecast: %(error)s", + error=str(ex), + ) + ) from ex return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"]) -def prophet( # pylint: disable=too-many-arguments +def prophet( # pylint: disable=too-many-arguments # noqa: C901 df: DataFrame, time_grain: str, periods: int, @@ -136,6 +144,8 @@ def prophet( # pylint: disable=too-many-arguments raise InvalidPostProcessingError(_("DataFrame must include temporal column")) if len(df.columns) < 2: raise InvalidPostProcessingError(_("DataFrame include at least one series")) + if len(df) < 2: + raise InvalidPostProcessingError(_("Forecast requires at least 2 data points")) target_df = DataFrame() diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index 4d6884c8af0e..f41f4a67f3cf 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -26,6 +26,8 @@ from superset.constants import TimeGrain from superset.exceptions import InvalidPostProcessingError +_PANDAS_VERSION = tuple(int(x) for x in pd.__version__.split(".")[:2]) + NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = { "average": np.average, "argmin": np.argmin, @@ -76,18 +78,18 @@ ) PROPHET_TIME_GRAIN_MAP: dict[str, str] = { - TimeGrain.SECOND: "S", + TimeGrain.SECOND: "s", TimeGrain.MINUTE: "min", TimeGrain.FIVE_MINUTES: "5min", TimeGrain.TEN_MINUTES: "10min", TimeGrain.FIFTEEN_MINUTES: "15min", TimeGrain.THIRTY_MINUTES: "30min", - TimeGrain.HOUR: "H", + TimeGrain.HOUR: "h", TimeGrain.DAY: "D", TimeGrain.WEEK: "W", - TimeGrain.MONTH: "M", - TimeGrain.QUARTER: "Q", - TimeGrain.YEAR: "A", + TimeGrain.MONTH: "ME" if _PANDAS_VERSION >= (2, 2) else "M", + TimeGrain.QUARTER: "QE" if _PANDAS_VERSION >= (2, 2) else "Q", + TimeGrain.YEAR: "YE" if _PANDAS_VERSION >= (2, 2) else "A", TimeGrain.WEEK_STARTING_SUNDAY: "W-SUN", TimeGrain.WEEK_STARTING_MONDAY: "W-MON", TimeGrain.WEEK_ENDING_SATURDAY: "W-SAT", diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py index 7dacaeff9de1..c87c3790a95c 100644 --- a/tests/unit_tests/pandas_postprocessing/test_prophet.py +++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py @@ -16,6 +16,7 @@ # under the License. from datetime import datetime from importlib.util import find_spec +from unittest.mock import patch import pandas as pd import pytest @@ -186,6 +187,43 @@ def test_prophet_incorrect_time_grain(): ) +def test_prophet_insufficient_data(): + single_row_df = pd.DataFrame( + { + DTTM_ALIAS: [datetime(2022, 1, 1)], + "sales": [100.0], + } + ) + with pytest.raises(InvalidPostProcessingError, match="at least 2 data points"): + prophet( + df=single_row_df, + time_grain="P1D", + periods=3, + confidence_interval=0.9, + ) + + +def test_prophet_fit_error(): + if find_spec("prophet") is None: + pytest.skip("prophet not installed") + + with patch( + "superset.utils.pandas_postprocessing.prophet._prophet_fit_and_predict" + ) as mock_fit: + mock_fit.side_effect = InvalidPostProcessingError( + "Unable to generate forecast: Dataframe has fewer than 2 non-NaN rows." + ) + with pytest.raises( + InvalidPostProcessingError, match="Unable to generate forecast" + ): + prophet( + df=prophet_df, + time_grain="P1D", + periods=3, + confidence_interval=0.9, + ) + + def test_prophet_uncertainty_lower_bound_can_be_negative_for_negative_series(): """ Regression for #21734: when the input series contains negative values,