diff --git a/mypy/semanal.py b/mypy/semanal.py index aa74122be255..7995f090b962 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -201,6 +201,7 @@ ClassPattern, MappingPattern, OrPattern, + Pattern, SequencePattern, SingletonPattern, StarredPattern, @@ -546,6 +547,12 @@ def __init__( # import foo.bar self.transitive_submodule_imports: dict[str, set[str]] = {} + # Stack of sets of names assigned in each enclosing class body. + # Used to determine whether a name in a class body should be looked up + # in enclosing function-local scopes or skipped (matching CPython's + # LOAD_NAME vs LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF behavior. + self.class_body_assigned_names: list[set[str]] = [] + # mypyc doesn't properly handle implementing an abstractproperty # with a regular attribute so we make them properties @property @@ -2100,6 +2107,10 @@ def is_core_builtin_class(self, defn: ClassDef) -> bool: def analyze_class_body_common(self, defn: ClassDef) -> None: """Parts of class body analysis that are common to all kinds of class defs.""" self.enter_class(defn.info) + # Pre-scan class body to find names assigned at class scope level. + # This must happen after enter_class (which pushes an empty set) so we + # can replace it with the real set. + self.class_body_assigned_names[-1] = collect_class_body_assigned_names(defn.defs.body) if any(b.self_type is not None for b in defn.info.mro): self.setup_self_type() defn.defs.accept(self) @@ -2225,6 +2236,7 @@ def enter_class(self, info: TypeInfo) -> None: self.loop_depth.append(0) self._type = info self.missing_names.append(set()) + self.class_body_assigned_names.append(set()) def leave_class(self) -> None: """Restore analyzer state.""" @@ -2234,6 +2246,7 @@ def leave_class(self) -> None: self.scope_stack.pop() self._type = self.type_stack.pop() self.missing_names.pop() + self.class_body_assigned_names.pop() def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: decorator.accept(self) @@ -6600,9 +6613,19 @@ def _lookup( v._fullname = self.qualified_name(name) return SymbolTableNode(MDEF, v) # 3. Local (function) scopes - for table in reversed(self.locals): - if table is not None and name in table: - return table[name] + # If we're in a class body and this name is assigned in the class body, + # skip enclosing function locals. This matches CPython's use of LOAD_NAME + # (class dict -> globals -> builtins) instead of + # LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF for such names. + skip_func_locals = ( + self.is_class_scope() + and self.class_body_assigned_names + and name in self.class_body_assigned_names[-1] + ) + if not skip_func_locals: + for table in reversed(self.locals): + if table is not None and name in table: + return table[name] # 4. Current file global scope if name in self.globals: @@ -8326,6 +8349,118 @@ def names_modified_in_lvalue(lvalue: Lvalue) -> list[NameExpr]: return [] +def _collect_lvalue_names(lvalue: Lvalue, names: set[str]) -> None: + """Collect simple names from an lvalue into a set.""" + if isinstance(lvalue, NameExpr): + names.add(lvalue.name) + elif isinstance(lvalue, StarExpr): + _collect_lvalue_names(lvalue.expr, names) + elif isinstance(lvalue, (ListExpr, TupleExpr)): + for item in lvalue.items: + _collect_lvalue_names(item, names) + + +def _collect_pattern_names(pattern: Pattern, names: set[str]) -> None: + """Collect names bound by a match pattern.""" + if isinstance(pattern, AsPattern): + if pattern.name is not None: + names.add(pattern.name.name) + if pattern.pattern is not None: + _collect_pattern_names(pattern.pattern, names) + elif isinstance(pattern, OrPattern): + for p in pattern.patterns: + _collect_pattern_names(p, names) + elif isinstance(pattern, SequencePattern): + for p in pattern.patterns: + _collect_pattern_names(p, names) + elif isinstance(pattern, StarredPattern): + if pattern.capture is not None: + names.add(pattern.capture.name) + elif isinstance(pattern, MappingPattern): + for p in pattern.values: + _collect_pattern_names(p, names) + if pattern.rest is not None: + names.add(pattern.rest.name) + elif isinstance(pattern, ClassPattern): + for p in pattern.positionals: + _collect_pattern_names(p, names) + for p in pattern.keyword_values: + _collect_pattern_names(p, names) + + +def collect_class_body_assigned_names(stmts: list[Statement]) -> set[str]: + """Pre-scan a class body to find all names that are assigned at class scope. + + This mirrors CPython's compile-time analysis that determines whether a name + in a class body is accessed via LOAD_NAME (class dict -> globals -> builtins) + or LOAD_CLASSDEREF or LOAD_FROM_DICT_OR_DEREF (class dict -> enclosing function cell). + + Names that are assigned anywhere in the class body (even inside if/for/while/try/with blocks) + use LOAD_NAME, so they should NOT be resolved from enclosing function locals. + + The scan is shallow: it recurses into control-flow blocks (if, for, while, + try, with, match) which don't create new scopes, but does NOT recurse into + function or class definitions which create their own scopes. + """ + names: set[str] = set() + for s in stmts: + if isinstance(s, AssignmentStmt): + for lvalue in s.lvalues: + _collect_lvalue_names(lvalue, names) + elif isinstance(s, OperatorAssignmentStmt): + _collect_lvalue_names(s.lvalue, names) + elif isinstance(s, (FuncDef, OverloadedFuncDef, Decorator)): + names.add(s.name) + elif isinstance(s, ClassDef): + names.add(s.name) + elif isinstance(s, Import): + for module_id, as_id in s.ids: + names.add(as_id if as_id else module_id.split(".")[0]) + elif isinstance(s, ImportFrom): + for name, as_name in s.names: + names.add(as_name if as_name else name) + elif isinstance(s, TypeAliasStmt): + names.add(s.name.name) + elif isinstance(s, DelStmt): + _collect_lvalue_names(s.expr, names) + elif isinstance(s, ForStmt): + _collect_lvalue_names(s.index, names) + names.update(collect_class_body_assigned_names(s.body.body)) + if s.else_body: + names.update(collect_class_body_assigned_names(s.else_body.body)) + elif isinstance(s, IfStmt): + for block in s.body: + names.update(collect_class_body_assigned_names(block.body)) + if s.else_body: + names.update(collect_class_body_assigned_names(s.else_body.body)) + elif isinstance(s, WhileStmt): + names.update(collect_class_body_assigned_names(s.body.body)) + if s.else_body: + names.update(collect_class_body_assigned_names(s.else_body.body)) + elif isinstance(s, TryStmt): + names.update(collect_class_body_assigned_names(s.body.body)) + for var in s.vars: + if var is not None: + names.add(var.name) + for handler in s.handlers: + names.update(collect_class_body_assigned_names(handler.body)) + if s.else_body: + names.update(collect_class_body_assigned_names(s.else_body.body)) + if s.finally_body: + names.update(collect_class_body_assigned_names(s.finally_body.body)) + elif isinstance(s, WithStmt): + for target in s.target: + if target is not None: + _collect_lvalue_names(target, names) + names.update(collect_class_body_assigned_names(s.body.body)) + elif isinstance(s, MatchStmt): + for pattern in s.patterns: + _collect_pattern_names(pattern, names) + for body_block in s.bodies: + names.update(collect_class_body_assigned_names(body_block.body)) + return names + + def is_same_var_from_getattr(n1: SymbolNode | None, n2: SymbolNode | None) -> bool: """Do n1 and n2 refer to the same Var derived from module-level __getattr__?""" return ( diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 5a66eff2bd3b..019578fcc4cf 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -9412,3 +9412,153 @@ from typ import NT def f() -> NT: return NT(x='') [builtins fixtures/tuple.pyi] + +-- Class body scope: enclosing function local visibility +-- ----------------------------------------------------- +-- These tests verify that names assigned in a class body nested inside a +-- function are NOT resolved from the enclosing function's local scope, +-- matching CPython's LOAD_NAME semantics for class bodies (class dict -> +-- globals -> builtins, skipping enclosing function locals). + +[case testClassBodySkipsEnclosingFunctionLocalForAssignedName] +# A name assigned in the class body should not resolve from the enclosing +# function's locals; it should fall through to globals/builtins. +x = 1 +y = 2 +def func() -> None: + x = "xlocal" + y = "ylocal" + class C: + reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(y) # N: Revealed type is "builtins.int" + y = 1 +func() + +[case testClassBodyAssignedNameNoGlobalGivesError] +# If the name is assigned in the class body and there is no global definition, +# looking it up before the assignment should be an error. +def func() -> None: + y = "ylocal" + class C: + z = y # E: Name "y" is not defined + y = 1 +func() + +[case testClassBodyUnassignedNameSeesEnclosingFunctionLocal] +# A name that is NOT assigned anywhere in the class body can still be +# resolved from the enclosing function's local scope. +def func() -> None: + x = "hello" + class C: + reveal_type(x) # N: Revealed type is "builtins.str" +func() + +[case testClassBodyMethodClosesOverEnclosingFunctionLocal] +# Methods inside the class should still close over the enclosing function +# locals, even for names that are assigned in the class body. +x = 1 +y = 2 +def func() -> None: + x = "xlocal" + y = "ylocal" + class C: + y = 1 + def method(self) -> None: + reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(y) # N: Revealed type is "builtins.str" +func() + +[case testClassBodyForwardRefClassAttrFallsToGlobal] +# When a class attribute is assigned textually *after* its use, +# is_active_symbol_in_class_body rejects it, and lookup should fall +# through to the global — not the enclosing function local. +x: int = 10 +def func() -> None: + x = "hello" + class C: + y = x # should see global int, not func's str + x = 42 + reveal_type(C.y) # N: Revealed type is "builtins.int" +func() + +[case testClassBodyNestedClassInFunctionSkipsFunctionLocal] +# A doubly-nested class should also skip the enclosing function's locals. +y: int = 1 +def func() -> None: + y = "ylocal" + class Outer: + class Inner: + reveal_type(y) # N: Revealed type is "builtins.int" + y = 42 +func() + +[case testClassBodyMethodInNestedClassClosesOverFunctionLocal] +# A method in a nested class should still form a closure over the +# enclosing function's locals. +def func() -> None: + y = "ylocal" + class Outer: + class Inner: + y = 42 + def method(self) -> None: + reveal_type(y) # N: Revealed type is "builtins.str" +func() + +[case testClassBodyGenericClassInFunctionStillWorks] +# TypeVar defined in the enclosing function should still be visible +# in the class body (TypeVarExpr is not a plain Var). +from typing import TypeVar, Generic +def func() -> None: + T = TypeVar('T') + class MyGeneric(Generic[T]): + def get(self, x: T) -> T: + return x + reveal_type(MyGeneric[int]().get(1)) # N: Revealed type is "builtins.int" +func() + +[case testClassBodyAssignmentInControlFlow] +# Names assigned inside if/for/while/try/with blocks in the class body +# should still be treated as class-local. +x: int = 10 +def func() -> None: + x = "xlocal" + class C: + reveal_type(x) # N: Revealed type is "builtins.int" + if True: + x = 42 +func() +[builtins fixtures/bool.pyi] + +[case testClassBodyForLoopVariable] +# A for-loop index variable at class scope should be local to the class. +x: int = 10 +def func() -> None: + x = "xlocal" + class C: + reveal_type(x) # N: Revealed type is "builtins.int" + for x in [1, 2, 3]: + pass +func() +[builtins fixtures/list.pyi] + +[case testClassBodyImportedName] +# An imported name in the class body should be treated as a class-local binding. +import typing +x: int = 10 +def func() -> None: + x = "xlocal" + class C: + reveal_type(x) # N: Revealed type is "builtins.int" + import typing as x # type: ignore + # x is now assigned in the class body; use before assignment sees global int +func() + +[case testClassBodyComprehensionSeesEnclosingVars] +# Comprehensions inside a class body can see enclosing function locals +def func() -> None: + items = [1, 2, 3] + class C: + result = [i for i in items] + reveal_type(result) # N: Revealed type is "builtins.list[builtins.int]" +func() +[builtins fixtures/list.pyi]