diff --git a/pyproject.toml b/pyproject.toml index 8e9ff06..4f8ce43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ preview = true [tool.ruff.lint] select = ["B", "E", "F", "I", "RUF", "UP", "DOC102", "DOC202", "DOC403", "DOC502"] -ignore = ["E501", "RUF001", "RUF002"] +ignore = ["E501", "RUF001", "RUF002", "RUF029"] [tool.ruff.lint.per-file-ignores] "startle/_typing.py" = ["UP007", "UP045"] diff --git a/startle/_start.py b/startle/_start.py index d479231..e33d579 100644 --- a/startle/_start.py +++ b/startle/_start.py @@ -1,5 +1,6 @@ import sys from collections.abc import Callable +from inspect import iscoroutinefunction from typing import Any, Literal, TypeVar, cast from ._console import console, error, post_error @@ -100,7 +101,12 @@ def _start_func( f_args, f_kwargs = args_.make_func_args() # finally, call the function with the arguments - return func(*f_args, **f_kwargs) + if iscoroutinefunction(func): + import asyncio + + return asyncio.run(func(*f_args, **f_kwargs)) + else: + return func(*f_args, **f_kwargs) except (ParserOptionError, ParserValueError) as e: if catch: error(str(e), exit=False, endl=False) @@ -164,7 +170,13 @@ def cmd_prog_name(cmd_name: str) -> str: # finally, call the function with the arguments func = cmd2func[cmd] - return func(*f_args, **f_kwargs) + + if iscoroutinefunction(func): + import asyncio + + return asyncio.run(func(*f_args, **f_kwargs)) + else: + return func(*f_args, **f_kwargs) except (ParserOptionError, ParserValueError) as e: if catch: error(str(e), exit=False, endl=False) diff --git a/tests/test_start/test_start_cmds.py b/tests/test_start/test_start_cmds.py index 82ef57e..930285c 100644 --- a/tests/test_start/test_start_cmds.py +++ b/tests/test_start/test_start_cmds.py @@ -57,6 +57,50 @@ def div(a: int, b: int) -> None: print(f"{a} / {b} = {a / b}") +async def aadd(a: int, b: int) -> None: + """ + Add two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} + {b} = {a + b}") + + +async def asub(a: int, b: int) -> None: + """ + Subtract two numbers. + + Args: + a: The first number. + b: The second number + """ + print(f"{a} - {b} = {a - b}") + + +async def amul(a: int, b: int) -> None: + """ + Multiply two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} * {b} = {a * b}") + + +async def adiv(a: int, b: int) -> None: + """ + Divide two numbers. + + Args: + a: The dividend. + b: The divisor. + """ + print(f"{a} / {b} = {a / b}") + + @mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) @mark.parametrize("default", [False, True]) def test_calc( @@ -68,6 +112,14 @@ def test_calc( check(capsys, run_, [add, sub, mul, div], ["sub", "2", "3"], "2 - 3 = -1\n") check(capsys, run_, [add, sub, mul, div], ["mul", "2", "3"], "2 * 3 = 6\n") check(capsys, run_, [add, sub, mul, div], ["div", "6", "3"], "6 / 3 = 2.0\n") + check(capsys, run, [aadd, asub, amul, adiv], ["aadd", "2", "3"], "2 + 3 = 5\n") + check(capsys, run, [aadd, asub, amul, adiv], ["asub", "2", "3"], "2 - 3 = -1\n") + check(capsys, run, [aadd, asub, amul, adiv], ["amul", "2", "3"], "2 * 3 = 6\n") + check(capsys, run, [aadd, asub, amul, adiv], ["adiv", "6", "3"], "6 / 3 = 2.0\n") + check(capsys, run, [add, asub, mul, adiv], ["add", "2", "3"], "2 + 3 = 5\n") + check(capsys, run, [aadd, asub, mul, div], ["asub", "2", "3"], "2 - 3 = -1\n") + check(capsys, run, [add, asub, mul, adiv], ["mul", "2", "3"], "2 * 3 = 6\n") + check(capsys, run, [add, sub, amul, adiv], ["adiv", "6", "3"], "6 / 3 = 2.0\n") check( capsys, partial(run, default="sum") if default else run, @@ -87,6 +139,17 @@ def test_calc( "\nAdd two numbers.\n\nUsage:\n", exit_code="0", ) + check_exits( + capsys, run, [aadd, asub, amul, adiv], ["--help"], "\nUsage:\n", exit_code="0" + ) + check_exits( + capsys, + run, + [aadd, asub, amul, adiv], + ["aadd", "--help"], + "\nAdd two numbers.\n\nUsage:\n", + exit_code="0", + ) if default: check(capsys, run_, [add, sub, mul, div], ["2", "3"], "2 + 3 = 5\n") @@ -98,6 +161,22 @@ def test_calc( ["2", "3"], "Error: Unknown command `2`!\n", ) + if default: + check( + capsys, + partial(run, default="aadd"), + [aadd, asub, amul, adiv], + ["2", "3"], + "2 + 3 = 5\n", + ) + else: + check_exits( + capsys, + run, + [aadd, asub, amul, adiv], + ["2", "3"], + "Error: Unknown command `2`!\n", + ) if default: check_exits( @@ -112,6 +191,19 @@ def test_calc( capsys, run_, [add, sub, mul, div], [], "Error: No command given!\n" ) + if default: + check_exits( + capsys, + partial(run, default="aadd"), + [aadd, asub, amul, adiv], + [], + "Error: Required option `a` is not provided!\n", + ) + else: + check_exits( + capsys, run, [aadd, asub, amul, adiv], [], "Error: No command given!\n" + ) + check_exits( capsys, run_, @@ -121,9 +213,16 @@ def test_calc( ) check_exits( capsys, - run_, - [add, sub, mul, div], - ["sub", "2"], + run, + [aadd, asub, amul, adiv], + ["aadd", "2", "3", "4"], + "Error: Unexpected positional argument: `4`!\n", + ) + check_exits( + capsys, + run, + [aadd, asub, amul, adiv], + ["asub", "2"], "Error: Required option `b` is not provided!\n", ) @@ -229,3 +328,13 @@ def test_recursive_commands() -> None: ["add", "2", "3"], recurse=True, ) + + with raises( + ParserConfigError, + match=("Recurse option is not yet supported for multiple functions."), + ): + run_w_explicit_args( + [add, asub, amul], + ["add", "2", "3"], + recurse=True, + ) diff --git a/tests/test_start/test_start_func.py b/tests/test_start/test_start_func.py index f420f47..0c31145 100644 --- a/tests/test_start/test_start_func.py +++ b/tests/test_start/test_start_func.py @@ -191,3 +191,185 @@ def f(*, blip: bool = False) -> None: run_w_sys_argv(f, [help_cmd], name="my_program") captured = capsys.readouterr() assert remove_trailing_spaces(captured.out) == remove_trailing_spaces(expected) + + +async def ahi1(name: str, count: int = 1) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +async def ahi2(name: str, count: int = 1, /) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +async def ahi3(name: str, /, count: int = 1) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +async def ahi4(name: str, /, *, count: int = 1) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +async def ahi5(name: str, *, count: int = 1) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +async def ahi6(*, name: str, count: int = 1) -> None: + for _ in range(count): + print(f"Hello, {name}!") + + +@mark.parametrize("hi", [ahi1, ahi2, ahi3, ahi4, ahi5, ahi6]) +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) +def test_async_hi( + capsys: CaptureFixture[str], run: Callable[..., Any], hi: Callable[..., Any] +) -> None: + if hi in [ahi1, ahi2, ahi3, ahi4]: + check(capsys, run, hi, ["Alice"], "Hello, Alice!\n") + + if hi in [ahi1, ahi2, ahi3]: + check(capsys, run, hi, ["Bob", "3"], "Hello, Bob!\nHello, Bob!\nHello, Bob!\n") + + if hi in [ahi1, ahi5, ahi6]: + check( + capsys, + run, + hi, + ["--name", "Bob", "--count", "3"], + "Hello, Bob!\nHello, Bob!\nHello, Bob!\n", + ) + check( + capsys, + run, + hi, + ["--count", "3", "--name", "Bob"], + "Hello, Bob!\nHello, Bob!\nHello, Bob!\n", + ) + check(capsys, run, hi, ["--name", "Alice"], "Hello, Alice!\n") + + if hi in [ahi1, ahi3, ahi4, ahi5]: + check( + capsys, + run, + hi, + ["--count", "3", "Bob"], + "Hello, Bob!\nHello, Bob!\nHello, Bob!\n", + ) + check( + capsys, + run, + hi, + ["Bob", "--count", "3"], + "Hello, Bob!\nHello, Bob!\nHello, Bob!\n", + ) + + if hi is ahi1: + check( + capsys, + run, + hi, + ["--name", "Bob", "3"], + "Hello, Bob!\nHello, Bob!\nHello, Bob!\n", + ) + + +@mark.parametrize("hi", [ahi1, ahi2, ahi3, ahi4, ahi5, ahi6]) +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) +def test_parse_err_async( + capsys: CaptureFixture[str], run: Callable[..., Any], hi: Callable[..., Any] +) -> None: + if hi in [ahi1, ahi5, ahi6]: + check_exits( + capsys, run, hi, [], "Error: Required option `name` is not provided!" + ) + check_exits( + capsys, + run, + hi, + ["--name", "Bob", "--count", "3", "--name", "Alice"], + "Error: Option `name` is multiply given!", + ) + check_exits( + capsys, + run, + hi, + ["--name", "Bob", "--count", "3", "--lastname", "Alice"], + "Error: Unexpected option `lastname`!", + ) + with raises(ParserOptionError, match="Required option `name` is not provided!"): + run(hi, [], catch=False) + with raises(ParserOptionError, match="Option `name` is multiply given!"): + run(hi, ["--name", "Bob", "--count", "3", "--name", "Alice"], catch=False) + with raises(ParserOptionError, match="Unexpected option `lastname`!"): + run( + hi, + ["--name", "Bob", "--count", "3", "--lastname", "Alice"], + catch=False, + ) + else: + check_exits( + capsys, + run, + hi, + [], + "Error: Required positional argument is not provided!", + ) + with raises( + ParserOptionError, + match="Required positional argument is not provided!", + ): + run(hi, [], catch=False) + + +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) +@mark.parametrize("catch", [False, True]) +def test_config_err_async(run: Callable[..., Any], catch: bool) -> None: + async def f(help: bool = False) -> None: + pass + + async def f2(dummy: str) -> None: + pass + + with raises( + ParserConfigError, match=r"Cannot use `help` as parameter name in `f\(\)`!" + ): + run(f, [], catch=catch) + with raises( + ParserConfigError, match=r"Cannot use `help` as parameter name in `f\(\)`!" + ): + run([f, f2], [], catch=catch) + + +@mark.parametrize("help_cmd", ["--help", "-?", "-?b", "-b?"]) +def test_custom_program_name_help_async( + capsys: CaptureFixture[str], help_cmd: str +) -> None: + async def f(*, blip: bool = False) -> None: + """ + Do something. + + Args: + blip: Whether to blip or not. + """ + + # here, output is not detected as a tty, so it does not use rich + expected = """\ + +Do something. + +Usage: + my_program [--blip] + +where + (option) -b|--blip Whether to blip or not. (flag) + (option) -?|--help Show this help message and exit. + +""" + with raises(SystemExit): + run_w_sys_argv(f, [help_cmd], name="my_program") + captured = capsys.readouterr() + assert remove_trailing_spaces(captured.out) == remove_trailing_spaces(expected)