Skip to content

Commit 8955629

Browse files
Fix: Bigquery support of complex nested types (#3190)
1 parent c44c83c commit 8955629

File tree

3 files changed

+195
-14
lines changed

3 files changed

+195
-14
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
logger = logging.getLogger(__name__)
3838

3939

40+
NestedField = t.Tuple[str, str, t.List[str]]
41+
NestedFieldsDict = t.Dict[str, t.List[NestedField]]
42+
43+
4044
@set_catalog()
4145
class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin):
4246
"""
@@ -282,11 +286,11 @@ def fetchall(
282286
def _split_alter_expressions(
283287
self,
284288
alter_expressions: t.List[exp.Alter],
285-
) -> t.Tuple[t.Dict[str, list[t.Tuple[str, str]]], t.List[exp.Alter]]:
289+
) -> t.Tuple[NestedFieldsDict, t.List[exp.Alter]]:
286290
"""
287291
Returns a dictionary of the nested fields to add and a list of the non-nested alter expressions.
288292
"""
289-
nested_fields_to_add: t.Dict[str, list[t.Tuple[str, str]]] = defaultdict(list)
293+
nested_fields_to_add: NestedFieldsDict = defaultdict(list)
290294
non_nested_expressions = []
291295

292296
for alter_expression in alter_expressions:
@@ -296,17 +300,58 @@ def _split_alter_expressions(
296300
and isinstance(action.this, exp.Dot)
297301
and isinstance(action.kind, exp.DataType)
298302
):
299-
record = action.this.this.sql(dialect="bigquery")
300-
field = action.this.expression.sql(dialect="bigquery")
301-
data_type = action.kind.sql(dialect="bigquery")
302-
nested_fields_to_add[record].append((data_type, field))
303+
root_field, *leaf_fields = action.this.this.sql(dialect=self.dialect).split(".")
304+
new_field = action.this.expression.sql(dialect=self.dialect)
305+
data_type = action.kind.sql(dialect=self.dialect)
306+
nested_fields_to_add[root_field].append((new_field, data_type, leaf_fields))
303307
else:
304308
non_nested_expressions.append(alter_expression)
305309

306310
return nested_fields_to_add, non_nested_expressions
307311

312+
def _build_nested_fields(
313+
self,
314+
current_fields: t.List[bigquery.SchemaField],
315+
fields_to_add: t.List[NestedField],
316+
) -> t.List[bigquery.SchemaField]:
317+
"""
318+
Recursively builds and updates the schema fields with the new nested fields.
319+
"""
320+
from google.cloud import bigquery
321+
322+
new_fields = []
323+
root: t.List[t.Tuple[str, str]] = []
324+
leaves: NestedFieldsDict = defaultdict(list)
325+
for new_field, data_type, leaf_fields in fields_to_add:
326+
if leaf_fields:
327+
leaves[leaf_fields[0]].append((new_field, data_type, leaf_fields[1:]))
328+
else:
329+
root.append((new_field, data_type))
330+
331+
for field in current_fields:
332+
# If the new fields are nested, we need to recursively build them
333+
if field.name in leaves:
334+
subfields = list(field.fields)
335+
subfields = self._build_nested_fields(subfields, leaves[field.name])
336+
new_fields.append(
337+
bigquery.SchemaField(
338+
field.name, "RECORD", mode=field.mode, fields=tuple(subfields)
339+
)
340+
)
341+
else:
342+
new_fields.append(field)
343+
344+
# Build and append the new root-level fields
345+
new_fields.extend(
346+
self.__get_bq_schemafield(
347+
new_field[0], exp.DataType.build(new_field[1], dialect=self.dialect)
348+
)
349+
for new_field in root
350+
)
351+
return new_fields
352+
308353
def _update_table_schema_nested_fields(
309-
self, nested_fields_to_add: t.Dict[str, list[t.Tuple[str, str]]], table_name: str
354+
self, nested_fields_to_add: NestedFieldsDict, table_name: str
310355
) -> None:
311356
"""
312357
Updates a BigQuery table schema by adding the new nested fields provided.
@@ -316,18 +361,17 @@ def _update_table_schema_nested_fields(
316361
table = self._get_table(table_name)
317362
original_schema = table.schema
318363
new_schema = []
319-
320364
for field in original_schema:
321-
current_fields = list(field.fields)
322365
if field.name in nested_fields_to_add:
323-
current_fields.extend(
324-
bigquery.SchemaField(new_field[1], new_field[0], mode="NULLABLE")
325-
for new_field in nested_fields_to_add[field.name]
366+
fields = self._build_nested_fields(
367+
list(field.fields), nested_fields_to_add[field.name]
326368
)
327-
if current_fields:
328369
new_schema.append(
329370
bigquery.SchemaField(
330-
field.name, "RECORD", mode="NULLABLE", fields=tuple(current_fields)
371+
field.name,
372+
"RECORD",
373+
mode=field.mode,
374+
fields=tuple(fields),
331375
)
332376
)
333377
else:

tests/core/engine_adapter/test_bigquery.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,3 +805,94 @@ def test_materialized_view_properties(
805805
assert sql_calls == [
806806
"CREATE OR REPLACE MATERIALIZED VIEW `test_table` PARTITION BY `ds` CLUSTER BY `a` AS SELECT 1",
807807
]
808+
809+
810+
def test_nested_fields_update(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
811+
adapter = make_mocked_engine_adapter(BigQueryEngineAdapter)
812+
813+
current_schema = [
814+
bigquery.SchemaField(
815+
"user",
816+
"RECORD",
817+
"NULLABLE",
818+
fields=(
819+
bigquery.SchemaField("name", "STRING", "NULLABLE"),
820+
bigquery.SchemaField(
821+
"orders",
822+
"RECORD",
823+
"REPEATED",
824+
fields=([bigquery.SchemaField("id", "INT64", "NULLABLE")]),
825+
),
826+
),
827+
)
828+
]
829+
new_nested_fields = [("year", "INT64", ["user", "orders"]), ("active", "BOOL", ["user"])]
830+
expected = [
831+
bigquery.SchemaField(
832+
"user",
833+
"RECORD",
834+
"NULLABLE",
835+
fields=(
836+
bigquery.SchemaField("name", "STRING", "NULLABLE"),
837+
bigquery.SchemaField(
838+
"orders",
839+
"RECORD",
840+
"REPEATED",
841+
fields=(
842+
bigquery.SchemaField("id", "INT64", "NULLABLE"),
843+
bigquery.SchemaField("year", "INT64", "NULLABLE"),
844+
),
845+
),
846+
bigquery.SchemaField("active", "BOOL", "NULLABLE"),
847+
),
848+
)
849+
]
850+
assert adapter._build_nested_fields(current_schema, new_nested_fields) == expected
851+
852+
current_schema = [
853+
bigquery.SchemaField(
854+
"users",
855+
"RECORD",
856+
"REPEATED",
857+
fields=(
858+
[
859+
bigquery.SchemaField(
860+
"user",
861+
"RECORD",
862+
"NULLABLE",
863+
fields=(bigquery.SchemaField("name", "STRING", "NULLABLE"),),
864+
)
865+
]
866+
),
867+
)
868+
]
869+
new_nested_fields = [
870+
("orders", "ARRAY<INT64>", ["users", "user"]),
871+
("tags", "STRING", ["users"]),
872+
("details", "ARRAY<STRING>", []),
873+
]
874+
expected = [
875+
bigquery.SchemaField(
876+
"users",
877+
"RECORD",
878+
"REPEATED",
879+
fields=(
880+
bigquery.SchemaField(
881+
"user",
882+
"RECORD",
883+
"NULLABLE",
884+
fields=(
885+
bigquery.SchemaField("name", "STRING", "NULLABLE"),
886+
bigquery.SchemaField("orders", "INT64", "REPEATED"),
887+
),
888+
),
889+
bigquery.SchemaField(
890+
"tags",
891+
"STRING",
892+
"NULLABLE",
893+
),
894+
),
895+
),
896+
bigquery.SchemaField("details", "STRING", "REPEATED"),
897+
]
898+
assert adapter._build_nested_fields(current_schema, new_nested_fields) == expected

tests/core/test_schema_diff.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,29 @@ def test_schema_diff_calculate_type_transitions():
393393
],
394394
dict(support_positional_add=True, support_nested_operations=True),
395395
),
396+
# Add columns in different levels of nesting of structs
397+
(
398+
"STRUCT<id INT, info STRUCT<col_a INT, col_b INT, col_c INT>>",
399+
"STRUCT<id INT, info STRUCT<col_a INT, col_b INT, col_c INT, col_d INT>, txt TEXT>",
400+
[
401+
TableAlterOperation.add(
402+
[
403+
TableAlterColumn.primitive("txt"),
404+
],
405+
"TEXT",
406+
expected_table_struct="STRUCT<id INT, info STRUCT<col_a INT, col_b INT, col_c INT>, txt TEXT>",
407+
),
408+
TableAlterOperation.add(
409+
[
410+
TableAlterColumn.struct("info"),
411+
TableAlterColumn.primitive("col_d"),
412+
],
413+
"INT",
414+
expected_table_struct="STRUCT<id INT, info STRUCT<col_a INT, col_b INT, col_c INT, col_d INT>, txt TEXT>",
415+
),
416+
],
417+
dict(support_positional_add=False, support_nested_operations=True),
418+
),
396419
# Remove a column from the start of a struct
397420
(
398421
"STRUCT<id INT, info STRUCT<col_a INT, col_b INT, col_c INT>>",
@@ -721,6 +744,29 @@ def test_schema_diff_calculate_type_transitions():
721744
},
722745
),
723746
),
747+
# Add columns to struct of array within different nesting levels
748+
(
749+
"STRUCT<id INT, infos ARRAY<STRUCT<col_a INT, col_b INT, col_c INT>>>",
750+
"STRUCT<id INT, infos ARRAY<STRUCT<col_a INT, col_b INT, col_c INT, col_d INT >>, col_e INT>",
751+
[
752+
TableAlterOperation.add(
753+
[
754+
TableAlterColumn.primitive("col_e"),
755+
],
756+
"INT",
757+
expected_table_struct="STRUCT<id INT, infos ARRAY<STRUCT<col_a INT, col_b INT, col_c INT>>, col_e INT>",
758+
),
759+
TableAlterOperation.add(
760+
[
761+
TableAlterColumn.array_of_struct("infos"),
762+
TableAlterColumn.primitive("col_d"),
763+
],
764+
"INT",
765+
expected_table_struct="STRUCT<id INT, infos ARRAY<STRUCT<col_a INT, col_b INT, col_c INT, col_d INT>>, col_e INT>",
766+
),
767+
],
768+
dict(support_positional_add=False, support_nested_operations=True),
769+
),
724770
# Add an array of primitives
725771
(
726772
"STRUCT<id INT, infos ARRAY<STRUCT<col_a INT, col_b INT, col_c INT>>>",

0 commit comments

Comments
 (0)