Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 79eb9ac

Browse files
perf: Avoid requery for some result downsample methods
1 parent 0ff1395 commit 79eb9ac

5 files changed

Lines changed: 62 additions & 54 deletions

File tree

bigframes/core/blocks.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -814,49 +814,36 @@ def _materialize_local(
814814
total_rows = result_batches.approx_total_rows
815815
# Remove downsampling config from subsequent invocations, as otherwise could result in many
816816
# iterations if downsampling undershoots
817-
return self._downsample(
818-
total_rows=total_rows,
819-
sampling_method=sample_config.sampling_method,
820-
fraction=fraction,
821-
random_state=sample_config.random_state,
822-
)._materialize_local(
823-
MaterializationOptions(ordered=materialize_options.ordered)
824-
)
825-
else:
826-
df = result_batches.to_pandas()
827-
df = self._copy_index_to_pandas(df)
828-
df.set_axis(self.column_labels, axis=1, copy=False)
829-
return df, execute_result.query_job
830-
831-
def _downsample(
832-
self, total_rows: int, sampling_method: str, fraction: float, random_state
833-
) -> Block:
834-
# either selecting fraction or number of rows
835-
if sampling_method == _HEAD:
836-
filtered_block = self.slice(stop=int(total_rows * fraction))
837-
return filtered_block
838-
elif (sampling_method == _UNIFORM) and (random_state is None):
839-
filtered_expr = self.expr._uniform_sampling(fraction)
840-
block = Block(
841-
filtered_expr,
842-
index_columns=self.index_columns,
843-
column_labels=self.column_labels,
844-
index_labels=self.index.names,
845-
)
846-
return block
847-
elif sampling_method == _UNIFORM:
848-
block = self.split(
849-
fracs=(fraction,),
850-
random_state=random_state,
851-
sort=False,
852-
)[0]
853-
return block
817+
if sample_config.sampling_method == "head":
818+
# Just truncates the result iterator without a follow-up query
819+
raw_df = result_batches.to_pandas(limit=int(total_rows * fraction))
820+
elif (
821+
sample_config.sampling_method == "uniform"
822+
and sample_config.random_state is None
823+
):
824+
# Pushes sample into result without new query
825+
sampled_batches = execute_result.batches(sample_rate=fraction)
826+
raw_df = sampled_batches.to_pandas()
827+
else: # uniform sample with random state requires a full follow-up query
828+
return self._downsample(
829+
fraction=fraction,
830+
random_state=sample_config.random_state,
831+
)._materialize_local(
832+
MaterializationOptions(ordered=materialize_options.ordered)
833+
)
854834
else:
855-
# This part should never be called, just in case.
856-
raise NotImplementedError(
857-
f"The downsampling method {sampling_method} is not implemented, "
858-
f"please choose from {','.join(_SAMPLING_METHODS)}."
859-
)
835+
raw_df = result_batches.to_pandas()
836+
df = self._copy_index_to_pandas(raw_df)
837+
df.set_axis(self.column_labels, axis=1, copy=False)
838+
return df, execute_result.query_job
839+
840+
def _downsample(self, fraction: float, random_state) -> Block:
841+
block = self.split(
842+
fracs=(fraction,),
843+
random_state=random_state,
844+
sort=False,
845+
)[0]
846+
return block
860847

861848
def split(
862849
self,

bigframes/core/bq_data.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,22 @@ def get_arrow_batches(
171171
columns: Sequence[str],
172172
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
173173
project_id: str,
174+
sample_rate: Optional[float] = None,
174175
) -> ReadResult:
175176
table_mod_options = {}
176177
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}
178+
179+
predicates = []
177180
if data.sql_predicate:
178-
read_options_dict["row_restriction"] = data.sql_predicate
181+
predicates.append(data.sql_predicate)
182+
if sample_rate is not None:
183+
assert isinstance(sample_rate, float)
184+
predicates.append(f"RAND() < {sample_rate}")
185+
186+
if predicates:
187+
full_predicates = " AND ".join(f"( {pred} )" for pred in predicates)
188+
read_options_dict["row_restriction"] = full_predicates
189+
179190
read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)
180191

181192
if data.at_time:

bigframes/core/local_data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import uuid
2626

2727
import geopandas # type: ignore
28+
import numpy
2829
import numpy as np
2930
import pandas as pd
3031
import pyarrow as pa
@@ -124,12 +125,20 @@ def to_arrow(
124125
geo_format: Literal["wkb", "wkt"] = "wkt",
125126
duration_type: Literal["int", "duration"] = "duration",
126127
json_type: Literal["string"] = "string",
128+
sample_rate: Optional[float] = None,
127129
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
128130
if geo_format != "wkt":
129131
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
130132
assert json_type == "string"
131133

132-
batches = self.data.to_batches()
134+
data = self.data
135+
136+
# This exists for symmetry with remote sources, but sampling local data like this shouldn't really happen
137+
if sample_rate is not None:
138+
to_take = numpy.random.rand(data.num_rows) < sample_rate
139+
data = data.filter(to_take)
140+
141+
batches = data.to_batches()
133142
schema = self.data.schema
134143
if duration_type == "int":
135144
schema = _schema_durations_to_ints(schema)

bigframes/session/executor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def to_arrow_table(self) -> pyarrow.Table:
107107
# Bug with some pyarrow versions, empty_table only supports base storage types, not extension types.
108108
return self._schema.to_pyarrow(use_storage_types=True).empty_table()
109109

110-
def to_pandas(self) -> pd.DataFrame:
111-
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema)
110+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
111+
return pd.concat(self.to_pandas_batches(max_results=limit))
112112

113113
def to_pandas_batches(
114114
self, page_size: Optional[int] = None, max_results: Optional[int] = None
@@ -158,7 +158,7 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
158158
...
159159

160160
@abc.abstractmethod
161-
def batches(self) -> ResultsIterator:
161+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
162162
...
163163

164164
@property
@@ -200,9 +200,9 @@ def execution_metadata(self) -> ExecutionMetadata:
200200
def schema(self) -> bigframes.core.schema.ArraySchema:
201201
return self._data.schema
202202

203-
def batches(self) -> ResultsIterator:
203+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
204204
return ResultsIterator(
205-
iter(self._data.to_arrow()[1]),
205+
iter(self._data.to_arrow(sample_rate=sample_rate)[1]),
206206
self.schema,
207207
self._data.metadata.row_count,
208208
self._data.metadata.total_bytes,
@@ -226,7 +226,7 @@ def execution_metadata(self) -> ExecutionMetadata:
226226
def schema(self) -> bigframes.core.schema.ArraySchema:
227227
return self._schema
228228

229-
def batches(self) -> ResultsIterator:
229+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
230230
return ResultsIterator(iter([]), self.schema, 0, 0)
231231

232232

@@ -260,12 +260,13 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
260260
source_ids = [selection[0] for selection in self._selected_fields]
261261
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))
262262

263-
def batches(self) -> ResultsIterator:
263+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
264264
read_batches = bq_data.get_arrow_batches(
265265
self._data,
266266
[x[0] for x in self._selected_fields],
267267
self._storage_client,
268268
self._project_id,
269+
sample_rate=sample_rate,
269270
)
270271
arrow_batches: Iterator[pa.RecordBatch] = map(
271272
functools.partial(

tests/system/small/test_dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4473,7 +4473,7 @@ def test_df_kurt(scalars_dfs):
44734473
"n_default",
44744474
],
44754475
)
4476-
def test_sample(scalars_dfs, frac, n, random_state):
4476+
def test_df_to_pandas_sample(scalars_dfs, frac, n, random_state):
44774477
scalars_df, _ = scalars_dfs
44784478
df = scalars_df.sample(frac=frac, n=n, random_state=random_state)
44794479
bf_result = df.to_pandas()
@@ -4484,15 +4484,15 @@ def test_sample(scalars_dfs, frac, n, random_state):
44844484
assert bf_result.shape[1] == scalars_df.shape[1]
44854485

44864486

4487-
def test_sample_determinism(penguins_df_default_index):
4487+
def test_df_to_pandas_sample_determinism(penguins_df_default_index):
44884488
df = penguins_df_default_index.sample(n=100, random_state=12345).head(15)
44894489
bf_result = df.to_pandas()
44904490
bf_result2 = df.to_pandas()
44914491

44924492
pandas.testing.assert_frame_equal(bf_result, bf_result2)
44934493

44944494

4495-
def test_sample_raises_value_error(scalars_dfs):
4495+
def test_df_to_pandas_sample_raises_value_error(scalars_dfs):
44964496
scalars_df, _ = scalars_dfs
44974497
with pytest.raises(
44984498
ValueError, match="Only one of 'n' or 'frac' parameter can be specified."

0 commit comments

Comments
 (0)