Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit b29c6f3

Browse files
committed
move cache()
1 parent 9e2cd71 commit b29c6f3

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

bigframes/ml/model_selection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,21 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra
110110
joined_df = dfs[0]
111111
for df in dfs[1:]:
112112
joined_df = joined_df.join(df, how="outer")
113-
joined_df = joined_df.cache()
114113
if stratify is None:
115114
joined_df_train, joined_df_test = joined_df._split(
116115
fracs=(train_size, test_size), random_state=random_state
117116
)
118117
else:
119118
joined_df_train, joined_df_test = _stratify_split(joined_df, stratify)
120119

120+
joined_df_train = joined_df_train.cache()
121+
joined_df_test = joined_df_test.cache()
122+
121123
results = []
122124
for array in arrays:
123125
columns = array.name if isinstance(array, bpd.Series) else array.columns
124-
results.append(joined_df_train[columns].cache())
125-
results.append(joined_df_test[columns].cache())
126+
results.append(joined_df_train[columns])
127+
results.append(joined_df_test[columns])
126128

127129
return results
128130

0 commit comments

Comments
 (0)