Skip to content

Commit 4d8e34d

Browse files
authored
fix: Spark Create Table DDL WAP (#2033)
* fix: spark create table ddl wap * feedback
1 parent 9f0b0eb commit 4d8e34d

File tree

2 files changed

+104
-33
lines changed

2 files changed

+104
-33
lines changed

sqlmesh/core/engine_adapter/spark.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class SparkEngineAdapter(GetCurrentCatalogFromFunctionMixin, HiveMetastoreTableP
6161
# currently check for storage formats we say we don't support REPLACE TABLE
6262
SUPPORTS_REPLACE_TABLE = False
6363

64+
WAP_PREFIX = "wap_"
65+
BRANCH_PREFIX = "branch_"
66+
6467
@property
6568
def spark(self) -> PySparkSession:
6669
return self._connection_pool.get().spark
@@ -428,6 +431,24 @@ def _create_table(
428431
column_descriptions: t.Optional[t.Dict[str, str]] = None,
429432
**kwargs: t.Any,
430433
) -> None:
434+
table_name = (
435+
table_name_or_schema.this
436+
if isinstance(table_name_or_schema, exp.Schema)
437+
else exp.to_table(table_name_or_schema)
438+
)
439+
# Spark doesn't support creating a wap table DDL. Therefore we check if this is a wap table and if it is,
440+
# this is not a replace, and the table already exists then we can safely just return. Otherwise we let it error.
441+
if not expression and isinstance(table_name.this, exp.Dot):
442+
wap_id = table_name.this.parts[-1].name
443+
if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"):
444+
table_name.set("this", table_name.this.this)
445+
446+
wap_supported = (
447+
kwargs.get("storage_format") or ""
448+
).lower() == "iceberg" or self.wap_supported(table_name)
449+
do_dummy_insert = (
450+
False if not wap_supported or not exists else not self.table_exists(table_name)
451+
)
431452
super()._create_table(
432453
table_name_or_schema,
433454
expression,
@@ -443,9 +464,7 @@ def _create_table(
443464
if isinstance(table_name_or_schema, exp.Schema)
444465
else exp.to_table(table_name_or_schema)
445466
)
446-
if (kwargs.get("storage_format") or "").lower() == "iceberg" or self.wap_supported(
447-
table_name
448-
):
467+
if do_dummy_insert:
449468
# Performing a dummy insert to create a dummy snapshot for Iceberg tables
450469
# to workaround https://github.com/apache/iceberg/issues/8849.
451470
dummy_insert = exp.insert(exp.select("*").from_(table_name), table_name)
@@ -459,20 +478,20 @@ def wap_supported(self, table_name: TableName) -> bool:
459478
)
460479

461480
def wap_table_name(self, table_name: TableName, wap_id: str) -> str:
462-
branch_name = _wap_branch_name(wap_id)
481+
branch_name = self._wap_branch_name(wap_id)
463482
fqn = self._ensure_fqn(table_name)
464-
return exp.Dot.build([fqn, exp.to_identifier(f"branch_{branch_name}")]).sql(
483+
return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql(
465484
dialect=self.dialect
466485
)
467486

468487
def wap_prepare(self, table_name: TableName, wap_id: str) -> str:
469-
branch_name = _wap_branch_name(wap_id)
488+
branch_name = self._wap_branch_name(wap_id)
470489
fqn = self._ensure_fqn(table_name)
471490
self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}")
472491
return self.wap_table_name(table_name, wap_id)
473492

474493
def wap_publish(self, table_name: TableName, wap_id: str) -> None:
475-
branch_name = _wap_branch_name(wap_id)
494+
branch_name = self._wap_branch_name(wap_id)
476495
fqn = self._ensure_fqn(table_name)
477496

478497
get_snapshot_id_query = (
@@ -512,6 +531,6 @@ def _build_create_comment_column_exp(
512531

513532
return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT '{column_comment}'"
514533

515-
516-
def _wap_branch_name(wap_id: str) -> str:
517-
return f"wap_{wap_id}"
534+
@classmethod
535+
def _wap_branch_name(cls, wap_id: str) -> str:
536+
return f"{cls.WAP_PREFIX}{wap_id}"

tests/core/engine_adapter/test_spark.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def test_create_table_properties(make_mocked_engine_adapter: t.Callable):
3333
storage_format="parquet",
3434
)
3535

36-
adapter.cursor.execute.assert_called_once_with(
37-
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING PARQUET PARTITIONED BY (`colb`)"
38-
)
36+
assert to_sql_calls(adapter) == [
37+
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING PARQUET PARTITIONED BY (`colb`)",
38+
]
3939

4040
adapter.cursor.reset_mock()
4141
adapter.create_table(
@@ -45,9 +45,9 @@ def test_create_table_properties(make_mocked_engine_adapter: t.Callable):
4545
storage_format="parquet",
4646
)
4747

48-
adapter.cursor.execute.assert_called_once_with(
49-
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING PARQUET PARTITIONED BY (`cola`, `colb`)"
50-
)
48+
assert to_sql_calls(adapter) == [
49+
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING PARQUET PARTITIONED BY (`cola`, `colb`)",
50+
]
5151

5252
with pytest.raises(SQLMeshError):
5353
adapter.create_table(
@@ -256,7 +256,7 @@ def test_replace_query_self_ref_not_exists(
256256
temp_table_id = "abcdefgh"
257257
mocker.patch(
258258
"sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table",
259-
side_effect=lambda *args, **kwargs: make_temp_table_name(table_name, temp_table_id),
259+
return_value=make_temp_table_name(table_name, temp_table_id),
260260
)
261261

262262
mocker.patch(
@@ -329,19 +329,19 @@ def test_create_table_table_options(make_mocked_engine_adapter: t.Callable):
329329
},
330330
)
331331

332-
adapter.cursor.execute.assert_called_once_with(
333-
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) TBLPROPERTIES ('test.conf.key'='value')"
334-
)
332+
assert to_sql_calls(adapter) == [
333+
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) TBLPROPERTIES ('test.conf.key'='value')",
334+
]
335335

336336

337337
def test_create_state_table(make_mocked_engine_adapter: t.Callable):
338338
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
339339

340340
adapter.create_state_table("test_table", {"a": "int", "b": "int"}, primary_key=["a"])
341341

342-
adapter.cursor.execute.assert_called_once_with(
343-
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITIONED BY (`a`)"
344-
)
342+
assert to_sql_calls(adapter) == [
343+
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITIONED BY (`a`)",
344+
]
345345

346346

347347
test_primitive_params = [
@@ -783,7 +783,12 @@ def test_wap_publish(make_mocked_engine_adapter: t.Callable, mocker: MockerFixtu
783783
)
784784

785785

786-
def test_create_table_iceberg(make_mocked_engine_adapter: t.Callable):
786+
def test_create_table_iceberg(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable):
787+
mocker.patch(
788+
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists",
789+
return_value=False,
790+
)
791+
787792
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
788793

789794
columns_to_types = {
@@ -799,17 +804,13 @@ def test_create_table_iceberg(make_mocked_engine_adapter: t.Callable):
799804
storage_format="ICEBERG",
800805
)
801806

802-
adapter.cursor.execute.assert_has_calls(
803-
[
804-
call(
805-
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING ICEBERG PARTITIONED BY (`colb`)"
806-
),
807-
call("INSERT INTO `test_table` SELECT * FROM `test_table`"),
808-
]
809-
)
807+
assert to_sql_calls(adapter) == [
808+
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING ICEBERG PARTITIONED BY (`colb`)",
809+
"INSERT INTO `test_table` SELECT * FROM `test_table`",
810+
]
810811

811812

812-
def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
813+
def test_comments(make_mocked_engine_adapter: t.Callable):
813814
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
814815

815816
adapter._create_table_comment(
@@ -827,3 +828,54 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture)
827828
"COMMENT ON TABLE `test_table` IS 'test description'",
828829
"ALTER TABLE `test_table` ALTER COLUMN `a` COMMENT 'a description'",
829830
]
831+
832+
833+
def test_create_table_with_wap(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
834+
mocker.patch(
835+
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists",
836+
return_value=False,
837+
)
838+
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
839+
840+
adapter.create_table(
841+
"catalog.schema.table.branch_wap_12345",
842+
{"a": "int"},
843+
storage_format="ICEBERG",
844+
)
845+
846+
sql_calls = to_sql_calls(adapter)
847+
assert sql_calls == [
848+
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` int) USING ICEBERG",
849+
"INSERT INTO `catalog`.`schema`.`table` SELECT * FROM `catalog`.`schema`.`table`",
850+
]
851+
852+
853+
def test_replace_query_with_wap_self_reference(
854+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name
855+
):
856+
mocker.patch(
857+
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists",
858+
return_value=True,
859+
)
860+
mocker.patch(
861+
"sqlmesh.core.engine_adapter.base.random_id",
862+
return_value="abcdefgh",
863+
)
864+
865+
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
866+
867+
adapter.replace_query(
868+
"catalog.schema.table.branch_wap_12345",
869+
parse_one("SELECT 1 as a FROM catalog.schema.table.branch_wap_12345"),
870+
columns_to_types={"a": "int"},
871+
storage_format="ICEBERG",
872+
)
873+
874+
sql_calls = to_sql_calls(adapter)
875+
assert sql_calls == [
876+
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` int)",
877+
"CREATE SCHEMA IF NOT EXISTS `schema`",
878+
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh` USING ICEBERG AS SELECT `a` FROM `catalog`.`schema`.`table`.`branch_wap_12345`",
879+
"INSERT OVERWRITE TABLE `catalog`.`schema`.`table`.`branch_wap_12345` (`a`) SELECT 1 AS `a` FROM `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh`",
880+
"DROP TABLE IF EXISTS `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh`",
881+
]

0 commit comments

Comments
 (0)