Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 54 additions & 69 deletions src/numpy-stubs/linalg/_linalg.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections.abc import Iterable, Sequence
from typing import Any, Generic, Literal as L, NamedTuple, SupportsIndex as CanIndex, SupportsInt, TypeAlias, overload
from typing import Any, Generic, Literal as L, NamedTuple, SupportsIndex, SupportsInt, TypeAlias, overload
from typing_extensions import TypeVar

import _numtype as _nt
import numpy as np
from numpy._core.fromnumeric import matrix_transpose
from numpy._core.numeric import vecdot
from numpy._core.umath import vecdot
from numpy._globals import _NoValueType
from numpy._typing import DTypeLike, _DTypeLike as _ToDType

Expand Down Expand Up @@ -65,27 +65,6 @@ _InexactT_co = TypeVar("_InexactT_co", bound=np.inexact, default=Any, covariant=
_FloatingNDT_co = TypeVar("_FloatingNDT_co", bound=np.floating | _nt.Array[np.floating], default=Any, covariant=True)
_InexactNDT_co = TypeVar("_InexactNDT_co", bound=np.inexact | _nt.Array[np.inexact], default=Any, covariant=True)

_AnyNumberT = TypeVar(
"_AnyNumberT",
np.int8,
np.int16,
np.int32,
np.int64,
np.long,
np.ulong,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float16,
np.float32,
np.float64,
np.longdouble,
np.complex64,
np.complex128,
np.clongdouble,
)

###

_Option: TypeAlias = _T | _NoValueType
Expand All @@ -94,7 +73,7 @@ _False: TypeAlias = L[False]
_True: TypeAlias = L[True]

_Tuple2: TypeAlias = tuple[_T, _T]
_ToInt: TypeAlias = SupportsInt | CanIndex
_ToInt: TypeAlias = SupportsInt | SupportsIndex

_Ax2: TypeAlias = _ToInt | _Tuple2[_ToInt]
_Axes: TypeAlias = Iterable[int]
Expand Down Expand Up @@ -320,9 +299,11 @@ _NegInt: TypeAlias = L[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -

#
@overload # workaround for microsoft/pyright#10232
def matrix_power(a: _nt.CastsArray[np.float64, _nt.NeitherShape], n: CanIndex) -> _Array2ND[np.float64]: ...
def matrix_power(a: _nt.CastsArray[np.float64, _nt.NeitherShape], n: SupportsIndex) -> _Array2ND[np.float64]: ...
@overload # workaround for microsoft/pyright#10232
def matrix_power(a: _nt.CastsWithArray[np.float64, _NumberT, _nt.NeitherShape], n: CanIndex) -> _Array2ND[_NumberT]: ...
def matrix_power(
a: _nt.CastsWithArray[np.float64, _NumberT, _nt.NeitherShape], n: SupportsIndex
) -> _Array2ND[_NumberT]: ...
@overload
def matrix_power(a: _nt.CanLenArray[_NumberT, _Shape2NDT], n: _PosInt) -> _nt.Array[_NumberT, _Shape2NDT]: ...
@overload
Expand All @@ -332,19 +313,19 @@ def matrix_power(a: _nt.ToInt_1nd, n: _PosInt) -> _Array2ND[np.intp]: ...
@overload
def matrix_power(a: _nt.CoInteger_1nd, n: _NegInt) -> _Array2ND[np.float64]: ...
@overload
def matrix_power(a: _nt.ToFloat64_1nd, n: CanIndex) -> _Array2ND[np.float64]: ...
def matrix_power(a: _nt.ToFloat64_1nd, n: SupportsIndex) -> _Array2ND[np.float64]: ...
@overload
def matrix_power(a: _nt.ToComplex128_1nd, n: CanIndex) -> _Array2ND[np.complex128]: ...
def matrix_power(a: _nt.ToComplex128_1nd, n: SupportsIndex) -> _Array2ND[np.complex128]: ...
@overload
def matrix_power(a: _nt._ToArray_1nd[_Inexact32T], n: CanIndex) -> _Array2ND[_Inexact32T]: ...
def matrix_power(a: _nt._ToArray_1nd[_Inexact32T], n: SupportsIndex) -> _Array2ND[_Inexact32T]: ...
@overload
def matrix_power(a: _nt.ToObject_1nd, n: CanIndex) -> _Array2ND[np.object_]: ...
def matrix_power(a: _nt.ToObject_1nd, n: SupportsIndex) -> _Array2ND[np.object_]: ...
@overload
def matrix_power(a: _nt.ToUInteger_1nd, n: _PosInt) -> _Array2ND[np.unsignedinteger]: ...
@overload
def matrix_power(a: _nt.ToInteger_1nd, n: _PosInt) -> _Array2ND[np.integer]: ...
@overload
def matrix_power(a: _nt.CoComplex_1nd | _nt.ToObject_1nd, n: CanIndex) -> _Array2ND[Any]: ...
def matrix_power(a: _nt.CoComplex_1nd | _nt.ToObject_1nd, n: SupportsIndex) -> _Array2ND[Any]: ...

#
@overload
Expand Down Expand Up @@ -394,9 +375,9 @@ def outer(x1: _nt.ToNumber_1d, x2: _nt.ToNumber_1d, /) -> _nt.Array2D[Any]: ...
@overload # workaround for microsoft/pyright#10232
def multi_dot(arrays: Iterable[_nt._ToArray_nnd[_nt.co_number]], *, out: None = None) -> Any: ...
@overload
def multi_dot(arrays: Iterable[_nt._ToArray_1ds[_AnyNumberT]], *, out: None = None) -> _AnyNumberT: ...
def multi_dot(arrays: Iterable[_nt._ToArray_1ds[_NumberT]], *, out: None = None) -> _NumberT: ...
@overload
def multi_dot(arrays: Iterable[_nt._ToArray_2nd[_AnyNumberT]], *, out: None = None) -> _nt.Array[_AnyNumberT]: ...
def multi_dot(arrays: Iterable[_nt._ToArray_2nd[_NumberT]], *, out: None = None) -> _nt.Array[_NumberT]: ...
@overload
def multi_dot(arrays: Iterable[Sequence[bool]], *, out: None = None) -> np.bool: ...
@overload
Expand All @@ -420,11 +401,7 @@ def multi_dot(
arrays: Iterable[_nt.CoComplex_1nd | _nt.ToTimeDelta_1nd | _nt.ToObject_1nd], *, out: None = None
) -> Any: ...

# pyright false positive in case of typevar constraints
@overload
def cross( # pyright: ignore[reportOverlappingOverload]
x1: _nt._ToArray_1nd[_AnyNumberT], x2: _nt._ToArray_1nd[_AnyNumberT], /, *, axis: int = -1
) -> _nt.Array[_AnyNumberT]: ...
#
@overload
def cross(x1: _nt.ToBool_1nd, x2: _nt.ToBool_1nd, /, *, axis: int = -1) -> _nt.Array[np.bool]: ...
@overload
Expand All @@ -440,6 +417,10 @@ def cross(x1: _nt.ToComplex128_1nd, x2: _nt.CoComplex128_1nd, /, *, axis: int =
@overload
def cross(x1: _nt.CoComplex128_1nd, x2: _nt.ToComplex128_1nd, /, *, axis: int = -1) -> _nt.Array[np.complex128]: ...
@overload
def cross(
x1: _nt._ToArray_1nd[_NumberT], x2: _nt._ToArray_1nd[_NumberT], /, *, axis: int = -1
) -> _nt.Array[_NumberT]: ...
@overload
def cross(x1: _nt.ToInteger_1nd, x2: _nt.CoInteger_1nd, /, *, axis: int = -1) -> _nt.Array[np.integer]: ...
@overload
def cross(x1: _nt.CoInteger_1nd, x2: _nt.ToInteger_1nd, /, *, axis: int = -1) -> _nt.Array[np.integer]: ...
Expand All @@ -454,16 +435,10 @@ def cross(x1: _nt.CoComplex_1nd, x2: _nt.ToComplex_1nd, /, *, axis: int = -1) ->
@overload
def cross(x1: _nt.CoComplex_1nd, x2: _nt.CoComplex_1nd, /, *, axis: int = -1) -> _nt.Array[Any]: ...

# pyright false positive in case of typevar constraints
#
@overload # workaround for microsoft/pyright#10232
def matmul(x1: _nt._ToArray_nnd[_nt.co_number], x2: _nt._ToArray_nnd[_nt.co_number], /) -> Any: ...
@overload
def matmul(x1: _nt._ToArray_1ds[_AnyNumberT], x2: _nt._ToArray_1ds[_AnyNumberT], /) -> _AnyNumberT: ... # pyright: ignore[reportOverlappingOverload]
@overload
def matmul(x1: _nt._ToArray_2nd[_AnyNumberT], x2: _nt._ToArray_1nd[_AnyNumberT], /) -> _nt.Array[_AnyNumberT]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def matmul(x1: _nt._ToArray_1nd[_AnyNumberT], x2: _nt._ToArray_2nd[_AnyNumberT], /) -> _nt.Array[_AnyNumberT]: ... # pyright: ignore[reportOverlappingOverload]
@overload
def matmul(x1: _nt.ToBool_1ds, x2: _nt.ToBool_1ds, /) -> np.bool: ...
@overload
def matmul(x1: _nt.ToBool_2nd, x2: _nt.ToBool_1nd, /) -> _nt.Array[np.bool]: ...
Expand Down Expand Up @@ -494,6 +469,12 @@ def matmul(x1: _nt.ToComplex128_2nd, x2: _nt.CoComplex128_1nd, /) -> _nt.Array[n
@overload
def matmul(x1: _nt.CoComplex128_1nd, x2: _nt.ToComplex128_2nd, /) -> _nt.Array[np.complex128]: ...
@overload
def matmul(x1: _nt._ToArray_1ds[_NumberT], x2: _nt._ToArray_1ds[_NumberT], /) -> _NumberT: ...
@overload
def matmul(x1: _nt._ToArray_2nd[_NumberT], x2: _nt._ToArray_1nd[_NumberT], /) -> _nt.Array[_NumberT]: ...
@overload
def matmul(x1: _nt._ToArray_1nd[_NumberT], x2: _nt._ToArray_2nd[_NumberT], /) -> _nt.Array[_NumberT]: ...
@overload
def matmul(x1: _nt.ToInteger_1ds, x2: _nt.CoInteger_1ds, /) -> np.integer: ...
@overload
def matmul(x1: _nt.CoInteger_1ds, x2: _nt.ToInteger_1ds, /) -> np.integer: ...
Expand Down Expand Up @@ -942,56 +923,60 @@ def vector_norm(

#
@overload
def diagonal(x: _nt.ToObject_2nd, /, *, offset: CanIndex = 0) -> _nt.Array[np.object_]: ...
def diagonal(x: _nt.ToObject_2nd, /, *, offset: SupportsIndex = 0) -> _nt.Array[np.object_]: ...
@overload
def diagonal(x: _nt._ToArray_2ds[_NativeScalarT], /, *, offset: CanIndex = 0) -> _nt.Array1D[_NativeScalarT]: ...
def diagonal(x: _nt._ToArray_2ds[_NativeScalarT], /, *, offset: SupportsIndex = 0) -> _nt.Array1D[_NativeScalarT]: ...
@overload
def diagonal(x: _ToArray_2nd_ish[_NativeScalarT], /, *, offset: CanIndex = 0) -> _nt.Array[_NativeScalarT]: ...
def diagonal(x: _ToArray_2nd_ish[_NativeScalarT], /, *, offset: SupportsIndex = 0) -> _nt.Array[_NativeScalarT]: ...
@overload
def diagonal(x: _nt.Sequence2ND[bool], /, *, offset: CanIndex = 0) -> _nt.Array[np.bool]: ...
def diagonal(x: _nt.Sequence2ND[bool], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.bool]: ...
@overload
def diagonal(x: _nt.Sequence2ND[_nt.JustInt], /, *, offset: CanIndex = 0) -> _nt.Array[np.intp]: ...
def diagonal(x: _nt.Sequence2ND[_nt.JustInt], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.intp]: ...
@overload
def diagonal(x: _nt.Sequence2ND[_nt.JustFloat], /, *, offset: CanIndex = 0) -> _nt.Array[np.float64]: ...
def diagonal(x: _nt.Sequence2ND[_nt.JustFloat], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.float64]: ...
@overload
def diagonal(x: _nt.Sequence2ND[_nt.JustComplex], /, *, offset: CanIndex = 0) -> _nt.Array[np.complex128]: ...
def diagonal(x: _nt.Sequence2ND[_nt.JustComplex], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.complex128]: ...
@overload
def diagonal(x: _nt.Sequence2ND[_nt.JustBytes], /, *, offset: CanIndex = 0) -> _nt.Array[np.bytes_]: ...
def diagonal(x: _nt.Sequence2ND[_nt.JustBytes], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.bytes_]: ...
@overload
def diagonal(x: _nt.Sequence2ND[_nt.JustStr], /, *, offset: CanIndex = 0) -> _nt.Array[np.str_]: ...
def diagonal(x: _nt.Sequence2ND[_nt.JustStr], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.str_]: ...
@overload
def diagonal(x: _nt.ToGeneric_1nd, /, *, offset: CanIndex = 0) -> _nt.Array[Any]: ...
def diagonal(x: _nt.ToGeneric_1nd, /, *, offset: SupportsIndex = 0) -> _nt.Array[Any]: ...

#
@overload
def trace(x: _nt._ToArray_2ds[_ScalarT], /, *, offset: CanIndex = 0, dtype: None = None) -> _ScalarT: ...
def trace(x: _nt._ToArray_2ds[_ScalarT], /, *, offset: SupportsIndex = 0, dtype: None = None) -> _ScalarT: ...
@overload
def trace(x: _nt._ToArray_3nd[_ScalarT], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[_ScalarT]: ...
def trace(
x: _nt._ToArray_3nd[_ScalarT], /, *, offset: SupportsIndex = 0, dtype: None = None
) -> _nt.Array[_ScalarT]: ...
@overload
def trace(x: _nt.Sequence2D[bool], /, *, offset: CanIndex = 0, dtype: None = None) -> np.bool: ...
def trace(x: _nt.Sequence2D[bool], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.bool: ...
@overload
def trace(x: _nt.Sequence3ND[bool], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[np.bool]: ...
def trace(x: _nt.Sequence3ND[bool], /, *, offset: SupportsIndex = 0, dtype: None = None) -> _nt.Array[np.bool]: ...
@overload
def trace(x: _nt.Sequence2D[_nt.JustInt], /, *, offset: CanIndex = 0, dtype: None = None) -> np.intp: ...
def trace(x: _nt.Sequence2D[_nt.JustInt], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.intp: ...
@overload
def trace(x: _nt.Sequence3ND[_nt.JustInt], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[np.intp]: ...
def trace(
x: _nt.Sequence3ND[_nt.JustInt], /, *, offset: SupportsIndex = 0, dtype: None = None
) -> _nt.Array[np.intp]: ...
@overload
def trace(x: _nt.Sequence2D[_nt.JustFloat], /, *, offset: CanIndex = 0, dtype: None = None) -> np.float64: ...
def trace(x: _nt.Sequence2D[_nt.JustFloat], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.float64: ...
@overload
def trace(
x: _nt.Sequence3ND[_nt.JustFloat], /, *, offset: CanIndex = 0, dtype: None = None
x: _nt.Sequence3ND[_nt.JustFloat], /, *, offset: SupportsIndex = 0, dtype: None = None
) -> _nt.Array[np.float64]: ...
@overload
def trace(x: _nt.Sequence2D[_nt.JustComplex], /, *, offset: CanIndex = 0, dtype: None = None) -> np.complex128: ...
def trace(x: _nt.Sequence2D[_nt.JustComplex], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.complex128: ...
@overload
def trace(
x: _nt.Sequence3ND[_nt.JustComplex], /, *, offset: CanIndex = 0, dtype: None = None
x: _nt.Sequence3ND[_nt.JustComplex], /, *, offset: SupportsIndex = 0, dtype: None = None
) -> _nt.Array[np.complex128]: ...
@overload
def trace(x: _nt.CoComplex_2ds, /, *, offset: CanIndex = 0, dtype: _ToDType[_ScalarT]) -> _ScalarT: ...
def trace(x: _nt.CoComplex_2ds, /, *, offset: SupportsIndex = 0, dtype: _ToDType[_ScalarT]) -> _ScalarT: ...
@overload
def trace(x: _nt.CoComplex_3nd, /, *, offset: CanIndex = 0, dtype: _ToDType[_ScalarT]) -> _nt.Array[_ScalarT]: ...
def trace(x: _nt.CoComplex_3nd, /, *, offset: SupportsIndex = 0, dtype: _ToDType[_ScalarT]) -> _nt.Array[_ScalarT]: ...
@overload
def trace(x: _nt.CoComplex_3nd, /, *, offset: CanIndex = 0, dtype: DTypeLike | None = None) -> _nt.Array[Any]: ...
def trace(x: _nt.CoComplex_3nd, /, *, offset: SupportsIndex = 0, dtype: DTypeLike | None = None) -> _nt.Array[Any]: ...
@overload
def trace(x: _nt.CoComplex_1nd, /, *, offset: CanIndex = 0, dtype: DTypeLike | None = None) -> Any: ...
def trace(x: _nt.CoComplex_1nd, /, *, offset: SupportsIndex = 0, dtype: DTypeLike | None = None) -> Any: ...
Loading