diff --git a/python/quadrants/types/ndarray_type.py b/python/quadrants/types/ndarray_type.py index 994fa70b5..9ed077b67 100644 --- a/python/quadrants/types/ndarray_type.py +++ b/python/quadrants/types/ndarray_type.py @@ -94,8 +94,10 @@ def __init__( self.boundary = int(to_boundary_enum(boundary)) @classmethod - def __class_getitem__(cls, args, **kwargs): - return cls(*args, **kwargs) + def __class_getitem__(cls, args): + if not isinstance(args, tuple): + args = (args,) + return cls(*args) def check_matched(self, ndarray_type: NdarrayTypeMetadata, arg_name: str): # FIXME(Haidong) Cannot use Vector/MatrixType due to circular import diff --git a/tests/python/test_ndarray_typing.py b/tests/python/test_ndarray_typing.py index 0ce6b4b7d..531c786ce 100644 --- a/tests/python/test_ndarray_typing.py +++ b/tests/python/test_ndarray_typing.py @@ -16,3 +16,9 @@ def test_ndarray_typing_square_brackets(): b[1, 1] = 5 some_kernel(a, b) assert a[1, 1] == 5 + 2 + + +def test_ndarray_typing_single_arg(): + t = qd.types.NDArray[qd.i32] + assert t.dtype == qd.i32 + assert t.ndim is None