Skip to content

Commit c68d848

Browse files
committed
feat(typing): allow new_1d to take ndarrays
1 parent 351af38 commit c68d848

2 files changed

Lines changed: 34 additions & 1 deletion

File tree

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"np.dtype": "class:numpy.dtype",
4646
"np.ndarray": "class:numpy.ndarray",
4747
"np.floating": "class:numpy.floating",
48+
"np.generic[Any]": "class:numpy.generic",
4849
# pytools typing
4950
"ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
5051
"ReadableBuffer": "data:pytools.ReadableBuffer",

pytools/obj_array.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
.. autoclass:: T_co
88
.. autoclass:: ResultT
99
.. autoclass:: ShapeT
10+
.. autoclass:: NumpyTypeT
1011
1112
.. autoclass:: ObjectArray
1213
.. autoclass:: ObjectArray0D
@@ -99,6 +100,7 @@
99100

100101
ResultT = TypeVar("ResultT")
101102
ShapeT = TypeVar("ShapeT", bound=tuple[int, ...])
103+
NumpyTypeT = TypeVar("NumpyTypeT", bound="np.generic[Any]")
102104

103105

104106
class _ObjectArrayMetaclass(type):
@@ -340,7 +342,37 @@ def from_numpy(
340342
return cast("ObjectArray[ShapeT, T_co]", cast("object", ary))
341343

342344

343-
def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]:
345+
@overload
346+
def new_1d( # pyright: ignore[reportOverlappingOverload]
347+
res_list: np.ndarray[tuple[int], np.dtype[NumpyTypeT]]
348+
) -> ObjectArray1D[NumpyTypeT]: ...
349+
350+
@overload
351+
def new_1d(
352+
res_list: np.ndarray[tuple[int, int], np.dtype[NumpyTypeT]]
353+
) -> ObjectArray1D[np.ndarray[tuple[int], np.dtype[NumpyTypeT]]]: ...
354+
355+
@overload
356+
def new_1d(
357+
res_list: np.ndarray[tuple[int, int, int], np.dtype[NumpyTypeT]]
358+
) -> ObjectArray1D[np.ndarray[tuple[int, int], np.dtype[NumpyTypeT]]]: ...
359+
360+
@overload
361+
def new_1d(
362+
res_list: np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]
363+
) -> ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]: ...
364+
365+
@overload
366+
def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]: ...
367+
368+
369+
def new_1d(
370+
res_list: (
371+
Sequence[T_co]
372+
| np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]
373+
)
374+
) -> (ObjectArray1D[T_co]
375+
| ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]):
344376
"""Create a one-dimensional object array from *res_list*.
345377
This differs from ``numpy.array(res_list, dtype=object)``
346378
by whether it tries to determine its shape by descending

0 commit comments

Comments
 (0)