Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ preview = true

[tool.ruff.lint]
select = ["B", "E", "F", "I", "RUF", "UP", "DOC102", "DOC202", "DOC403", "DOC502"]
ignore = ["E501", "RUF001", "RUF002"]
ignore = ["E501", "RUF001", "RUF002", "RUF029"]

[tool.ruff.lint.per-file-ignores]
"startle/_typing.py" = ["UP007", "UP045"]
Expand Down
16 changes: 14 additions & 2 deletions startle/_start.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from collections.abc import Callable
from inspect import iscoroutinefunction
from typing import Any, Literal, TypeVar, cast

from ._console import console, error, post_error
Expand Down Expand Up @@ -100,7 +101,12 @@ def _start_func(
f_args, f_kwargs = args_.make_func_args()

# finally, call the function with the arguments
return func(*f_args, **f_kwargs)
if iscoroutinefunction(func):
import asyncio

return asyncio.run(func(*f_args, **f_kwargs))
else:
return func(*f_args, **f_kwargs)
except (ParserOptionError, ParserValueError) as e:
if catch:
error(str(e), exit=False, endl=False)
Expand Down Expand Up @@ -164,7 +170,13 @@ def cmd_prog_name(cmd_name: str) -> str:

# finally, call the function with the arguments
func = cmd2func[cmd]
return func(*f_args, **f_kwargs)

if iscoroutinefunction(func):
import asyncio

return asyncio.run(func(*f_args, **f_kwargs))
else:
return func(*f_args, **f_kwargs)
except (ParserOptionError, ParserValueError) as e:
if catch:
error(str(e), exit=False, endl=False)
Expand Down
115 changes: 112 additions & 3 deletions tests/test_start/test_start_cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,50 @@ def div(a: int, b: int) -> None:
print(f"{a} / {b} = {a / b}")


async def aadd(a: int, b: int) -> None:
"""
Add two numbers.

Args:
a: The first number.
b: The second number.
"""
print(f"{a} + {b} = {a + b}")


async def asub(a: int, b: int) -> None:
"""
Subtract two numbers.

Args:
a: The first number.
b: The second number
"""
print(f"{a} - {b} = {a - b}")


async def amul(a: int, b: int) -> None:
"""
Multiply two numbers.

Args:
a: The first number.
b: The second number.
"""
print(f"{a} * {b} = {a * b}")


async def adiv(a: int, b: int) -> None:
"""
Divide two numbers.

Args:
a: The dividend.
b: The divisor.
"""
print(f"{a} / {b} = {a / b}")


@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv])
@mark.parametrize("default", [False, True])
def test_calc(
Expand All @@ -68,6 +112,14 @@ def test_calc(
check(capsys, run_, [add, sub, mul, div], ["sub", "2", "3"], "2 - 3 = -1\n")
check(capsys, run_, [add, sub, mul, div], ["mul", "2", "3"], "2 * 3 = 6\n")
check(capsys, run_, [add, sub, mul, div], ["div", "6", "3"], "6 / 3 = 2.0\n")
check(capsys, run, [aadd, asub, amul, adiv], ["aadd", "2", "3"], "2 + 3 = 5\n")
check(capsys, run, [aadd, asub, amul, adiv], ["asub", "2", "3"], "2 - 3 = -1\n")
check(capsys, run, [aadd, asub, amul, adiv], ["amul", "2", "3"], "2 * 3 = 6\n")
check(capsys, run, [aadd, asub, amul, adiv], ["adiv", "6", "3"], "6 / 3 = 2.0\n")
check(capsys, run, [add, asub, mul, adiv], ["add", "2", "3"], "2 + 3 = 5\n")
check(capsys, run, [aadd, asub, mul, div], ["asub", "2", "3"], "2 - 3 = -1\n")
check(capsys, run, [add, asub, mul, adiv], ["mul", "2", "3"], "2 * 3 = 6\n")
check(capsys, run, [add, sub, amul, adiv], ["adiv", "6", "3"], "6 / 3 = 2.0\n")
check(
capsys,
partial(run, default="sum") if default else run,
Expand All @@ -87,6 +139,17 @@ def test_calc(
"\nAdd two numbers.\n\nUsage:\n",
exit_code="0",
)
check_exits(
capsys, run, [aadd, asub, amul, adiv], ["--help"], "\nUsage:\n", exit_code="0"
)
check_exits(
capsys,
run,
[aadd, asub, amul, adiv],
["aadd", "--help"],
"\nAdd two numbers.\n\nUsage:\n",
exit_code="0",
)

if default:
check(capsys, run_, [add, sub, mul, div], ["2", "3"], "2 + 3 = 5\n")
Expand All @@ -98,6 +161,22 @@ def test_calc(
["2", "3"],
"Error: Unknown command `2`!\n",
)
if default:
check(
capsys,
partial(run, default="aadd"),
[aadd, asub, amul, adiv],
["2", "3"],
"2 + 3 = 5\n",
)
else:
check_exits(
capsys,
run,
[aadd, asub, amul, adiv],
["2", "3"],
"Error: Unknown command `2`!\n",
)

if default:
check_exits(
Expand All @@ -112,6 +191,19 @@ def test_calc(
capsys, run_, [add, sub, mul, div], [], "Error: No command given!\n"
)

if default:
check_exits(
capsys,
partial(run, default="aadd"),
[aadd, asub, amul, adiv],
[],
"Error: Required option `a` is not provided!\n",
)
else:
check_exits(
capsys, run, [aadd, asub, amul, adiv], [], "Error: No command given!\n"
)

check_exits(
capsys,
run_,
Expand All @@ -121,9 +213,16 @@ def test_calc(
)
check_exits(
capsys,
run_,
[add, sub, mul, div],
["sub", "2"],
run,
[aadd, asub, amul, adiv],
["aadd", "2", "3", "4"],
"Error: Unexpected positional argument: `4`!\n",
)
check_exits(
capsys,
run,
[aadd, asub, amul, adiv],
["asub", "2"],
"Error: Required option `b` is not provided!\n",
)

Expand Down Expand Up @@ -229,3 +328,13 @@ def test_recursive_commands() -> None:
["add", "2", "3"],
recurse=True,
)

with raises(
ParserConfigError,
match=("Recurse option is not yet supported for multiple functions."),
):
run_w_explicit_args(
[add, asub, amul],
["add", "2", "3"],
recurse=True,
)
182 changes: 182 additions & 0 deletions tests/test_start/test_start_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,185 @@ def f(*, blip: bool = False) -> None:
run_w_sys_argv(f, [help_cmd], name="my_program")
captured = capsys.readouterr()
assert remove_trailing_spaces(captured.out) == remove_trailing_spaces(expected)


async def ahi1(name: str, count: int = 1) -> None:
for _ in range(count):
print(f"Hello, {name}!")


async def ahi2(name: str, count: int = 1, /) -> None:
for _ in range(count):
print(f"Hello, {name}!")


async def ahi3(name: str, /, count: int = 1) -> None:
for _ in range(count):
print(f"Hello, {name}!")


async def ahi4(name: str, /, *, count: int = 1) -> None:
for _ in range(count):
print(f"Hello, {name}!")


async def ahi5(name: str, *, count: int = 1) -> None:
for _ in range(count):
print(f"Hello, {name}!")


async def ahi6(*, name: str, count: int = 1) -> None:
for _ in range(count):
print(f"Hello, {name}!")


@mark.parametrize("hi", [ahi1, ahi2, ahi3, ahi4, ahi5, ahi6])
@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv])
def test_async_hi(
capsys: CaptureFixture[str], run: Callable[..., Any], hi: Callable[..., Any]
) -> None:
if hi in [ahi1, ahi2, ahi3, ahi4]:
check(capsys, run, hi, ["Alice"], "Hello, Alice!\n")

if hi in [ahi1, ahi2, ahi3]:
check(capsys, run, hi, ["Bob", "3"], "Hello, Bob!\nHello, Bob!\nHello, Bob!\n")

if hi in [ahi1, ahi5, ahi6]:
check(
capsys,
run,
hi,
["--name", "Bob", "--count", "3"],
"Hello, Bob!\nHello, Bob!\nHello, Bob!\n",
)
check(
capsys,
run,
hi,
["--count", "3", "--name", "Bob"],
"Hello, Bob!\nHello, Bob!\nHello, Bob!\n",
)
check(capsys, run, hi, ["--name", "Alice"], "Hello, Alice!\n")

if hi in [ahi1, ahi3, ahi4, ahi5]:
check(
capsys,
run,
hi,
["--count", "3", "Bob"],
"Hello, Bob!\nHello, Bob!\nHello, Bob!\n",
)
check(
capsys,
run,
hi,
["Bob", "--count", "3"],
"Hello, Bob!\nHello, Bob!\nHello, Bob!\n",
)

if hi is ahi1:
check(
capsys,
run,
hi,
["--name", "Bob", "3"],
"Hello, Bob!\nHello, Bob!\nHello, Bob!\n",
)


@mark.parametrize("hi", [ahi1, ahi2, ahi3, ahi4, ahi5, ahi6])
@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv])
def test_parse_err_async(
capsys: CaptureFixture[str], run: Callable[..., Any], hi: Callable[..., Any]
) -> None:
if hi in [ahi1, ahi5, ahi6]:
check_exits(
capsys, run, hi, [], "Error: Required option `name` is not provided!"
)
check_exits(
capsys,
run,
hi,
["--name", "Bob", "--count", "3", "--name", "Alice"],
"Error: Option `name` is multiply given!",
)
check_exits(
capsys,
run,
hi,
["--name", "Bob", "--count", "3", "--lastname", "Alice"],
"Error: Unexpected option `lastname`!",
)
with raises(ParserOptionError, match="Required option `name` is not provided!"):
run(hi, [], catch=False)
with raises(ParserOptionError, match="Option `name` is multiply given!"):
run(hi, ["--name", "Bob", "--count", "3", "--name", "Alice"], catch=False)
with raises(ParserOptionError, match="Unexpected option `lastname`!"):
run(
hi,
["--name", "Bob", "--count", "3", "--lastname", "Alice"],
catch=False,
)
else:
check_exits(
capsys,
run,
hi,
[],
"Error: Required positional argument <name> is not provided!",
)
with raises(
ParserOptionError,
match="Required positional argument <name> is not provided!",
):
run(hi, [], catch=False)


@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv])
@mark.parametrize("catch", [False, True])
def test_config_err_async(run: Callable[..., Any], catch: bool) -> None:
async def f(help: bool = False) -> None:
pass

async def f2(dummy: str) -> None:
pass

with raises(
ParserConfigError, match=r"Cannot use `help` as parameter name in `f\(\)`!"
):
run(f, [], catch=catch)
with raises(
ParserConfigError, match=r"Cannot use `help` as parameter name in `f\(\)`!"
):
run([f, f2], [], catch=catch)


@mark.parametrize("help_cmd", ["--help", "-?", "-?b", "-b?"])
def test_custom_program_name_help_async(
capsys: CaptureFixture[str], help_cmd: str
) -> None:
async def f(*, blip: bool = False) -> None:
"""
Do something.

Args:
blip: Whether to blip or not.
"""

# here, output is not detected as a tty, so it does not use rich
expected = """\

Do something.

Usage:
my_program [--blip]

where
(option) -b|--blip Whether to blip or not. (flag)
(option) -?|--help Show this help message and exit.

"""
with raises(SystemExit):
run_w_sys_argv(f, [help_cmd], name="my_program")
captured = capsys.readouterr()
assert remove_trailing_spaces(captured.out) == remove_trailing_spaces(expected)
Loading