From 46d32553c66632863976cce391d5807a9fe5b101 Mon Sep 17 00:00:00 2001 From: jorenham Date: Sat, 27 Dec 2025 15:44:44 +0100 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8=20`ndarray`=20shape-type=20gradua?= =?UTF-8?q?l=20default?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/__init__.pyi | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/numpy-stubs/__init__.pyi b/src/numpy-stubs/__init__.pyi index 67e92a25..895be011 100644 --- a/src/numpy-stubs/__init__.pyi +++ b/src/numpy-stubs/__init__.pyi @@ -628,8 +628,9 @@ _NumericArrayT = TypeVar("_NumericArrayT", bound=_nt.Array[number | timedelta64 _ShapeT = TypeVar("_ShapeT", bound=_nt.Shape) _ShapeT2 = TypeVar("_ShapeT2", bound=_nt.Shape) -# TODO(jorenham): use `Shape` instead of `AnyShape` once python/mypy#19110 is fixed -_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.AnyShape, covariant=True) +# TODO(jorenham): use `Shape` instead of `AnyShape` as bound once python/mypy#19110 is fixed +_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.AnyShape, default=_nt.AnyShape, covariant=True) +_ShapeT0_co = TypeVar("_ShapeT0_co", bound=_nt.AnyShape, covariant=True) _Shape1NDT = TypeVar("_Shape1NDT", bound=_nt.Shape1N) _ScalarT = TypeVar("_ScalarT", bound=generic) @@ -884,9 +885,9 @@ class _CanItem(Protocol[_T_co]): def item(self, /) -> _T_co: ... @type_check_only -class _HasShapeAndItem(Protocol[_ShapeT_co, _T_co]): +class _HasShapeAndItem(Protocol[_ShapeT0_co, _T_co]): @property - def __inner_shape__(self, /) -> _ShapeT_co: ... + def __inner_shape__(self, /) -> _ShapeT0_co: ... def item(self, /) -> _T_co: ... @type_check_only @@ -895,9 +896,9 @@ class _HasDType(Protocol[_T_co]): def dtype(self, /) -> _T_co: ... @type_check_only -class _HasShapeAndDType(Protocol[_ShapeT_co, _T_co]): +class _HasShapeAndDType(Protocol[_ShapeT0_co, _T_co]): @property - def __inner_shape__(self, /) -> _ShapeT_co: ... + def __inner_shape__(self, /) -> _ShapeT0_co: ... @property def dtype(self, /) -> _T_co: ... From 70340619561bd85640af5d24354840fd7ee35950 Mon Sep 17 00:00:00 2001 From: jorenham Date: Sat, 27 Dec 2025 16:42:09 +0100 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=90=9F=20ignore=20weird=20new=20false?= =?UTF-8?q?=20positive=20pyright=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/__init__.pyi | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/numpy-stubs/__init__.pyi b/src/numpy-stubs/__init__.pyi index 895be011..9569f6c7 100644 --- a/src/numpy-stubs/__init__.pyi +++ b/src/numpy-stubs/__init__.pyi @@ -1921,6 +1921,9 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): def __pos__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019 def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019 + # NOTE: The pyright `reportOverlappingOverload` errors below are false positives that + # started appearing after adding a default to `_ShapeT_co` + # @overload def __add__(self: _nt.Array[_ScalarT], x: _nt.Casts[_ScalarT], /) -> _nt.Array[_ScalarT]: ... @@ -1937,11 +1940,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __add__(self: _nt.Array[datetime64], x: _nt.CoTimeDelta_nd, /) -> _nt.Array[datetime64]: ... @overload - def __add__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ... + def __add__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ... # pyright: ignore[reportOverlappingOverload] @overload - def __add__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] + def __add__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload] @overload - def __add__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ... + def __add__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ... # pyright: ignore[reportOverlappingOverload] @overload def __add__(self: _nt.StringArrayND[_T], x: _nt.ToString_nd[_T] | _nt.ToStr_nd, /) -> _nt.StringArrayND[_T]: ... @overload @@ -1965,11 +1968,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __radd__(self: _nt.Array[datetime64], x: _nt.CoTimeDelta_nd, /) -> _nt.Array[datetime64]: ... @overload - def __radd__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ... + def __radd__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ... # pyright: ignore[reportOverlappingOverload] @overload - def __radd__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] + def __radd__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload] @overload - def __radd__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ... + def __radd__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ... # pyright: ignore[reportOverlappingOverload] @overload def __radd__(self: _nt.StringArrayND[_T], x: _nt.ToString_nd[_T] | _nt.ToStr_nd, /) -> _nt.StringArrayND[_T]: ... @overload @@ -2021,7 +2024,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __sub__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... @overload - def __sub__( + def __sub__( # pyright: ignore[reportOverlappingOverload] self: _nt.Array[number[_AnyNumberItemT]], x: _nt.Sequence1ND[_nt.op.CanRSub[_AnyNumberItemT]], / ) -> _nt.Array[Incomplete]: ... @@ -2045,7 +2048,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __rsub__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... @overload - def __rsub__( + def __rsub__( # pyright: ignore[reportOverlappingOverload] self: _nt.Array[number[_AnyNumberItemT]], x: _nt.Sequence1ND[_nt.op.CanSub[_AnyNumberItemT]], / ) -> _nt.Array[Incomplete]: ... @@ -2083,11 +2086,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __mul__(self: _nt.CastsWithComplex[_ScalarT], x: _PyComplexND, /) -> _nt.Array[_ScalarT]: ... @overload - def __mul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ... + def __mul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ... # pyright: ignore[reportOverlappingOverload] @overload - def __mul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] + def __mul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload] @overload - def __mul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ... + def __mul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ... # pyright: ignore[reportOverlappingOverload] @overload def __mul__(self: _nt.StringArrayND[_T], x: _nt.ToInteger_nd, /) -> _nt.StringArrayND[_T]: ... @overload @@ -2109,11 +2112,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __rmul__(self: _nt.CastsWithComplex[_ScalarT], x: _PyComplexND, /) -> _nt.Array[_ScalarT]: ... @overload - def __rmul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ... + def __rmul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ... # pyright: ignore[reportOverlappingOverload] @overload - def __rmul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] + def __rmul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload] @overload - def __rmul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ... + def __rmul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ... # pyright: ignore[reportOverlappingOverload] @overload def __rmul__(self: _nt.StringArrayND[_T], x: _nt.ToInteger_nd, /) -> _nt.StringArrayND[_T]: ... @overload From 72b2d64f46a4832a52e396be14bb9545c0306fae Mon Sep 17 00:00:00 2001 From: jorenham Date: Sat, 27 Dec 2025 16:43:03 +0100 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=90=9F=20work=20around=20a=20weird=20?= =?UTF-8?q?pyright=20inference=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/@test/runtime/legacy/mod.py | 8 ++++---- src/numpy-stubs/@test/runtime/legacy/simple.py | 2 +- src/numpy-stubs/@test/static/accept/mod.pyi | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/numpy-stubs/@test/runtime/legacy/mod.py b/src/numpy-stubs/@test/runtime/legacy/mod.py index 6f2914b6..a2115c37 100644 --- a/src/numpy-stubs/@test/runtime/legacy/mod.py +++ b/src/numpy-stubs/@test/runtime/legacy/mod.py @@ -28,7 +28,7 @@ AR2 % td divmod(td, td) # pyright: ignore[reportArgumentType, reportCallIssue] # microsoft/pyright#10899 -divmod(td, AR2) +divmod(td, AR2) # pyright: ignore[reportCallIssue] # microsoft/pyright#10899 divmod(AR2, td) # Bool @@ -49,7 +49,7 @@ divmod(b, i8) divmod(b, u8) divmod(b_, f8) -divmod(b_, AR) +divmod(b_, AR) # pyright: ignore[reportCallIssue] # microsoft/pyright#10899 b % b_ i % b_ @@ -91,7 +91,7 @@ divmod(i8, f4) divmod(i4, i4) divmod(i4, f4) -divmod(i8, AR) +divmod(i8, AR) # pyright: ignore[reportCallIssue] # microsoft/pyright#10899 b % i8 i % i8 @@ -130,7 +130,7 @@ divmod(f8, f8) divmod(f8, f4) divmod(f4, f4) -divmod(f8, AR) +divmod(f8, AR) # pyright: ignore[reportCallIssue] # microsoft/pyright#10899 b % f8 i % f8 diff --git a/src/numpy-stubs/@test/runtime/legacy/simple.py b/src/numpy-stubs/@test/runtime/legacy/simple.py index b97a3eef..67f8f9ca 100644 --- a/src/numpy-stubs/@test/runtime/legacy/simple.py +++ b/src/numpy-stubs/@test/runtime/legacy/simple.py @@ -134,7 +134,7 @@ def iterable_func(x: Iterable[object]) -> Iterable[object]: array %= 1 divmod(array, 1) # pyright: ignore[reportArgumentType, reportCallIssue] # microsoft/pyright#10899 -divmod(1, nonzero_array) +divmod(1, nonzero_array) # pyright: ignore[reportCallIssue] # microsoft/pyright#10899 array**1 1**array diff --git a/src/numpy-stubs/@test/static/accept/mod.pyi b/src/numpy-stubs/@test/static/accept/mod.pyi index bed5f49f..38028894 100644 --- a/src/numpy-stubs/@test/static/accept/mod.pyi +++ b/src/numpy-stubs/@test/static/accept/mod.pyi @@ -71,7 +71,7 @@ assert_type(divmod(m_td, m_int), tuple[np.int64, np.timedelta64[int | None]]) # assert_type(divmod(m_td, m_td), tuple[np.int64, np.timedelta64[dt.timedelta | None]]) # pyright: ignore[reportArgumentType, reportAssertTypeFailure, reportCallIssue] assert_type(divmod(AR_m, m), tuple[_nt.Array[np.int64], _nt.Array[np.timedelta64]]) -assert_type(divmod(m, AR_m), tuple[_nt.Array[np.int64], _nt.Array[np.timedelta64]]) +assert_type(divmod(m, AR_m), tuple[_nt.Array[np.int64], _nt.Array[np.timedelta64]]) # pyright: ignore[reportAssertTypeFailure, reportCallIssue] # Bool assert_type(b_ % b, np.int8) @@ -90,7 +90,7 @@ assert_type(divmod(b_, f), tuple[np.float64, np.float64]) # pyright: ignore[rep assert_type(divmod(b_, i8), tuple[np.int64, np.int64]) assert_type(divmod(b_, u8), tuple[np.uint64, np.uint64]) assert_type(divmod(b_, f8), tuple[np.float64, np.float64]) -assert_type(divmod(b_, AR_b), tuple[_nt.Array[np.int8], _nt.Array[np.int8]]) +assert_type(divmod(b_, AR_b), tuple[_nt.Array[np.int8], _nt.Array[np.int8]]) # pyright: ignore[reportAssertTypeFailure, reportCallIssue] assert_type(b % b_, np.int8) assert_type(i % b_, np.intp) @@ -130,7 +130,7 @@ assert_type(divmod(i8, i8), tuple[np.int64, np.int64]) assert_type(divmod(i8, f8), tuple[np.float64, np.float64]) assert_type(divmod(i8, f4), tuple[np.float64, np.float64]) assert_type(divmod(i4, i4), tuple[np.int32, np.int32]) -assert_type(divmod(i8, AR_b), tuple[_nt.Array[np.int64], _nt.Array[np.int64]]) +assert_type(divmod(i8, AR_b), tuple[_nt.Array[np.int64], _nt.Array[np.int64]]) # pyright: ignore[reportAssertTypeFailure, reportCallIssue] assert_type(b % i8, np.int64) assert_type(f % i8, np.float64) @@ -165,7 +165,7 @@ assert_type(divmod(f8, f), tuple[np.float64, np.float64]) assert_type(divmod(f8, f8), tuple[np.float64, np.float64]) assert_type(divmod(f8, f4), tuple[np.float64, np.float64]) assert_type(divmod(f4, f4), tuple[np.float32, np.float32]) -assert_type(divmod(f8, AR_b), tuple[_nt.Array[np.float64], _nt.Array[np.float64]]) +assert_type(divmod(f8, AR_b), tuple[_nt.Array[np.float64], _nt.Array[np.float64]]) # pyright: ignore[reportAssertTypeFailure, reportCallIssue] assert_type(b % f8, np.float64) assert_type(f % f8, np.float64) # pyright: ignore[reportAssertTypeFailure] # pyright incorrectly infers `float`