Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions startle/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def strip_container(hint: "TypeHint | type") -> tuple[type | None, Any]:
args_ = get_args(hint)

if orig in [list, set, frozenset]:
return orig, strip_annotated(args_[0]) if args_ else str
return orig, normalize(args_[0]) if args_ else str
if orig is tuple and len(args_) == 2 and args_[1] is ...:
return orig, strip_annotated(args_[0]) if args_ else str
return orig, normalize(args_[0]) if args_ else str
if orig is tuple and not args_:
return orig, str
if hint in [list, tuple, set, frozenset]:
Expand All @@ -221,17 +221,17 @@ def strip_container(hint: "TypeHint | type") -> tuple[type | None, Any]:

# handle abstract collections
if orig in [MutableSequence]:
return list, strip_annotated(args_[0]) if args_ else str
return list, normalize(args_[0]) if args_ else str
if hint in [MutableSequence]:
return list, str

if orig in [Sequence, Iterable]:
return tuple, strip_annotated(args_[0]) if args_ else str
return tuple, normalize(args_[0]) if args_ else str
if hint in [Sequence, Iterable]:
return tuple, str

if orig in [MutableSet]:
return set, strip_annotated(args_[0]) if args_ else str
return set, normalize(args_[0]) if args_ else str
if hint in [MutableSet]:
return set, str

Expand Down
26 changes: 24 additions & 2 deletions tests/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._utils import Opt, Opts, check_args


def check(draw: Callable, opt: Opt):
def check(draw: Callable[..., None], opt: Opt):
check_args(draw, opt("shape", ["square"]), ["square"], {})
check_args(draw, opt("shape", ["circle"]), ["circle"], {})
check_args(draw, opt("shape", ["triangle"]), ["triangle"], {})
Expand All @@ -22,7 +22,7 @@ def check(draw: Callable, opt: Opt):
check_args(draw, opt("shape", ["rectangle"]), [], {})


def check_with_default(draw: Callable, opt: Opt):
def check_with_default(draw: Callable[..., None], opt: Opt):
check_args(draw, [], ["circle"], {})
check(draw, opt)

Expand All @@ -38,3 +38,25 @@ def draw_with_default(shape: Literal["square", "circle", "triangle"] = "circle")
print(f"Drawing a {shape}.")

check_with_default(draw_with_default, opt)


@mark.parametrize("opt", Opts())
def test_many_literals(opt: Opt):
def draw(shapes: list[Literal["square", "circle", "triangle"]]):
print(f"Drawing {len(shapes)} shapes.")

check_args(draw, opt("shapes", ["square", "circle"]), [["square", "circle"]], {})
check_args(
draw,
opt("shapes", ["square", "square", "triangle"]),
[["square", "square", "triangle"]],
{},
)

with raises(
ParserValueError,
match=re.escape(
"Cannot parse literal ('square', 'circle', 'triangle') from `rectangle`!"
),
):
check_args(draw, opt("shapes", ["square", "rectangle"]), [], {})
63 changes: 62 additions & 1 deletion tests/test_type_alias/test_type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated
from typing import Annotated, Literal

from pytest import mark, raises
from startle import parse, register
Expand All @@ -16,6 +16,9 @@
from startle.error import ParserConfigError, ParserOptionError, ParserValueError

from .._utils import check_args
from ..test_literal import Opt, Opts
from ..test_literal import check as check_literal
from ..test_literal import check_with_default as check_literal_with_default
from ..test_parse_class import check_parse_exits

type MyFloat = float
Expand Down Expand Up @@ -331,3 +334,61 @@ def test_unsupported_type(mul_f, register_t):

del PARSERS[Rational]
del METAVARS[Rational]


def sum1(vals: list[MyFloat]) -> float:
return sum(vals)


def sum2(vals: list[MyFloat2]) -> float:
return sum(vals)


def sum3(vals: list[MyFloat3]) -> float:
return sum(vals)


@mark.parametrize("sum_f", [sum1, sum2, sum3])
def test_list_of_type_alias(sum_f: Callable[[list[float]], float]):
check_args(sum_f, ["1.0", "2.0", "3.0"], [[1.0, 2.0, 3.0]], {})

with raises(ParserValueError, match="Cannot parse float from `x`!"):
check_args(sum_f, ["1.0", "x", "3.0"], [], {})


type Shape = Literal["square", "circle", "triangle"]


@mark.parametrize("opt", Opts())
def test_literal(opt: Opt):
def draw(shape: Shape):
print(f"Drawing a {shape}.")

check_literal(draw, opt)

def draw_with_default(shape: Shape = "circle"):
print(f"Drawing a {shape}.")

check_literal_with_default(draw_with_default, opt)


@mark.parametrize("opt", Opts())
def test_many_literals(opt: Opt):
def draw(shapes: tuple[Shape, ...]):
print(f"Drawing {len(shapes)} shapes.")

check_args(draw, opt("shapes", ["square", "circle"]), [("square", "circle")], {})
check_args(
draw,
opt("shapes", ["square", "square", "triangle"]),
[("square", "square", "triangle")],
{},
)

with raises(
ParserValueError,
match=re.escape(
"Cannot parse literal ('square', 'circle', 'triangle') from `rectangle`!"
),
):
check_args(draw, opt("shapes", ["square", "rectangle"]), [], {})
Loading