Skip to content

Commit e099cbb

Browse files
authored
Feat: Snowflake: Handle forward_only changes to 'clustered_by' (#3205)
1 parent 8955629 commit e099cbb

File tree

2 files changed

+225
-2
lines changed

2 files changed

+225
-2
lines changed

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import pandas as pd
88
from pandas.api.types import is_datetime64_any_dtype # type: ignore
9-
from sqlglot import exp
9+
from sqlglot import exp, parse_one
10+
from sqlglot.helper import seq_get
1011
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1112
from sqlglot.optimizer.qualify_columns import quote_identifiers
1213

@@ -32,6 +33,14 @@
3233
from sqlmesh.core.node import IntervalUnit
3334

3435

36+
class SnowflakeDataObject(DataObject):
37+
clustering_key: t.Optional[str] = None
38+
39+
@property
40+
def is_clustered(self) -> bool:
41+
return bool(self.clustering_key)
42+
43+
3544
@set_catalog(
3645
override_mapping={
3746
"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG,
@@ -325,6 +334,7 @@ def _get_data_objects(
325334
)
326335
.else_(exp.column("TABLE_TYPE"))
327336
.as_("type"),
337+
exp.column("CLUSTERING_KEY").as_("clustering_key"),
328338
)
329339
.from_(exp.table_("TABLES", db="INFORMATION_SCHEMA", catalog=catalog_name))
330340
.where(exp.column("TABLE_SCHEMA").eq(schema.db))
@@ -338,11 +348,12 @@ def _get_data_objects(
338348
if df.empty:
339349
return []
340350
return [
341-
DataObject(
351+
SnowflakeDataObject(
342352
catalog=row.catalog, # type: ignore
343353
schema=row.schema_name, # type: ignore
344354
name=row.name, # type: ignore
345355
type=DataObjectType.from_str(row.type), # type: ignore
356+
clustering_key=row.clustering_key, # type: ignore
346357
)
347358
for row in df.itertuples()
348359
]
@@ -422,3 +433,50 @@ def _create_column_comments(
422433
f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.",
423434
exc_info=True,
424435
)
436+
437+
def get_alter_expressions(
438+
self, current_table_name: TableName, target_table_name: TableName
439+
) -> t.List[exp.Alter]:
440+
schema_expressions = super().get_alter_expressions(current_table_name, target_table_name)
441+
additional_expressions = []
442+
443+
# check for a change in clustering
444+
current_table = exp.to_table(current_table_name)
445+
target_table = exp.to_table(target_table_name)
446+
447+
current_table_info = t.cast(
448+
SnowflakeDataObject,
449+
seq_get(self.get_data_objects(current_table.db, {current_table.name}), 0),
450+
)
451+
target_table_info = t.cast(
452+
SnowflakeDataObject,
453+
seq_get(self.get_data_objects(target_table.db, {target_table.name}), 0),
454+
)
455+
456+
if current_table_info and target_table_info:
457+
if target_table_info.is_clustered:
458+
if target_table_info.clustering_key and (
459+
current_table_info.clustering_key != target_table_info.clustering_key
460+
):
461+
# Note: If you create a table with eg `CLUSTER BY (c2, c1)` and read the info back from information_schema,
462+
# it gets returned as a string like "LINEAR(c2, c1)" which we need to parse back into a list of columns
463+
parsed_cluster_key = parse_one(
464+
target_table_info.clustering_key, dialect=self.dialect
465+
)
466+
additional_expressions.append(
467+
exp.Alter(
468+
this=current_table,
469+
kind="TABLE",
470+
actions=[exp.Cluster(expressions=parsed_cluster_key.expressions)],
471+
)
472+
)
473+
elif current_table_info.is_clustered:
474+
additional_expressions.append(
475+
exp.Alter(
476+
this=current_table,
477+
kind="TABLE",
478+
actions=[exp.Command(this="DROP", expression="CLUSTERING KEY")],
479+
)
480+
)
481+
482+
return schema_expressions + additional_expressions
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import typing as t
2+
import pytest
3+
from sqlglot import exp
4+
from sqlglot.optimizer.qualify_columns import quote_identifiers
5+
from sqlglot.helper import seq_get
6+
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
7+
from sqlmesh.core.engine_adapter.snowflake import SnowflakeDataObject
8+
import sqlmesh.core.dialect as d
9+
from sqlmesh.core.model import SqlModel, load_sql_based_model
10+
from sqlmesh.core.plan import Plan
11+
from tests.core.engine_adapter.integration import TestContext
12+
13+
pytestmark = [pytest.mark.engine, pytest.mark.remote, pytest.mark.snowflake]
14+
15+
16+
@pytest.fixture
17+
def mark_gateway() -> t.Tuple[str, str]:
18+
return "snowflake", "inttest_snowflake"
19+
20+
21+
@pytest.fixture
22+
def test_type() -> str:
23+
return "query"
24+
25+
26+
def test_get_alter_expressions_includes_clustering(
27+
ctx: TestContext, engine_adapter: SnowflakeEngineAdapter
28+
):
29+
ctx.init()
30+
31+
clustered_table = ctx.table("clustered_table")
32+
clustered_differently_table = ctx.table("clustered_differently_table")
33+
normal_table = ctx.table("normal_table")
34+
35+
engine_adapter.execute(f"CREATE TABLE {clustered_table} (c1 int, c2 timestamp) CLUSTER BY (c1)")
36+
engine_adapter.execute(
37+
f"CREATE TABLE {clustered_differently_table} (c1 int, c2 timestamp) CLUSTER BY (c1, to_date(c2))"
38+
)
39+
engine_adapter.execute(f"CREATE TABLE {normal_table} (c1 int, c2 timestamp)")
40+
41+
assert len(engine_adapter.get_alter_expressions(normal_table, normal_table)) == 0
42+
assert len(engine_adapter.get_alter_expressions(clustered_table, clustered_table)) == 0
43+
44+
# alter table drop clustered
45+
clustered_to_normal = engine_adapter.get_alter_expressions(clustered_table, normal_table)
46+
assert len(clustered_to_normal) == 1
47+
assert (
48+
clustered_to_normal[0].sql(dialect=ctx.dialect)
49+
== f"ALTER TABLE {clustered_table} DROP CLUSTERING KEY"
50+
)
51+
52+
# alter table add clustered
53+
normal_to_clustered = engine_adapter.get_alter_expressions(normal_table, clustered_table)
54+
assert len(normal_to_clustered) == 1
55+
assert (
56+
normal_to_clustered[0].sql(dialect=ctx.dialect)
57+
== f"ALTER TABLE {normal_table} CLUSTER BY (c1)"
58+
)
59+
60+
# alter table change clustering
61+
clustered_to_clustered_differently = engine_adapter.get_alter_expressions(
62+
clustered_table, clustered_differently_table
63+
)
64+
assert len(clustered_to_clustered_differently) == 1
65+
assert (
66+
clustered_to_clustered_differently[0].sql(dialect=ctx.dialect)
67+
== f"ALTER TABLE {clustered_table} CLUSTER BY (c1, TO_DATE(c2))"
68+
)
69+
70+
# alter table change clustering
71+
clustered_differently_to_clustered = engine_adapter.get_alter_expressions(
72+
clustered_differently_table, clustered_table
73+
)
74+
assert len(clustered_differently_to_clustered) == 1
75+
assert (
76+
clustered_differently_to_clustered[0].sql(dialect=ctx.dialect)
77+
== f"ALTER TABLE {clustered_differently_table} CLUSTER BY (c1)"
78+
)
79+
80+
81+
def test_adding_clustered_by_forward_only(ctx: TestContext, engine_adapter: SnowflakeEngineAdapter):
82+
model_name = ctx.table("TEST")
83+
84+
sqlmesh = ctx.create_context()
85+
86+
def _create_model(**kwargs: t.Any) -> SqlModel:
87+
extra_props = "\n".join([f"{k} {v}," for k, v in kwargs.items()])
88+
return t.cast(
89+
SqlModel,
90+
load_sql_based_model(
91+
d.parse(
92+
f"""
93+
MODEL (
94+
name {model_name},
95+
kind INCREMENTAL_BY_TIME_RANGE (
96+
time_column PARTITIONDATE
97+
),
98+
{extra_props}
99+
start '2021-01-01',
100+
cron '@daily',
101+
dialect 'snowflake'
102+
);
103+
104+
select 1 as ID, current_timestamp() as PARTITIONDATE
105+
"""
106+
)
107+
),
108+
)
109+
110+
def _get_data_object(table: exp.Table) -> SnowflakeDataObject:
111+
data_object = seq_get(engine_adapter.get_data_objects(table.db, {table.name}), 0)
112+
if not data_object:
113+
raise ValueError(f"Expected metadata for {table}")
114+
return t.cast(SnowflakeDataObject, data_object)
115+
116+
m1 = _create_model()
117+
m2 = _create_model(clustered_by="PARTITIONDATE")
118+
m3 = _create_model(clustered_by="(ID, PARTITIONDATE)")
119+
120+
# Initial plan - non-clustered table
121+
sqlmesh.upsert_model(m1)
122+
plan_1: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True)
123+
assert len(plan_1.snapshots) == 1
124+
target_table_1 = exp.to_table(list(plan_1.snapshots.values())[0].table_name())
125+
quote_identifiers(target_table_1)
126+
127+
assert not _get_data_object(target_table_1).is_clustered
128+
129+
# Next plan - add clustering key (non-clustered -> clustered)
130+
sqlmesh.upsert_model(m2)
131+
plan_2: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True)
132+
assert len(plan_2.snapshots) == 1
133+
target_table_2 = exp.to_table(list(plan_2.snapshots.values())[0].table_name())
134+
quote_identifiers(target_table_2)
135+
136+
assert target_table_1 == target_table_2
137+
138+
metadata = _get_data_object(target_table_1)
139+
assert metadata.is_clustered
140+
assert metadata.clustering_key == 'LINEAR("PARTITIONDATE")'
141+
142+
# Next plan - change clustering key (clustered -> clustered differently)
143+
sqlmesh.upsert_model(m3)
144+
plan_3: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True)
145+
assert len(plan_3.snapshots) == 1
146+
target_table_3 = exp.to_table(list(plan_3.snapshots.values())[0].table_name())
147+
quote_identifiers(target_table_3)
148+
149+
assert target_table_1 == target_table_3
150+
151+
metadata = _get_data_object(target_table_1)
152+
assert metadata.is_clustered
153+
assert metadata.clustering_key == 'LINEAR("ID", "PARTITIONDATE")'
154+
155+
# Next plan - drop clustering key
156+
sqlmesh.upsert_model(m1)
157+
plan_4: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True)
158+
assert len(plan_4.snapshots) == 1
159+
target_table_4 = exp.to_table(list(plan_4.snapshots.values())[0].table_name())
160+
quote_identifiers(target_table_4)
161+
162+
assert target_table_1 == target_table_4
163+
164+
metadata = _get_data_object(target_table_1)
165+
assert not metadata.is_clustered

0 commit comments

Comments
 (0)