From 4669a97b3b6f4d6805171b9d03b45fc6430079a5 Mon Sep 17 00:00:00 2001 From: Lalit Yadav Date: Thu, 18 Jun 2026 18:25:26 -0500 Subject: [PATCH] Cache BigQuery table definitions in Python SDK --- .../apache_beam/io/gcp/bigquery_test.py | 30 ++-- .../apache_beam/io/gcp/bigquery_tools.py | 81 +++++++++ .../apache_beam/io/gcp/bigquery_tools_test.py | 168 ++++++++++++++++++ 3 files changed, 268 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index 234c99847a44..12c75720570e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -345,7 +345,11 @@ def _add_argparse_args(cls, parser): cls.UserDefinedOptions = UserDefinedOptions + def setUp(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() + def tearDown(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() # Reset runtime options to avoid side-effects caused by other tests. RuntimeValueProvider.set_runtime_options(None) @@ -510,12 +514,8 @@ def test_create_temp_dataset_exception(self, exception_type, error_message): expected_retries=3), ]) def test_get_table_transient_exception(self, responses, expected_retries): - class DummyTable: - class DummySchema: - fields = [] - - numBytes = 5 - schema = DummySchema() + dummy_table = bigquery.Table( + numBytes=5, schema=bigquery.TableSchema(fields=[])) # TODO(https://github.com/apache/beam/issues/34549): This test relies on # lineage metrics which Prism doesn't seem to handle correctly. Defaulting @@ -542,16 +542,16 @@ def store_callback(unused_request): raise exception else: call_counter += 1 - return DummyTable() + return dummy_table mock_get_table.side_effect = store_callback _ = p | beam.io.ReadFromBigQuery( table="project.dataset.table", gcs_location="gs://some_bucket") - # ReadFromBigQuery export mode calls get_table() twice. Once to get - # metadata (numBytes), and once to retrieve the table's schema - # Any additional calls are retries - self.assertEqual(expected_retries, mock_get_table.call_count - 2) + # ReadFromBigQuery export mode retrieves table metadata once. The later + # schema lookup uses the cached table definition. + # Any additional calls are retries. + self.assertEqual(expected_retries, mock_get_table.call_count - 1) self.assertSetEqual( Lineage.query(p.result.metrics(), Lineage.SOURCE), set(["bigquery:project.dataset.table"])) @@ -819,9 +819,11 @@ def _cleanup_files(self): os.remove('insert_calls2') def setUp(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() self._cleanup_files() def tearDown(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() self._cleanup_files() def test_noop_schema_parsing(self): @@ -2090,6 +2092,12 @@ def test_with_batched_input_splits_large_batch(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class BigQueryStreamingInsertTransformTests(unittest.TestCase): + def setUp(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() + + def tearDown(self): + bigquery_tools.BigQueryWrapper._clear_table_definition_cache() + def test_dofn_client_process_performs_batching(self): client = mock.Mock() client.tables.Get.return_value = bigquery.Table( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index 8dd58cd55a01..611ecfcd4ce4 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -28,6 +28,7 @@ # pytype: skip-file # pylint: disable=wrong-import-order, wrong-import-position +import collections import datetime import decimal import io @@ -35,6 +36,7 @@ import logging import re import sys +import threading import time import uuid from json.decoder import JSONDecodeError @@ -358,6 +360,12 @@ class BigQueryWrapper(object): HISTOGRAM_METRIC_LOGGER = MetricLogger() + # Shared by wrapper instances within one Python SDK worker process. + _TABLE_DEFINITION_CACHE_MAX_ENTRIES = 256 + _TABLE_DEFINITION_CACHE_TTL_SECS = 60 * 60 + _table_definition_cache = collections.OrderedDict() + _table_definition_cache_lock = threading.RLock() + def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None): self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions()) self.gcp_bq_client = client or gcp_bigquery.Client( @@ -394,6 +402,69 @@ def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None): self.created_temp_dataset = False + @classmethod + def _table_definition_cache_key(cls, project_id, dataset_id, table_id): + return (project_id, dataset_id, table_id) + + @classmethod + def _get_cached_table_definition(cls, project_id, dataset_id, table_id): + cache_key = cls._table_definition_cache_key( + project_id, dataset_id, table_id) + now = time.monotonic() + with cls._table_definition_cache_lock: + cache_entry = cls._table_definition_cache.get(cache_key) + if cache_entry is None: + return None + + expires_at, table = cache_entry + if expires_at <= now: + cls._table_definition_cache.pop(cache_key, None) + return None + + cls._table_definition_cache.move_to_end(cache_key) + return table + + @classmethod + def _cache_table_definition(cls, project_id, dataset_id, table_id, table): + table_type = getattr(bigquery, 'Table', None) + if table_type is None or not isinstance(table, table_type): + cls._invalidate_table_definition_cache(project_id, dataset_id, table_id) + return + + cache_key = cls._table_definition_cache_key( + project_id, dataset_id, table_id) + expires_at = time.monotonic() + cls._TABLE_DEFINITION_CACHE_TTL_SECS + with cls._table_definition_cache_lock: + cls._table_definition_cache[cache_key] = (expires_at, table) + cls._table_definition_cache.move_to_end(cache_key) + while (len(cls._table_definition_cache) + > cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES): + cls._table_definition_cache.popitem(last=False) + + @classmethod + def _invalidate_table_definition_cache( + cls, project_id=None, dataset_id=None, table_id=None): + with cls._table_definition_cache_lock: + if (project_id is not None and dataset_id is not None and + table_id is not None): + cache_key = cls._table_definition_cache_key( + project_id, dataset_id, table_id) + cls._table_definition_cache.pop(cache_key, None) + return + + keys_to_delete = [ + key for key in cls._table_definition_cache + if ((project_id is None or key[0] == project_id) and + (dataset_id is None or key[1] == dataset_id) and + (table_id is None or key[2] == table_id)) + ] + for key in keys_to_delete: + cls._table_definition_cache.pop(key, None) + + @classmethod + def _clear_table_definition_cache(cls): + cls._invalidate_table_definition_cache() + @property def unique_row_id(self): """Returns a unique row ID (str) used to avoid multiple insertions. @@ -804,9 +875,15 @@ def get_table(self, project_id, dataset_id, table_id): Raises: HttpError: if lookup failed. """ + cached_table = self._get_cached_table_definition( + project_id, dataset_id, table_id) + if cached_table is not None: + return cached_table + request = bigquery.BigqueryTablesGetRequest( projectId=project_id, datasetId=dataset_id, tableId=table_id) response = self.client.tables.Get(request) + self._cache_table_definition(project_id, dataset_id, table_id, response) return response def _create_table( @@ -833,6 +910,7 @@ def _create_table( request = bigquery.BigqueryTablesInsertRequest( projectId=project_id, datasetId=dataset_id, table=table) response = self.client.tables.Insert(request) + self._cache_table_definition(project_id, dataset_id, table_id, response) _LOGGER.debug("Created the table with id %s", table_id) # The response is a bigquery.Table instance. return response @@ -909,9 +987,12 @@ def _delete_table(self, project_id, dataset_id, table_id): if exn.status_code == 404: _LOGGER.warning( 'Table %s:%s.%s does not exist', project_id, dataset_id, table_id) + self._invalidate_table_definition_cache( + project_id, dataset_id, table_id) return else: raise + self._invalidate_table_definition_cache(project_id, dataset_id, table_id) @retry.with_exponential_backoff( num_retries=MAX_RETRIES, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 078c42160941..4e2ca572423e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -177,6 +177,28 @@ def test_calling_with_all_arguments(self): @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQueryWrapper(unittest.TestCase): + def setUp(self): + self.wrapper_cls = beam.io.gcp.bigquery_tools.BigQueryWrapper + self._cache_ttl_secs = self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS + self._cache_max_entries = ( + self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES) + self.wrapper_cls._clear_table_definition_cache() + + def tearDown(self): + self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS = self._cache_ttl_secs + self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES = ( + self._cache_max_entries) + self.wrapper_cls._clear_table_definition_cache() + + def _make_table( + self, + project_id='project-id', + dataset_id='dataset_id', + table_id='table_id'): + return bigquery.Table( + tableReference=bigquery.TableReference( + projectId=project_id, datasetId=dataset_id, tableId=table_id)) + def test_delete_non_existing_dataset(self): client = mock.Mock() client.datasets.Delete.side_effect = HttpError( @@ -292,6 +314,152 @@ def test_temporary_dataset_is_unique(self, patched_time_sleep): wrapper.create_temporary_dataset('project-id', 'location') self.assertTrue(client.datasets.Get.called) + def test_get_table_uses_shared_table_definition_cache(self): + table = self._make_table() + client_1 = mock.Mock() + client_1.tables.Get.return_value = table + client_2 = mock.Mock() + client_2.tables.Get.return_value = self._make_table() + + wrapper_1 = beam.io.gcp.bigquery_tools.BigQueryWrapper(client_1) + wrapper_2 = beam.io.gcp.bigquery_tools.BigQueryWrapper(client_2) + + self.assertIs( + wrapper_1.get_table('project-id', 'dataset_id', 'table_id'), table) + self.assertIs( + wrapper_2.get_table('project-id', 'dataset_id', 'table_id'), table) + client_1.tables.Get.assert_called_once() + client_2.tables.Get.assert_not_called() + + def test_get_table_caches_tables_independently(self): + first_table = self._make_table(table_id='first_table') + second_table = self._make_table(table_id='second_table') + client = mock.Mock() + client.tables.Get.side_effect = [first_table, second_table] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'first_table'), + first_table) + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'second_table'), + second_table) + + self.assertEqual(client.tables.Get.call_count, 2) + + def test_get_table_refreshes_expired_cache_entry(self): + self.wrapper_cls._TABLE_DEFINITION_CACHE_TTL_SECS = 1 + first_table = self._make_table() + second_table = self._make_table() + client = mock.Mock() + client.tables.Get.side_effect = [first_table, second_table] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + with mock.patch( + 'apache_beam.io.gcp.bigquery_tools.time.monotonic') as monotonic: + monotonic.return_value = 100 + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), + first_table) + + monotonic.return_value = 100.5 + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), + first_table) + + monotonic.return_value = 101.1 + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), + second_table) + + self.assertEqual(client.tables.Get.call_count, 2) + + def test_get_table_does_not_cache_failures(self): + table = self._make_table() + client = mock.Mock() + client.tables.Get.side_effect = [ + HttpError(response={'status': '404'}, url='', content=''), table + ] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + with self.assertRaises(HttpError): + wrapper.get_table('project-id', 'dataset_id', 'table_id') + + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), table) + self.assertEqual(client.tables.Get.call_count, 2) + + def test_get_table_does_not_cache_none_response(self): + table = self._make_table() + client = mock.Mock() + client.tables.Get.side_effect = [None, table] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + self.assertIsNone(wrapper.get_table('project-id', 'dataset_id', 'table_id')) + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), table) + self.assertEqual(client.tables.Get.call_count, 2) + + def test_delete_table_invalidates_table_definition_cache(self): + first_table = self._make_table() + second_table = self._make_table() + client = mock.Mock() + client.tables.Get.side_effect = [first_table, second_table] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), first_table) + wrapper._delete_table('project-id', 'dataset_id', 'table_id') + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), second_table) + + self.assertEqual(client.tables.Get.call_count, 2) + + def test_create_table_updates_table_definition_cache(self): + stale_table = self._make_table() + created_table = self._make_table() + client = mock.Mock() + client.tables.Get.return_value = stale_table + client.tables.Insert.return_value = created_table + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'table_id'), stale_table) + self.assertIs( + wrapper._create_table( + 'project-id', 'dataset_id', 'table_id', bigquery.TableSchema()), + created_table) + + new_client = mock.Mock() + new_wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(new_client) + self.assertIs( + new_wrapper.get_table('project-id', 'dataset_id', 'table_id'), + created_table) + new_client.tables.Get.assert_not_called() + + def test_table_definition_cache_evicts_oldest_entry(self): + self.wrapper_cls._TABLE_DEFINITION_CACHE_MAX_ENTRIES = 1 + first_table = self._make_table(table_id='first_table') + second_table = self._make_table(table_id='second_table') + refreshed_first_table = self._make_table(table_id='first_table') + client = mock.Mock() + client.tables.Get.side_effect = [ + first_table, second_table, refreshed_first_table + ] + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'first_table'), + first_table) + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'second_table'), + second_table) + self.assertIs( + wrapper.get_table('project-id', 'dataset_id', 'first_table'), + refreshed_first_table) + + self.assertEqual(client.tables.Get.call_count, 3) + def test_get_or_create_dataset_created(self): client = mock.Mock() client.datasets.Get.side_effect = HttpError(