From c048ad979c398bcae178f3b6de7e1affaa81838c Mon Sep 17 00:00:00 2001 From: oir Date: Wed, 25 Mar 2026 22:07:29 -0400 Subject: [PATCH] fix aliases in nested type hints --- startle/_typing.py | 10 ++-- tests/test_literal.py | 26 +++++++++- tests/test_type_alias/test_type_alias.py | 63 +++++++++++++++++++++++- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/startle/_typing.py b/startle/_typing.py index 606c2d9..8f73b4e 100644 --- a/startle/_typing.py +++ b/startle/_typing.py @@ -210,9 +210,9 @@ def strip_container(hint: "TypeHint | type") -> tuple[type | None, Any]: args_ = get_args(hint) if orig in [list, set, frozenset]: - return orig, strip_annotated(args_[0]) if args_ else str + return orig, normalize(args_[0]) if args_ else str if orig is tuple and len(args_) == 2 and args_[1] is ...: - return orig, strip_annotated(args_[0]) if args_ else str + return orig, normalize(args_[0]) if args_ else str if orig is tuple and not args_: return orig, str if hint in [list, tuple, set, frozenset]: @@ -221,17 +221,17 @@ def strip_container(hint: "TypeHint | type") -> tuple[type | None, Any]: # handle abstract collections if orig in [MutableSequence]: - return list, strip_annotated(args_[0]) if args_ else str + return list, normalize(args_[0]) if args_ else str if hint in [MutableSequence]: return list, str if orig in [Sequence, Iterable]: - return tuple, strip_annotated(args_[0]) if args_ else str + return tuple, normalize(args_[0]) if args_ else str if hint in [Sequence, Iterable]: return tuple, str if orig in [MutableSet]: - return set, strip_annotated(args_[0]) if args_ else str + return set, normalize(args_[0]) if args_ else str if hint in [MutableSet]: return set, str diff --git a/tests/test_literal.py b/tests/test_literal.py index 98daeac..73e4a2b 100644 --- a/tests/test_literal.py +++ b/tests/test_literal.py @@ -8,7 +8,7 @@ from ._utils import Opt, Opts, check_args -def check(draw: Callable, opt: Opt): +def check(draw: Callable[..., None], opt: Opt): check_args(draw, opt("shape", ["square"]), ["square"], {}) check_args(draw, opt("shape", ["circle"]), ["circle"], {}) check_args(draw, opt("shape", ["triangle"]), ["triangle"], {}) @@ -22,7 +22,7 @@ def check(draw: Callable, opt: Opt): check_args(draw, opt("shape", ["rectangle"]), [], {}) -def check_with_default(draw: Callable, opt: Opt): +def check_with_default(draw: Callable[..., None], opt: Opt): check_args(draw, [], ["circle"], {}) check(draw, opt) @@ -38,3 +38,25 @@ def draw_with_default(shape: Literal["square", "circle", "triangle"] = "circle") print(f"Drawing a {shape}.") check_with_default(draw_with_default, opt) + + +@mark.parametrize("opt", Opts()) +def test_many_literals(opt: Opt): + def draw(shapes: list[Literal["square", "circle", "triangle"]]): + print(f"Drawing {len(shapes)} shapes.") + + check_args(draw, opt("shapes", ["square", "circle"]), [["square", "circle"]], {}) + check_args( + draw, + opt("shapes", ["square", "square", "triangle"]), + [["square", "square", "triangle"]], + {}, + ) + + with raises( + ParserValueError, + match=re.escape( + "Cannot parse literal ('square', 'circle', 'triangle') from `rectangle`!" + ), + ): + check_args(draw, opt("shapes", ["square", "rectangle"]), [], {}) diff --git a/tests/test_type_alias/test_type_alias.py b/tests/test_type_alias/test_type_alias.py index 816901d..78224fc 100644 --- a/tests/test_type_alias/test_type_alias.py +++ b/tests/test_type_alias/test_type_alias.py @@ -7,7 +7,7 @@ import re from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated +from typing import Annotated, Literal from pytest import mark, raises from startle import parse, register @@ -16,6 +16,9 @@ from startle.error import ParserConfigError, ParserOptionError, ParserValueError from .._utils import check_args +from ..test_literal import Opt, Opts +from ..test_literal import check as check_literal +from ..test_literal import check_with_default as check_literal_with_default from ..test_parse_class import check_parse_exits type MyFloat = float @@ -331,3 +334,61 @@ def test_unsupported_type(mul_f, register_t): del PARSERS[Rational] del METAVARS[Rational] + + +def sum1(vals: list[MyFloat]) -> float: + return sum(vals) + + +def sum2(vals: list[MyFloat2]) -> float: + return sum(vals) + + +def sum3(vals: list[MyFloat3]) -> float: + return sum(vals) + + +@mark.parametrize("sum_f", [sum1, sum2, sum3]) +def test_list_of_type_alias(sum_f: Callable[[list[float]], float]): + check_args(sum_f, ["1.0", "2.0", "3.0"], [[1.0, 2.0, 3.0]], {}) + + with raises(ParserValueError, match="Cannot parse float from `x`!"): + check_args(sum_f, ["1.0", "x", "3.0"], [], {}) + + +type Shape = Literal["square", "circle", "triangle"] + + +@mark.parametrize("opt", Opts()) +def test_literal(opt: Opt): + def draw(shape: Shape): + print(f"Drawing a {shape}.") + + check_literal(draw, opt) + + def draw_with_default(shape: Shape = "circle"): + print(f"Drawing a {shape}.") + + check_literal_with_default(draw_with_default, opt) + + +@mark.parametrize("opt", Opts()) +def test_many_literals(opt: Opt): + def draw(shapes: tuple[Shape, ...]): + print(f"Drawing {len(shapes)} shapes.") + + check_args(draw, opt("shapes", ["square", "circle"]), [("square", "circle")], {}) + check_args( + draw, + opt("shapes", ["square", "square", "triangle"]), + [("square", "square", "triangle")], + {}, + ) + + with raises( + ParserValueError, + match=re.escape( + "Cannot parse literal ('square', 'circle', 'triangle') from `rectangle`!" + ), + ): + check_args(draw, opt("shapes", ["square", "rectangle"]), [], {})