Skip to content

Commit 074a928

Browse files
authored
Fix: get_snapshots with empty list and dedup (#1573)
1 parent 9593f4b commit 074a928

6 files changed

Lines changed: 38 additions & 23 deletions

File tree

sqlmesh/core/state_sync/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def delete_expired_environments(self) -> t.List[Environment]:
323323

324324
@abc.abstractmethod
325325
def unpause_snapshots(
326-
self, snapshots: t.Iterable[SnapshotInfoLike], unpaused_dt: TimeLike
326+
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
327327
) -> None:
328328
"""Unpauses target snapshots.
329329

sqlmesh/core/state_sync/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_snapshots(
5050
snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]],
5151
hydrate_seeds: bool = False,
5252
) -> t.Dict[SnapshotId, Snapshot]:
53-
if not snapshot_ids:
53+
if snapshot_ids is None:
5454
return self.state_sync.get_snapshots(snapshot_ids, hydrate_seeds)
5555

5656
existing = {}
@@ -144,7 +144,7 @@ def remove_interval(
144144
self.state_sync.remove_interval(snapshot_intervals, execution_time, remove_shared_versions)
145145

146146
def unpause_snapshots(
147-
self, snapshots: t.Iterable[SnapshotInfoLike], unpaused_dt: TimeLike
147+
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
148148
) -> None:
149149
self.snapshot_cache.clear()
150150
self.state_sync.unpause_snapshots(snapshots, unpaused_dt)

sqlmesh/core/state_sync/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _is_snapshot_used(snapshot: Snapshot) -> bool:
224224

225225
@transactional()
226226
def unpause_snapshots(
227-
self, snapshots: t.Iterable[SnapshotInfoLike], unpaused_dt: TimeLike
227+
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
228228
) -> None:
229229
current_ts = now()
230230

@@ -335,7 +335,7 @@ def _get_snapshots(
335335
@abc.abstractmethod
336336
def _get_snapshots_with_same_version(
337337
self,
338-
snapshots: t.Iterable[SnapshotNameVersionLike],
338+
snapshots: t.Collection[SnapshotNameVersionLike],
339339
lock_for_update: bool = False,
340340
) -> t.List[Snapshot]:
341341
"""Fetches all snapshots that share the same version as the snapshots.

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,11 @@ def _get_snapshots(
363363
query = (
364364
exp.select(exp.column("snapshot", table="snapshots"))
365365
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
366-
.where(self._snapshot_id_filter(snapshot_ids, "snapshots") if snapshot_ids else None)
366+
.where(
367+
None
368+
if snapshot_ids is None
369+
else self._snapshot_id_filter(snapshot_ids, "snapshots")
370+
)
367371
)
368372
if hydrate_seeds:
369373
query = query.select(exp.column("content", table="seeds")).join(
@@ -411,7 +415,7 @@ def _get_snapshots(
411415

412416
def _get_snapshots_with_same_version(
413417
self,
414-
snapshots: t.Iterable[SnapshotNameVersionLike],
418+
snapshots: t.Collection[SnapshotNameVersionLike],
415419
lock_for_update: bool = False,
416420
) -> t.List[Snapshot]:
417421
"""Fetches all snapshots that share the same version as the snapshots.
@@ -906,7 +910,11 @@ def map_data_versions(
906910
def _snapshot_id_filter(
907911
self, snapshot_ids: t.Iterable[SnapshotIdLike], alias: t.Optional[str] = None
908912
) -> t.Union[exp.In, exp.Boolean, exp.Condition]:
909-
if not snapshot_ids:
913+
name_identifiers = {
914+
(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids
915+
}
916+
917+
if not name_identifiers:
910918
return exp.false()
911919
elif self.engine_adapter.SUPPORTS_TUPLE_IN:
912920
return t.cast(
@@ -917,22 +925,24 @@ def _snapshot_id_filter(
917925
exp.column("identifier", table=alias),
918926
)
919927
),
920-
).isin(*[(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids])
928+
).isin(*name_identifiers)
921929
else:
922930
return exp.or_(
923931
*[
924932
exp.and_(
925-
exp.column("name", table=alias).eq(snapshot_id.name),
926-
exp.column("identifier", table=alias).eq(snapshot_id.identifier),
933+
exp.column("name", table=alias).eq(name),
934+
exp.column("identifier", table=alias).eq(identifier),
927935
)
928-
for snapshot_id in snapshot_ids
936+
for name, identifier in name_identifiers
929937
]
930938
)
931939

932940
def _snapshot_name_version_filter(
933941
self, snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], alias: str = "snapshots"
934942
) -> t.Union[exp.In, exp.Boolean, exp.Condition]:
935-
if not snapshot_name_versions:
943+
name_versions = {(s.name, s.version) for s in snapshot_name_versions}
944+
945+
if not name_versions:
936946
return exp.false()
937947
elif self.engine_adapter.SUPPORTS_TUPLE_IN:
938948
return t.cast(
@@ -943,20 +953,15 @@ def _snapshot_name_version_filter(
943953
exp.column("version", table=alias),
944954
)
945955
),
946-
).isin(
947-
*[
948-
(snapshot_name_version.name, snapshot_name_version.version)
949-
for snapshot_name_version in snapshot_name_versions
950-
]
951-
)
956+
).isin(*name_versions)
952957
else:
953958
return exp.or_(
954959
*[
955960
exp.and_(
956-
exp.column("name", table=alias).eq(snapshot_name_version.name),
957-
exp.column("version", table=alias).eq(snapshot_name_version.version),
961+
exp.column("name", table=alias).eq(name),
962+
exp.column("version", table=alias).eq(version),
958963
)
959-
for snapshot_name_version in snapshot_name_versions
964+
for name, version in name_versions
960965
]
961966
)
962967

sqlmesh/schedulers/airflow/state_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def delete_expired_environments(self) -> t.List[Environment]:
263263
)
264264

265265
def unpause_snapshots(
266-
self, snapshots: t.Iterable[SnapshotInfoLike], unpaused_dt: TimeLike
266+
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
267267
) -> None:
268268
"""Unpauses target snapshots.
269269

tests/core/test_state_sync.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,3 +1318,13 @@ def test_max_interval_end_for_environment(
13181318
assert state_sync.max_interval_end_for_environment(environment_name) == to_timestamp(
13191319
"2023-01-03"
13201320
)
1321+
1322+
1323+
def test_get_snapshots(mocker):
1324+
mock = mocker.MagicMock()
1325+
cache = CachingStateSync(mock)
1326+
cache.get_snapshots([])
1327+
mock.get_snapshots.assert_not_called()
1328+
1329+
cache.get_snapshots(None)
1330+
mock.get_snapshots.assert_called()

0 commit comments

Comments
 (0)