diff --git a/runtime/databricks/automl_runtime/forecast/deepar/model.py b/runtime/databricks/automl_runtime/forecast/deepar/model.py index 137c37a..5c13e4a 100644 --- a/runtime/databricks/automl_runtime/forecast/deepar/model.py +++ b/runtime/databricks/automl_runtime/forecast/deepar/model.py @@ -100,7 +100,7 @@ def predict(self, pred_df = pred_df.rename(columns={'index': self._time_col}) if self._id_cols: - id_col_name = '-'.join(self._id_cols) + id_col_name = self._id_cols[0] pred_df = pred_df.rename(columns={'item_id': id_col_name}) else: pred_df = pred_df.drop(columns='item_id') diff --git a/runtime/databricks/automl_runtime/forecast/deepar/utils.py b/runtime/databricks/automl_runtime/forecast/deepar/utils.py index 016de93..fbf6475 100644 --- a/runtime/databricks/automl_runtime/forecast/deepar/utils.py +++ b/runtime/databricks/automl_runtime/forecast/deepar/utils.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List, Optional +from typing import List, Optional, Union, Dict import pandas as pd - def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency_unit: str, @@ -66,10 +65,12 @@ def validate_and_generate_index(df: pd.DataFrame, return new_index_full -def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str, - frequency_unit: str, - frequency_quantity: int, - id_cols: Optional[List[str]] = None): +def set_index_and_fill_missing_time_steps( + df: pd.DataFrame, time_col: str, + frequency_unit: str, + frequency_quantity: int, + id_cols: Optional[List[str]] = None +) -> Union[pd.DataFrame, Dict[any, pd.DataFrame]]: """ Transform the input dataframe to an acceptable format for the GluonTS library. @@ -95,14 +96,16 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str, valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency_unit=frequency_unit, frequency_quantity=frequency_quantity) if id_cols is not None: + if len(id_cols) > 1: + raise ValueError("DeepAR does not support multiple time series id columns") df_dict = {} for grouped_id, grouped_df in df.groupby(id_cols): if isinstance(grouped_id, tuple): - ts_id = "-".join([str(x) for x in grouped_id]) - else: - ts_id = str(grouped_id) - df_dict[ts_id] = (grouped_df.set_index(time_col).sort_index() - .reindex(valid_index).drop(id_cols, axis=1)) + # TODO (ML-52171): Fix the DeepAR library to support multi-time series id columns + # For now, DeepAR is dropped for multiple id_cols + raise ValueError("DeepAR does not support multiple time series id columns") + df_dict[grouped_id] = (grouped_df.set_index(time_col).sort_index() + .reindex(valid_index).drop(id_cols, axis=1)) return df_dict