Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/quadrants/lang/_func_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QuadrantsRuntimeError,
QuadrantsRuntimeTypeError,
QuadrantsSyntaxError,
get_func_signature,
)
from quadrants.lang.kernel_arguments import ArgMetadata
from quadrants.lang.matrix import MatrixType
Expand Down Expand Up @@ -97,7 +98,7 @@ def check_parameter_annotations(self) -> None:

Note: NOT in the hot path. Just run once, on function registration
"""
sig = inspect.signature(self.func)
sig = get_func_signature(self.func)
if hasattr(self.func, "__wrapped__"):
raise_exception(
QuadrantsSyntaxError,
Expand Down Expand Up @@ -189,7 +190,7 @@ def _populate_global_vars_for_templates(
for i in template_slot_locations:
template_var_name = argument_metas[i].name
global_vars[template_var_name] = py_args[i]
parameters = inspect.signature(fn).parameters
parameters = get_func_signature(fn).parameters
for i, (parameter_name, parameter) in enumerate(parameters.items()):
if is_dataclass(parameter.annotation):
_kernel_impl_dataclass.populate_global_vars_from_dataclass(
Expand Down
5 changes: 3 additions & 2 deletions python/quadrants/lang/_kernel_impl_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ast
import dataclasses
import inspect
from typing import Any

from quadrants.lang import util
Expand Down Expand Up @@ -73,7 +72,9 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st
"""
struct_locals = set()
assert ctx.func is not None
sig = inspect.signature(ctx.func.func)
from quadrants.lang.exception import get_func_signature

sig = get_func_signature(ctx.func.func)
parameters = sig.parameters
for param_name, parameter in parameters.items():
if dataclasses.is_dataclass(parameter.annotation):
Expand Down
7 changes: 3 additions & 4 deletions python/quadrants/lang/_perf_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import os
import time
from collections import defaultdict
Expand All @@ -8,7 +7,7 @@
from . import impl
from ._exceptions import raise_exception
from ._quadrants_callable import QuadrantsCallable
from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError
from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError, get_func_signature

NUM_WARMUP: int = 3
NUM_ACTIVE: int = 1
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
self.num_active = num_active if num_active is not None else NUM_ACTIVE
self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT
self.repeat_after_seconds = repeat_after_seconds if repeat_after_seconds is not None else REPEAT_AFTER_SECONDS
sig = inspect.signature(fn)
sig = get_func_signature(fn)
self._param_types: dict[str, Any] = {}
for param_name, param in sig.parameters.items():
self._param_types[param_name] = param.annotation
Expand Down Expand Up @@ -99,7 +98,7 @@ def register(
dispatch_impl_set = self._dispatch_impl_set

def decorator(func: Callable | QuadrantsCallable) -> DispatchImpl:
sig = inspect.signature(func)
sig = get_func_signature(func)
log_str = f"perf_dispatch registering {func.__name__}" # type: ignore
_logging.debug(log_str)
if QD_PERFDISPATCH_PRINT_DEBUG:
Expand Down
10 changes: 10 additions & 0 deletions python/quadrants/lang/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def get_ret(needed, provided):
return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")


def get_func_signature(func):
"""Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError."""
import inspect

try:
return inspect.signature(func, eval_str=True)
except (NameError, AttributeError) as e:
raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e


def handle_exception_from_cpp(exc):
if isinstance(exc, core.QuadrantsTypeError):
return QuadrantsTypeError(str(exc))
Expand Down
25 changes: 25 additions & 0 deletions tests/python/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Test that kernels work with `from __future__ import annotations` (PEP 563)."""

from __future__ import annotations

import quadrants as qd

from tests import test_utils


@qd.kernel
def add_kernel(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32, 1]) -> None:
for i in a:
a[i] = a[i] + b[i]


@test_utils.test()
def test_future_annotations_kernel():
a = qd.ndarray(qd.i32, (4,))
b = qd.ndarray(qd.i32, (4,))
for i in range(4):
a[i] = i
b[i] = 10
add_kernel(a, b)
for i in range(4):
assert a[i] == i + 10
Loading