Skip to content

Commit b906682

Browse files
authored
fix!: ignore private attributes doing equality pydantic (#2035)
1 parent 4ee6821 commit b906682

File tree

3 files changed

+74
-10
lines changed

3 files changed

+74
-10
lines changed

sqlmesh/core/snapshot/definition.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from datetime import datetime, timedelta
77
from enum import IntEnum
8-
from functools import lru_cache
8+
from functools import cached_property, lru_cache
99

1010
from pydantic import Field
1111
from sqlglot import exp
@@ -258,7 +258,7 @@ def data_version(self) -> SnapshotDataVersion:
258258
def is_new_version(self) -> bool:
259259
raise NotImplementedError
260260

261-
@property
261+
@cached_property
262262
def fully_qualified_table(self) -> t.Optional[exp.Table]:
263263
raise NotImplementedError
264264

@@ -351,8 +351,6 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True):
351351
# This can be removed from this model once Pydantic 1 support is dropped (must remain in `Snapshot` though)
352352
base_table_name_override: t.Optional[str] = None
353353

354-
_fully_qualified_table: t.Optional[exp.Table] = None
355-
356354
def __lt__(self, other: SnapshotTableInfo) -> bool:
357355
return self.name < other.name
358356

@@ -368,11 +366,9 @@ def table_name(self, is_deployable: bool = True) -> str:
368366
def physical_schema(self) -> str:
369367
return self.physical_schema_
370368

371-
@property
369+
@cached_property
372370
def fully_qualified_table(self) -> exp.Table:
373-
if not self._fully_qualified_table:
374-
self._fully_qualified_table = exp.to_table(self.name)
375-
return self._fully_qualified_table
371+
return exp.to_table(self.name)
376372

377373
@property
378374
def table_info(self) -> SnapshotTableInfo:
@@ -1030,7 +1026,7 @@ def disable_restatement(self) -> bool:
10301026
"""Is restatement disabled for the node"""
10311027
return self.is_model and self.model.disable_restatement
10321028

1033-
@property
1029+
@cached_property
10341030
def fully_qualified_table(self) -> t.Optional[exp.Table]:
10351031
if not self.is_model:
10361032
return None

sqlmesh/utils/pydantic.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
T = t.TypeVar("T")
2525
DEFAULT_ARGS = {"exclude_none": True, "by_alias": True}
26-
PYDANTIC_MAJOR_VERSION = int(pydantic.__version__.split(".")[0])
26+
PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION = [int(p) for p in pydantic.__version__.split(".")][
27+
:2
28+
]
2729

2830

2931
if PYDANTIC_MAJOR_VERSION >= 2:
@@ -105,6 +107,8 @@ class Config:
105107
smart_union = True
106108
keep_untouched = (cached_property,)
107109

110+
_hash_func_mapping: t.ClassVar[t.Dict[t.Type[t.Any], t.Callable[[t.Any], int]]] = {}
111+
108112
def dict(
109113
self,
110114
**kwargs: t.Any,
@@ -191,6 +195,27 @@ def _fields(
191195
if predicate(field_info)
192196
}
193197

198+
def __eq__(self, other: t.Any) -> bool:
199+
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6):
200+
if isinstance(other, pydantic.BaseModel):
201+
return self.dict() == other.dict()
202+
else:
203+
return self.dict() == other
204+
return super().__eq__(other)
205+
206+
def __hash__(self) -> int:
207+
if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6):
208+
obj = {k: v for k, v in self.__dict__.items() if k in self.all_field_infos()}
209+
return hash(self.__class__) + hash(tuple(obj.values()))
210+
211+
from pydantic._internal._model_construction import ( # type: ignore
212+
make_hash_func,
213+
)
214+
215+
if self.__class__ not in PydanticModel._hash_func_mapping:
216+
PydanticModel._hash_func_mapping[self.__class__] = make_hash_func(self.__class__)
217+
return PydanticModel._hash_func_mapping[self.__class__](self)
218+
194219
def __str__(self) -> str:
195220
args = []
196221

tests/utils/test_pydantic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import cached_property
2+
13
from sqlmesh.utils.date import TimeLike, to_date, to_datetime
24
from sqlmesh.utils.pydantic import PYDANTIC_MAJOR_VERSION, PydanticModel
35

@@ -16,3 +18,44 @@ class Test(PydanticModel):
1618
else:
1719
assert deserialized_date.ds == target_ds
1820
assert deserialized_datetime.ds == "2022-01-01T00:00:00+00:00"
21+
22+
23+
def test_pydantic_2_equality() -> None:
24+
class TestModel(PydanticModel):
25+
name: str
26+
27+
@cached_property
28+
def private(self) -> str:
29+
return "should be ignored"
30+
31+
model_a = TestModel(name="a")
32+
model_a_duplicate = TestModel(name="a")
33+
assert model_a == model_a_duplicate
34+
model_b = TestModel(name="b")
35+
assert model_a != model_b
36+
37+
38+
def test_pydantic_2_hash() -> None:
39+
class TestModel(PydanticModel):
40+
name: str
41+
42+
@cached_property
43+
def private(self) -> str:
44+
return "should be ignored"
45+
46+
class TestModel2(PydanticModel):
47+
name: str
48+
field2: str = "test"
49+
50+
@cached_property
51+
def private(self) -> str:
52+
return "should be ignored"
53+
54+
model_a = TestModel(name="a")
55+
model_a_duplicate = TestModel(name="a")
56+
assert hash(model_a) == hash(model_a_duplicate)
57+
58+
model_2_a = TestModel2(name="a")
59+
model_2_b = TestModel2(name="a")
60+
assert hash(model_2_a) == hash(model_2_b)
61+
assert hash(model_a) != hash(model_2_a)

0 commit comments

Comments
 (0)