diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index dd5cdbac8..078e921f5 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -29,6 +29,7 @@ QuadrantsRuntimeError, QuadrantsRuntimeTypeError, QuadrantsSyntaxError, + get_func_signature, ) from quadrants.lang.kernel_arguments import ArgMetadata from quadrants.lang.matrix import MatrixType @@ -97,7 +98,7 @@ def check_parameter_annotations(self) -> None: Note: NOT in the hot path. Just run once, on function registration """ - sig = inspect.signature(self.func) + sig = get_func_signature(self.func) if hasattr(self.func, "__wrapped__"): raise_exception( QuadrantsSyntaxError, @@ -189,7 +190,7 @@ def _populate_global_vars_for_templates( for i in template_slot_locations: template_var_name = argument_metas[i].name global_vars[template_var_name] = py_args[i] - parameters = inspect.signature(fn).parameters + parameters = get_func_signature(fn).parameters for i, (parameter_name, parameter) in enumerate(parameters.items()): if is_dataclass(parameter.annotation): _kernel_impl_dataclass.populate_global_vars_from_dataclass( diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index c5d7bd530..ebd298dff 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -1,6 +1,5 @@ import ast import dataclasses -import inspect from typing import Any from quadrants.lang import util @@ -73,7 +72,9 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - sig = inspect.signature(ctx.func.func) + from quadrants.lang.exception import get_func_signature + + sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): if dataclasses.is_dataclass(parameter.annotation): diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index 4bc21844a..a1999199d 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -1,4 +1,3 @@ -import inspect import os import time from collections import defaultdict @@ -8,7 +7,7 @@ from . import impl from ._exceptions import raise_exception from ._quadrants_callable import QuadrantsCallable -from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError +from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError, get_func_signature NUM_WARMUP: int = 3 NUM_ACTIVE: int = 1 @@ -58,7 +57,7 @@ def __init__( self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT self.repeat_after_seconds = repeat_after_seconds if repeat_after_seconds is not None else REPEAT_AFTER_SECONDS - sig = inspect.signature(fn) + sig = get_func_signature(fn) self._param_types: dict[str, Any] = {} for param_name, param in sig.parameters.items(): self._param_types[param_name] = param.annotation @@ -99,7 +98,7 @@ def register( dispatch_impl_set = self._dispatch_impl_set def decorator(func: Callable | QuadrantsCallable) -> DispatchImpl: - sig = inspect.signature(func) + sig = get_func_signature(func) log_str = f"perf_dispatch registering {func.__name__}" # type: ignore _logging.debug(log_str) if QD_PERFDISPATCH_PRINT_DEBUG: diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index 771dd56b0..beaf2eeb0 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -57,6 +57,16 @@ def get_ret(needed, provided): return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}") +def get_func_signature(func): + """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" + import inspect + + try: + return inspect.signature(func, eval_str=True) + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + + def handle_exception_from_cpp(exc): if isinstance(exc, core.QuadrantsTypeError): return QuadrantsTypeError(str(exc)) diff --git a/tests/python/test_future_annotations.py b/tests/python/test_future_annotations.py new file mode 100644 index 000000000..c359679f7 --- /dev/null +++ b/tests/python/test_future_annotations.py @@ -0,0 +1,25 @@ +"""Test that kernels work with `from __future__ import annotations` (PEP 563).""" + +from __future__ import annotations + +import quadrants as qd + +from tests import test_utils + + +@qd.kernel +def add_kernel(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32, 1]) -> None: + for i in a: + a[i] = a[i] + b[i] + + +@test_utils.test() +def test_future_annotations_kernel(): + a = qd.ndarray(qd.i32, (4,)) + b = qd.ndarray(qd.i32, (4,)) + for i in range(4): + a[i] = i + b[i] = 10 + add_kernel(a, b) + for i in range(4): + assert a[i] == i + 10