Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ imgui.ini
!pyrightconfig.json
*.whl
*.so
stubs/
stubs/quadrants/_lib/
CHANGELOG.md
python/quadrants/_version.py
13 changes: 13 additions & 0 deletions stubs/quadrants/types/annotations.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Generic, TypeVar

T = TypeVar("T")

class Template(Generic[T]):
element_type: type[T]
ndim: int | None
def __init__(self, element_type: type[T] = ..., ndim: int | None = ...) -> None: ...
def __getitem__(self, i: Any) -> T: ...

template = Template

class sparse_matrix_builder: ...
10 changes: 10 additions & 0 deletions stubs/quadrants/types/compound_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any

class CompoundType:
def from_kernel_struct_ret(self, launch_ctx: Any, index: tuple[Any, ...]) -> Any: ...
def check_matched(self, other: Any) -> bool: ...
def to_string(self) -> str: ...

def matrix(n: int | None = ..., m: int | None = ..., dtype: Any = ...) -> Any: ...
def vector(n: int | None = ..., dtype: Any = ...) -> Any: ...
def struct(**kwargs: Any) -> Any: ...
27 changes: 27 additions & 0 deletions stubs/quadrants/types/ndarray_type.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any

class NdarrayType:
dtype: Any
ndim: int | None
needs_grad: bool | None
boundary: int

def __init__(
self,
dtype: Any = ...,
ndim: int | None = ...,
element_dim: int | None = ...,
element_shape: tuple[int, ...] | None = ...,
field_dim: int | None = ...,
needs_grad: bool | None = ...,
boundary: str = ...,
) -> None: ...
@classmethod
def __class_getitem__(cls, args: Any) -> type[NdarrayType]: ...
def __getitem__(self, i: Any) -> Any: ...
def __setitem__(self, i: Any, v: Any) -> None: ...
def __repr__(self) -> str: ...
def __str__(self) -> str: ...

ndarray = NdarrayType
NDArray = NdarrayType
69 changes: 69 additions & 0 deletions stubs/quadrants/types/primitive_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, ClassVar, Union

from quadrants._lib.core.quadrants_python import DataTypeCxx

class PrimitiveMeta(type):
cxx: DataTypeCxx
def __eq__(cls, other: object) -> bool: ...
def __ne__(cls, other: object) -> bool: ...
def __hash__(cls) -> int: ...
def __repr__(cls) -> str: ...
def __getattr__(cls, name: str) -> Any: ...

class PrimitiveBase(metaclass=PrimitiveMeta):
cxx: ClassVar[DataTypeCxx]

class f16(PrimitiveBase): ...
class f32(PrimitiveBase): ...
class f64(PrimitiveBase): ...
class i8(PrimitiveBase): ...
class i16(PrimitiveBase): ...
class i32(PrimitiveBase): ...
class i64(PrimitiveBase): ...
class u1(PrimitiveBase): ...
class u8(PrimitiveBase): ...
class u16(PrimitiveBase): ...
class u32(PrimitiveBase): ...
class u64(PrimitiveBase): ...

float16 = f16
float32 = f32
float64 = f64
int8 = i8
int16 = i16
int32 = i32
int64 = i64
uint1 = u1
uint8 = u8
uint16 = u16
uint32 = u32
uint64 = u64

# Raw C++ DataType instances (internal use)
f16_cxx: DataTypeCxx
f32_cxx: DataTypeCxx
f64_cxx: DataTypeCxx
i8_cxx: DataTypeCxx
i16_cxx: DataTypeCxx
i32_cxx: DataTypeCxx
i64_cxx: DataTypeCxx
u1_cxx: DataTypeCxx
u8_cxx: DataTypeCxx
u16_cxx: DataTypeCxx
u32_cxx: DataTypeCxx
u64_cxx: DataTypeCxx

class RefType:
tp: Any
def __init__(self, tp: Any) -> None: ...

def ref(tp: Any) -> RefType: ...

real_types: set[type[PrimitiveBase] | type]
real_type_ids: set[int]
integer_types: set[type[PrimitiveBase] | type]
integer_type_ids: set[int]
all_types: set[type[PrimitiveBase] | type]
cxx_type_ids: set[int]
type_ids: set[int]
_python_primitive_types = Union[int, float, bool, str, None]
6 changes: 6 additions & 0 deletions stubs/quadrants/types/utils.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any

def is_signed(dt: Any) -> bool: ...
def is_integral(dt: Any) -> bool: ...
def is_real(dt: Any) -> bool: ...
def is_tensor(dt: Any) -> bool: ...
21 changes: 16 additions & 5 deletions tests/python/pyright/test_ndarray_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is a test file. It just has to exist, to check that pyright works with it.
# Pyright test: NDArray annotations accepted by the type checker and functional at runtime.

import quadrants as qd

Expand All @@ -7,29 +7,40 @@
qd.init(arch=qd.cpu)


# Legacy call syntax (still works, pyright warns about call expressions in type positions)
@qd.kernel
def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]


@qd.kernel()
def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]


# New subscript syntax (preferred, no pyright warnings)
@qd.kernel
def k3(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ...


@qd.data_oriented
class SomeClass:
@qd.kernel
def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]

@qd.kernel()
def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]

@qd.kernel
def k3(self, a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ...


@test_utils.test()
def test_ndarray_type():
a = qd.ndarray(qd.i32, (10,))
k1(a, a, a)
k2(a, a, a)
k3(a, a, a)

some_class = SomeClass()
some_class.k1(a, a, a)
some_class.k2(a, a, a)
some_class.k3(a, a, a)
73 changes: 73 additions & 0 deletions tests/python/pyright/test_primitive_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Pyright test: primitive dtype classes and NDArray subscript syntax.

This file must produce zero pyright errors. It validates that:
- Primitive dtypes (f32, i32, etc.) work as type annotations
- NDArray[dtype, ndim] subscript syntax is accepted
- from __future__ import annotations (stringified) works
- NDArray works via qd.types and top-level qd.NDArray
- Return types and Optional wrappers work
"""

from __future__ import annotations

from typing import Optional

import quadrants as qd


# Primitive types as annotations
def accept_f32(x: qd.f32) -> None: ...


def accept_i32(x: qd.i32) -> None: ...


def accept_any_dtype(x: qd.f32 | qd.i32 | qd.u8) -> None: ...


# NDArray subscript: dtype + ndim
def kernel_2d(a: qd.types.NDArray[qd.f32, 2]) -> None: ...


# NDArray subscript: dtype only
def kernel_dtype(a: qd.types.NDArray[qd.i32]) -> None: ...


# NDArray bare (no subscript)
def kernel_bare(a: qd.types.NDArray) -> None: ...


# Multiple NDArray args with different types
def multi_args(
a: qd.types.NDArray[qd.f32, 2],
b: qd.types.NDArray[qd.i32, 1],
c: qd.types.NDArray,
) -> None: ...


# Top-level NDArray alias (accessible via qd.types)
def top_level(a: qd.types.NDArray[qd.f32, 2]) -> None: ...


# Return types
def make_arr() -> qd.types.NDArray[qd.f32, 2]: ...


# Optional wrapping
def maybe_arr(x: Optional[qd.types.NDArray[qd.f32, 2]]) -> None: ...


# Variable annotations
field1: qd.types.NDArray[qd.f32, 2]
field2: qd.types.NDArray


# In class body
class MyModel:
buf: qd.types.NDArray[qd.f32, 3]

def forward(self, x: qd.types.NDArray[qd.f32, 2]) -> qd.types.NDArray[qd.f32, 2]: ...


# Access via qd.types submodule
def via_types(x: qd.types.NDArray[qd.types.f32, 2]) -> None: ...
Loading