Skip to content

Commit 22649a0

Browse files
bpr things
1 parent dc62cdc commit 22649a0

9 files changed

Lines changed: 18 additions & 16 deletions

File tree

arraycontext/container/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494

9595

9696
if TYPE_CHECKING:
97+
from typing import Any
98+
9799
from pymbolic.geometric_algebra import MultiVector
98100

99101
from arraycontext import ArrayOrContainer
@@ -283,7 +285,7 @@ def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
283285

284286
@serialize_container.register(np.ndarray)
285287
def _serialize_ndarray_container(
286-
ary: numpy.ndarray) -> SerializedContainer:
288+
ary: numpy.ndarray[Any, Any]) -> SerializedContainer:
287289
if ary.dtype.char != "O":
288290
raise NotAnArrayContainerError(
289291
f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
@@ -303,8 +305,8 @@ def _serialize_ndarray_container(
303305
@deserialize_container.register(np.ndarray)
304306
# https://github.com/python/mypy/issues/13040
305307
def _deserialize_ndarray_container( # type: ignore[misc]
306-
template: numpy.ndarray,
307-
serialized: SerializedContainer) -> numpy.ndarray:
308+
template: numpy.ndarray[Any, Any],
309+
serialized: SerializedContainer) -> numpy.ndarray[Any, Any]:
308310
# disallow subclasses
309311
assert type(template) is np.ndarray
310312
assert template.dtype.char == "O"

arraycontext/container/traversal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
917917
# {{{ numpy conversion
918918

919919
def from_numpy(
920-
ary: np.ndarray | ScalarLike,
920+
ary: np.ndarray[Any, Any] | ScalarLike,
921921
actx: ArrayContext) -> ArrayOrContainerOrScalar:
922922
"""Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
923923
to the base array type of :class:`~arraycontext.ArrayContext`.

arraycontext/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
282282
ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike")
283283

284284

285-
NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike]
285+
NumpyOrContainerOrScalar = Union[np.ndarray[Any, Any], "ArrayContainer", ScalarLike]
286286

287287
# }}}
288288

@@ -358,7 +358,7 @@ def zeros(self,
358358
return self.np.zeros(shape, dtype)
359359

360360
@overload
361-
def from_numpy(self, array: np.ndarray) -> Array:
361+
def from_numpy(self, array: np.ndarray[Any, Any]) -> Array:
362362
...
363363

364364
@overload
@@ -379,7 +379,7 @@ def from_numpy(self,
379379
"""
380380

381381
@overload
382-
def to_numpy(self, array: Array) -> np.ndarray:
382+
def to_numpy(self, array: Array) -> np.ndarray[Any, Any]:
383383
...
384384

385385
@overload

arraycontext/impl/cupy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def clone(self):
7676
return type(self)()
7777

7878
@overload
79-
def from_numpy(self, array: np.ndarray) -> Array:
79+
def from_numpy(self, array: np.ndarray[Any, Any]) -> Array:
8080
...
8181

8282
@overload
@@ -91,7 +91,7 @@ def from_numpy(self,
9191
actx=self)
9292

9393
@overload
94-
def to_numpy(self, array: Array) -> np.ndarray:
94+
def to_numpy(self, array: Array) -> np.ndarray[Any, Any]:
9595
...
9696

9797
@overload

arraycontext/impl/cupy/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def linspace(self, *args, **kwargs):
178178
import cupy as cp
179179
return cp.linspace(*args, **kwargs)
180180

181-
def zeros_like(self, ary):
181+
def zeros_like(self, ary): # pyright: ignore[reportIncompatibleMethodOverride]
182182
import cupy as cp
183183
if isinstance(ary, int | float | complex):
184184
# Cupy does not support zeros_like with scalar arguments

arraycontext/impl/numpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def clone(self):
8686
return type(self)()
8787

8888
@overload
89-
def from_numpy(self, array: np.ndarray) -> Array:
89+
def from_numpy(self, array: np.ndarray[Any, Any]) -> Array:
9090
...
9191

9292
@overload
@@ -99,7 +99,7 @@ def from_numpy(self,
9999
return array
100100

101101
@overload
102-
def to_numpy(self, array: Array) -> np.ndarray:
102+
def to_numpy(self, array: Array) -> np.ndarray[Any, Any]:
103103
...
104104

105105
@overload

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
150150
return false_ary
151151
return np.logical_and.reduce(
152152
[(true_ary if kx_i == ky_i else false_ary)
153-
and cast(np.ndarray, self.array_equal(x_i, y_i))
153+
and cast(np.ndarray, self.array_equal(x_i, y_i)) # pyright: ignore[reportMissingTypeArgument]
154154
for (kx_i, x_i), (ky_i, y_i)
155155
in zip(serialized_x, serialized_y, strict=True)],
156156
initial=true_ary)

arraycontext/impl/pyopencl/taggable_cl_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def zeros(
212212

213213
def to_device(
214214
queue: cl.CommandQueue,
215-
ary: np.ndarray[Any],
215+
ary: np.ndarray[Any, Any],
216216
*, axes: tuple[Axis, ...] | None = None,
217217
tags: frozenset[Tag] = _EMPTY_TAG_SET,
218218
allocator: cla.Allocator | None = None,

arraycontext/impl/pytato/compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ def __hash__(self):
100100

101101
@dataclass(frozen=True, eq=True)
102102
class ScalarInputDescriptor(AbstractInputDescriptor):
103-
dtype: np.dtype
103+
dtype: np.dtype[Any]
104104

105105

106106
@dataclass(frozen=True, eq=True)
107107
class LeafArrayDescriptor(AbstractInputDescriptor):
108-
dtype: np.dtype
108+
dtype: np.dtype[Any]
109109
shape: pt.array.ShapeType
110110

111111
# }}}

0 commit comments

Comments
 (0)