From d37bc027835a409f5f230ff3539c880236d78d35 Mon Sep 17 00:00:00 2001 From: jorenham Date: Fri, 13 Jun 2025 07:41:51 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix=20`nditer.=5F=5Fnext=5F=5F`?= =?UTF-8?q?=20return=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../@test/static/accept/multiarray.pyi | 4 +- .../@test/static/accept/nditer.pyi | 14 +- .../@test/static/reject/multiarray.pyi | 12 +- .../@test/static/reject/nditer.pyi | 8 +- src/numpy-stubs/_core/_multiarray_umath.pyi | 123 +++++++++++------- 5 files changed, 96 insertions(+), 65 deletions(-) diff --git a/src/numpy-stubs/@test/static/accept/multiarray.pyi b/src/numpy-stubs/@test/static/accept/multiarray.pyi index 833bdfd6..5ecf7b0e 100644 --- a/src/numpy-stubs/@test/static/accept/multiarray.pyi +++ b/src/numpy-stubs/@test/static/accept/multiarray.pyi @@ -45,7 +45,7 @@ def func12(a: int) -> tuple[complex, str]: ... assert_type(next(b_f8), tuple[Any, ...]) assert_type(b_f8.reset(), None) assert_type(b_f8.index, int) -assert_type(b_f8.iters, tuple[np.flatiter[Any], ...]) +assert_type(b_f8.iters, tuple[np.flatiter, ...]) assert_type(b_f8.nd, int) assert_type(b_f8.ndim, int) assert_type(b_f8.numiter, int) @@ -55,7 +55,7 @@ assert_type(b_f8.size, int) assert_type(next(b_i8_f8_f8), tuple[Any, ...]) assert_type(b_i8_f8_f8.reset(), None) assert_type(b_i8_f8_f8.index, int) -assert_type(b_i8_f8_f8.iters, tuple[np.flatiter[Any], ...]) +assert_type(b_i8_f8_f8.iters, tuple[np.flatiter, ...]) assert_type(b_i8_f8_f8.nd, int) assert_type(b_i8_f8_f8.ndim, int) assert_type(b_i8_f8_f8.numiter, int) diff --git a/src/numpy-stubs/@test/static/accept/nditer.pyi b/src/numpy-stubs/@test/static/accept/nditer.pyi index 937171a3..3dc4f955 100644 --- a/src/numpy-stubs/@test/static/accept/nditer.pyi +++ b/src/numpy-stubs/@test/static/accept/nditer.pyi @@ -1,4 +1,4 @@ -from typing import assert_type +from typing import Any, assert_type import _numtype as _nt import numpy as np @@ -10,7 +10,7 @@ assert_type(np.nditer([0, 1], op_flags=[["readonly", "readonly"]]), np.nditer) assert_type(np.nditer([0, 1], op_dtypes=np.int_), np.nditer) assert_type(np.nditer([0, 1], order="C", casting="no"), np.nditer) -assert_type(nditer_obj.dtypes, tuple[np.dtype, ...]) +assert_type(nditer_obj.dtypes, tuple[np.dtype, *tuple[np.dtype, ...]]) assert_type(nditer_obj.finished, bool) assert_type(nditer_obj.has_delayed_bufalloc, bool) assert_type(nditer_obj.has_index, bool) @@ -18,15 +18,15 @@ assert_type(nditer_obj.has_multi_index, bool) assert_type(nditer_obj.index, int) assert_type(nditer_obj.iterationneedsapi, bool) assert_type(nditer_obj.iterindex, int) -assert_type(nditer_obj.iterrange, tuple[int, ...]) +assert_type(nditer_obj.iterrange, tuple[int, int]) assert_type(nditer_obj.itersize, int) -assert_type(nditer_obj.itviews, tuple[_nt.Array, ...]) +assert_type(nditer_obj.itviews, tuple[_nt.Array, *tuple[_nt.Array, ...]]) assert_type(nditer_obj.multi_index, tuple[int, ...]) assert_type(nditer_obj.ndim, int) assert_type(nditer_obj.nop, int) -assert_type(nditer_obj.operands, tuple[_nt.Array, ...]) +assert_type(nditer_obj.operands, tuple[_nt.Array, *tuple[_nt.Array, ...]]) assert_type(nditer_obj.shape, tuple[int, ...]) -assert_type(nditer_obj.value, tuple[_nt.Array, ...]) +assert_type(nditer_obj.value, _nt.Array | Any) assert_type(nditer_obj.close(), None) assert_type(nditer_obj.copy(), np.nditer) @@ -39,7 +39,7 @@ assert_type(nditer_obj.reset(), None) assert_type(len(nditer_obj), int) assert_type(iter(nditer_obj), np.nditer) -assert_type(next(nditer_obj), tuple[_nt.Array, ...]) +assert_type(next(nditer_obj), _nt.Array | Any) assert_type(nditer_obj.__copy__(), np.nditer) # noqa: PLC2801 with nditer_obj as f: assert_type(f, np.nditer) diff --git a/src/numpy-stubs/@test/static/reject/multiarray.pyi b/src/numpy-stubs/@test/static/reject/multiarray.pyi index b1a62024..9cdd390d 100644 --- a/src/numpy-stubs/@test/static/reject/multiarray.pyi +++ b/src/numpy-stubs/@test/static/reject/multiarray.pyi @@ -42,9 +42,9 @@ np.datetime_as_string("2012") # type: ignore[call-overload] # pyright: ignore[ np.char.compare_chararrays("a", b"a", "==", False) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] -np.nested_iters([AR_i8, AR_i8]) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] -np.nested_iters([AR_i8, AR_i8], 0) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] -np.nested_iters([AR_i8, AR_i8], [0]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] -np.nested_iters([AR_i8, AR_i8], [[0], [1]], flags=["test"]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] -np.nested_iters([AR_i8, AR_i8], [[0], [1]], op_flags=[["test"]]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] -np.nested_iters([AR_i8, AR_i8], [[0], [1]], buffersize=1.0) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +np.nested_iters([AR_i8, AR_i8]) # type: ignore[call-overload] # pyright: ignore[reportCallIssue] +np.nested_iters([AR_i8, AR_i8], 0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nested_iters([AR_i8, AR_i8], [0]) # type: ignore[list-item] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nested_iters([AR_i8, AR_i8], [[0], [1]], flags=["test"]) # type: ignore[list-item] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nested_iters([AR_i8, AR_i8], [[0], [1]], op_flags=[["test"]]) # type: ignore[list-item] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nested_iters([AR_i8, AR_i8], [[0], [1]], buffersize=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] diff --git a/src/numpy-stubs/@test/static/reject/nditer.pyi b/src/numpy-stubs/@test/static/reject/nditer.pyi index 4ac7fbc1..a2b9f280 100644 --- a/src/numpy-stubs/@test/static/reject/nditer.pyi +++ b/src/numpy-stubs/@test/static/reject/nditer.pyi @@ -2,7 +2,7 @@ import numpy as np class Test(np.nditer): ... # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] -np.nditer([0, 1], flags=["test"]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] -np.nditer([0, 1], op_flags=[["test"]]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] -np.nditer([0, 1], itershape=(1.0,)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] -np.nditer([0, 1], buffersize=1.0) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +np.nditer([0, 1], flags=["test"]) # type: ignore[list-item] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nditer([0, 1], op_flags=[["test"]]) # type: ignore[list-item] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nditer([0, 1], itershape=(1.0,)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue] +np.nditer([0, 1], buffersize=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] diff --git a/src/numpy-stubs/_core/_multiarray_umath.pyi b/src/numpy-stubs/_core/_multiarray_umath.pyi index c258dd91..9844728f 100644 --- a/src/numpy-stubs/_core/_multiarray_umath.pyi +++ b/src/numpy-stubs/_core/_multiarray_umath.pyi @@ -101,7 +101,7 @@ _IterFlag: TypeAlias = L[ "reduce_ok", "zerosize_ok", ] -_IterFlagOp: TypeAlias = L[ +_OpFlag: TypeAlias = L[ "readonly", "writeonly", "readwrite", "no_broadcast", "config", @@ -114,6 +114,8 @@ _IterFlagOp: TypeAlias = L[ "overlap_assume_elementwise", "virtual", # undocumented ] # fmt: skip +_OpFlags: TypeAlias = Sequence[_OpFlag] +_OpAxes: TypeAlias = Sequence[CanIndex] | None _ShapeLike1D: TypeAlias = CanIndex | tuple[CanIndex] _ShapeLike2D: TypeAlias = tuple[CanIndex, CanIndex] @@ -308,7 +310,7 @@ class flagsobj: @final class broadcast: @property - def iters(self) -> tuple[flatiter[Incomplete], ...]: ... + def iters(self) -> tuple[flatiter[_nt.Array[Incomplete]], ...]: ... @property def index(self) -> int: ... @property @@ -323,11 +325,11 @@ class broadcast: def shape(self) -> _nt.Shape: ... # - def __new__(cls, *args: npt.ArrayLike) -> Self: ... + def __new__(cls, *args: _nt.ToGeneric_nd) -> Self: ... # - def __next__(self) -> tuple[Incomplete, ...]: ... def __iter__(self) -> Self: ... + def __next__(self) -> tuple[Incomplete, ...]: ... # def reset(self) -> None: ... @@ -370,8 +372,38 @@ class flatiter(Generic[_ArrayT_co]): @final class nditer: + @overload + def __init__( + self, + /, + op: _nt.ToGeneric_nd, + flags: Sequence[_IterFlag] | None = None, + op_flags: _OpFlags | None = None, + op_dtypes: _nt.ToDType | None = None, + order: _OrderKACF = "K", + casting: _CastingKind = "safe", + op_axes: _OpAxes = None, + itershape: _ShapeLike | None = None, + buffersize: CanIndex = 0, + ) -> None: ... + @overload + def __init__( + self, + /, + op: Sequence[_nt.ToGeneric_nd | None], + flags: Sequence[_IterFlag] | None = None, + op_flags: Sequence[_OpFlags] | None = None, + op_dtypes: Sequence[_nt.ToDType | None] | None = None, + order: _OrderKACF = "K", + casting: _CastingKind = "safe", + op_axes: Sequence[_OpAxes] | None = None, + itershape: _ShapeLike | None = None, + buffersize: CanIndex = 0, + ) -> None: ... + + # @property - def dtypes(self) -> tuple[np.dtype, ...]: ... + def dtypes(self) -> tuple[np.dtype[Incomplete], *tuple[np.dtype[Incomplete], ...]]: ... @property def shape(self) -> _nt.Shape: ... @property @@ -393,74 +425,73 @@ class nditer: @property def nop(self) -> int: ... @property - def index(self) -> int: ... + def index(self) -> int: ... # might raise ValueError @property - def multi_index(self) -> tuple[int, ...]: ... + def multi_index(self) -> _nt.Shape: ... # might raise ValueError @property def iterindex(self) -> int: ... @property def itersize(self) -> int: ... @property - def iterrange(self) -> tuple[int, ...]: ... + def iterrange(self) -> tuple[int, int]: ... @property - def itviews(self) -> tuple[_nt.Array[Incomplete], ...]: ... + def itviews(self) -> tuple[_nt.Array[Incomplete], *tuple[_nt.Array[Incomplete], ...]]: ... @property - def operands(self) -> tuple[_nt.Array[Incomplete], ...]: ... - @property - def value(self) -> tuple[_nt.Array[Incomplete], ...]: ... - - # - def __init__( - self, - /, - op: Sequence[npt.ArrayLike | None] | npt.ArrayLike, - flags: Sequence[_IterFlag] | None = None, - op_flags: Sequence[Sequence[_IterFlagOp]] | None = None, - op_dtypes: Sequence[npt.DTypeLike] | npt.DTypeLike = None, - order: _OrderKACF = "K", - casting: _CastingKind = "safe", - op_axes: Sequence[Sequence[CanIndex]] | None = None, - itershape: _ShapeLike | None = None, - buffersize: CanIndex = 0, - ) -> None: ... + def operands(self) -> tuple[_nt.Array[Incomplete], *tuple[_nt.Array[Incomplete], ...]]: ... # def __enter__(self) -> Self: ... def __exit__(self, t: type[BaseException] | None, e: BaseException | None, tb: TracebackType | None, /) -> None: ... def close(self) -> None: ... def reset(self) -> None: ... + def enable_external_loop(self) -> None: ... + def remove_axis(self, i: CanIndex, /) -> None: ... + def remove_multi_index(self) -> None: ... + def debug_print(self) -> None: ... + def iternext(self) -> py_bool: ... + + # + def __copy__(self) -> Self: ... + def copy(self) -> Self: ... # - def __len__(self) -> int: ... def __iter__(self) -> Self: ... - def __next__(self) -> tuple[_nt.Array[Incomplete], ...]: ... - def iternext(self) -> py_bool: ... + + # returns either a single array or a tuple of multiple arrays + def __next__(self) -> _nt.Array[Incomplete] | Incomplete: ... + @property + def value(self) -> _nt.Array[Incomplete] | Incomplete: ... # + def __len__(self) -> int: ... @overload def __getitem__(self, index: CanIndex, /) -> _nt.Array[Incomplete]: ... @overload def __getitem__(self, index: slice, /) -> tuple[_nt.Array[Incomplete], ...]: ... - def __setitem__(self, index: slice | CanIndex, value: npt.ArrayLike, /) -> None: ... - - # - def __copy__(self) -> Self: ... - def copy(self) -> nditer: ... - - # . - def debug_print(self) -> None: ... - def enable_external_loop(self) -> None: ... - - # - def remove_axis(self, i: CanIndex, /) -> None: ... - def remove_multi_index(self) -> None: ... + @overload + def __setitem__(self, index: CanIndex, value: _nt.ToGeneric_nd, /) -> None: ... + @overload + def __setitem__(self, index: slice, value: Sequence[_nt.ToGeneric_nd], /) -> None: ... +# +@overload +def nested_iters( + op: _nt.ToGeneric_nd, + axes: Sequence[Sequence[CanIndex]], + flags: Sequence[_IterFlag] | None = None, + op_flags: _OpFlags | None = None, + op_dtypes: _nt.ToDType | None = None, + order: _OrderKACF = "K", + casting: _CastingKind = "safe", + buffersize: CanIndex = 0, +) -> tuple[nditer, ...]: ... +@overload def nested_iters( - op: Sequence[npt.ArrayLike] | npt.ArrayLike, + op: Sequence[_nt.ToGeneric_nd | None], axes: Sequence[Sequence[CanIndex]], flags: Sequence[_IterFlag] | None = None, - op_flags: Sequence[Sequence[_IterFlagOp]] | None = None, - op_dtypes: Sequence[npt.DTypeLike] | npt.DTypeLike = None, + op_flags: Sequence[_OpFlags] | None = None, + op_dtypes: Sequence[_nt.ToDType | None] | None = None, order: _OrderKACF = "K", casting: _CastingKind = "safe", buffersize: CanIndex = 0,