Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions packages/bigframes/bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
import pandas
import pyarrow as pa
from google.cloud import bigquery_storage_v1
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
from google.cloud.bigquery_storage_v1 import (
types as bq_storage_types,
writer as bq_storage_writer,
)

import bigframes._tools
import bigframes._tools.strings
Expand Down Expand Up @@ -520,38 +523,51 @@ def write_data(
)
serialized_schema = schema.serialize().to_pybytes()

def stream_worker(work: Iterator[pa.RecordBatch]) -> str:
def stream_worker(
work: Iterator[pa.RecordBatch], max_outstanding: int = 5
) -> str:
requested_stream = bq_storage_types.WriteStream(
type_=bq_storage_types.WriteStream.Type.PENDING
)
stream = self._write_client.create_write_stream(
parent=parent, write_stream=requested_stream
)
stream_name = stream.name
base_request = bq_storage_types.AppendRowsRequest(
write_stream=stream.name,
)
base_request.arrow_rows.writer_schema.serialized_schema = serialized_schema

def request_generator():
current_offset = 0
for batch in work:
request = bq_storage_types.AppendRowsRequest(
write_stream=stream.name, offset=current_offset
)
stream_manager = bq_storage_writer.AppendRowsStream(
client=self._write_client, initial_request_template=base_request
)
stream_name = stream.name
current_offset = 0
futures: list[bq_storage_writer.AppendRowsFuture] = []

for batch in work:
if len(futures) >= max_outstanding:
row_errors = futures.pop(0).result().row_errors
if row_errors:
raise ValueError(
f"Problem loading rows: {row_errors}. {constants.FEEDBACK_LINK}"
)

request.arrow_rows.writer_schema.serialized_schema = (
serialized_schema
)
request.arrow_rows.rows.serialized_record_batch = (
batch.serialize().to_pybytes()
)
request = bq_storage_types.AppendRowsRequest(offset=current_offset)
request.arrow_rows.rows.serialized_record_batch = (
batch.serialize().to_pybytes()
)

yield request
current_offset += batch.num_rows
futures.append(stream_manager.send(request))
current_offset += batch.num_rows

responses = self._write_client.append_rows(requests=request_generator())
for resp in responses:
if resp.row_errors:
for future in futures:
row_errors = future.result().row_errors
if row_errors:
raise ValueError(
f"Errors in stream {stream_name}: {resp.row_errors}"
f"Problem loading rows: {row_errors}. {constants.FEEDBACK_LINK}"
)

stream_manager.close()
self._write_client.finalize_write_stream(name=stream_name)
return stream_name

Expand Down
Loading