Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 138 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
Expand Down Expand Up @@ -546,6 +547,12 @@ def __init__(
# import foo.bar
self.transitive_submodule_imports: dict[str, set[str]] = {}

# Stack of sets of names assigned in each enclosing class body.
# Used to determine whether a name in a class body should be looked up
# in enclosing function-local scopes or skipped (matching CPython's
# LOAD_NAME vs LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF behavior.
self.class_body_assigned_names: list[set[str]] = []

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -2100,6 +2107,10 @@ def is_core_builtin_class(self, defn: ClassDef) -> bool:
def analyze_class_body_common(self, defn: ClassDef) -> None:
"""Parts of class body analysis that are common to all kinds of class defs."""
self.enter_class(defn.info)
# Pre-scan class body to find names assigned at class scope level.
# This must happen after enter_class (which pushes an empty set) so we
# can replace it with the real set.
self.class_body_assigned_names[-1] = collect_class_body_assigned_names(defn.defs.body)
if any(b.self_type is not None for b in defn.info.mro):
self.setup_self_type()
defn.defs.accept(self)
Expand Down Expand Up @@ -2225,6 +2236,7 @@ def enter_class(self, info: TypeInfo) -> None:
self.loop_depth.append(0)
self._type = info
self.missing_names.append(set())
self.class_body_assigned_names.append(set())

def leave_class(self) -> None:
"""Restore analyzer state."""
Expand All @@ -2234,6 +2246,7 @@ def leave_class(self) -> None:
self.scope_stack.pop()
self._type = self.type_stack.pop()
self.missing_names.pop()
self.class_body_assigned_names.pop()

def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None:
decorator.accept(self)
Expand Down Expand Up @@ -6600,9 +6613,19 @@ def _lookup(
v._fullname = self.qualified_name(name)
return SymbolTableNode(MDEF, v)
# 3. Local (function) scopes
for table in reversed(self.locals):
if table is not None and name in table:
return table[name]
# If we're in a class body and this name is assigned in the class body,
# skip enclosing function locals. This matches CPython's use of LOAD_NAME
# (class dict -> globals -> builtins) instead of
# LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF for such names.
skip_func_locals = (
self.is_class_scope()
and self.class_body_assigned_names
and name in self.class_body_assigned_names[-1]
)
if not skip_func_locals:
for table in reversed(self.locals):
if table is not None and name in table:
return table[name]

# 4. Current file global scope
if name in self.globals:
Expand Down Expand Up @@ -8326,6 +8349,118 @@ def names_modified_in_lvalue(lvalue: Lvalue) -> list[NameExpr]:
return []


def _collect_lvalue_names(lvalue: Lvalue, names: set[str]) -> None:
"""Collect simple names from an lvalue into a set."""
if isinstance(lvalue, NameExpr):
names.add(lvalue.name)
elif isinstance(lvalue, StarExpr):
_collect_lvalue_names(lvalue.expr, names)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
_collect_lvalue_names(item, names)


def _collect_pattern_names(pattern: Pattern, names: set[str]) -> None:
"""Collect names bound by a match pattern."""
if isinstance(pattern, AsPattern):
if pattern.name is not None:
names.add(pattern.name.name)
if pattern.pattern is not None:
_collect_pattern_names(pattern.pattern, names)
elif isinstance(pattern, OrPattern):
for p in pattern.patterns:
_collect_pattern_names(p, names)
elif isinstance(pattern, SequencePattern):
for p in pattern.patterns:
_collect_pattern_names(p, names)
elif isinstance(pattern, StarredPattern):
if pattern.capture is not None:
names.add(pattern.capture.name)
elif isinstance(pattern, MappingPattern):
for p in pattern.values:
_collect_pattern_names(p, names)
if pattern.rest is not None:
names.add(pattern.rest.name)
elif isinstance(pattern, ClassPattern):
for p in pattern.positionals:
_collect_pattern_names(p, names)
for p in pattern.keyword_values:
_collect_pattern_names(p, names)


def collect_class_body_assigned_names(stmts: list[Statement]) -> set[str]:
"""Pre-scan a class body to find all names that are assigned at class scope.

This mirrors CPython's compile-time analysis that determines whether a name
in a class body is accessed via LOAD_NAME (class dict -> globals -> builtins)
or LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF (class dict -> enclosing function cell).

Names that are assigned anywhere in the class body (even inside if/for/while/try/with blocks)
use LOAD_NAME, so they should NOT be resolved from enclosing function locals.

The scan is shallow: it recurses into control-flow blocks (if, for, while,
try, with, match) which don't create new scopes, but does NOT recurse into
function or class definitions which create their own scopes.
"""
names: set[str] = set()
for s in stmts:
if isinstance(s, AssignmentStmt):
for lvalue in s.lvalues:
_collect_lvalue_names(lvalue, names)
elif isinstance(s, OperatorAssignmentStmt):
_collect_lvalue_names(s.lvalue, names)
elif isinstance(s, (FuncDef, OverloadedFuncDef, Decorator)):
names.add(s.name)
elif isinstance(s, ClassDef):
names.add(s.name)
elif isinstance(s, Import):
for module_id, as_id in s.ids:
names.add(as_id if as_id else module_id.split(".")[0])
elif isinstance(s, ImportFrom):
for name, as_name in s.names:
names.add(as_name if as_name else name)
elif isinstance(s, TypeAliasStmt):
names.add(s.name.name)
elif isinstance(s, DelStmt):
_collect_lvalue_names(s.expr, names)
elif isinstance(s, ForStmt):
_collect_lvalue_names(s.index, names)
names.update(collect_class_body_assigned_names(s.body.body))
if s.else_body:
names.update(collect_class_body_assigned_names(s.else_body.body))
elif isinstance(s, IfStmt):
for block in s.body:
names.update(collect_class_body_assigned_names(block.body))
if s.else_body:
names.update(collect_class_body_assigned_names(s.else_body.body))
elif isinstance(s, WhileStmt):
names.update(collect_class_body_assigned_names(s.body.body))
if s.else_body:
names.update(collect_class_body_assigned_names(s.else_body.body))
elif isinstance(s, TryStmt):
names.update(collect_class_body_assigned_names(s.body.body))
for var in s.vars:
if var is not None:
names.add(var.name)
for handler in s.handlers:
names.update(collect_class_body_assigned_names(handler.body))
if s.else_body:
names.update(collect_class_body_assigned_names(s.else_body.body))
if s.finally_body:
names.update(collect_class_body_assigned_names(s.finally_body.body))
elif isinstance(s, WithStmt):
for target in s.target:
if target is not None:
_collect_lvalue_names(target, names)
names.update(collect_class_body_assigned_names(s.body.body))
elif isinstance(s, MatchStmt):
for pattern in s.patterns:
_collect_pattern_names(pattern, names)
for body_block in s.bodies:
names.update(collect_class_body_assigned_names(body_block.body))
return names


def is_same_var_from_getattr(n1: SymbolNode | None, n2: SymbolNode | None) -> bool:
"""Do n1 and n2 refer to the same Var derived from module-level __getattr__?"""
return (
Expand Down
150 changes: 150 additions & 0 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -9412,3 +9412,153 @@ from typ import NT
def f() -> NT:
return NT(x='')
[builtins fixtures/tuple.pyi]

-- Class body scope: enclosing function local visibility
-- -----------------------------------------------------
-- These tests verify that names assigned in a class body nested inside a
-- function are NOT resolved from the enclosing function's local scope,
-- matching CPython's LOAD_NAME semantics for class bodies (class dict ->
-- globals -> builtins, skipping enclosing function locals).

[case testClassBodySkipsEnclosingFunctionLocalForAssignedName]
# A name assigned in the class body should not resolve from the enclosing
# function's locals; it should fall through to globals/builtins.
x = 1
y = 2
def func() -> None:
x = "xlocal"
y = "ylocal"
class C:
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(y) # N: Revealed type is "builtins.int"
y = 1
func()

[case testClassBodyAssignedNameNoGlobalGivesError]
# If the name is assigned in the class body and there is no global definition,
# looking it up before the assignment should be an error.
def func() -> None:
y = "ylocal"
class C:
z = y # E: Name "y" is not defined
y = 1
func()

[case testClassBodyUnassignedNameSeesEnclosingFunctionLocal]
# A name that is NOT assigned anywhere in the class body can still be
# resolved from the enclosing function's local scope.
def func() -> None:
x = "hello"
class C:
reveal_type(x) # N: Revealed type is "builtins.str"
func()

[case testClassBodyMethodClosesOverEnclosingFunctionLocal]
# Methods inside the class should still close over the enclosing function
# locals, even for names that are assigned in the class body.
x = 1
y = 2
def func() -> None:
x = "xlocal"
y = "ylocal"
class C:
y = 1
def method(self) -> None:
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(y) # N: Revealed type is "builtins.str"
func()

[case testClassBodyForwardRefClassAttrFallsToGlobal]
# When a class attribute is assigned textually *after* its use,
# is_active_symbol_in_class_body rejects it, and lookup should fall
# through to the global — not the enclosing function local.
x: int = 10
def func() -> None:
x = "hello"
class C:
y = x # should see global int, not func's str
x = 42
reveal_type(C.y) # N: Revealed type is "builtins.int"
func()

[case testClassBodyNestedClassInFunctionSkipsFunctionLocal]
# A doubly-nested class should also skip the enclosing function's locals.
y: int = 1
def func() -> None:
y = "ylocal"
class Outer:
class Inner:
reveal_type(y) # N: Revealed type is "builtins.int"
y = 42
func()

[case testClassBodyMethodInNestedClassClosesOverFunctionLocal]
# A method in a nested class should still form a closure over the
# enclosing function's locals.
def func() -> None:
y = "ylocal"
class Outer:
class Inner:
y = 42
def method(self) -> None:
reveal_type(y) # N: Revealed type is "builtins.str"
func()

[case testClassBodyGenericClassInFunctionStillWorks]
# TypeVar defined in the enclosing function should still be visible
# in the class body (TypeVarExpr is not a plain Var).
from typing import TypeVar, Generic
def func() -> None:
T = TypeVar('T')
class MyGeneric(Generic[T]):
def get(self, x: T) -> T:
return x
reveal_type(MyGeneric[int]().get(1)) # N: Revealed type is "builtins.int"
func()

[case testClassBodyAssignmentInControlFlow]
# Names assigned inside if/for/while/try/with blocks in the class body
# should still be treated as class-local.
x: int = 10
def func() -> None:
x = "xlocal"
class C:
reveal_type(x) # N: Revealed type is "builtins.int"
if True:
x = 42
func()
[builtins fixtures/bool.pyi]

[case testClassBodyForLoopVariable]
# A for-loop index variable at class scope should be local to the class.
x: int = 10
def func() -> None:
x = "xlocal"
class C:
reveal_type(x) # N: Revealed type is "builtins.int"
for x in [1, 2, 3]:
pass
func()
[builtins fixtures/list.pyi]

[case testClassBodyImportedName]
# An imported name in the class body should be treated as a class-local binding.
import typing
x: int = 10
def func() -> None:
x = "xlocal"
class C:
reveal_type(x) # N: Revealed type is "builtins.int"
import typing as x # type: ignore
# x is now assigned in the class body; use before assignment sees global int
func()

[case testClassBodyComprehensionSeesEnclosingVars]
# Comprehensions inside a class body can see enclosing function locals
def func() -> None:
items = [1, 2, 3]
class C:
result = [i for i in items]
reveal_type(result) # N: Revealed type is "builtins.list[builtins.int]"
func()
[builtins fixtures/list.pyi]
Loading