Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
walk(base, base.__qualname__, is_metadata)

for k, v in obj.__dict__.items():
if k.startswith("__"):
# skip dunder methods bar __init__ as it might contain user defined logic with cross class references
if k.startswith("__") and k != "__init__":
continue

# Traverse methods in a class to find global references
Expand All @@ -362,10 +363,14 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
if callable(v):
# Walk the method if it's part of the object, else it's a global function and we just store it
if v.__qualname__.startswith(obj.__qualname__):
for k, v in func_globals(v).items():
walk(v, k, is_metadata)
else:
walk(v, v.__name__, is_metadata)
try:
for k, v in func_globals(v).items():
walk(v, k, is_metadata)
except (OSError, TypeError):
# __init__ may come from built-ins or wrapped callables
pass
else:
walk(v, k, is_metadata)
elif callable(obj):
for k, v in func_globals(obj).items():
walk(v, k, is_metadata)
Expand Down
42 changes: 42 additions & 0 deletions tests/utils/test_metaprogramming.py
Comment thread
georgesittas marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,48 @@ def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
assert serialized_env == expected_env


class ReferencedClass:
def __init__(self, value: int):
self.value = value

def get_value(self) -> int:
return self.value


class ClassThatReferencesAnother:
def __init__(self, x: int):
self.helper = ReferencedClass(x * 2)

def compute(self) -> int:
return self.helper.get_value() + 10


def function_using_class_with_reference(y: int) -> int:
obj = ClassThatReferencesAnother(y)
return obj.compute()


def test_serialize_env_with_class_referencing_another_class() -> None:
# firstly we can confirm that func_globals picks up the reference
init_globals = func_globals(ClassThatReferencesAnother.__init__)
assert "ReferencedClass" in init_globals

path = Path("tests/utils")
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}

# build ajd serialize environment for the function that uses the class
build_env(function_using_class_with_reference, env=env, name="test_func", path=path)
serialized_env = serialize_env(env, path=path)

# both classes should be in the serialized environment
assert "ClassThatReferencesAnother" in serialized_env
assert "ReferencedClass" in serialized_env

prepared_env = prepare_env(serialized_env)
result = eval("test_func(33)", prepared_env)
assert result == 76


def test_dict_sort_basic_types():
"""Test dict_sort with basic Python types."""
# Test basic types that should use standard repr
Expand Down
Loading