Skip to content

Commit d6c090a

Browse files
inducermajosm
andcommitted
Misc typing fixes
Co-authored-by: Matt Smith <mjsmith6@illinois.edu>
1 parent 7d1c865 commit d6c090a

3 files changed

Lines changed: 60 additions & 35 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,11 @@ def __init__(
163163
"""
164164
super().__init__()
165165

166-
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
166+
self._freeze_prg_cache: dict[
167+
pt.AbstractResultWithNamedArrays, lp.TranslationUnit] = {}
167168
self._dag_transform_cache: dict[
168-
pt.DictOfNamedArrays,
169-
tuple[pt.DictOfNamedArrays, str]] = {}
169+
pt.AbstractResultWithNamedArrays,
170+
tuple[pt.AbstractResultWithNamedArrays, str]] = {}
170171

171172
if compile_trace_callback is None:
172173
def _compile_trace_callback(what, stage, ir):
@@ -226,8 +227,8 @@ def _tag_axis(ary: ArrayOrScalar) -> ArrayOrScalar:
226227

227228
# {{{ compilation
228229

229-
def transform_dag(self, dag: pytato.DictOfNamedArrays
230-
) -> pytato.DictOfNamedArrays:
230+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
231+
) -> pytato.AbstractResultWithNamedArrays:
231232
"""
232233
Returns a transformed version of *dag*. Sub-classes are supposed to
233234
override this method to implement context-specific transformations on
@@ -278,11 +279,12 @@ def get_target(self):
278279

279280
# }}}
280281

282+
@override
281283
def outline(self,
282284
f: Callable[..., Any],
283285
*,
284286
id: Hashable | None = None,
285-
tags: frozenset[Tag] = frozenset()
287+
tags: frozenset[Tag] = frozenset() # pyright: ignore[reportCallInDefaultInitializer]
286288
) -> Callable[..., Any]:
287289
from pytato.tags import FunctionIdentifier
288290

@@ -620,12 +622,10 @@ def _to_frozen(
620622
pt.make_dict_of_named_arrays(key_to_pt_arrays))
621623

622624
# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
623-
pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined(
624-
pt_dict_of_named_arrays)
625-
pt_dict_of_named_arrays = pt.inline_calls(pt_dict_of_named_arrays)
625+
dag = pt.tag_all_calls_to_be_inlined(dag)
626+
dag = pt.inline_calls(dag)
626627

627-
normalized_expr, bound_arguments = _normalize_pt_expr(
628-
pt_dict_of_named_arrays)
628+
normalized_expr, bound_arguments = _normalize_pt_expr(dag)
629629

630630
try:
631631
pt_prg = self._freeze_prg_cache[normalized_expr]
@@ -771,8 +771,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
771771
from .compile import LazilyPyOpenCLCompilingFunctionCaller
772772
return LazilyPyOpenCLCompilingFunctionCaller(self, f)
773773

774-
def transform_dag(self, dag: pytato.DictOfNamedArrays
775-
) -> pytato.DictOfNamedArrays:
774+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
775+
) -> pytato.AbstractResultWithNamedArrays:
776776
import pytato as pt
777777
dag = pt.tag_all_calls_to_be_inlined(dag)
778778
dag = pt.inline_calls(dag)
@@ -971,8 +971,9 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
971971
from .compile import LazilyJAXCompilingFunctionCaller
972972
return LazilyJAXCompilingFunctionCaller(self, f)
973973

974-
def transform_dag(self, dag: pytato.DictOfNamedArrays
975-
) -> pytato.DictOfNamedArrays:
974+
@override
975+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
976+
) -> pytato.AbstractResultWithNamedArrays:
976977
import pytato as pt
977978
dag = pt.tag_all_calls_to_be_inlined(dag)
978979
dag = pt.inline_calls(dag)

arraycontext/impl/pytato/compile.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from collections.abc import Callable, Hashable, Mapping
6767

6868
import pyopencl.array as cla
69+
from pytato.array import AxesT
6970

7071
AllowedArray: TypeAlias = "pt.Array | TaggableCLArray | cla.Array"
7172
AllowedArrayTc = TypeVar("AllowedArrayTc", pt.Array, TaggableCLArray, "cla.Array")
@@ -408,12 +409,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
408409
self.actx._compile_trace_callback(
409410
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
410411

411-
name_in_program_to_tags = {
412-
name: out.tags
413-
for name, out in pt_dict_of_named_arrays._data.items()}
414-
name_in_program_to_axes = {
415-
name: out.axes
416-
for name, out in pt_dict_of_named_arrays._data.items()}
412+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
413+
name_in_program_to_axes: dict[str, AxesT] = {}
414+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
415+
name_in_program_to_tags.update({
416+
name: out.tags
417+
for name, out in pt_dict_of_named_arrays._data.items()})
418+
419+
name_in_program_to_axes.update({
420+
name: out.axes
421+
for name, out in pt_dict_of_named_arrays._data.items()})
417422

418423
self.actx._compile_trace_callback(
419424
prg_id, "pre_generate_loopy", pt_dict_of_named_arrays)
@@ -505,12 +510,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
505510
self.actx._compile_trace_callback(
506511
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
507512

508-
name_in_program_to_tags = {
509-
name: out.tags
510-
for name, out in pt_dict_of_named_arrays._data.items()}
511-
name_in_program_to_axes = {
512-
name: out.axes
513-
for name, out in pt_dict_of_named_arrays._data.items()}
513+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
514+
name_in_program_to_axes: dict[str, AxesT] = {}
515+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
516+
name_in_program_to_tags.update({
517+
name: out.tags
518+
for name, out in pt_dict_of_named_arrays._data.items()})
519+
520+
name_in_program_to_axes.update({
521+
name: out.axes
522+
for name, out in pt_dict_of_named_arrays._data.items()})
514523

515524
self.actx._compile_trace_callback(
516525
prg_id, "pre_generate_jax", pt_dict_of_named_arrays)

arraycontext/impl/pytato/utils.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
1212
.. autofunction:: tabulate_profiling_data
13+
14+
References
15+
^^^^^^^^^^
16+
17+
.. autoclass:: ArrayOrNamesTc
18+
19+
A constrained type variable binding to either
20+
:class:`pytato.Array` or :class:`pytato.AbstractResultWithNames`.
1321
"""
1422

1523

@@ -40,22 +48,22 @@
4048

4149
from typing import TYPE_CHECKING, Any, cast
4250

51+
from typing_extensions import override
52+
4353
import pytools
4454
from pytato.analysis import get_num_call_sites
4555
from pytato.array import (
46-
AbstractResultWithNamedArrays,
4756
Array,
4857
Axis as PtAxis,
4958
DataWrapper,
50-
DictOfNamedArrays,
5159
Placeholder,
5260
SizeParam,
5361
make_placeholder,
5462
)
55-
from pytato.function import FunctionDefinition
5663
from pytato.target.loopy import LoopyPyOpenCLTarget
5764
from pytato.transform import (
5865
ArrayOrNames,
66+
ArrayOrNamesTc,
5967
CopyMapper,
6068
TransformMapperCache,
6169
deduplicate,
@@ -69,6 +77,8 @@
6977
from collections.abc import Mapping
7078

7179
import loopy as lp
80+
from pytato import AbstractResultWithNamedArrays
81+
from pytato.function import FunctionDefinition
7282

7383
from arraycontext import ArrayContext
7484
from arraycontext.container import SerializationKey
@@ -98,6 +108,7 @@ def __init__(
98108
self.vng = UniqueNameGenerator()
99109
self.seen_inputs: set[str] = set()
100110

111+
@override
101112
def map_data_wrapper(self, expr: DataWrapper) -> Array:
102113
if expr.name is not None:
103114
if expr.name in self.seen_inputs:
@@ -119,13 +130,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
119130
axes=expr.axes,
120131
tags=expr.tags)
121132

133+
@override
122134
def map_size_param(self, expr: SizeParam) -> Array:
123135
raise NotImplementedError
124136

137+
@override
125138
def map_placeholder(self, expr: Placeholder) -> Array:
126139
raise ValueError("Placeholders cannot appear in"
127140
" DatawrapperToBoundPlaceholderMapper.")
128141

142+
@override
129143
def map_function_definition(
130144
self, expr: FunctionDefinition) -> FunctionDefinition:
131145
raise ValueError("Function definitions cannot appear in"
@@ -135,8 +149,8 @@ def map_function_definition(
135149
# FIXME: This strategy doesn't work if the DAG has functions, since function
136150
# definitions can't contain non-argument placeholders
137151
def _normalize_pt_expr(
138-
expr: DictOfNamedArrays
139-
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
152+
expr: AbstractResultWithNamedArrays
153+
) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]:
140154
"""
141155
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
142156
normalized form of *expr*, with all instances of
@@ -155,7 +169,6 @@ def _normalize_pt_expr(
155169

156170
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
157171
normalized_expr = normalize_mapper(expr)
158-
assert isinstance(normalized_expr, AbstractResultWithNamedArrays)
159172
return normalized_expr, normalize_mapper.bound_arguments
160173

161174

@@ -193,6 +206,7 @@ def __init__(self, actx: ArrayContext) -> None:
193206
super().__init__()
194207
self.actx = actx
195208

209+
@override
196210
def map_data_wrapper(self, expr: DataWrapper) -> Array:
197211
import numpy as np
198212

@@ -225,6 +239,7 @@ def __init__(self, actx: ArrayContext) -> None:
225239
super().__init__()
226240
self.actx = actx
227241

242+
@override
228243
def map_data_wrapper(self, expr: DataWrapper) -> Array:
229244
import numpy as np
230245

@@ -244,15 +259,15 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
244259
non_equality_tags=expr.non_equality_tags)
245260

246261

247-
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
262+
def transfer_from_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc:
248263
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
249264
instances to be device arrays, using
250265
:meth:`~arraycontext.ArrayContext.from_numpy`.
251266
"""
252267
return TransferFromNumpyMapper(actx)(expr)
253268

254269

255-
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
270+
def transfer_to_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc:
256271
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
257272
instances to be :class:`numpy.ndarray` instances, using
258273
:meth:`~arraycontext.ArrayContext.to_numpy`.

0 commit comments

Comments
 (0)