Skip to content

Commit dc13684

Browse files
authored
Fix(spark): support replacing self-referencing queries (#1642)
* Handle spark self-referencing queries * Enable SCD type 2 test for spark and mssql * Fix self referencing return * Move unit test from mixins to spark
1 parent f8e3623 commit dc13684

File tree

3 files changed

+58
-11
lines changed

3 files changed

+58
-11
lines changed

sqlmesh/core/engine_adapter/spark.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
import pandas as pd
88
from sqlglot import exp
9+
from sqlglot.optimizer.qualify_columns import quote_identifiers
910

1011
from sqlmesh.core.dialect import to_schema
1112
from sqlmesh.core.engine_adapter.base import (
1213
CatalogSupport,
1314
InsertOverwriteStrategy,
1415
SourceQuery,
1516
)
16-
from sqlmesh.core.engine_adapter.mixins import HiveMetastoreTablePropertiesMixin
17+
from sqlmesh.core.engine_adapter.mixins import (
18+
HiveMetastoreTablePropertiesMixin,
19+
LogicalReplaceQueryMixin,
20+
)
1721
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType, set_catalog
1822
from sqlmesh.utils import classproperty
1923
from sqlmesh.utils.errors import SQLMeshError
@@ -338,6 +342,21 @@ def replace_query(
338342
columns_to_types = columns_to_types or self.columns(table_name)
339343
if not columns_to_types:
340344
raise SQLMeshError("Cannot replace table without columns to types")
345+
346+
# Self-referential queries: cannot insert overwrite a SELECT from itself, so
347+
# use LogicalReplaceQuery (which creates a temp table and SELECTs from it)
348+
if len(source_queries) > 1:
349+
raise SQLMeshError("Cannot replace table with a batched dataframe")
350+
with source_queries[0] as query:
351+
target_table = exp.to_table(table_name)
352+
self_referencing = any(
353+
quote_identifiers(table) == quote_identifiers(target_table)
354+
for table in query.find_all(exp.Table)
355+
)
356+
357+
if self_referencing:
358+
return LogicalReplaceQueryMixin.replace_query(self, table_name, query, columns_to_types) # type: ignore
359+
341360
self.create_table(table_name, columns_to_types)
342361
return self._insert_overwrite_by_condition(
343362
table_name, source_queries, columns_to_types, where=exp.true()
@@ -382,3 +401,7 @@ def create_view(
382401
super().create_view(
383402
view_name, query_or_df, columns_to_types, replace, materialized, **create_kwargs
384403
)
404+
405+
def _truncate_table(self, table_name: TableName) -> str:
406+
table = quote_identifiers(exp.to_table(table_name))
407+
return f"TRUNCATE TABLE {table.sql(dialect=self.dialect)}"

tests/core/engine_adapter/test_integration.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def input_data(
138138
return self._format_df(data, to_datetime=self.dialect != "trino")
139139

140140
def output_data(self, data: pd.DataFrame) -> pd.DataFrame:
141-
return self._format_df(data, include_tz=self.dialect in ("spark", "databricks"))
141+
return self._format_df(data, include_tz=self.dialect == "databricks")
142142

143143
def table(self, table_name: str, schema: str = TEST_SCHEMA) -> exp.Table:
144144
return exp.to_table(
@@ -737,17 +737,9 @@ def test_merge(ctx: TestContext):
737737

738738

739739
def test_scd_type_2(ctx: TestContext):
740-
if ctx.dialect == "tsql":
741-
pytest.skip(f"MSSQL scd type 2 functionality waiting on sqlglot cte in FROM fix")
742-
if ctx.dialect == "spark":
743-
pytest.skip(
744-
"Spark SCD Type 2 does not currently work due to this error: `[UNSUPPORTED_OVERWRITE.TABLE] Can't overwrite the target that is also being read from`"
745-
)
746-
747-
name_type = "varchar(max)" if ctx.dialect == "tsql" else "string"
748740
ctx.columns_to_types = {
749741
"id": "int",
750-
"name": name_type,
742+
"name": "string",
751743
"updated_at": "timestamp",
752744
"valid_from": "timestamp",
753745
"valid_to": "timestamp",

tests/core/engine_adapter/test_spark.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,38 @@ def test_replace_query_pandas(make_mocked_engine_adapter: t.Callable, mocker: Mo
129129
]
130130

131131

132+
def test_replace_query_self_ref(
133+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
134+
):
135+
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
136+
adapter.cursor.fetchone.return_value = (1,)
137+
138+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
139+
table_name = "db.table"
140+
temp_table_id = "abcdefgh"
141+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
142+
143+
mocker.patch(
144+
"sqlmesh.core.engine_adapter.spark.LogicalReplaceQueryMixin.table_exists",
145+
return_value=True,
146+
)
147+
mocker.patch(
148+
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.columns",
149+
return_value={"col": exp.DataType(this=exp.DataType.Type.INT)},
150+
)
151+
152+
adapter.replace_query(table_name, parse_one(f"SELECT col + 1 AS col FROM {table_name}"))
153+
154+
assert to_sql_calls(adapter) == [
155+
"DESCRIBE `db`.`table`",
156+
"CREATE SCHEMA IF NOT EXISTS `db`",
157+
f"CREATE TABLE IF NOT EXISTS `db`.`__temp_table_{temp_table_id}` AS SELECT `col` FROM `db`.`table`",
158+
"TRUNCATE TABLE `db`.`table`",
159+
f"INSERT INTO `db`.`table` (`col`) SELECT `col` + 1 AS `col` FROM `db`.`__temp_table_{temp_table_id}`",
160+
f"DROP TABLE IF EXISTS `db`.`__temp_table_{temp_table_id}`",
161+
]
162+
163+
132164
def test_create_table_table_options(make_mocked_engine_adapter: t.Callable):
133165
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
134166

0 commit comments

Comments
 (0)