Skip to content

Commit 9080ebb

Browse files
committed
feat(typing): allow new_1d to take ndarrays
1 parent 351af38 commit 9080ebb

2 files changed

Lines changed: 20 additions & 1 deletion

File tree

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"np.dtype": "class:numpy.dtype",
4646
"np.ndarray": "class:numpy.ndarray",
4747
"np.floating": "class:numpy.floating",
48+
"np.generic[Any]": "class:numpy.generic",
4849
# pytools typing
4950
"ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
5051
"ReadableBuffer": "data:pytools.ReadableBuffer",

pytools/obj_array.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
.. autoclass:: T_co
88
.. autoclass:: ResultT
99
.. autoclass:: ShapeT
10+
.. autoclass:: NumpyTypeT
1011
1112
.. autoclass:: ObjectArray
1213
.. autoclass:: ObjectArray0D
@@ -99,6 +100,7 @@
99100

100101
ResultT = TypeVar("ResultT")
101102
ShapeT = TypeVar("ShapeT", bound=tuple[int, ...])
103+
NumpyTypeT = TypeVar("NumpyTypeT", bound="np.generic[Any]")
102104

103105

104106
class _ObjectArrayMetaclass(type):
@@ -340,7 +342,23 @@ def from_numpy(
340342
return cast("ObjectArray[ShapeT, T_co]", cast("object", ary))
341343

342344

343-
def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]:
345+
@overload
346+
def new_1d(
347+
res_list: np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]
348+
) -> ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]: ...
349+
350+
351+
@overload
352+
def new_1d(res_list: Sequence[T_co]) -> ObjectArray1D[T_co]: ...
353+
354+
355+
def new_1d(
356+
res_list: (
357+
Sequence[T_co]
358+
| np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]
359+
)
360+
) -> (ObjectArray1D[T_co]
361+
| ObjectArray1D[np.ndarray[tuple[int, ...], np.dtype[NumpyTypeT]]]):
344362
"""Create a one-dimensional object array from *res_list*.
345363
This differs from ``numpy.array(res_list, dtype=object)``
346364
by whether it tries to determine its shape by descending

0 commit comments

Comments
 (0)