Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions sdks/python/apache_beam/io/gcp/bigquery_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ def _add_argparse_args(cls, parser):

cls.UserDefinedOptions = UserDefinedOptions

def setUp(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()

def tearDown(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()
# Reset runtime options to avoid side-effects caused by other tests.
RuntimeValueProvider.set_runtime_options(None)

Expand Down Expand Up @@ -510,12 +514,8 @@ def test_create_temp_dataset_exception(self, exception_type, error_message):
expected_retries=3),
])
def test_get_table_transient_exception(self, responses, expected_retries):
class DummyTable:
class DummySchema:
fields = []

numBytes = 5
schema = DummySchema()
dummy_table = bigquery.Table(
numBytes=5, schema=bigquery.TableSchema(fields=[]))

# TODO(https://github.com/apache/beam/issues/34549): This test relies on
# lineage metrics which Prism doesn't seem to handle correctly. Defaulting
Expand All @@ -542,16 +542,16 @@ def store_callback(unused_request):
raise exception
else:
call_counter += 1
return DummyTable()
return dummy_table

mock_get_table.side_effect = store_callback
_ = p | beam.io.ReadFromBigQuery(
table="project.dataset.table", gcs_location="gs://some_bucket")

# ReadFromBigQuery export mode calls get_table() twice. Once to get
# metadata (numBytes), and once to retrieve the table's schema
# Any additional calls are retries
self.assertEqual(expected_retries, mock_get_table.call_count - 2)
# ReadFromBigQuery export mode retrieves table metadata once. The later
# schema lookup uses the cached table definition.
# Any additional calls are retries.
self.assertEqual(expected_retries, mock_get_table.call_count - 1)
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SOURCE),
set(["bigquery:project.dataset.table"]))
Expand Down Expand Up @@ -819,9 +819,11 @@ def _cleanup_files(self):
os.remove('insert_calls2')

def setUp(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()
self._cleanup_files()

def tearDown(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()
self._cleanup_files()

def test_noop_schema_parsing(self):
Expand Down Expand Up @@ -2090,6 +2092,12 @@ def test_with_batched_input_splits_large_batch(self):

@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class BigQueryStreamingInsertTransformTests(unittest.TestCase):
def setUp(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()

def tearDown(self):
bigquery_tools.BigQueryWrapper._clear_table_definition_cache()

def test_dofn_client_process_performs_batching(self):
client = mock.Mock()
client.tables.Get.return_value = bigquery.Table(
Expand Down
81 changes: 81 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
# pytype: skip-file
# pylint: disable=wrong-import-order, wrong-import-position

import collections
import datetime
import decimal
import io
import json
import logging
import re
import sys
import threading
import time
import uuid
from json.decoder import JSONDecodeError
Expand Down Expand Up @@ -358,6 +360,12 @@ class BigQueryWrapper(object):

HISTOGRAM_METRIC_LOGGER = MetricLogger()

# Shared by wrapper instances within one Python SDK worker process.
_TABLE_DEFINITION_CACHE_MAX_ENTRIES = 256
_TABLE_DEFINITION_CACHE_TTL_SECS = 60 * 60

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A TTL of 1 hour (60 * 60 seconds) is quite long for caching table definitions. In streaming pipelines or long-running jobs where table schemas might be updated dynamically (e.g., adding a nullable column), workers could use stale metadata for up to an hour, leading to schema mismatch errors. Consider reducing the default TTL to a more conservative value (e.g., 5 or 10 minutes) or making it configurable via pipeline options.

Suggested change
_TABLE_DEFINITION_CACHE_TTL_SECS = 60 * 60
_TABLE_DEFINITION_CACHE_TTL_SECS = 5 * 60

_table_definition_cache = collections.OrderedDict()
_table_definition_cache_lock = threading.RLock()

def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):
self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions())
self.gcp_bq_client = client or gcp_bigquery.Client(
Expand Down Expand Up @@ -394,6 +402,69 @@ def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):

self.created_temp_dataset = False

@classmethod
def _table_definition_cache_key(cls, project_id, dataset_id, table_id):
return (project_id, dataset_id, table_id)

@classmethod
def _get_cached_table_definition(cls, project_id, dataset_id, table_id):
cache_key = cls._table_definition_cache_key(
project_id, dataset_id, table_id)
now = time.monotonic()
with cls._table_definition_cache_lock:
cache_entry = cls._table_definition_cache.get(cache_key)
if cache_entry is None:
return None

expires_at, table = cache_entry
if expires_at <= now:
cls._table_definition_cache.pop(cache_key, None)
return None

cls._table_definition_cache.move_to_end(cache_key)
return table

@classmethod
def _cache_table_definition(cls, project_id, dataset_id, table_id, table):
table_type = getattr(bigquery, 'Table', None)
if table_type is None or not isinstance(table, table_type):
cls._invalidate_table_definition_cache(project_id, dataset_id, table_id)
return

cache_key = cls._table_definition_cache_key(
project_id, dataset_id, table_id)
expires_at = time.monotonic() + cls._TABLE_DEFINITION_CACHE_TTL_SECS
with cls._table_definition_cache_lock:
cls._table_definition_cache[cache_key] = (expires_at, table)
cls._table_definition_cache.move_to_end(cache_key)
while (len(cls._table_definition_cache)
> cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES):
cls._table_definition_cache.popitem(last=False)

@classmethod
def _invalidate_table_definition_cache(
cls, project_id=None, dataset_id=None, table_id=None):
with cls._table_definition_cache_lock:
if (project_id is not None and dataset_id is not None and
table_id is not None):
cache_key = cls._table_definition_cache_key(
project_id, dataset_id, table_id)
cls._table_definition_cache.pop(cache_key, None)
return

Comment on lines +448 to +454

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When clearing the entire cache (i.e., when project_id, dataset_id, and table_id are all None), we can use cls._table_definition_cache.clear() instead of building a list of all keys and popping them one by one. This is much more efficient and simpler.

    with cls._table_definition_cache_lock:
      if project_id is None and dataset_id is None and table_id is None:
        cls._table_definition_cache.clear()
        return

      if (project_id is not None and dataset_id is not None and
          table_id is not None):
        cache_key = cls._table_definition_cache_key(
            project_id, dataset_id, table_id)
        cls._table_definition_cache.pop(cache_key, None)
        return

keys_to_delete = [
key for key in cls._table_definition_cache
if ((project_id is None or key[0] == project_id) and
(dataset_id is None or key[1] == dataset_id) and
(table_id is None or key[2] == table_id))
]
for key in keys_to_delete:
cls._table_definition_cache.pop(key, None)

@classmethod
def _clear_table_definition_cache(cls):
cls._invalidate_table_definition_cache()

@property
def unique_row_id(self):
"""Returns a unique row ID (str) used to avoid multiple insertions.
Expand Down Expand Up @@ -804,9 +875,15 @@ def get_table(self, project_id, dataset_id, table_id):
Raises:
HttpError: if lookup failed.
"""
cached_table = self._get_cached_table_definition(
project_id, dataset_id, table_id)
if cached_table is not None:
return cached_table

request = bigquery.BigqueryTablesGetRequest(
projectId=project_id, datasetId=dataset_id, tableId=table_id)
response = self.client.tables.Get(request)
self._cache_table_definition(project_id, dataset_id, table_id, response)
return response

def _create_table(
Expand All @@ -833,6 +910,7 @@ def _create_table(
request = bigquery.BigqueryTablesInsertRequest(
projectId=project_id, datasetId=dataset_id, table=table)
response = self.client.tables.Insert(request)
self._cache_table_definition(project_id, dataset_id, table_id, response)
_LOGGER.debug("Created the table with id %s", table_id)
# The response is a bigquery.Table instance.
return response
Expand Down Expand Up @@ -909,9 +987,12 @@ def _delete_table(self, project_id, dataset_id, table_id):
if exn.status_code == 404:
_LOGGER.warning(
'Table %s:%s.%s does not exist', project_id, dataset_id, table_id)
self._invalidate_table_definition_cache(
project_id, dataset_id, table_id)
return
else:
raise
self._invalidate_table_definition_cache(project_id, dataset_id, table_id)

@retry.with_exponential_backoff(
num_retries=MAX_RETRIES,
Expand Down
168 changes: 168 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,28 @@ def test_calling_with_all_arguments(self):

@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBigQueryWrapper(unittest.TestCase):
def setUp(self):
self.wrapper_cls = beam.io.gcp.bigquery_tools.BigQueryWrapper
self._cache_ttl_secs = self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS
self._cache_max_entries = (
self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES)
self.wrapper_cls._clear_table_definition_cache()

def tearDown(self):
self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS = self._cache_ttl_secs
self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES = (
self._cache_max_entries)
self.wrapper_cls._clear_table_definition_cache()

def _make_table(
self,
project_id='project-id',
dataset_id='dataset_id',
table_id='table_id'):
return bigquery.Table(
tableReference=bigquery.TableReference(
projectId=project_id, datasetId=dataset_id, tableId=table_id))

def test_delete_non_existing_dataset(self):
client = mock.Mock()
client.datasets.Delete.side_effect = HttpError(
Expand Down Expand Up @@ -292,6 +314,152 @@ def test_temporary_dataset_is_unique(self, patched_time_sleep):
wrapper.create_temporary_dataset('project-id', 'location')
self.assertTrue(client.datasets.Get.called)

def test_get_table_uses_shared_table_definition_cache(self):
table = self._make_table()
client_1 = mock.Mock()
client_1.tables.Get.return_value = table
client_2 = mock.Mock()
client_2.tables.Get.return_value = self._make_table()

wrapper_1 = beam.io.gcp.bigquery_tools.BigQueryWrapper(client_1)
wrapper_2 = beam.io.gcp.bigquery_tools.BigQueryWrapper(client_2)

self.assertIs(
wrapper_1.get_table('project-id', 'dataset_id', 'table_id'), table)
self.assertIs(
wrapper_2.get_table('project-id', 'dataset_id', 'table_id'), table)
client_1.tables.Get.assert_called_once()
client_2.tables.Get.assert_not_called()

def test_get_table_caches_tables_independently(self):
first_table = self._make_table(table_id='first_table')
second_table = self._make_table(table_id='second_table')
client = mock.Mock()
client.tables.Get.side_effect = [first_table, second_table]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'first_table'),
first_table)
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'second_table'),
second_table)

self.assertEqual(client.tables.Get.call_count, 2)

def test_get_table_refreshes_expired_cache_entry(self):
self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS = 1
first_table = self._make_table()
second_table = self._make_table()
client = mock.Mock()
client.tables.Get.side_effect = [first_table, second_table]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

with mock.patch(
'apache_beam.io.gcp.bigquery_tools.time.monotonic') as monotonic:
monotonic.return_value = 100
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'),
first_table)

monotonic.return_value = 100.5
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'),
first_table)

monotonic.return_value = 101.1
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'),
second_table)

self.assertEqual(client.tables.Get.call_count, 2)

def test_get_table_does_not_cache_failures(self):
table = self._make_table()
client = mock.Mock()
client.tables.Get.side_effect = [
HttpError(response={'status': '404'}, url='', content=''), table
]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

with self.assertRaises(HttpError):
wrapper.get_table('project-id', 'dataset_id', 'table_id')

self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'), table)
self.assertEqual(client.tables.Get.call_count, 2)

def test_get_table_does_not_cache_none_response(self):
table = self._make_table()
client = mock.Mock()
client.tables.Get.side_effect = [None, table]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

self.assertIsNone(wrapper.get_table('project-id', 'dataset_id', 'table_id'))
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'), table)
self.assertEqual(client.tables.Get.call_count, 2)

def test_delete_table_invalidates_table_definition_cache(self):
first_table = self._make_table()
second_table = self._make_table()
client = mock.Mock()
client.tables.Get.side_effect = [first_table, second_table]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'), first_table)
wrapper._delete_table('project-id', 'dataset_id', 'table_id')
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'), second_table)

self.assertEqual(client.tables.Get.call_count, 2)

def test_create_table_updates_table_definition_cache(self):
stale_table = self._make_table()
created_table = self._make_table()
client = mock.Mock()
client.tables.Get.return_value = stale_table
client.tables.Insert.return_value = created_table
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'table_id'), stale_table)
self.assertIs(
wrapper._create_table(
'project-id', 'dataset_id', 'table_id', bigquery.TableSchema()),
created_table)

new_client = mock.Mock()
new_wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(new_client)
self.assertIs(
new_wrapper.get_table('project-id', 'dataset_id', 'table_id'),
created_table)
new_client.tables.Get.assert_not_called()

def test_table_definition_cache_evicts_oldest_entry(self):
self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES = 1
first_table = self._make_table(table_id='first_table')
second_table = self._make_table(table_id='second_table')
refreshed_first_table = self._make_table(table_id='first_table')
client = mock.Mock()
client.tables.Get.side_effect = [
first_table, second_table, refreshed_first_table
]
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'first_table'),
first_table)
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'second_table'),
second_table)
self.assertIs(
wrapper.get_table('project-id', 'dataset_id', 'first_table'),
refreshed_first_table)

self.assertEqual(client.tables.Get.call_count, 3)

def test_get_or_create_dataset_created(self):
client = mock.Mock()
client.datasets.Get.side_effect = HttpError(
Expand Down
Loading