diff --git a/mypy/checker.py b/mypy/checker.py index 6a0e8f3718d3..d89b45238d72 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -213,6 +213,7 @@ FunctionLike, Instance, LiteralType, + LiteralValue, NoneType, Overloaded, PartialType, @@ -911,23 +912,29 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: impl_type = self.extract_callable_type(inner_type, defn.impl) is_descriptor_get = defn.info and defn.name == "__get__" + + # Pre-extract callable types and literal fingerprints for each overload + # item, skipping items whose signature could not be extracted. + # Each entry is (original 0-based index, Decorator, sig, fingerprint). + prepared_items: list[tuple[int, Decorator, CallableType, LiteralFingerprint]] = [] for i, item in enumerate(defn.items): assert isinstance(item, Decorator) - sig1 = self.extract_callable_type(item.var.type, item) - if sig1 is None: - continue + sig = self.extract_callable_type(item.var.type, item) + if sig is not None: + prepared_items.append((i, item, sig, build_literal_fingerprint(sig))) - for j, item2 in enumerate(defn.items[i + 1 :]): - assert isinstance(item2, Decorator) - sig2 = self.extract_callable_type(item2.var.type, item2) - if sig2 is None: + for prepared_items_i, (i, item, sig1, literals_fingerprint1) in enumerate(prepared_items): + for j, item2, sig2, literals_fingerprint2 in prepared_items[prepared_items_i + 1 :]: + if not are_argument_counts_overlapping(sig1, sig2): continue - if not are_argument_counts_overlapping(sig1, sig2): + # If there is any argument position where both overloads + # carry a LiteralType with different values they are disjoint. + if literal_args_are_disjoint(literals_fingerprint1, literals_fingerprint2): continue if overload_can_never_match(sig1, sig2): - self.msg.overloaded_signature_will_never_match(i + 1, i + j + 2, item2.func) + self.msg.overloaded_signature_will_never_match(i + 1, j + 1, item2.func) elif not is_descriptor_get: # Note: we force mypy to check overload signatures in strict-optional mode # so we don't incorrectly report errors when a user tries typing an overload @@ -947,14 +954,14 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: with state.strict_optional_set(True): if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars): flip_note = ( - j == 0 + j == i + 1 and not is_unsafe_overlapping_overload_signatures( sig2, sig1, type_vars ) and not overload_can_never_match(sig2, sig1) ) self.msg.overloaded_signatures_overlap( - i + 1, i + j + 2, flip_note, item.func + i + 1, j + 1, flip_note, item.func ) if impl_type is not None: @@ -8958,6 +8965,69 @@ def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) - return typ.copy_modified(variables=list(typ.variables) + class_type_vars) +# Fingerprint type for literal-disjointedness checks: maps argument index to +# the set of (Python type of value, value) pairs present at that position. +# Using type(value) as part of the key means Literal[1] (int) and +# Literal[True] (bool) are kept distinct even though 1 == True in Python. +# A union such as Literal["a", "b"] or Literal["a"] | Literal["b"] produces +# a frozenset of two entries; a plain Literal["a"] produces a length 1 set. +LiteralFingerprint = dict[int, frozenset[tuple[type, LiteralValue]]] + + +def build_literal_fingerprint(sig: CallableType) -> LiteralFingerprint: + """Build a LiteralFingerprint for one overload signature. + + Each *required* argument position (ARG_POS or ARG_NAMED) that carries only + LiteralType values (including unions such as ``Literal["a", "b"]``) is + recorded as a frozenset of ``(type(value), value)`` pairs. Positions with + any non-literal type, or with an optional argument kind (ARG_OPT, + ARG_NAMED_OPT, ARG_STAR, ARG_STAR2), are omitted. + + Optional arguments are excluded because a caller can always omit them, + meaning two overloads that differ only in an optional Literal argument still + overlap (a call that omits the argument matches both). Only required + arguments can prove that no single call can match both overloads. + """ + fingerprint: LiteralFingerprint = {} + for idx, (arg_kind, arg_type) in enumerate(zip(sig.arg_kinds, sig.arg_types)): + if not arg_kind.is_required(): + continue + proper = get_proper_type(arg_type) + if isinstance(proper, LiteralType): + fingerprint[idx] = frozenset([(type(proper.value), proper.value)]) + elif isinstance(proper, UnionType): + # Literal["a", "b"] and Literal["a"] | Literal["b"] are both + # represented as a UnionType of LiteralTypes. Collect all the + # literal values; if any member is not a LiteralType the whole + # position is skipped (a non-literal in the union makes it too + # broad to prove disjointedness). + vals: set[tuple[type, LiteralValue]] = set() + for member in proper.items: + m = get_proper_type(member) + if isinstance(m, LiteralType): + vals.add((type(m.value), m.value)) + else: + vals = set() + break + if vals: + fingerprint[idx] = frozenset(vals) + return fingerprint + + +def literal_args_are_disjoint(fp1: LiteralFingerprint, fp2: LiteralFingerprint) -> bool: + """Return True if two overloads are provably disjoint via a Literal argument. + + If there is any argument position where both carry only LiteralType values + and those value sets are disjoint, no single call can match both overloads + and the pairwise overlap check can be skipped entirely. + """ + for idx, vals1 in fp1.items(): + vals2 = fp2.get(idx) + if vals2 is not None and vals1.isdisjoint(vals2): + return True + return False + + def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: """Check if the 'other' method can never be matched due to 'signature'. diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 8616a1b6d165..00957bd21748 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6865,3 +6865,123 @@ if isinstance(headers, dict): reveal_type(headers) # N: Revealed type is "__main__.Headers | typing.Iterable[tuple[builtins.bytes, builtins.bytes]]" [builtins fixtures/isinstancelist.pyi] + +-- Tests for literal-disjointness fast path in check_overlapping_overloads + +[case testOverloadLiteralDistinctStringsNoError] +# Overloads with distinct Literal[str] arguments are provably disjoint; no +# overlap or never-match errors should be reported. +from typing import overload, Literal +@overload +def f(x: Literal["a"]) -> int: ... +@overload +def f(x: Literal["b"]) -> str: ... +@overload +def f(x: Literal["c"]) -> float: ... +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralDuplicateStillErrors] +# Two overloads sharing the same Literal value should still trigger an error. +# Signature 1 covers all inputs of type Literal["a"], so signature 2 is unreachable. +from typing import overload, Literal +@overload +def f(x: Literal["a"]) -> int: ... +@overload +def f(x: Literal["a"]) -> str: ... # E: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralWithBroadCatchAll] +# Distinct Literal overloads followed by a broad catch-all should produce no +# overlap errors. The broad type must come last (correct ordering). +from typing import overload, Literal, Any +@overload +def f(x: Literal["a"]) -> int: ... +@overload +def f(x: Literal["b"]) -> str: ... +@overload +def f(x: str) -> Any: ... +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralBroadBeforeLiteralErrors] +# A broad type before a specific Literal means the Literal can never match. +from typing import overload, Literal +@overload +def f(x: str) -> int: ... +@overload +def f(x: Literal["a"]) -> str: ... # E: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralImplErrorsNotSuppressed] +# The literal fast path must not suppress implementation-body consistency errors. +# Use bytes as the impl return type — incompatible with both int and str. +from typing import overload, Literal +@overload +def f(x: Literal["a"]) -> int: ... +@overload +def f(x: Literal["b"]) -> str: ... +def f(x: str) -> bytes: # E: Overloaded function implementation cannot produce return type of signature 1 # E: Overloaded function implementation cannot produce return type of signature 2 + return b"" +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralUnionDistinctNoError] +# Literal unions with disjoint value sets are provably disjoint; no errors. +from typing import overload, Literal, Union +@overload +def f(x: Literal["a", "b"]) -> int: ... +@overload +def f(x: Literal["c", "d"]) -> str: ... +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralUnionOverlapErrors] +# Literal unions that share a value are NOT disjoint and should be flagged. +from typing import overload, Literal +@overload +def f(x: Literal["a", "b"]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Literal["b", "c"]) -> str: ... +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralUnionMixedNoFastPath] +# A union with a non-Literal member is not fingerprinted, so the full check runs. +from typing import overload, Literal, Union +@overload +def f(x: Literal["a"]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Union[Literal["b"], str]) -> str: ... +def f(x: str) -> object: + return x +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralOptionalArgNotDisjoint] +# Overloads that differ in an optional positional Literal arg still overlap: +# a call like f(1) omits mode and matches both signatures. +from typing import Literal, overload +@overload +def f(x: int, mode: Literal[True] = ...) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types # N: Flipping the order of overloads will fix this error +@overload +def f(x: object, mode: Literal[False] = ...) -> str: ... +def f(x: object, mode: bool = True) -> object: ... +[builtins fixtures/tuple.pyi] + +[case testOverloadLiteralOptionalKwOnlyNotDisjoint] +# Overloads that differ in an optional keyword-only Literal arg still overlap: +# a call like f(1) omits mode and matches both signatures. +from typing import Literal, overload +@overload +def f(x: int, *, mode: Literal[True] = ...) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types # N: Flipping the order of overloads will fix this error +@overload +def f(x: object, *, mode: Literal[False] = ...) -> str: ... +def f(x: object, *, mode: bool = True) -> object: ... +[builtins fixtures/tuple.pyi]