diff --git a/packages/bigframes/bigframes/session/loader.py b/packages/bigframes/bigframes/session/loader.py index b0a9e0a1ed31..960208063105 100644 --- a/packages/bigframes/bigframes/session/loader.py +++ b/packages/bigframes/bigframes/session/loader.py @@ -52,7 +52,10 @@ import pandas import pyarrow as pa from google.cloud import bigquery_storage_v1 -from google.cloud.bigquery_storage_v1 import types as bq_storage_types +from google.cloud.bigquery_storage_v1 import ( + types as bq_storage_types, + writer as bq_storage_writer, +) import bigframes._tools import bigframes._tools.strings @@ -520,38 +523,51 @@ def write_data( ) serialized_schema = schema.serialize().to_pybytes() - def stream_worker(work: Iterator[pa.RecordBatch]) -> str: + def stream_worker( + work: Iterator[pa.RecordBatch], max_outstanding: int = 5 + ) -> str: requested_stream = bq_storage_types.WriteStream( type_=bq_storage_types.WriteStream.Type.PENDING ) stream = self._write_client.create_write_stream( parent=parent, write_stream=requested_stream ) - stream_name = stream.name + base_request = bq_storage_types.AppendRowsRequest( + write_stream=stream.name, + ) + base_request.arrow_rows.writer_schema.serialized_schema = serialized_schema - def request_generator(): - current_offset = 0 - for batch in work: - request = bq_storage_types.AppendRowsRequest( - write_stream=stream.name, offset=current_offset - ) + stream_manager = bq_storage_writer.AppendRowsStream( + client=self._write_client, initial_request_template=base_request + ) + stream_name = stream.name + current_offset = 0 + futures: list[bq_storage_writer.AppendRowsFuture] = [] + + for batch in work: + if len(futures) >= max_outstanding: + row_errors = futures.pop(0).result().row_errors + if row_errors: + raise ValueError( + f"Problem loading rows: {row_errors}. {constants.FEEDBACK_LINK}" + ) - request.arrow_rows.writer_schema.serialized_schema = ( - serialized_schema - ) - request.arrow_rows.rows.serialized_record_batch = ( - batch.serialize().to_pybytes() - ) + request = bq_storage_types.AppendRowsRequest(offset=current_offset) + request.arrow_rows.rows.serialized_record_batch = ( + batch.serialize().to_pybytes() + ) - yield request - current_offset += batch.num_rows + futures.append(stream_manager.send(request)) + current_offset += batch.num_rows - responses = self._write_client.append_rows(requests=request_generator()) - for resp in responses: - if resp.row_errors: + for future in futures: + row_errors = future.result().row_errors + if row_errors: raise ValueError( - f"Errors in stream {stream_name}: {resp.row_errors}" + f"Problem loading rows: {row_errors}. {constants.FEEDBACK_LINK}" ) + + stream_manager.close() self._write_client.finalize_write_stream(name=stream_name) return stream_name